chemicalchecker.util.splitter.traintest.Traintest

class Traintest(hdf5_file, split, replace_nan=None)[source]

Bases: object

Traintest class.

Initialize a Traintest instance.

We assume the file is containing diffrent splits. e.g. “x_train”, “y_train”, “x_test”, …

Methods

close

Close the HDF5.

create

Create the HDF5 file with validation splits for both X and Y.

create_signature_file

Create the HDF5 file with both X and Y, train and test.

generator_fn

Return the generator function that we can query for batches.

get_all_x

Get full X.

get_all_x_columns

Get all the X.

get_all_y

Get full Y.

get_split_indeces

Get random indeces for different splits.

get_split_names

Return the name of the splits.

get_sw

Get a batch of X.

get_x

Get a batch of X.

get_x_shapes

Return the shpaes of X.

get_xy

Get a batch of X and Y.

get_xy_shapes

Return the shpaes of X an Y.

get_y

Get a batch of Y.

open

Open the HDF5.

split_h5

Create the HDF5 file with validation splits from an input file.

split_h5_blocks

Create the HDF5 file with validation splits from an input file.

Attributes

sw_name

available_splits = self.get_split_names() if split not in available_splits: raise Exception("Split '%s' not found in %s!" % (split, str(available_splits)))

close()[source]

Close the HDF5.

static create(X, Y, out_file, split_names=['train', 'test', 'validation'], split_fractions=[0.8, 0.1, 0.1], x_dtype=<class 'numpy.float32'>, y_dtype=<class 'numpy.float32'>, chunk_size=10000)[source]

Create the HDF5 file with validation splits for both X and Y.

Parameters:
  • X (numpy.ndarray) – features to train from.

  • Y (numpy.ndarray) – labels to predict.

  • out_file (str) – path of the h5 file to write.

  • split_names (list(str)) – names for the split of data.

  • split_fractions (list(float)) – fraction of data in each split.

  • x_dtype (type) – numpy data type for X.

  • y_dtype (type) – numpy data type for Y (np.float32 for regression, np.int32 for classification.

static create_signature_file(sign_from, sign_to, out_filename)[source]

Create the HDF5 file with both X and Y, train and test.

static generator_fn(file_name, split, batch_size=None, only_x=False, sample_weights=False, shuffle=True, return_on_epoch=False)[source]

Return the generator function that we can query for batches.

get_all_x()[source]

Get full X.

get_all_x_columns(columns)[source]

Get all the X.

Parameters:

colums (tuple(int,int)) – start, stop indexes.

get_all_y()[source]

Get full Y.

static get_split_indeces(rows, fractions, random_state=None)[source]

Get random indeces for different splits.

get_split_names()[source]

Return the name of the splits.

get_sw(beg_idx, end_idx)[source]

Get a batch of X.

get_x(beg_idx, end_idx)[source]

Get a batch of X.

get_x_shapes()[source]

Return the shpaes of X.

get_xy(beg_idx, end_idx)[source]

Get a batch of X and Y.

get_xy_shapes()[source]

Return the shpaes of X an Y.

get_y(beg_idx, end_idx)[source]

Get a batch of Y.

open()[source]

Open the HDF5.

static split_h5(in_file, out_file, split_names=['train', 'test', 'validation'], split_fractions=[0.8, 0.1, 0.1], chunk_size=1000)[source]

Create the HDF5 file with validation splits from an input file.

Parameters:
  • in_file (str) – path of the h5 file to read from.

  • out_file (str) – path of the h5 file to write.

  • split_names (list(str)) – names for the split of data.

  • split_fractions (list(float)) – fraction of data in each split.

static split_h5_blocks(in_file, out_file, split_names=['train', 'test', 'validation'], split_fractions=[0.8, 0.1, 0.1], block_size=1000, datasets=None)[source]

Create the HDF5 file with validation splits from an input file.

Parameters:
  • in_file (str) – path of the h5 file to read from.

  • out_file (str) – path of the h5 file to write.

  • split_names (list(str)) – names for the split of data.

  • split_fractions (list(float)) – fraction of data in each split.

  • block_size (int) – size of the block to be used.

  • dataset (list) – only split the given dataset and ignore others.

sw_name

available_splits = self.get_split_names() if split not in available_splits:

raise Exception(“Split ‘%s’ not found in %s!” %

(split, str(available_splits)))