chemicalchecker.util.splitter.pairtraintest.PairTraintest
- class PairTraintest(hdf5_file, split, nr_neig=10, replace_nan=None)[source]
Bases:
object
PairTraintest class.
Initialize a PairTraintest class.
We assume the file is containing diffrent splits. e.g. “x_train”, “y_train”, “x_test”, …
Methods
Close the HDF5.
Create the HDF5 file with validation splits.
generate_splits
Return the generator function that we can query for batches.
Get full X.
Get full X.
Get full X.
Get full Y.
Get a batch of X.
get_pos_neg
Get a batch of X and Y.
Return the shpaes of X an Y.
Get random indeces for different splits.
Return the name of the splits.
Return the shpaes of X an Y.
Get a batch of Y.
Open the HDF5.
- static create(X1, X2, pairs, split_names, out_file, mean_center_x=True, shuffle=True)[source]
Create the HDF5 file with validation splits.
- Parameters:
X (numpy.ndarray) – features to train from.
out_file (str) – path of the h5 file to write.
neigbors_matrix (numpy.ndarray) – matrix for computing neighbors.
neigbors (int) – Number of positive neighbors to include.
mean_center_x (bool) – center each feature on its mean?
shuffle (bool) – Shuffle positive and negatives.
split_names (list(str)) – names for the split of data.
split_fractions (list(float)) – fraction of data in each split.
x_dtype (type) – numpy data type for X.
y_dtype (type) – numpy data type for Y (np.float32 for regression, np.int32 for classification.
- static generator_fn(file_name, split, batch_size=None, only_x=False, replace_nan=None, mask_fn=None)[source]
Return the generator function that we can query for batches.
file_name(str): The H5 generated via create split(str): One of ‘train_train’, ‘train_test’, or ‘test_test’ batch_size(int): Size of a batch of data. only_x(bool): Usually when predicting only X are useful. replace_nan(bool): Value used for NaN replacement.