chemicalchecker.util.splitter.neighbortriplet.TripletIterator

class TripletIterator(hdf5_file, split, replace_nan=None)[source]

Bases: object

TripletIterator class.

Initialize a TripletIterator instance.

This allows iterating on train/test splits of triplets generated via a TripletSampler class. We assume the file is containing different splits. e.g. “x_train”, “y_train”, “x_test”, …

Parameters:
  • hdf5_file (str) – the path to a file generated via TripletSample class.

  • split (str) – The H5 typically contains ‘train’ or ‘test’ splits, the iterator will focus on that split.

  • replace_nan (float) – If None, nothing is replaced, otherwise NaN are replaced by the value specified.

Methods

close

Close the HDF5.

generator_fn

Return the generator function that iterates on batches.

get_all_p

Get full X.

get_all_x

Get full X.

get_all_y

Get full Y.

get_split_names

Return the name of the splits.

get_t

Get a batch of X.

get_ty_shapes

Return the shapes of X an Y.

get_x_columns

Get full X.

get_xy_shapes

Return the shapes of X an Y.

get_y

Get a batch of Y.

open

Open the HDF5.

close()[source]

Close the HDF5.

static generator_fn(file_name, split, replace_nan=None, batch_size=None, shuffle=True, train=True, augment_fn=None, augment_kwargs={}, mask_fn=None, trim_mask=None, sharedx=None, sharedx_trim=None, onlyself_notself=False, p_self_decay=False)[source]

Return the generator function that iterates on batches.

A TripletIterator on the specified file and split is initialized, we allow for additional masking, shared X matrix and more.

Parameters:
  • file_name (str) – The H5 generated via a TripletSampler class

  • split (str) – One of ‘train_train’, ‘train_test’, or ‘test_test’

  • replace_nan (bool) – Value used for NaN replacement.

  • batch_size (int) – Size of a batch of data.

  • shuffle (bool) – Shuffle the order of batches.

  • train (bool) – At train time the augment function is applied.

  • augment_fn (func) – Function to augment data.

  • augment_kwargs (dict) – Parameters for the augment function.

  • mask_fn (func) – Function to mask data while iterating.

  • trim_mask (array) – Initial masking of data (spaces are excluded).

  • sharedx (matrix) – The preloaded X matrix.

  • sharedx_trim (matrix) – The preloaded and pre-trimmed X matrix.

  • onlyself_notself (bool) – when True the iterator will return a quintuplet with also only_self and not_self.

  • p_self_decay (bool) – when True the p_self probability will decay within the batch, and restart at each batch.

get_all_p()[source]

Get full X.

get_all_x()[source]

Get full X.

get_all_y()[source]

Get full Y.

get_split_names()[source]

Return the name of the splits.

get_t(beg_idx, end_idx)[source]

Get a batch of X.

get_ty_shapes()[source]

Return the shapes of X an Y.

get_x_columns(mask)[source]

Get full X.

get_xy_shapes()[source]

Return the shapes of X an Y.

get_y(beg_idx, end_idx)[source]

Get a batch of Y.

open()[source]

Open the HDF5.