chemicalchecker.util.splitter.neighborerror.NeighborErrorTraintest

class NeighborErrorTraintest(hdf5_file, split, nr_neig=10, replace_nan=None)[source]

Bases: object

NeighborErrorTraintest class.

Initialize a NeighborErrorTraintest instance.

We assume the file is containing diffrent splits. e.g. “x_train”, “y_train”, “x_test”, …

Methods

close

Close the HDF5.

create

Create the HDF5 file with validation splits.

generator_fn

Return the generator function that we can query for batches.

get_all_p

Get full X.

get_all_x

Get full X.

get_all_y

Get full Y.

get_split_indeces

Get random indeces for different splits.

get_split_names

Return the name of the splits.

get_t

Get a batch of X.

get_ty_shapes

Return the shpaes of X an Y.

get_xy_shapes

Return the shpaes of X an Y.

get_y

Get a batch of Y.

open

Open the HDF5.

close()[source]

Close the HDF5.

static create(to_predict, out_file, predict_fn, subsample_fn, max_x=10000, split_names=['train', 'test'], split_fractions=[0.8, 0.2], suffix='eval', x_dtype=<class 'numpy.float32'>, y_dtype=<class 'numpy.float32'>)[source]

Create the HDF5 file with validation splits.

Parameters:
  • X (numpy.ndarray) – features to train from.

  • out_file (str) – path of the h5 file to write.

  • neigbors_matrix (numpy.ndarray) – matrix for computing neighbors.

  • neigbors (int) – Number of positive neighbors to include.

  • mean_center_x (bool) – center each feature on its mean?

  • shuffle (bool) – Shuffle positive and negatives.

  • split_names (list(str)) – names for the split of data.

  • split_fractions (list(float)) – fraction of data in each split.

  • x_dtype (type) – numpy data type for X.

  • y_dtype (type) – numpy data type for Y (np.float32 for regression, np.int32 for classification.

static generator_fn(file_name, split, batch_size=None, replace_nan=None, augment_scale=1, augment_fn=None, augment_kwargs={}, mask_fn=None, shuffle=True, return_on_epoch=True, sharedx=None)[source]

Return the generator function that we can query for batches.

file_name(str): The H5 generated via create split(str): One of ‘train_train’, ‘train_test’, or ‘test_test’ batch_size(int): Size of a batch of data. replace_nan(bool): Value used for NaN replacement. augment_scale(int): Augment the train size by this factor. augment_fn(func): Function to augment data. augment_kwargs(dict): Parameters for the aument functions.

get_all_p()[source]

Get full X.

get_all_x()[source]

Get full X.

get_all_y()[source]

Get full Y.

static get_split_indeces(rows, fractions)[source]

Get random indeces for different splits.

get_split_names()[source]

Return the name of the splits.

get_t(beg_idx, end_idx)[source]

Get a batch of X.

get_ty_shapes()[source]

Return the shpaes of X an Y.

get_xy_shapes()[source]

Return the shpaes of X an Y.

get_y(beg_idx, end_idx)[source]

Get a batch of Y.

open()[source]

Open the HDF5.