chemicalchecker.util.splitter.neighborpair.NeighborPairTraintest
- class NeighborPairTraintest(hdf5_file, split, nr_neig=10, replace_nan=None)[source]
Bases:
object
NeighborPairTraintest class.
Initialize a NeighborPairTraintest instnace.
We assume the file is containing diffrent splits. e.g. “x_train”, “y_train”, “x_test”, …
Methods
Close the HDF5.
Create the HDF5 file with validation splits.
Return the generator function that we can query for batches.
Get full X.
Get full X.
Get full Y.
Get a batch of X.
Get a batch of X and Y.
Return the shpaes of X an Y.
Get random indeces for different splits.
Return the name of the splits.
Return the shpaes of X an Y.
Get a batch of Y.
Open the HDF5.
- static create(X, out_file, neigbors_matrix=None, pos_neighbors=10, neg_neighbors=100, scaler_dest=None, mean_center_x=True, shuffle=True, check_distances=True, split_names=['train', 'test'], split_fractions=[0.8, 0.2], x_dtype=<class 'numpy.float32'>, y_dtype=<class 'numpy.float32'>, debug_test=False)[source]
Create the HDF5 file with validation splits.
- Parameters:
X (numpy.ndarray) – features to train from.
out_file (str) – path of the h5 file to write.
neigbors_matrix (numpy.ndarray) – matrix for computing neighbors.
neigbors (int) – Number of positive neighbors to include.
mean_center_x (bool) – center each feature on its mean?
shuffle (bool) – Shuffle positive and negatives.
split_names (list(str)) – names for the split of data.
split_fractions (list(float)) – fraction of data in each split.
x_dtype (type) – numpy data type for X.
y_dtype (type) – numpy data type for Y (np.float32 for regression, np.int32 for classification.
- static generator_fn(file_name, split, batch_size=None, only_x=False, replace_nan=None, augment_scale=1, augment_fn=None, augment_kwargs={}, mask_fn=None, shuffle=True, sharedx=None)[source]
Return the generator function that we can query for batches.
file_name(str): The H5 generated via create split(str): One of ‘train_train’, ‘train_test’, or ‘test_test’ batch_size(int): Size of a batch of data. only_x(bool): Usually when predicting only X are useful. replace_nan(bool): Value used for NaN replacement. augment_scale(int): Augment the train size by this factor. augment_fn(func): Function to augment data. augment_kwargs(dict): Parameters for the aument functions.