"""Splitter on Neighbor pairs."""
import os
import h5py
import pickle
import itertools
import numpy as np
from tqdm import tqdm
from scipy.spatial.distance import euclidean
from sklearn.preprocessing import RobustScaler
from chemicalchecker.util import logged
from chemicalchecker.util.remove_near_duplicates import RNDuplicates
[docs]@logged
class NeighborPairTraintest(object):
"""NeighborPairTraintest class."""
def __init__(self, hdf5_file, split, nr_neig=10, replace_nan=None):
"""Initialize a NeighborPairTraintest instnace.
We assume the file is containing diffrent splits.
e.g. "x_train", "y_train", "x_test", ...
"""
self._file = hdf5_file
self._f = None
self.replace_nan = replace_nan
self.nr_neig = nr_neig
self.x_name = "x"
if split is None:
self.p_name = "p"
self.y_name = "y"
else:
self.p_name = "p_%s" % split
self.y_name = "y_%s" % split
available_splits = self.get_split_names()
if split not in available_splits:
raise Exception("Split '%s' not found in %s!" %
(split, str(available_splits)))
[docs] def get_py_shapes(self):
"""Return the shpaes of X an Y."""
self.open()
p_shape = self._f[self.p_name].shape
y_shape = self._f[self.y_name].shape
self.close()
return p_shape, y_shape
[docs] def get_xy_shapes(self):
"""Return the shpaes of X an Y."""
self.open()
x_shape = self._f[self.x_name].shape
y_shape = self._f[self.y_name].shape
self.close()
return x_shape, y_shape
[docs] def get_split_names(self):
"""Return the name of the splits."""
self.open()
if "split_names" in self._f:
split_names = [a.decode() for a in self._f["split_names"]]
else:
split_names = ['train', 'test']
self.close()
combos = itertools.combinations_with_replacement(split_names, 2)
return ['_'.join(x) for x in combos]
[docs] def open(self):
"""Open the HDF5."""
self._f = h5py.File(self._file, 'r')
#self.__log.info("HDF5 open %s", self._file)
[docs] def close(self):
"""Close the HDF5."""
try:
self._f.close()
#self.__log.info("HDF5 close %s", self._file)
except AttributeError:
self.__log.error('HDF5 file is not open yet.')
[docs] def get_py(self, beg_idx, end_idx):
"""Get a batch of X and Y."""
features = self._f[self.p_name][beg_idx: end_idx]
# handle NaNs
if self.replace_nan is not None:
features[np.where(np.isnan(features))] = self.replace_nan
labels = self._f[self.y_name][beg_idx: end_idx]
return features, labels
[docs] def get_p(self, beg_idx, end_idx):
"""Get a batch of X."""
features = self._f[self.p_name][beg_idx: end_idx]
# handle NaNs
if self.replace_nan is not None:
features[np.where(np.isnan(features))] = self.replace_nan
return features
[docs] def get_y(self, beg_idx, end_idx):
"""Get a batch of Y."""
features = self._f[self.y_name][beg_idx: end_idx]
# handle NaNs
if self.replace_nan is not None:
features[np.where(np.isnan(features))] = self.replace_nan
return features
[docs] def get_all_x(self):
"""Get full X."""
features = self._f[self.x_name][:]
# handle NaNs
if self.replace_nan is not None:
features[np.where(np.isnan(features))] = self.replace_nan
return features
[docs] def get_all_p(self):
"""Get full X."""
features = self._f[self.p_name][:]
# handle NaNs
if self.replace_nan is not None:
features[np.where(np.isnan(features))] = self.replace_nan
return features
[docs] def get_all_y(self):
"""Get full Y."""
labels = self._f[self.y_name][:]
return labels
[docs] @staticmethod
def get_split_indeces(rows, fractions):
"""Get random indeces for different splits."""
if not sum(fractions) == 1.0:
raise Exception("Split fractions should sum to 1.0")
# shuffle indeces
idxs = list(range(rows))
np.random.shuffle(idxs)
# from frequs to indices
splits = np.cumsum(fractions)
splits = splits[:-1]
splits *= len(idxs)
splits = splits.round().astype(np.int)
return np.split(idxs, splits)
[docs] @staticmethod
def create(X, out_file, neigbors_matrix=None, pos_neighbors=10,
neg_neighbors=100, scaler_dest=None,
mean_center_x=True, shuffle=True,
check_distances=True,
split_names=['train', 'test'], split_fractions=[.8, .2],
x_dtype=np.float32, y_dtype=np.float32, debug_test=False):
"""Create the HDF5 file with validation splits.
Args:
X(numpy.ndarray): features to train from.
out_file(str): path of the h5 file to write.
neigbors_matrix(numpy.ndarray): matrix for computing neighbors.
neigbors(int): Number of positive neighbors to include.
mean_center_x(bool): center each feature on its mean?
shuffle(bool): Shuffle positive and negatives.
split_names(list(str)): names for the split of data.
split_fractions(list(float)): fraction of data in each split.
x_dtype(type): numpy data type for X.
y_dtype(type): numpy data type for Y (np.float32 for regression,
np.int32 for classification.
"""
try:
import faiss
except ImportError as err:
raise err
NeighborPairTraintest.__log.debug(
"{:<20} shape: {:>10}".format("input X", str(X.shape)))
# train test validation splits
if len(split_names) != len(split_fractions):
raise Exception("Split names and fraction should be same amount.")
# override parameters for debug
if debug_test:
pos_neigbors = 10
split_names = ['train', 'test']
split_fractions = [.8, .2]
# the neigbors_matrix is optional
if neigbors_matrix is None:
neigbors_matrix = X
else:
if len(neigbors_matrix) != len(X):
raise Exception("neigbors_matrix shuold be same length as X.")
# reduce redundancy, keep full-ref mapping
rnd = RNDuplicates(cpu=10)
_, ref_matrix, full_ref_map = rnd.remove(neigbors_matrix)
ref_full_map = np.array(rnd.final_ids)
rows = ref_matrix.shape[0]
if debug_test:
# we'll use this to later check that the mapping went fine
test = faiss.IndexFlatL2(neigbors_matrix.shape[1])
test.add(neigbors_matrix)
tmp = dict()
for key, value in full_ref_map.items():
tmp.setdefault(value, list()).append(key)
max_repeated = max([len(x) for x in tmp.values()])
_, test_neig = test.search(neigbors_matrix, max_repeated + 1)
# split chunks, get indeces of chunks for each split
chunk_size = np.floor(rows / 100)
split_chunk_idx = NeighborPairTraintest.get_split_indeces(
int(np.floor(rows / chunk_size)) + 1,
split_fractions)
# split ref matrix, keep ref-split mapping
nr_matrix = dict()
split_ref_map = dict()
for split_name, chunks in zip(split_names, split_chunk_idx):
# need total size and mapping of chunks
src_dst = list()
total_size = 0
for dst, src in enumerate(sorted(chunks)):
# source chunk start-end
src_start = src * chunk_size
src_end = (src * chunk_size) + chunk_size
# check current chunk size to avoid overflowing
curr_chunk_size = chunk_size
if src_end > ref_matrix.shape[0]:
src_end = ref_matrix.shape[0]
curr_chunk_size = src_end - src_start
# update total size
total_size += curr_chunk_size
# destination start-end
dst_start = dst * chunk_size
dst_end = (dst * chunk_size) + curr_chunk_size
src_slice = (int(src_start), int(src_end))
dst_slice = (int(dst_start), int(dst_end))
src_dst.append((src_slice, dst_slice))
# create chunk matrix
cols = ref_matrix.shape[1]
nr_matrix[split_name] = np.zeros((int(total_size), cols),
dtype=ref_matrix.dtype)
split_ref_map[split_name] = dict()
ref_idxs = np.arange(ref_matrix.shape[0])
for src_slice, dst_slice in tqdm(src_dst):
src_chunk = slice(*src_slice)
dst_chunk = slice(*dst_slice)
NeighborPairTraintest.__log.debug(
"writing src: %s to dst: %s" % (src_slice, dst_slice))
ref_src_chunk = ref_idxs[src_chunk]
ref_dst_chunk = ref_idxs[dst_chunk]
for src_id, dst_id in zip(ref_src_chunk, ref_dst_chunk):
split_ref_map[split_name][dst_id] = src_id
nr_matrix[split_name][dst_chunk] = ref_matrix[src_chunk]
NeighborPairTraintest.__log.debug(
"nr_matrix %s %s", split_name, nr_matrix[split_name].shape)
# for each split generate NN
NN = dict()
for split_name in split_names:
# create faiss index
NN[split_name] = faiss.IndexFlatL2(nr_matrix[split_name].shape[1])
# add data
NN[split_name].add(nr_matrix[split_name])
# mean centering columns
if mean_center_x:
scaler = RobustScaler()
X = scaler.fit_transform(X)
if scaler_dest is None:
scaler_dest = os.path.split(out_file)[0]
scaler_file = os.path.join(scaler_dest, 'scaler.pkl')
pickle.dump(scaler, open(scaler_file, 'wb'))
# create dataset
NeighborPairTraintest.__log.info('Traintest saving to %s', out_file)
with h5py.File(out_file, "w") as fh:
fh.create_dataset('x', data=X)
# for each split combo generate pairs and ys
combos = itertools.combinations_with_replacement(split_names, 2)
#combos = [('train', 'train'), ('train', 'test'), ('test', 'test')]
for split1, split2 in combos:
# handle case where we ask more neig then molecules
if pos_neighbors > nr_matrix[split2].shape[0]:
combo_neig = nr_matrix[split2].shape[0]
NeighborPairTraintest.__log.warning(
'split %s is small, reducing pos_neighbors to %i' %
(split2, combo_neig))
else:
combo_neig = pos_neighbors
# remove self neighbors when splits are the same
if split1 == split2:
# search NN
dists, neig_idxs = NN[split1].search(nr_matrix[split2],
combo_neig + 1)
# the nearest neig between same groups is the molecule
# itself
assert(all(neig_idxs[:, 0] ==
np.arange(0, len(neig_idxs))))
neig_idxs = neig_idxs[:, 1:]
else:
_, neig_idxs = NN[split1].search(
nr_matrix[split2], combo_neig)
if debug_test:
_, neig_idxs = NN[split1].search(
nr_matrix[split2], combo_neig)
# get positive pairs
# get first pair element idxs
idxs1 = np.repeat(
np.arange(nr_matrix[split2].shape[0]), combo_neig)
# get second pair elements idxs
idxs2_1 = neig_idxs.flatten()
assert(len(idxs1) == len(idxs2_1))
# map back to reference
idxs1_ref = np.array([split_ref_map[split2][x] for x in idxs1])
idxs2_1_ref = np.array(
[split_ref_map[split1][x] for x in idxs2_1])
# map back to full
idxs1_full = ref_full_map[idxs1_ref]
idxs2_1_full = ref_full_map[idxs2_1_ref]
# oversample the positives
neg_pos_ratio = np.floor(neg_neighbors / combo_neig)
idxs1_full = np.repeat(idxs1_full, neg_pos_ratio)
idxs2_1_full = np.repeat(idxs2_1_full, neg_pos_ratio)
if debug_test:
# train ~= full
total = 0
ok = 0
for t1, t2 in zip(idxs1_full, idxs2_1_full):
total += 1
if t2 in test_neig[t1]:
ok += 1
print(split1, split2, ok / total, ok, combo_neig)
# get negative pairs
idxs2_0 = list()
for idx, row in enumerate(neig_idxs):
no_neig = set(range(nr_matrix[split2].shape[0])) - set(row)
# avoid fetching itself as negative!
if split1 == split2:
no_neig = no_neig - set([idx])
smpl = np.random.choice(
list(no_neig), neg_neighbors, replace=False)
idxs2_0.extend(smpl)
idxs2_0 = np.array(idxs2_0)
# map
idxs2_0_ref = np.array(
[split_ref_map[split1][x] for x in idxs2_0])
idxs2_0_full = ref_full_map[idxs2_0_ref]
# stack pairs and ys
pairs_1 = np.vstack((idxs1_full, idxs2_1_full)).T
y_1 = np.ones((1, pairs_1.shape[0]))
pairs_0 = np.vstack((idxs1_full, idxs2_0_full)).T
y_0 = np.zeros((1, pairs_0.shape[0]))
all_pairs = np.vstack((pairs_1, pairs_0))
all_ys = np.hstack((y_1, y_0)).T
# shuffling
shuffle_idxs = np.arange(all_ys.shape[0])
if shuffle:
np.random.shuffle(shuffle_idxs)
if check_distances:
import matplotlib.pyplot as plt
import seaborn as sns
d1 = list()
d0 = list()
for idx in range(len(all_ys))[:500]:
dist = euclidean(
neigbors_matrix[all_pairs[shuffle_idxs][idx][0]],
neigbors_matrix[all_pairs[shuffle_idxs][idx][1]])
if all_ys[shuffle_idxs][idx] == 1:
d1.append(dist)
else:
d0.append(dist)
name = "%s_%s" % (split1, split2)
plot_file = os.path.join(os.path.split(out_file)[0],
'dist_%s.png' % name)
sns.distplot(d1, label='1')
sns.distplot(d0, label='0')
plt.legend()
plt.savefig(plot_file)
plt.close()
# save to h5
ds_name = "p_%s_%s" % (split1, split2)
NeighborPairTraintest.__log.info(
'writing %s %s %s', ds_name, pairs_1.shape, pairs_0.shape)
fh.create_dataset(ds_name, data=all_pairs[shuffle_idxs])
ds_name = "y_%s_%s" % (split1, split2)
NeighborPairTraintest.__log.info(
'writing %s %s', ds_name, all_ys.shape)
fh.create_dataset(ds_name, data=all_ys[shuffle_idxs])
NeighborPairTraintest.__log.info(
'NeighborPairTraintest saved to %s', out_file)
[docs] @staticmethod
def generator_fn(file_name, split, batch_size=None, only_x=False,
replace_nan=None, augment_scale=1,
augment_fn=None, augment_kwargs={},
mask_fn=None, shuffle=True,
sharedx=None):
"""Return the generator function that we can query for batches.
file_name(str): The H5 generated via `create`
split(str): One of 'train_train', 'train_test', or 'test_test'
batch_size(int): Size of a batch of data.
only_x(bool): Usually when predicting only X are useful.
replace_nan(bool): Value used for NaN replacement.
augment_scale(int): Augment the train size by this factor.
augment_fn(func): Function to augment data.
augment_kwargs(dict): Parameters for the aument functions.
"""
reader = NeighborPairTraintest(file_name, split)
reader.open()
# read shapes
x_shape = reader._f[reader.x_name].shape
y_shape = reader._f[reader.y_name].shape
p_shape = reader._f[reader.p_name].shape
# read data types
x_dtype = reader._f[reader.x_name].dtype
y_dtype = reader._f[reader.y_name].dtype
# no batch size -> return everything
if not batch_size:
batch_size = p_shape[0]
# keep X in memory for resolving pairs quickly
if sharedx is not None:
X = sharedx
else:
NeighborPairTraintest.__log.debug('Reading X in memory')
X = reader.get_all_x()
# default mask is not mask
if mask_fn is None:
def mask_fn(*data):
return data
batch_beg_end = np.zeros((int(np.ceil(p_shape[0] / batch_size)), 2))
last = 0
for row in batch_beg_end:
row[0] = last
row[1] = last + batch_size
last = row[1]
batch_beg_end = batch_beg_end.astype(int)
NeighborPairTraintest.__log.debug('Generator ready')
def example_generator_fn():
# generator function yielding data
epoch = 0
batch_idx = 0
while True:
if batch_idx == len(batch_beg_end):
batch_idx = 0
epoch += 1
if shuffle:
np.random.shuffle(batch_beg_end)
# Traintest.__log.debug('EPOCH %i (caller: %s)', epoch,
# inspect.stack()[1].function)
beg_idx, end_idx = batch_beg_end[batch_idx]
pairs, y = reader.get_py(beg_idx, end_idx)
x1 = X[pairs[:, 0]]
x2 = X[pairs[:, 1]]
if augment_fn is not None:
tmp_x1 = list()
tmp_x2 = list()
tmp_y = list()
for i in range(augment_scale):
tmp_x1.append(augment_fn(
x1, **augment_kwargs))
tmp_x2.append(augment_fn(
x2, **augment_kwargs))
tmp_y.append(y)
x1 = np.vstack(tmp_x1)
x2 = np.vstack(tmp_x2)
y = np.vstack(tmp_y)
x1, x2, y = mask_fn(x1, x2, y)
if replace_nan is not None:
x1[np.where(np.isnan(x1))] = replace_nan
x2[np.where(np.isnan(x2))] = replace_nan
if only_x:
yield [x1, x2]
else:
yield [x1, x2], y
batch_idx += 1
pair_shape = (p_shape[0] * augment_scale, x_shape[1])
shapes = (pair_shape, pair_shape, y_shape)
dtypes = (x_dtype, x_dtype, y_dtype)
return shapes, dtypes, example_generator_fn