chemicalchecker.util.splitter.neighbortriplet.BaseTripletSampler
- class BaseTripletSampler(triplet_signature, mol_signature, out_file, save_kwargs={})[source]
Bases:
object
Base class for triplet samplers.
Methods
Get random indexes for different splits.
Save sampled triplets to file.
- save_triplets(triplets, mean_center_x=True, shuffle=True, split_names=['train', 'test'], split_fractions=[0.8, 0.2], suffix='eval', cpu=1, x_dtype=<class 'numpy.float32'>, y_dtype=<class 'numpy.float32'>)[source]
Save sampled triplets to file.
This function saves triplets performing the train test split, shuffling and normalization.
- Parameters:
triplets (array) – Indexes of anchor, positive and negative for each triplet.
mean_center_x (bool) – Normalize data columns wise.
shuffle (bool) – shuffle order of triplets.
split_names (list str) – names of the splits.
split_fractions (list float) – fraction of each split.
suffix (str) – suffix of the generated scaler.