Source code for chemicalchecker.util.splitter.pairtraintest

"""Splitter on pairs."""
import h5py
import numpy as np

from chemicalchecker.util import logged


[docs]@logged class PairTraintest(object): """PairTraintest class.""" def __init__(self, hdf5_file, split, nr_neig=10, replace_nan=None): """Initialize a PairTraintest class. We assume the file is containing diffrent splits. e.g. "x_train", "y_train", "x_test", ... """ self._file = hdf5_file self._f = None self.replace_nan = replace_nan self.nr_neig = nr_neig self.x1_name = "x1" self.x2_name = "x2" if split is None: self.p_name = "p" self.y_name = "y" else: self.p_name = "p_%s" % split self.y_name = "y_%s" % split available_splits = self.get_split_names() if split not in available_splits: raise Exception("Split '%s' not found in %s!" % (split, str(available_splits))) def get_pos_neg(self): y = self.get_all_y() return len(y[y == 1]), len(y[y == 0])
[docs] def get_py_shapes(self): """Return the shpaes of X an Y.""" self.open() p_shape = self._f[self.p_name].shape y_shape = self._f[self.y_name].shape self.close() return p_shape, y_shape
[docs] def get_xy_shapes(self): """Return the shpaes of X an Y.""" self.open() x1_shape = self._f[self.x1_name].shape x2_shape = self._f[self.x2_name].shape y_shape = self._f[self.y_name].shape self.close() return x1_shape, x2_shape, y_shape
[docs] def get_split_names(self): """Return the name of the splits.""" self.open() if "split_names" in self._f: split_names = [a.decode() for a in self._f["split_names"]] else: split_names = ['train', 'test'] self.__log.info("Using default split names %s" % split_names) self.close() return split_names
[docs] def open(self): """Open the HDF5.""" self._f = h5py.File(self._file, 'r') self.__log.info("HDF5 open %s", self._file)
[docs] def close(self): """Close the HDF5.""" try: self._f.close() self.__log.info("HDF5 close %s", self._file) except AttributeError: self.__log.error('HDF5 file is not open yet.')
[docs] def get_py(self, beg_idx, end_idx): """Get a batch of X and Y.""" features = self._f[self.p_name][beg_idx: end_idx] # handle NaNs if self.replace_nan is not None: features[np.where(np.isnan(features))] = self.replace_nan labels = self._f[self.y_name][beg_idx: end_idx] return features, labels
[docs] def get_p(self, beg_idx, end_idx): """Get a batch of X.""" features = self._f[self.p_name][beg_idx: end_idx] # handle NaNs if self.replace_nan is not None: features[np.where(np.isnan(features))] = self.replace_nan return features
[docs] def get_y(self, beg_idx, end_idx): """Get a batch of Y.""" features = self._f[self.y_name][beg_idx: end_idx] # handle NaNs if self.replace_nan is not None: features[np.where(np.isnan(features))] = self.replace_nan return features
[docs] def get_all_x1(self): """Get full X.""" features = self._f[self.x1_name][:] # handle NaNs if self.replace_nan is not None: features[np.where(np.isnan(features))] = self.replace_nan return features
[docs] def get_all_x2(self): """Get full X.""" features = self._f[self.x2_name][:] # handle NaNs if self.replace_nan is not None: features[np.where(np.isnan(features))] = self.replace_nan return features
[docs] def get_all_p(self): """Get full X.""" features = self._f[self.p_name][:] # handle NaNs if self.replace_nan is not None: features[np.where(np.isnan(features))] = self.replace_nan return features
[docs] def get_all_y(self): """Get full Y.""" labels = self._f[self.y_name][:] return labels
[docs] @staticmethod def get_split_indeces(rows, fractions): """Get random indeces for different splits.""" if not sum(fractions) == 1.0: raise Exception("Split fractions should sum to 1.0") # shuffle indeces idxs = list(range(rows)) np.random.shuffle(idxs) # from frequs to indices splits = np.cumsum(fractions) splits = splits[:-1] splits *= len(idxs) splits = splits.round().astype(np.int) return np.split(idxs, splits)
[docs] @staticmethod def create(X1, X2, pairs, split_names, out_file, mean_center_x=True, shuffle=True): """Create the HDF5 file with validation splits. Args: X(numpy.ndarray): features to train from. out_file(str): path of the h5 file to write. neigbors_matrix(numpy.ndarray): matrix for computing neighbors. neigbors(int): Number of positive neighbors to include. mean_center_x(bool): center each feature on its mean? shuffle(bool): Shuffle positive and negatives. split_names(list(str)): names for the split of data. split_fractions(list(float)): fraction of data in each split. x_dtype(type): numpy data type for X. y_dtype(type): numpy data type for Y (np.float32 for regression, np.int32 for classification. """ PairTraintest.__log.debug( "{:<20} shape: {:>10}".format("input X1", str(X1.shape))) PairTraintest.__log.debug( "{:<20} shape: {:>10}".format("input X2", str(X2.shape))) # train test validation splits if len(split_names) != len(pairs): raise Exception("Split names and set of pairs must be same nr.") for name, Y in zip(split_names, pairs): PairTraintest.__log.debug( "{:<20} shape: {:>10}".format(name, str(Y.shape))) # create dataset PairTraintest.__log.info('Traintest saving to %s', out_file) with h5py.File(out_file, "w") as fh: fh.create_dataset('x1', data=X1) fh.create_dataset('x2', data=X2) #fh.create_dataset('split_names', data=split_names) for name, PY in zip(split_names, pairs): # shuffling shuffle_idxs = np.arange(PY.shape[0]) if shuffle: np.random.shuffle(shuffle_idxs) # save to h5 P = PY[:, :2] ds_name = "p_%s" % name PairTraintest.__log.info( 'writing %s %s %s', ds_name, name, P.shape) fh.create_dataset(ds_name, data=P[shuffle_idxs]) Y = PY[:, -1] ds_name = "y_%s" % name PairTraintest.__log.info( 'writing %s %s', ds_name, Y.shape) fh.create_dataset(ds_name, data=Y[shuffle_idxs]) PairTraintest.__log.info('PairTraintest saved to %s', out_file)
@staticmethod def generate_splits(X1, X2, pairs): # leave left out x1_set = list(set(pairs[:, 0])) x1_train_idxs = x1_set[:int(len(x1_set) * .8)] x1_train_mask = np.isin(pairs[:, 0], x1_train_idxs) x1_train, x1_test = pairs[x1_train_mask], pairs[~x1_train_mask] x1_train_test = pairs[x1_train_mask], pairs[~x1_train_mask] assert(len(set(x1_train[:, 0]) & set(x1_test[:, 0])) == 0) # leave right out x2_set = list(set(pairs[:, 1])) x2_train_idxs = x2_set[:int(len(x2_set) * .8)] x2_train_mask = np.isin(pairs[:, 1], x2_train_idxs) x2_train, x2_test = pairs[x2_train_mask], pairs[~x2_train_mask] x2_train_test = pairs[x2_train_mask], pairs[~x2_train_mask] assert(len(set(x2_train[:, 1]) & set(x2_test[:, 1])) == 0) # leave both out both_train_mask = np.logical_and(x1_train_mask, x2_train_mask) both_test_mask = np.logical_and(~x1_train_mask, ~x2_train_mask) both_train, both_test = pairs[both_train_mask], pairs[both_test_mask] both_train_test = pairs[both_train_mask], pairs[both_test_mask] assert(len(set(both_train[:, 0]) & set(both_test[:, 0])) == 0) assert(len(set(both_train[:, 1]) & set(both_test[:, 1])) == 0) return x1_train_test, x2_train_test, both_train_test
[docs] @staticmethod def generator_fn(file_name, split, batch_size=None, only_x=False, replace_nan=None, mask_fn=None): """Return the generator function that we can query for batches. file_name(str): The H5 generated via `create` split(str): One of 'train_train', 'train_test', or 'test_test' batch_size(int): Size of a batch of data. only_x(bool): Usually when predicting only X are useful. replace_nan(bool): Value used for NaN replacement. """ reader = PairTraintest(file_name, split) reader.open() # read shapes x1_shape = reader._f[reader.x1_name].shape x2_shape = reader._f[reader.x2_name].shape y_shape = reader._f[reader.y_name].shape p_shape = reader._f[reader.p_name].shape # read data types x1_dtype = reader._f[reader.x1_name].dtype y_dtype = reader._f[reader.y_name].dtype # no batch size -> return everything if not batch_size: batch_size = p_shape[0] # keep X in memory for resolving pairs quickly PairTraintest.__log.debug('Loading Xs') X1 = reader.get_all_x1() X2 = reader.get_all_x2() # default mask is not mask if mask_fn is None: def mask_fn(*data): return data def example_generator_fn(): # generator function yielding data epoch = 0 beg_idx, end_idx = 0, batch_size total = reader._f[reader.p_name].shape[0] while True: if beg_idx >= total: beg_idx, end_idx = 0, batch_size epoch += 1 #PairTraintest.__log.debug('EPOCH %i', epoch) pairs, y = reader.get_py(beg_idx, end_idx) x1 = X1[pairs[:, 0]] x2 = X2[pairs[:, 1]] x1, x2, y = mask_fn(x1, x2, y) if only_x: yield np.hstack((x1, x2)) else: yield np.hstack((x1, x2)), y beg_idx, end_idx = beg_idx + batch_size, end_idx + batch_size shapes = (y_shape[0], x1_shape[1] + x2_shape[1]), y_shape dtypes = (x1_dtype, y_dtype) return shapes, dtypes, example_generator_fn