Source code for chemicalchecker.util.splitter.neighbortriplet

"""Splitter on Neighbor triplets."""
import os
import h5py
import pickle
import itertools
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.spatial.distance import euclidean, pdist, squareform
from sklearn.preprocessing import RobustScaler

from chemicalchecker.util import logged
from chemicalchecker.util.remove_near_duplicates import RNDuplicates


[docs]@logged class TripletIterator(object): """TripletIterator class.""" def __init__(self, hdf5_file, split, replace_nan=None): """Initialize a TripletIterator instance. This allows iterating on train/test splits of triplets generated via a `TripletSampler` class. We assume the file is containing different splits. e.g. "x_train", "y_train", "x_test", ... Args: hdf5_file(str): the path to a file generated via TripletSample class. split(str): The H5 typically contains 'train' or 'test' splits, the iterator will focus on that split. replace_nan(float): If None, nothing is replaced, otherwise NaN are replaced by the value specified. """ self._file = hdf5_file self._f = None self.replace_nan = replace_nan self.x_name = "x" if split is None: self.t_name = "t" self.y_name = "y" else: self.t_name = "t_%s" % split self.y_name = "y_%s" % split available_splits = self.get_split_names() if split not in available_splits: raise Exception("Split '%s' not found in %s!" % (split, str(available_splits)))
[docs] def get_ty_shapes(self): """Return the shapes of X an Y.""" self.open() t_shape = self._f[self.t_name].shape y_shape = self._f[self.y_name].shape self.close() return t_shape, y_shape
[docs] def get_xy_shapes(self): """Return the shapes of X an Y.""" self.open() x_shape = self._f[self.x_name].shape self.close() return x_shape
[docs] def get_split_names(self): """Return the name of the splits.""" self.open() if "split_names" in self._f: split_names = [a.decode() for a in self._f["split_names"]] else: split_names = ['train', 'test'] self.close() combos = itertools.combinations_with_replacement(split_names, 2) return ['_'.join(x) for x in combos]
[docs] def open(self): """Open the HDF5.""" self._f = h5py.File(self._file, 'r')
# self.__log.info("HDF5 open %s", self._file)
[docs] def close(self): """Close the HDF5.""" try: self._f.close() # self.__log.info("HDF5 close %s", self._file) except AttributeError: self.__log.error('HDF5 file is not open yet.')
[docs] def get_t(self, beg_idx, end_idx): """Get a batch of X.""" features = self._f[self.t_name][beg_idx: end_idx] # handle NaNs if self.replace_nan is not None: features[np.where(np.isnan(features))] = self.replace_nan return features
[docs] def get_y(self, beg_idx, end_idx): """Get a batch of Y.""" features = self._f[self.y_name][beg_idx: end_idx] # handle NaNs if self.replace_nan is not None: features[np.where(np.isnan(features))] = self.replace_nan return features
[docs] def get_all_x(self): """Get full X.""" features = self._f[self.x_name][:] # handle NaNs if self.replace_nan is not None: features[np.where(np.isnan(features))] = self.replace_nan return features
[docs] def get_x_columns(self, mask): """Get full X.""" features = self._f[self.x_name][:, mask] # handle NaNs if self.replace_nan is not None: features[np.where(np.isnan(features))] = self.replace_nan return features
[docs] def get_all_p(self): """Get full X.""" features = self._f[self.t_name][:] # handle NaNs if self.replace_nan is not None: features[np.where(np.isnan(features))] = self.replace_nan return features
[docs] def get_all_y(self): """Get full Y.""" labels = self._f[self.y_name][:] return labels
[docs] @staticmethod def generator_fn(file_name, split, replace_nan=None, batch_size=None, shuffle=True, train=True, augment_fn=None, augment_kwargs={}, mask_fn=None, trim_mask=None, sharedx=None, sharedx_trim=None, onlyself_notself=False, p_self_decay=False): """Return the generator function that iterates on batches. A TripletIterator on the specified file and split is initialized, we allow for additional masking, shared X matrix and more. Args: file_name(str): The H5 generated via a `TripletSampler` class split(str): One of 'train_train', 'train_test', or 'test_test' replace_nan(bool): Value used for NaN replacement. batch_size(int): Size of a batch of data. shuffle(bool): Shuffle the order of batches. train(bool): At train time the augment function is applied. augment_fn(func): Function to augment data. augment_kwargs(dict): Parameters for the augment function. mask_fn(func): Function to mask data while iterating. trim_mask(array): Initial masking of data (spaces are excluded). sharedx(matrix): The preloaded X matrix. sharedx_trim(matrix): The preloaded and pre-trimmed X matrix. onlyself_notself(bool): when True the iterator will return a quintuplet with also only_self and not_self. p_self_decay(bool): when True the p_self probability will decay within the batch, and restart at each batch. """ def notself(idxs, x1_data): x1_data_transf = np.copy(x1_data) for idx in idxs: # set current space to nan col_slice = slice(idx * 128, (idx + 1) * 128) x1_data_transf[:, col_slice] = np.nan return x1_data_transf TripletIterator.__log.debug('Generator for %s' % split) reader = TripletIterator(file_name, split) reader.open() # read shapes t_shape = reader._f[reader.t_name].shape # read data types x_dtype = reader._f[reader.x_name].dtype # no batch size -> return everything if not batch_size: batch_size = t_shape[0] # keep X in memory for resolving triplets quickly if sharedx is not None: if trim_mask is None: X = sharedx else: if sharedx_trim is not None: X = sharedx_trim else: X = sharedx[:, np.argwhere( np.repeat(trim_mask, 128)).ravel()] else: TripletIterator.__log.debug('Reading X in memory') if trim_mask is None: X = reader.get_all_x() else: if sharedx_trim is not None: X = sharedx_trim else: X = reader.get_x_columns(np.argwhere( np.repeat(trim_mask, 128)).ravel()) TripletIterator.__log.debug('X shape: %s' % str(X.shape)) # default mask is not masking if mask_fn is None: def mask_fn(*data): return data # default augment is doing nothing if augment_fn is None: def augment_fn(*data, **kwargs): return data # this variable is going to be used to shuffle batches batch_beg_end = np.zeros((int(np.ceil(t_shape[0] / batch_size)), 2)) last = 0 for row in batch_beg_end: row[0] = last row[1] = last + batch_size last = row[1] batch_beg_end = batch_beg_end.astype(int) # handle arguments for additional Xs if onlyself_notself: only_args = augment_kwargs.copy() only_args['p_only_self'] = 1.0 TripletIterator.__log.debug( 'Generator ready, onlyself_notself %s' % onlyself_notself) def example_generator_fn(): """Generator function yields data in batches""" batch_kwargs = augment_kwargs.copy() if p_self_decay: # we leave a p_self > 0 for the first 10th of batches # then it will linearly decrease nr_steps = int(len(batch_beg_end) / 10) + 1 p_self_current = augment_kwargs.get("p_self", 0.1) decay_step = p_self_current / nr_steps epoch = 0 batch_idx = 0 while True: # here we handles what happens at the last batch if batch_idx == len(batch_beg_end): batch_idx = 0 epoch += 1 if shuffle: np.random.shuffle(batch_beg_end) if p_self_decay: p_self_current = augment_kwargs.get("p_self", 0.1) # select the batch start/end and fetch triplets beg_idx, end_idx = batch_beg_end[batch_idx] tripets = reader.get_t(beg_idx, end_idx) y = reader.get_y(beg_idx, end_idx) x1 = X[tripets[:, 0]] x2 = X[tripets[:, 1]] x3 = X[tripets[:, 2]] if train and onlyself_notself: # at train time we want to apply subsampling x1 = augment_fn(x1, **batch_kwargs) x2 = augment_fn(x2, **batch_kwargs) x3 = augment_fn(x3, **batch_kwargs) if onlyself_notself: x4 = augment_fn(X[tripets[:, 0]], **only_args) x5 = notself(augment_kwargs['dataset_idx'], x1) # apply the mask function x1, x2, x3 = mask_fn(x1, x2, x3) # replace NaNs with specified value if replace_nan is not None: x1[np.where(np.isnan(x1))] = replace_nan x2[np.where(np.isnan(x2))] = replace_nan x3[np.where(np.isnan(x3))] = replace_nan if onlyself_notself: x4[np.where(np.isnan(x4))] = replace_nan x5[np.where(np.isnan(x5))] = replace_nan # yield the triplets if onlyself_notself: yield [x1, x2, x3, x4, x5], y else: yield [x1, x2, x3], y # go to next batch batch_idx += 1 # update subsampling parameters if p_self_decay: p_self_current -= decay_step batch_kwargs['p_self'] = max(0.0, p_self_current) # return shapes and dtypes along with iterator triplet_shape = (t_shape[0], X.shape[1]) shapes = (triplet_shape, triplet_shape, triplet_shape, triplet_shape) dtypes = (x_dtype, x_dtype, x_dtype, x_dtype) return shapes, dtypes, example_generator_fn
[docs]@logged class BaseTripletSampler(object): """Base class for triplet samplers.""" def __init__(self, triplet_signature, mol_signature, out_file, save_kwargs={}): self.triplet_signature = triplet_signature self.mol_signature = mol_signature self.out_file = out_file def_save_kwargs = { 'mean_center_x': True, 'shuffle': True, 'split_names': ['train', 'test'], 'split_fractions': [.8, .2], 'suffix': 'eval', 'cpu': 1, 'x_dtype': np.float32, 'y_dtype': np.float32 } def_save_kwargs.update(save_kwargs) self.save_kwargs = def_save_kwargs
[docs] def get_split_indeces(self, rows, fractions): """Get random indexes for different splits.""" if not sum(fractions) == 1.0: raise Exception("Split fractions should sum to 1.0") # shuffle indexes idxs = list(range(rows)) np.random.shuffle(idxs) # from frequencies to indices splits = np.cumsum(fractions) splits = splits[:-1] splits *= len(idxs) splits = splits.round().astype(np.int) return np.split(idxs, splits)
[docs] def save_triplets(self, triplets, mean_center_x=True, shuffle=True, split_names=['train', 'test'], split_fractions=[.8, .2], suffix='eval', cpu=1, x_dtype=np.float32, y_dtype=np.float32): """Save sampled triplets to file. This function saves triplets performing the train test split, shuffling and normalization. Args: triplets(array): Indexes of anchor, positive and negative for each triplet. mean_center_x(bool): Normalize data columns wise. shuffle(bool): shuffle order of triplets. split_names(list str): names of the splits. split_fractions(list float): fraction of each split. suffix(str): suffix of the generated scaler. """ ink_keys = self.mol_signature.keys _, X = self.mol_signature.get_vectors(ink_keys, dataset_name='x') self.__log.debug('X.shape %s', str(X.shape)) self.__log.debug('triplets.shape %s', str(triplets.shape)) # mean centering features if mean_center_x: scaler = RobustScaler() X = scaler.fit_transform(X) if suffix is None: scaler_file = os.path.join(os.path.split(self.out_file)[0], 'scaler.pkl') else: scaler_file = os.path.join(os.path.split(self.out_file)[0], 'scaler_%s.pkl' % suffix) pickle.dump(scaler, open(scaler_file, 'wb')) # shuffling shuffle_idxs = np.arange(triplets.shape[0]) if shuffle: np.random.shuffle(shuffle_idxs) triplets = np.array(triplets)[shuffle_idxs] # do train-test split on keys split_idxs = self.get_split_indeces( X.shape[0], split_fractions) # do train-test split for triplets (np.unique of indexes) split_idxs = dict(zip(split_names, split_idxs)) # find triplets having test-test train-train and train-test combos = itertools.combinations_with_replacement(split_names, 2) # reverse split names to first write test keys split_names.reverse() # create output file self.__log.info('Saving Triplets to %s', self.out_file) with h5py.File(self.out_file, "w") as fh: for split_n in split_names: fh.create_dataset('keys_%s' % split_n, data=np.array(ink_keys[split_idxs[split_n]], dtype=h5py.string_dtype())) if mean_center_x: fh.create_dataset('scaler', data=np.array([scaler_file], dtype=h5py.string_dtype())) fh.create_dataset('x', data=X, dtype=x_dtype) fh.create_dataset('x_ink', data=np.array(ink_keys, dtype=h5py.string_dtype())) for split1, split2 in combos: split1_idxs = split_idxs[split1] split2_idxs = split_idxs[split2] if split1 != split2: split1_mask = ~np.all( np.isin(triplets, split1_idxs), axis=1) split2_mask = ~np.all( np.isin(triplets, split2_idxs), axis=1) combo_mask = np.logical_and(split1_mask, split2_mask) else: combo_mask = np.all(np.isin(triplets, split1_idxs), axis=1) self.__log.debug('t_%s_%s %s' % (split1, split2, str(triplets[combo_mask].shape))) fh.create_dataset('t_%s_%s' % (split1, split2), data=triplets[combo_mask]) fh.create_dataset('y_%s_%s' % (split1, split2), data=np.zeros((len(triplets[combo_mask]), ))) self.__log.info('Triplets saved to %s', self.out_file)
[docs]@logged class PrecomputedTripletSampler(BaseTripletSampler): """The triplets are not sampled but pre-computed.""" def generate_triplets(self, X, ink_keys, triplets, out_file, mean_center_x=True, shuffle=True, split_names=['train', 'test'], split_fractions=[.8, .2], suffix='eval', cpu=1, x_dtype=np.float32, y_dtype=np.float32): try: from chemicalchecker.core.signature_data import DataSignature except ImportError as err: raise err # mean centering columns if mean_center_x: scaler = RobustScaler() X = scaler.fit_transform(X) if suffix is None: scaler_file = os.path.join(os.path.split(out_file)[0], 'scaler.pkl') else: scaler_file = os.path.join(os.path.split(out_file)[0], 'scaler_%s.pkl' % suffix) pickle.dump(scaler, open(scaler_file, 'wb')) # shuffling shuffle_idxs = np.arange(triplets.shape[0]) if shuffle: np.random.shuffle(shuffle_idxs) triplets = np.array(triplets)[shuffle_idxs] # do train-test split on keys split_idxs = self.get_split_indeces( X.shape[0], split_fractions) # do train-test split for triplets (np.unique of indexes) split_idxs = dict(zip(split_names, split_idxs)) # find triplets having test-test train-train and train-test combos = itertools.combinations_with_replacement(split_names, 2) # reverse split names to first write test keys split_names.reverse() # create dataset self.__log.info('Saving Triplets to %s', out_file) with h5py.File(out_file, "w") as fh: for split_n in split_names: fh.create_dataset('keys_%s' % split_n, data=np.array(ink_keys[split_idxs[split_n]], dtype=h5py.string_dtype())) if mean_center_x: fh.create_dataset( 'scaler', data=np.array([scaler_file], dtype=h5py.string_dtype())) fh.create_dataset('x', data=X, dtype=x_dtype) fh.create_dataset('x_ink', data=np.array(ink_keys, dtype=h5py.string_dtype())) for split1, split2 in combos: split1_idxs = split_idxs[split1] split2_idxs = split_idxs[split2] if split1 != split2: split1_mask = ~np.all( np.isin(triplets, split1_idxs), axis=1) split2_mask = ~np.all( np.isin(triplets, split2_idxs), axis=1) combo_mask = np.logical_and(split1_mask, split2_mask) else: combo_mask = np.all(np.isin(triplets, split1_idxs), axis=1) self.__log.debug('t_%s_%s %s' % (split1, split2, str(triplets[combo_mask].shape))) fh.create_dataset('t_%s_%s' % (split1, split2), data=triplets[combo_mask]) fh.create_dataset('y_%s_%s' % (split1, split2), data=np.zeros((len(triplets[combo_mask]), ))) self.__log.info('Triplets saved to %s', out_file)
[docs]@logged class AdriaTripletSampler(BaseTripletSampler): """The optimal Adria's way for sampling triplets in small dataset.""" def __init__(self, *args, **kwargs): BaseTripletSampler.__init__(self, *args, **kwargs)
[docs] def generate_triplets(self, num_triplets=1e6, frac_hard=0.3, frac_neig=0.05, metric='jaccard', low_thr=0.1, high_thr=0.5, plot=True): """Generate triplets. This function generate triplets defining positive and negatives assuming a binary triplet signature (e.g. sign0) and computing all the similarities across molecules. Args: num_triplets(int): Total number of triplets to generate. frac_hard(float): Fraction of triplets to be of the hard case. frac_neig(float): Fraction of neighbor we will consider. metric(std): Metric to compute similarities, must be a distance metric that can be converted to similarity by (1-dist) low_thr(float): Low similarity threshold, any pair below this is negative. high_thr(float): High similarity threshold, any pair above this is positive. plot(bool): Save plots of the sampling. """ self.__log.info('Generating Triplets...') self.__log.info('Triplets generated based on: %s' % self.triplet_signature.data_path) self.__log.info('Triplets representation: %s' % self.mol_signature.data_path) # this works with triplet signature being sign0 df = self.triplet_signature.as_dataframe() # later we will be saving the molecular representation in a different # signature (e.g. sign2), we need to use only those molecules df = df.loc[self.mol_signature.keys] # Getting similarities all_similarities = 1 - pdist(df, metric) df2 = pd.DataFrame(squareform(all_similarities), index=df.index.values, columns=df.index.values) # Defining derived parameters n_neigh = int(df2.shape[0]*frac_neig) frac_soft = 1 - frac_hard n_trip = int(np.round(num_triplets*frac_soft/df2.shape[0])) n_hard_trip = int(np.round(num_triplets*frac_hard/df2.shape[0])) ixs = np.array(df2.max()) >= low_thr df2 = df2.iloc[ixs, ixs] dgs = np.array(df2.columns) triplets = [] hard_triplets = {0: [], 1: [], 2: []} for ix, dg in tqdm(enumerate(df2.index.values), total=df2.shape[0]): _triplets = [] _hard_triplets = {0: [], 1: [], 2: []} # Getting similarity vector v = np.array(df2.iloc[ix]) v[ix] = np.nan # masking itself # Getting pos ixs = np.where(v >= high_thr)[0] if len(ixs) < n_neigh: ixs = np.argsort(v)[::-1] ixs = ixs[v[ixs] >= low_thr] if len(ixs) == 0: continue cutoff = v[ixs][min([n_neigh-1, len(ixs)-1])] ixs = v >= cutoff neighs = dgs[ixs] similarities = v[ixs] probs = similarities / np.sum(similarities) # Getting negs # minor fix, remove itself negs = np.array(list(set(dgs)-set(neighs.tolist()+[dg]))) # --Getting triplets # ----Negs # the soft triplets (easy) 70% for _ in range(n_trip): _triplets.append([dg, np.random.choice(neighs, p=probs), np.random.choice(negs)]) triplets.extend(_triplets) # --Adding hard triplets scores = np.unique(similarities) # unique and sort cutoffs = np.unique([np.percentile(scores, pc) for pc in [0, 25, 50, 75, 100]]) labels = np.arange(len(cutoffs))[:-1] groups = np.array(pd.cut(similarities, cutoffs, labels=labels)) labels = [x for x in labels if x in groups] if len(labels) > 1: n_subhard = int(np.ceil(n_hard_trip/(len(labels)-1))) hard_positives = neighs[groups == labels[-1]] hard_probs = similarities[groups == labels[-1]] hard_probs = hard_probs/np.sum(hard_probs) for i in range(len(labels)-1): hard_negatives = neighs[groups == labels[i]] assert len(set(hard_negatives) & set(hard_positives)) == 0 for _ in range(n_subhard): _hard_triplets[labels[i]].append([dg, np.random.choice( hard_positives, p=hard_probs), np.random.choice(hard_negatives)]) hard_triplets[0].extend(_hard_triplets[0]) hard_triplets[1].extend(_hard_triplets[1]) hard_triplets[2].extend(_hard_triplets[2]) all_triplets = list(triplets) for g in list(hard_triplets): all_triplets.extend(list(hard_triplets[g])) all_triplets = np.array(all_triplets) self.__log.info('triplets: %i' % len(all_triplets)) self.__log.info('easy triplets: %i (%.2f%%)' % (len(triplets), (100*(len(triplets))/len(all_triplets)))) total_hard = np.sum([len(hard_triplets[g]) for g in hard_triplets]) self.__log.info('hard triplets: %i (%.2f%%)' % (total_hard, (100*(total_hard/len(all_triplets))))) for g in hard_triplets: self.__log.info('\t--> Q%i vs Q4: %i (%.2f%%)' % (g+1, len(hard_triplets[g]), 100*len(hard_triplets[g])/len(all_triplets))) all_triplets = pd.DataFrame( all_triplets, columns=['anchor', 'pos', 'neg']).sort_values( ['anchor', 'pos', 'neg']).reset_index(drop=True) ink_pos = dict(zip(self.mol_signature.keys, np.arange(len(df)))) all_triplets_idxs = np.vectorize(ink_pos.get)(all_triplets.values) if plot: self.__log.info('Generating Triplets Plot...') import matplotlib.pyplot as plt import seaborn as sns fig, axes = plt.subplots(2, 2, figsize=(10, 10), dpi=100) ax1, ax2, ax3, ax4 = axes.flat # which fraction of mols as a give # of features? sns.ecdfplot(df.sum(1), ax=ax1) ax1.set_title('# Features distribution') # what is the distribution of similarities? # which similarity will we always consider as positive or negative? ax2.set_title('Similarity distribution, pos./neg. definition') pc = int((1 - frac_neig) * 100) pc_val = np.percentile(all_similarities, pc) ax2.axvline(low_thr, label='low_thr %.2f' % low_thr, ls='-.', color='.5') ax2.axvline(pc_val, label='%.2f (P%i)' % (pc_val, pc), color='.7') ax2.axvline(high_thr, label='high_thr %.2f' % high_thr, ls='--', color='.5') sns.histplot(all_similarities, kde=True, ax=ax2) ax2.legend() # what fraction of mols would we loose (closest neigh < low_thr)? sns.ecdfplot(df2.max(axis=1), ax=ax3) ax3.set_xlabel('Similarity to closest neighbor') ax3.axvline(low_thr, ls='-.', color='.5') ax3.axvline(high_thr, ls='--', color='.5') ax3.set_title('Closest neighbor of each mol.') lost_mols = np.sum(df2.max() < low_thr) if lost_mols > 0: ax3.annotate('Lost Mols.: %i' % lost_mols, xy=(low_thr-(low_thr/10), 0), xycoords='data', xytext=(-10, -40), textcoords='offset points', arrowprops=dict(facecolor='red', shrink=0.05), horizontalalignment='right', verticalalignment='bottom') # What's the difference in similarity between A-P and A-N? ax4.set_title('Triplet difficulty and Anchor-Pos. vs. Anchor-Neg.') k = df2.melt(ignore_index=False).reset_index().values pair2sim = dict(zip(zip(k[:, 0], k[:, 1]), k[:, 2])) v = [] for x in triplets: pos = x[0], x[1] neg = x[0], x[2] if (pos in pair2sim) & (neg in pair2sim): pos = pair2sim[pos] neg = pair2sim[neg] v.append(pos-neg) sns.ecdfplot(v, label='easy triplets', ax=ax4) for g in hard_triplets: v2 = [] for x in hard_triplets[g]: pos = x[0], x[1] neg = x[0], x[2] if (pos in pair2sim) & (neg in pair2sim): pos = pair2sim[pos] neg = pair2sim[neg] v2.append(pos-neg) sns.ecdfplot(v2, label='hard triplets (Q%i vs Q4)' % (g+1), ax=ax4) ax4.set_xlabel('pos.-neg. similarity delta') ax4.legend() plt.savefig(self.out_file + '.png') self.save_triplets(all_triplets_idxs, **self.save_kwargs)
[docs]@logged class OldTripletSampler(BaseTripletSampler): """Used to be the monstrous NeighborTripletTraintest. Performs well on large spaces, less well on smaller ones""" def __init__(self, *args, **kwargs): BaseTripletSampler.__init__(self, *args, **kwargs)
[docs] def generate_triplets(self, f_per=0.1, t_per=0.01, mean_center_x=True, shuffle=True, check_distances=True, split_names=['train', 'test'], split_fractions=[.8, .2], suffix='eval', x_dtype=np.float32, y_dtype=np.float32, num_triplets=1e6, limit=100000, cpu=1): """Sample triplets using an approach suited for large spaces. Args: num_triplets(int): Total number of triplets to generate. """ try: import faiss from chemicalchecker.core.signature_data import DataSignature except ImportError as err: raise err faiss.omp_set_num_threads(cpu) neighbors_sign = self.triplet_signature out_file = self.out_file ink_keys = self.mol_signature.keys _, X = self.mol_signature.get_vectors(ink_keys, dataset_name='x') # train test validation splits if len(split_names) != len(split_fractions): raise Exception("Split names and fraction should be same amount.") # Load neigh matrix and shuffle it neighbors_matrix = neighbors_sign[:] shuffle_idx = np.arange(neighbors_matrix.shape[0]) np.random.shuffle(shuffle_idx) OldTripletSampler.__log.debug('%s %s' % ( len(neighbors_matrix), str(X.shape))) if len(neighbors_matrix) != X.shape[0]: raise Exception("neighbors_matrix should be same length as X.") neighbors_matrix = neighbors_matrix[shuffle_idx] X = X[shuffle_idx] X_inks = np.array(neighbors_sign.keys)[shuffle_idx] OldTripletSampler.__log.debug( "{:<20} shape: {:>10}".format("input X", str(X.shape))) fullpath, _ = os.path.split(out_file) redundancy_path = os.path.join(fullpath, "redundancy_dict.pkl") # reduce redundancy, keep full-ref mapping # if not os.path.isfile(redundancy_path): OldTripletSampler.__log.info("Reducing redundancy") rnd = RNDuplicates(cpu=cpu) _, ref_matrix, full_ref_map = rnd.remove( neighbors_matrix.astype(np.float32)) ref_full_map = dict() for key, value in full_ref_map.items(): ref_full_map.setdefault(value, list()).append(key) full_refid_map = dict( zip(rnd.final_ids, np.arange(len(rnd.final_ids)))) refid_full_map = {full_refid_map[k]: v for k, v in ref_full_map.items()} # Limit signatures by limit value size_original_ref_matrix = len(ref_matrix) OldTripletSampler.__log.info( "Original size ref_matrix: %s" % size_original_ref_matrix) OldTripletSampler.__log.info("Limit of %s" % limit) ref_matrix = ref_matrix[:limit] OldTripletSampler.__log.info("Final size: %s" % len(ref_matrix)) # Set triplet_factors triplet_per_mol = max( [int(np.ceil(num_triplets / ref_matrix.shape[0])), 3]) easy_triplet_per_mol = max([int(np.ceil(triplet_per_mol * 0.8)), 1]) # triplet_per_mol - (2 * easy_triplet_per_mol) medi_triplet_per_mol = max([int(np.ceil(triplet_per_mol * 0.15)), 1]) # easy_triplet_per_mol hard_triplet_per_mol = max( [triplet_per_mol - (easy_triplet_per_mol + medi_triplet_per_mol), 1]) OldTripletSampler.__log.info( "Triplet_per_mol: %s" % triplet_per_mol) OldTripletSampler.__log.info( "E triplet per mol: %s" % easy_triplet_per_mol) OldTripletSampler.__log.info( "M triplet per mol: %s" % medi_triplet_per_mol) OldTripletSampler.__log.info( "H triplet per mol: %s" % hard_triplet_per_mol) OldTripletSampler.__log.info("Triplet_per_mol: %s" % ( easy_triplet_per_mol + medi_triplet_per_mol + hard_triplet_per_mol)) assert(triplet_per_mol <= (easy_triplet_per_mol + medi_triplet_per_mol + hard_triplet_per_mol)) # split chunks, get indexes of chunks for each split chunk_size = int(max(1, np.floor(ref_matrix.shape[0] / 100))) split_chunk_idx = self.get_split_indeces( int(np.floor(ref_matrix.shape[0] / chunk_size)) + 1, split_fractions) tot_split = float(sum([len(i) for i in split_chunk_idx])) real_fracs = ['%.2f' % (len(i) / tot_split) for i in split_chunk_idx] OldTripletSampler.__log.info( 'Fractions used: %s', ' '.join(real_fracs)) # split ref matrix, keep ref-split mapping nr_matrix = dict() split_ref_map = dict() for split_name, chunks in zip(split_names, split_chunk_idx): # need total size and mapping of chunks src_dst = list() total_size = 0 for dst, src in enumerate(sorted(chunks)): # source chunk start-end src_start = src * chunk_size src_end = (src * chunk_size) + chunk_size # check current chunk size to avoid overflowing curr_chunk_size = chunk_size if src_end > ref_matrix.shape[0]: src_end = ref_matrix.shape[0] curr_chunk_size = src_end - src_start # update total size total_size += curr_chunk_size # destination start-end dst_start = dst * chunk_size dst_end = (dst * chunk_size) + curr_chunk_size src_slice = (int(src_start), int(src_end)) dst_slice = (int(dst_start), int(dst_end)) src_dst.append((src_slice, dst_slice)) # create chunk matrix cols = ref_matrix.shape[1] nr_matrix[split_name] = np.zeros((int(total_size), cols), dtype=ref_matrix.dtype) split_ref_map[split_name] = dict() ref_idxs = np.arange(ref_matrix.shape[0]) for src_slice, dst_slice in tqdm(src_dst): src_chunk = slice(*src_slice) dst_chunk = slice(*dst_slice) OldTripletSampler.__log.debug( "writing src: %s to dst: %s" % (src_slice, dst_slice)) ref_src_chunk = ref_idxs[src_chunk] ref_dst_chunk = ref_idxs[dst_chunk] for src_id, dst_id in zip(ref_src_chunk, ref_dst_chunk): split_ref_map[split_name][dst_id] = src_id nr_matrix[split_name][dst_chunk] = ref_matrix[src_chunk] OldTripletSampler.__log.debug( "nr_matrix %s %s", split_name, nr_matrix[split_name].shape) # for each split generate NN OldTripletSampler.__log.info('Generating NN indexes') NN = dict() for split_name in split_names: # create faiss index NN[split_name] = faiss.IndexFlatL2(nr_matrix[split_name].shape[1]) # add data NN[split_name].add(nr_matrix[split_name]) # mean centering columns if mean_center_x: OldTripletSampler.__log.info('Mean centering X') scaler = RobustScaler() X = scaler.fit_transform(X) if suffix is None: scaler_file = os.path.join(os.path.split(out_file)[0], 'scaler.pkl') else: scaler_file = os.path.join(os.path.split(out_file)[0], 'scaler_%s.pkl' % suffix) pickle.dump(scaler, open(scaler_file, 'wb')) # create dataset OldTripletSampler.__log.info('Traintest saving to %s', out_file) combo_dists = dict() with h5py.File(out_file, "w") as fh: fh.create_dataset('x', data=X) fh.create_dataset('x_ink', data=np.array( X_inks, dtype=DataSignature.string_dtype())) if mean_center_x: fh.create_dataset( 'scaler', data=np.array([scaler_file], dtype=DataSignature.string_dtype())) # for each split combo generate triplets where [anchor, positive, # negative] combos = itertools.combinations_with_replacement(split_names, 2) for split1, split2 in combos: combo = '_'.join([split1, split2]) OldTripletSampler.__log.debug("SPLIT: %s" % combo) # define F and T according to the split that is being used LB = 10000 UB = 100000 TMAX = 50 TMIN = 5 def get_t_max(N): N = np.clip(N, LB, UB) a = (TMAX - TMIN) / (LB - UB) b = TMIN - a * UB return int(a * N + b) t_limit = get_t_max(size_original_ref_matrix) f_limit = 300 T = int( np.clip(t_per * nr_matrix[split2].shape[0], 10, t_limit)) F = np.clip(10 * T, 200, f_limit) F = int(min(F, (nr_matrix[split2].shape[0] - 1))) OldTripletSampler.__log.info("T per: %s" % (t_per)) OldTripletSampler.__log.info("F and T: %s %s" % (F, T)) assert(T < F) OldTripletSampler.__log.info("Searching Neighbors") # remove self neighbors when splits are the same if split1 == split2: # search NN in chunks neig_idxs = list() csize = 10000 for i in tqdm(range(0, len(nr_matrix[split2]), csize)): chunk = slice(i, i + csize) _, neig_idxs_chunk = NN[split1].search( nr_matrix[split2][chunk], F + 1) neig_idxs.append(neig_idxs_chunk) neig_idxs = np.vstack(neig_idxs) # the nearest neig between same groups is the molecule # itself # assert(all(neig_idxs[:, 0] == # np.arange(0, len(neig_idxs)))) neig_idxs = neig_idxs[:, 1:] else: _, neig_idxs = NN[split1].search( nr_matrix[split2], F) # get probabilities for T t_prob = ((np.arange(T + 1)[::-1]) / np.sum(np.arange(T + 1)))[:T] assert(sum(t_prob) > 0.99) # save list of split indeces # anchors_split = np.repeat(np.arange(len(neig_idxs)), # triplet_per_mol) easy_a_split = list() easy_p_split = list() easy_n_split = list() medi_a_split = list() medi_p_split = list() medi_n_split = list() hard_a_split = list() hard_p_split = list() hard_n_split = list() OldTripletSampler.__log.info("Generating triplets") nn_set = set(range(neig_idxs.shape[0])) # idx refere split2, all else to split1 for idx, row in enumerate(tqdm(neig_idxs)): # Add acnhors per type of triplet easy_a_split.extend(np.repeat(idx, easy_triplet_per_mol)) medi_a_split.extend(np.repeat(idx, medi_triplet_per_mol)) hard_a_split.extend(np.repeat(idx, hard_triplet_per_mol)) # positives are samples from tot T NNs for each category # Easy e_p_indexes = np.random.choice( T, easy_triplet_per_mol, replace=True, p=t_prob) positives = neig_idxs[idx, e_p_indexes] easy_p_split.extend(positives) # Medium m_p_indexes = np.random.choice( T, medi_triplet_per_mol, replace=True, p=t_prob) positives = neig_idxs[idx, m_p_indexes] medi_p_split.extend(positives) # Hard h_p_indexes = np.random.choice( T, hard_triplet_per_mol, replace=True, p=t_prob) positives = neig_idxs[idx, h_p_indexes] hard_p_split.extend(positives) """ p_indexes = np.random.choice(T, triplet_per_mol, replace=True, p=t_prob) positives = neig_idxs[idx, p_indexes] easy_p_split.extend(positives) medi_p_split.extend(positives) hard_p_split.extend(positives)""" # medium negatives are sampled from F (in NN but not T) m_negatives = np.random.choice( neig_idxs[idx][T:], medi_triplet_per_mol, replace=True) medi_n_split.extend(m_negatives) # hard negatives are sampled from T (but higher than # positives) hn_shifts = np.random.choice( int(np.ceil(T / 2)), hard_triplet_per_mol, replace=True) + 1 hn_indexes = hn_shifts + h_p_indexes # with small T we still have to avoid getting out of T # range off_range = np.where(hn_indexes >= neig_idxs.shape[1]) hn_indexes[off_range] = neig_idxs.shape[1] - 1 h_negatives = neig_idxs[idx, hn_indexes] hard_n_split.extend(h_negatives) # easy negatives (sampled from everywhere; in general # should be fine altough it may sample positives...) e_negatives = np.random.choice( len(neig_idxs), easy_triplet_per_mol, replace=True) easy_n_split.extend(e_negatives) # get reference ids OldTripletSampler.__log.info("Mapping triplets") # anchors_ref = [split_ref_map[split2][x] for x in # anchors_split] easy_a_ref = [split_ref_map[split2][x] for x in easy_a_split] easy_p_ref = [split_ref_map[split1][x] for x in easy_p_split] easy_n_ref = [split_ref_map[split1][x] for x in easy_n_split] medi_a_ref = [split_ref_map[split2][x] for x in medi_a_split] medi_p_ref = [split_ref_map[split1][x] for x in medi_p_split] medi_n_ref = [split_ref_map[split1][x] for x in medi_n_split] hard_a_ref = [split_ref_map[split2][x] for x in hard_a_split] hard_p_ref = [split_ref_map[split1][x] for x in hard_p_split] hard_n_ref = [split_ref_map[split1][x] for x in hard_n_split] # choose random from full analogs OldTripletSampler.__log.info( "Resolving multiple options") easy_a_full = np.array( [np.random.choice(refid_full_map[x]) for x in easy_a_ref]) easy_p_full = np.array( [np.random.choice(refid_full_map[x]) for x in easy_p_ref]) easy_n_full = np.array( [np.random.choice(refid_full_map[x]) for x in easy_n_ref]) medi_a_full = np.array( [np.random.choice(refid_full_map[x]) for x in medi_a_ref]) medi_p_full = np.array( [np.random.choice(refid_full_map[x]) for x in medi_p_ref]) medi_n_full = np.array( [np.random.choice(refid_full_map[x]) for x in medi_n_ref]) hard_a_full = np.array( [np.random.choice(refid_full_map[x]) for x in hard_a_ref]) hard_p_full = np.array( [np.random.choice(refid_full_map[x]) for x in hard_p_ref]) hard_n_full = np.array( [np.random.choice(refid_full_map[x]) for x in hard_n_ref]) # stack triplets OldTripletSampler.__log.info("Stacking triplets") easy_triplets = np.vstack( (easy_a_full, easy_p_full, easy_n_full)).T medium_triplets = np.vstack( (medi_a_full, medi_p_full, medi_n_full)).T hard_triplets = np.vstack( (hard_a_full, hard_p_full, hard_n_full)).T triplets = np.vstack( (easy_triplets, medium_triplets, hard_triplets)) # stack categories y = np.hstack(( np.full((easy_triplets.shape[0],), 0), np.full((medium_triplets.shape[0],), 1), np.full((hard_triplets.shape[0],), 2))) unique_ids = np.unique(triplets) OldTripletSampler.__log.info( 'Using %s molecules in triplets' % len(unique_ids)) # get inchikeys of test or train molecules if split1 == split2: ink_ids = np.array(sorted(unique_ids)) split_inks = np.array(np.array(X_inks[ink_ids]), DataSignature.string_dtype()) ds_name = "keys_%s" % split1 fh.create_dataset(ds_name, data=split_inks) # save to h5 ds_name = "t_%s_%s" % (split1, split2) ys_name = "y_%s_%s" % (split1, split2) _, unique_idx = np.unique(triplets, axis=0, return_index=True) # check for all categories to still be there if len(np.unique(y[unique_idx])) < 3: OldTripletSampler.__log.warning( 'Very few molecules available... triplets will be ' 'repeated in the difficulty categories.') # this can happend when we have very few molecules ty = np.hstack([triplets, np.expand_dims(y, 1)]) tripletsy = np.unique(ty, axis=0) triplets = tripletsy[:, :3] y = tripletsy[:, -1] else: triplets = triplets[unique_idx] y = y[unique_idx] # shuffling shuffle_idxs = np.arange(triplets.shape[0]) if shuffle: np.random.shuffle(shuffle_idxs) triplets = triplets[shuffle_idxs] y = y[shuffle_idxs] OldTripletSampler.__log.info( 'Using %s unique triplets' % len(y)) OldTripletSampler.__log.info( 'writing Name: %s E: %s M: %s H: %s T: %s', ds_name, y[ y == 0].shape[0], y[y == 1].shape[0], y[y == 2].shape[0], triplets.shape[0]) fh.create_dataset(ds_name, data=triplets) fh.create_dataset(ys_name, data=y) if check_distances: import matplotlib.pyplot as plt import seaborn as sns num_of_dist_errors = 0 dis_limit = min(50000, len(shuffle_idxs)) dists = np.empty((dis_limit, 3)) for idx, row in enumerate(shuffle_idxs[:dis_limit]): anchor = neighbors_matrix[triplets[row][0]] positive = neighbors_matrix[triplets[row][1]] negative = neighbors_matrix[triplets[row][2]] category = y[row] dis_ap = euclidean(anchor, positive) dis_an = euclidean(anchor, negative) dists[idx] = [dis_ap, dis_an, category] if (dis_ap > dis_an): # OldTripletSampler.__log.warning( # 'DIST ERROR %s %.2f %.2f %i' % # (triplets[row], dis_ap, dis_an, category)) num_of_dist_errors += 1 OldTripletSampler.__log.warning( 'TOTAL DIST ERRORS: %s' % num_of_dist_errors) assert(len(np.unique(dists[:, 2])) == 3) combo_dists[combo] = dists if check_distances: fig, axes = plt.subplots( 3, 3, sharex=True, sharey=False, figsize=(10, 10)) ax_idx = 0 cat_names = ['easy', 'medium', 'hard'] for combo, dists in combo_dists.items(): for cat_id in [0, 1, 2]: cat_mask = dists[:, 2] == cat_id ax = axes.flatten()[ax_idx] ax.set_title('%s %s' % (combo, cat_names[cat_id])) sns.histplot(dists[cat_mask, 0], label='AP', color='green', kde=True, ax=ax) sns.histplot(dists[cat_mask, 1], label='AN', color='red', kde=True, ax=ax) ax.legend() ax_idx += 1 if suffix is None: plot_file = os.path.join( os.path.split(out_file)[0], 'distances.png') else: plot_file = os.path.join( os.path.split(out_file)[0], 'distances_%s.png' % suffix) plt.savefig(plot_file) plt.close() OldTripletSampler.__log.info( 'OldTripletSampler saved to %s', out_file)