Source code for chemicalchecker.util.transform.metric_learn

"""Metric learning.

We bias the sign1 to respect triplets sampled from the whole CC
(semi-supervised) or from the space itself (unsupervised).
"""
import os
import pickle
import numpy as np

from .base import BaseTransform

from chemicalchecker.util import logged
from chemicalchecker.tool.siamese import SiameseTriplets
from chemicalchecker.util.splitter import PrecomputedTripletSampler


[docs]@logged class MetricLearn(BaseTransform): """MetricLearn class.""" def __init__(self, sign1, tmp, name, max_keys): try: from tensorflow.keras.layers import Dense except ImportError: raise ImportError("requires tensorflow " + "https://www.tensorflow.org/") params = { 'epochs': 3, 'dropouts': [None, None], 'layers_sizes': None, 'learning_rate': "auto", 'batch_size': 128, 'activations': ["tanh", "tanh"], 'layers': [Dense, Dense], 'loss_func': 'orthogonal_tloss', 'margin': 1.0, 'alpha': 1.0 } if tmp: raise Exception("Metric learn is not prepared to work with tmp") BaseTransform.__init__(self, sign1, name, max_keys, tmp) if params["layers_sizes"] is None: input_size = self.sign_ref.shape[1] final_layer_size = sign1.info_h5["V_tmp"][1] params["layers_sizes"] = [int((input_size + final_layer_size) / 2), final_layer_size] self.params = params def _fit(self, triplets): params = self.params # Get signatures V = self.sign_ref[:] dest_h5 = os.path.join(self.model_path, self.name + "_eval.h5") # generate traintest file triplet_sampler = PrecomputedTripletSampler( None, self.sign_ref, dest_h5) triplet_sampler.generate_triplets( V, self.sign_ref.keys, triplets, mean_center_x=True, shuffle=True, split_names=['train_train', 'train_test', 'test_test'], split_fractions=[.8, .1, .1]) # train siamese network model_dir = os.path.join(self.model_path, self.name + "_eval") params["traintest_file"] = dest_h5 mod = SiameseTriplets(model_dir, evaluate=True, plot=False, **params) mod.fit() # generate traintest file for the final model dest_h5 = os.path.join(self.model_path, self.name + ".h5") triplet_sampler = PrecomputedTripletSampler( None, self.sign_ref, dest_h5) NeighborTripletTraintest.precomputed_triplets( V, triplets, dest_h5, mean_center_x=True, shuffle=True, split_names=['train_train'], split_fractions=[1.]) # train siamese network model_dir = os.path.join(self.model_path, self.name) params["traintest_file"] = dest_h5 params["learning_rate"] = mod.learning_rate mod = SiameseTriplets(model_dir, evaluate=False, plot=False, **params) mod.fit() # Save and predict self.model_dir = model_dir self._predict(self.sign_ref, find_scale=True) self._predict(self.sign, find_scale=False) self.save() def _predict(self, sign1, find_scale=False): mod = SiameseTriplets(self.model_dir) V = mod.predict(sign1[:]) if find_scale: p25 = np.percentile(V.ravel(), 25) p75 = np.percentile(V.ravel(), 75) with open(os.path.join(self.model_dir, "ml_percs.pkl"), "wb") as f: pickle.dump((p25, p75), f) else: with open(os.path.join(self.model_dir, "ml_percs.pkl"), "rb") as f: p25, p75 = pickle.load(f) m = 2. / (p75 - p25) n = 1 - m * p75 V = m * V + n self.overwrite(sign1=sign1, V=V, keys=sign1.keys)
[docs]@logged class UnsupervisedMetricLearn(MetricLearn): def __init__(self, sign1, tmp): MetricLearn.__init__(self, sign1, tmp, "unsupml", max_keys=None) def fit(self): self.sign_ref.neighbors(tmp=True) self.__log.debug("Now getting triplets") triplets = self.sign_ref.get_self_triplets(True) self.__log.debug("Doing unsupervised metric learning") self._fit(triplets) def predict(self, sign1): self._predict(sign1, find_scale=False)
[docs]@logged class SemiSupervisedMetricLearn(MetricLearn): def __init__(self, sign1, tmp): MetricLearn.__init__(self, sign1, tmp, "semiml", max_keys=None) def fit(self): self.__log.debug("Getting precalculated triplets throughout the CC") triplets = self.sign_ref.get_triplets(True) self._fit(triplets) def predict(self, sign1): self._predict(sign1, find_scale=False)