chemicalchecker.util.splitter.traintest.Traintest
- class Traintest(hdf5_file, split, replace_nan=None)[source]
Bases:
object
Traintest class.
Initialize a Traintest instance.
We assume the file is containing diffrent splits. e.g. “x_train”, “y_train”, “x_test”, …
Methods
Close the HDF5.
Create the HDF5 file with validation splits for both X and Y.
Create the HDF5 file with both X and Y, train and test.
Return the generator function that we can query for batches.
Get full X.
Get all the X.
Get full Y.
Get random indeces for different splits.
Return the name of the splits.
Get a batch of X.
Get a batch of X.
Return the shpaes of X.
Get a batch of X and Y.
Return the shpaes of X an Y.
Get a batch of Y.
Open the HDF5.
Create the HDF5 file with validation splits from an input file.
Create the HDF5 file with validation splits from an input file.
Attributes
available_splits = self.get_split_names() if split not in available_splits: raise Exception("Split '%s' not found in %s!" % (split, str(available_splits)))
- static create(X, Y, out_file, split_names=['train', 'test', 'validation'], split_fractions=[0.8, 0.1, 0.1], x_dtype=<class 'numpy.float32'>, y_dtype=<class 'numpy.float32'>, chunk_size=10000)[source]
Create the HDF5 file with validation splits for both X and Y.
- Parameters:
X (numpy.ndarray) – features to train from.
Y (numpy.ndarray) – labels to predict.
out_file (str) – path of the h5 file to write.
split_names (list(str)) – names for the split of data.
split_fractions (list(float)) – fraction of data in each split.
x_dtype (type) – numpy data type for X.
y_dtype (type) – numpy data type for Y (np.float32 for regression, np.int32 for classification.
- static create_signature_file(sign_from, sign_to, out_filename)[source]
Create the HDF5 file with both X and Y, train and test.
- static generator_fn(file_name, split, batch_size=None, only_x=False, sample_weights=False, shuffle=True, return_on_epoch=False)[source]
Return the generator function that we can query for batches.
- get_all_x_columns(columns)[source]
Get all the X.
- Parameters:
colums (tuple(int,int)) – start, stop indexes.
- static get_split_indeces(rows, fractions, random_state=None)[source]
Get random indeces for different splits.
- static split_h5(in_file, out_file, split_names=['train', 'test', 'validation'], split_fractions=[0.8, 0.1, 0.1], chunk_size=1000)[source]
Create the HDF5 file with validation splits from an input file.
- Parameters:
in_file (str) – path of the h5 file to read from.
out_file (str) – path of the h5 file to write.
split_names (list(str)) – names for the split of data.
split_fractions (list(float)) – fraction of data in each split.
- static split_h5_blocks(in_file, out_file, split_names=['train', 'test', 'validation'], split_fractions=[0.8, 0.1, 0.1], block_size=1000, datasets=None)[source]
Create the HDF5 file with validation splits from an input file.
- Parameters:
in_file (str) – path of the h5 file to read from.
out_file (str) – path of the h5 file to write.
split_names (list(str)) – names for the split of data.
split_fractions (list(float)) – fraction of data in each split.
block_size (int) – size of the block to be used.
dataset (list) – only split the given dataset and ignore others.
- sw_name
available_splits = self.get_split_names() if split not in available_splits:
- raise Exception(“Split ‘%s’ not found in %s!” %
(split, str(available_splits)))