chemicalchecker.util.splitter.neighborpair.NeighborPairTraintest

class NeighborPairTraintest(hdf5_file, split, nr_neig=10, replace_nan=None)[source]

Bases: object

NeighborPairTraintest class.

Initialize a NeighborPairTraintest instnace.

We assume the file is containing diffrent splits. e.g. “x_train”, “y_train”, “x_test”, …

Methods

close

Close the HDF5.

create

Create the HDF5 file with validation splits.

generator_fn

Return the generator function that we can query for batches.

get_all_p

Get full X.

get_all_x

Get full X.

get_all_y

Get full Y.

get_p

Get a batch of X.

get_py

Get a batch of X and Y.

get_py_shapes

Return the shpaes of X an Y.

get_split_indeces

Get random indeces for different splits.

get_split_names

Return the name of the splits.

get_xy_shapes

Return the shpaes of X an Y.

get_y

Get a batch of Y.

open

Open the HDF5.

close()[source]

Close the HDF5.

static create(X, out_file, neigbors_matrix=None, pos_neighbors=10, neg_neighbors=100, scaler_dest=None, mean_center_x=True, shuffle=True, check_distances=True, split_names=['train', 'test'], split_fractions=[0.8, 0.2], x_dtype=<class 'numpy.float32'>, y_dtype=<class 'numpy.float32'>, debug_test=False)[source]

Create the HDF5 file with validation splits.

Parameters:
  • X (numpy.ndarray) – features to train from.

  • out_file (str) – path of the h5 file to write.

  • neigbors_matrix (numpy.ndarray) – matrix for computing neighbors.

  • neigbors (int) – Number of positive neighbors to include.

  • mean_center_x (bool) – center each feature on its mean?

  • shuffle (bool) – Shuffle positive and negatives.

  • split_names (list(str)) – names for the split of data.

  • split_fractions (list(float)) – fraction of data in each split.

  • x_dtype (type) – numpy data type for X.

  • y_dtype (type) – numpy data type for Y (np.float32 for regression, np.int32 for classification.

static generator_fn(file_name, split, batch_size=None, only_x=False, replace_nan=None, augment_scale=1, augment_fn=None, augment_kwargs={}, mask_fn=None, shuffle=True, sharedx=None)[source]

Return the generator function that we can query for batches.

file_name(str): The H5 generated via create split(str): One of ‘train_train’, ‘train_test’, or ‘test_test’ batch_size(int): Size of a batch of data. only_x(bool): Usually when predicting only X are useful. replace_nan(bool): Value used for NaN replacement. augment_scale(int): Augment the train size by this factor. augment_fn(func): Function to augment data. augment_kwargs(dict): Parameters for the aument functions.

get_all_p()[source]

Get full X.

get_all_x()[source]

Get full X.

get_all_y()[source]

Get full Y.

get_p(beg_idx, end_idx)[source]

Get a batch of X.

get_py(beg_idx, end_idx)[source]

Get a batch of X and Y.

get_py_shapes()[source]

Return the shpaes of X an Y.

static get_split_indeces(rows, fractions)[source]

Get random indeces for different splits.

get_split_names()[source]

Return the name of the splits.

get_xy_shapes()[source]

Return the shpaes of X an Y.

get_y(beg_idx, end_idx)[source]

Get a batch of Y.

open()[source]

Open the HDF5.