chemicalchecker.util.splitter.neighbortriplet.BaseTripletSampler

class BaseTripletSampler(triplet_signature, mol_signature, out_file, save_kwargs={})[source]

Bases: object

Base class for triplet samplers.

Methods

get_split_indeces

Get random indexes for different splits.

save_triplets

Save sampled triplets to file.

get_split_indeces(rows, fractions)[source]

Get random indexes for different splits.

save_triplets(triplets, mean_center_x=True, shuffle=True, split_names=['train', 'test'], split_fractions=[0.8, 0.2], suffix='eval', cpu=1, x_dtype=<class 'numpy.float32'>, y_dtype=<class 'numpy.float32'>)[source]

Save sampled triplets to file.

This function saves triplets performing the train test split, shuffling and normalization.

Parameters:
  • triplets (array) – Indexes of anchor, positive and negative for each triplet.

  • mean_center_x (bool) – Normalize data columns wise.

  • shuffle (bool) – shuffle order of triplets.

  • split_names (list str) – names of the splits.

  • split_fractions (list float) – fraction of each split.

  • suffix (str) – suffix of the generated scaler.