Source code for chemicalchecker.util.splitter.ae_siam_traintest

"""Splitter for Siamese Autoencoder."""
import h5py
import numpy as np
from tqdm import tqdm

from chemicalchecker.util import logged


[docs]@logged class AE_SiameseTraintest(object): """AE_SiameseTraintest class.""" def __init__(self, hdf5_file, split, replace_nan=None): """Initialize a AE_SiameseTraintest instance. We assume the file is containing diffrent splits. e.g. "x_train", "y_train", "x_test", ... """ self._file = hdf5_file self._f = None self.replace_nan = replace_nan if split is None: self.x_name_left = "x_left" self.y_name_left = "x_left" self.sw_name_left = "sw_left" self.x_name_right = "x_right" self.y_name_right = "x_right" self.sw_name_right = "sw_right" else: self.x_name_left = "x_left_%s" % split self.y_name_left = "x_left_%s" % split self.sw_name_left = "sw_left_%s" % split self.x_name_right = "x_right_%s" % split self.y_name_right = "x_right_%s" % split self.sw_name_right = "sw_right_%s" % split ''' available_splits = self.get_split_names() if split not in available_splits: raise Exception("Split '%s' not found in %s!" % (split, str(available_splits))) '''
[docs] def get_x_shapes(self): """Return the shpaes of X.""" self.open() x_shape = self._f[self.x_name_left].shape self.close() return x_shape
[docs] def get_xy_shapes(self): """Return the shpaes of X an Y.""" self.open() x_shape = self._f[self.x_name_left].shape y_shape = self._f[self.y_name_left].shape self.close() return x_shape, y_shape
[docs] def get_split_names(self): """Return the name of the splits.""" self.open() if "split_names" in self._f: split_names = [a.decode() for a in self._f["split_names"]] else: split_names = ['train', 'test', 'validation'] self.__log.info("Using default split names %s" % split_names) self.close() return split_names
[docs] def open(self): """Open the HDF5.""" self._f = h5py.File(self._file, 'r') self.__log.info("HDF5 open %s", self._file)
[docs] def close(self): """Close the HDF5.""" try: self._f.close() self.__log.info("HDF5 close %s", self._file) except AttributeError: self.__log.error('HDF5 file is not open yet.')
[docs] def get_sw(self, beg_idx, end_idx): """Get a batch of X.""" features_left = self._f[self.sw_name_left][beg_idx: end_idx] features_right = self._f[self.sw_name_right][beg_idx: end_idx] # handle NaNs if self.replace_nan is not None: features_left[ np.where(np.isnan(features_left))] = self.replace_nan features_right[ np.where(np.isnan(features_right))] = self.replace_nan return [features_left, features_right], [features_left, features_right]
[docs] def get_xy(self, beg_idx, end_idx, shuffle): """Get a batch of X and Y.""" features_left = self._f[self.x_name_left][beg_idx: end_idx] features_right = self._f[self.x_name_right][beg_idx: end_idx] if shuffle: np.random.shuffle(features_left) np.random.shuffle(features_right) # handle NaNs if self.replace_nan is not None: features_left[ np.where(np.isnan(features_left))] = self.replace_nan features_right[ np.where(np.isnan(features_right))] = self.replace_nan # print (features_left.shape, features_right.shape, beg_idx, end_idx) return [features_left, features_right], [features_left, features_right]
[docs] def get_x(self, beg_idx, end_idx): """Get a batch of X.""" features_left = self._f[self.x_name_left][beg_idx: end_idx] features_right = self._f[self.x_name_right][beg_idx: end_idx] # handle NaNs if self.replace_nan is not None: features_left[ np.where(np.isnan(features_left))] = self.replace_nan features_right[ np.where(np.isnan(features_right))] = self.replace_nan return [features_left, features_right]
[docs] @staticmethod def get_split_indeces(rows, fractions): """Get random indeces for different splits.""" if not sum(fractions) == 1.0: raise Exception("Split fractions should sum to 1.0") # shuffle indeces idxs = list(range(rows)) idxs_shuflle = list(range(rows)) np.random.shuffle(idxs_shuflle) # from frequs to indices splits = np.cumsum(fractions) splits = splits[:-1] splits *= len(idxs) splits = splits.round().astype(np.int) split_left = np.split(idxs, splits) split_right = [] for i in range(len(split_left)): split_right.append(np.copy(split_left[i])) np.random.shuffle(split_right[i]) final_split = [] for i in range(len(split_left)): final_split.append((split_left[i], split_right[i])) return final_split
[docs] @staticmethod def split_h5_blocks(in_file, out_file, split_names=['train', 'test', 'validation'], split_fractions=[.8, .1, .1], block_size=1000, input_datasets=None): """Create the HDF5 file with validation splits from an input file. Args: in_file(str): path of the h5 file to read from. out_file(str): path of the h5 file to write. split_names(list(str)): names for the split of data. split_fractions(list(float)): fraction of data in each split. block_size(int): size of the block to be used. dataset(list): only split the given dataset and ignore others. """ output_datasets = ['x'] with h5py.File(in_file, 'r') as hf_in: # log input datasets and get shapes for k in hf_in.keys(): AE_SiameseTraintest.__log.debug( "{:<20} shape: {:>10}".format(k, str(hf_in[k].shape))) rows = hf_in[k].shape[0] # reduce block size if it is not adequate while rows / (float(block_size) * 10) <= 1: block_size = int(block_size / 10) AE_SiameseTraintest.__log.warning( "Reducing block_size to: %s", block_size) # train test validation splits if len(split_names) != len(split_fractions): raise Exception( "Split names and fraction should be same amount.") split_names = [s.encode() for s in split_names] # get indeces of blocks for each split split_block_idx = AE_SiameseTraintest.get_split_indeces( rows, split_fractions) if input_datasets is None: input_datasets = hf_in.keys() for dataset_name in input_datasets: if dataset_name not in hf_in.keys(): raise Exception( "Dataset %s not found in source file." % dataset_name) if len(input_datasets) != len(output_datasets): raise Exception( "Length of input datasets and out datasets is not the same") # save to output file AE_SiameseTraintest.__log.info('Traintest saving to %s', out_file) with h5py.File(out_file, "w") as hf_out: # create fixed datasets hf_out.create_dataset( 'split_names', data=np.array(split_names)) hf_out.create_dataset( 'split_fractions', data=np.array(split_fractions)) for name, blocks in zip(split_names, split_block_idx): # for each original dataset for k in range(len(input_datasets)): # create all splits ds_name_left = "%s_left_%s" % ( output_datasets[k], name.decode()) ds_name_right = "%s_right_%s" % ( output_datasets[k], name.decode()) # need total size and mapping of blocks total_size = blocks[0].shape[0] index_right = blocks[1] # create block matrix reshape = False if len(hf_in[input_datasets[k]].shape) == 1: cols = 1 reshape = True else: cols = hf_in[input_datasets[k]].shape[1] hf_out.create_dataset(ds_name_left, (total_size, cols), dtype=hf_in[input_datasets[k]].dtype) hf_out.create_dataset(ds_name_right, (total_size, cols), dtype=hf_in[input_datasets[k]].dtype) for i in tqdm(range(0, total_size, block_size)): chunk = slice(i, i + block_size) dst_chunk = chunk src_chunk_left = chunk src_chunk_right = index_right[chunk] src_data_right = np.array( [hf_in[input_datasets[k]][j] for j in src_chunk_right]) if src_data_right.shape[0] != block_size: dst_chunk = slice( i, i + src_data_right.shape[0]) src_chunk_left = dst_chunk if reshape: hf_out[ds_name_left][dst_chunk] = np.expand_dims( hf_in[input_datasets[k]][src_chunk_left], 1) hf_out[ds_name_right][dst_chunk] = np.expand_dims( src_data_right, 1) else: hf_out[ds_name_left][dst_chunk] = hf_in[ input_datasets[k]][src_chunk_left] hf_out[ds_name_right][ dst_chunk] = src_data_right AE_SiameseTraintest.__log.info('Traintest saved to %s', out_file)
[docs] @staticmethod def generator_fn(file_name, split, batch_size=None, only_x=False, sample_weights=False, shuffle=True, return_on_epoch=False): """Return the generator function that we can query for batches.""" reader = AE_SiameseTraintest(file_name, split) reader.open() x_shape = reader._f[reader.x_name_left].shape y_shape = reader._f[reader.y_name_left].shape x_dtype = reader._f[reader.x_name_left].dtype y_dtype = reader._f[reader.y_name_left].dtype shapes = (x_shape, y_shape) dtypes = (x_dtype, y_dtype) # no batch size -> return everything if not batch_size: batch_size = x_shape[0] batch_beg_end = np.zeros((int(np.ceil(x_shape[0] / batch_size)), 2)) last = 0 for row in batch_beg_end: row[0] = last row[1] = last + batch_size last = row[1] batch_beg_end = batch_beg_end.astype(int) def example_generator_fn(): # generator function yielding data epoch = 0 batch_idx = 0 while True: if batch_idx == len(batch_beg_end): batch_idx = 0 epoch += 1 if shuffle: np.random.shuffle(batch_beg_end) # Traintest.__log.debug('EPOCH %i (caller: %s)', epoch, # inspect.stack()[1].function) if return_on_epoch: return beg_idx, end_idx = batch_beg_end[batch_idx] if only_x: if sample_weights: yield reader.get_x(beg_idx, end_idx), \ reader.get_sw(beg_idx, end_idx) else: yield reader.get_x(beg_idx, end_idx) else: if sample_weights: yield reader.get_xy(beg_idx, end_idx, shuffle), \ reader.get_sw(beg_idx, end_idx) else: yield reader.get_xy(beg_idx, end_idx, shuffle) batch_idx += 1 return shapes, dtypes, example_generator_fn