Source code for chemicalchecker.core.neig

"""Nearest Neighbor Signature.

Identify nearest neighbors and distances.
"""
import os
import h5py
import datetime
import numpy as np
from tqdm import tqdm
from numpy import linalg as LA
from bisect import bisect_left
from scipy.spatial.distance import euclidean, cosine

from .signature_base import BaseSignature
from .signature_data import DataSignature


from chemicalchecker.util import logged
from chemicalchecker.util.decorator import cached_property


[docs]@logged class neig(BaseSignature, DataSignature): """Neighbors Signature class.""" def __init__(self, signature_path, dataset, **params): """Initialize a Signature. Args: signature_path(str): the path to the signature directory. metric(str): The metric used in the KNN algorithm: euclidean or cosine (default: cosine) k_neig(int): The number of k neighbours to search for (default:1000) cpu(int): The number of cores to use (default:1) chunk(int): The size of the chunk to read the data (default:1000) """ # Calling init on the base class to trigger file existance checks BaseSignature.__init__( self, signature_path, dataset, **params) self.__log.debug('signature path is: %s', signature_path) self.data_path = os.path.join(signature_path, "neig.h5") self.__log.debug('data_path: %s', self.data_path) DataSignature.__init__(self, self.data_path, ds_data='distances', keys_name='row_keys') self.metric = "cosine" self.cpu = 1 self.chunk = 1000 self.k_neig = 1000 self.norms_file = os.path.join(self.model_path, "norms.h5") self.index_filename = os.path.join(self.model_path, 'faiss_neig.index') for param, value in params.items(): self.__log.debug('parameter %s : %s', param, value) if "metric" in params: self.metric = params["metric"] if "cpu" in params: self.cpu = params["cpu"] if "k_neig" in params: self.k_neig = params["k_neig"] if "chunk" in params: self.chunk = params["chunk"]
[docs] def fit(self, sign=None): """Fit neighbor model given a signature.""" try: import faiss except ImportError: raise ImportError("requires faiss " + "https://github.com/facebookresearch/faiss") # signature specific checks if self.molset != "reference": self.__log.debug("Fit will be done with the reference neig1") self = self.get_molset("reference") if sign is None: sign = self.get_sign( 'sign' + self.cctype[-1]).get_molset("reference") if sign.molset != "reference": self.__log.debug("Fit will be done with the reference sign") sign = self.get_sign( 'sign' + self.cctype[-1]).get_molset("reference") if not sign.is_fit(): raise Exception("sign is not fitted.") faiss.omp_set_num_threads(self.cpu) if os.path.isfile(sign.data_path): with h5py.File(sign.data_path, 'r') as dh5, h5py.File(self.data_path, 'w') as dh5out: if "keys" not in dh5.keys() or "V" not in dh5.keys(): raise Exception( "H5 file " + sign.data_path + " does not contain datasets 'keys' and 'V'") self.datasize = dh5["V"].shape self.data_type = dh5["V"].dtype k = min(self.datasize[0], self.k_neig) dh5out.create_dataset("row_keys", data=dh5["keys"][:]) dh5out["col_keys"] = h5py.SoftLink('/row_keys') dh5out.create_dataset( "indices", (self.datasize[0], k), dtype=np.int32) dh5out.create_dataset( "distances", (self.datasize[0], k), dtype=np.float32) dh5out.create_dataset("shape", data=(self.datasize[0], k)) dh5out.create_dataset( "date", data=[datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S").encode(encoding='UTF-8', errors='strict')]) dh5out.create_dataset( "metric", data=[self.metric.encode(encoding='UTF-8', errors='strict')]) if self.metric == "euclidean": index = faiss.IndexFlatL2(self.datasize[1]) else: index = faiss.IndexFlatIP(self.datasize[1]) for chunk in sign.chunker(): data_temp = np.array(dh5["V"][chunk], dtype=np.float32) if self.metric == "cosine": normst = LA.norm(data_temp, axis=1) index.add(data_temp / normst[:, None]) else: index.add(data_temp) for chunk in sign.chunker(): data_temp = np.array(dh5["V"][chunk], dtype=np.float32) if self.metric == "cosine": normst = LA.norm(data_temp, axis=1) Dt, It = index.search(data_temp / normst[:, None], k) else: Dt, It = index.search(data_temp, k) dh5out["indices"][chunk] = It if self.metric == "cosine": dh5out["distances"][chunk] = np.maximum(0.0, 1.0 - Dt) else: dh5out["distances"][chunk] = Dt else: raise Exception("The file " + sign.data_path + " does not exist") faiss.write_index(index, self.index_filename) # also predict for full if available sign_full = self.get_sign('sign' + self.cctype[-1]).get_molset("full") if os.path.isfile(sign_full.data_path): self.predict(sign_full, self.get_molset("full").data_path) self.mark_ready()
[docs] def predict(self, sign, destination=None, validations=False): """Use the fitted models to go from input to output.""" try: import faiss except ImportError: raise ImportError("requires faiss " + "https://github.com/facebookresearch/faiss") if destination is None: raise Exception("There is no destination file specified") faiss.omp_set_num_threads(self.cpu) if os.path.isfile(sign.data_path): with h5py.File(sign.data_path, 'r') as dh5, h5py.File(destination, 'w') as dh5out: if "keys" not in dh5.keys() or "V" not in dh5.keys(): raise Exception( "H5 file " + sign.data_path + " does not contain datasets 'keys' and 'V'") self.datasize = dh5["V"].shape self.data_type = dh5["V"].dtype index = faiss.read_index(self.index_filename) k = min(self.k_neig, index.ntotal) dh5out.create_dataset("row_keys", data=dh5["keys"][:]) with h5py.File(self.data_path, 'r') as hr5: dh5out.create_dataset("col_keys", data=hr5["row_keys"][:]) dh5out.create_dataset( "indices", (self.datasize[0], k), dtype=np.int32) dh5out.create_dataset( "distances", (self.datasize[0], k), dtype=np.float32) dh5out.create_dataset("shape", data=(self.datasize[0], k)) dh5out.create_dataset( "date", data=[datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S").encode(encoding='UTF-8', errors='strict')]) dh5out.create_dataset( "metric", data=[self.metric.encode(encoding='UTF-8', errors='strict')]) for chunk in sign.chunker(): data_temp = np.array(dh5["V"][chunk], dtype=np.float32) if self.metric == "cosine": normst = LA.norm(data_temp, axis=1) Dt, It = index.search(data_temp / normst[:, None], k) else: Dt, It = index.search(data_temp, k) dh5out["indices"][chunk] = It if self.metric == "cosine": dh5out["distances"][chunk] = np.maximum(0.0, 1.0 - Dt) else: dh5out["distances"][chunk] = Dt else: raise Exception("The file " + sign.data_path + " does not exist")
@cached_property def keys(self): """Get the list of keys (usually inchikeys) in the signature.""" self._check_data() self._check_dataset('row_keys') return self._get_all('row_keys') @cached_property def unique_keys(self): """Get the keys of the signature as a set.""" return set(self.keys)
[docs] def __getitem__(self, key): """Return the neighbours corresponding to the key. The key can be a string (then it's mapped though self.keys) or and int. Works fast with bisect, but should return None if the key is not in keys (ideally, keep a set to do this). Returns: dict with keys: 1. 'indices' the indices of neighbors 2. 'keys' the inchikey of neighbors 3. 'distances' the cosine distances. """ predictions = dict() if not os.path.isfile(self.data_path): raise Exception("Data file not available.") if isinstance(key, slice): with h5py.File(self.data_path, 'r') as hf: predictions["indices"] = hf['indices'][key] predictions["distances"] = hf['distances'][key] keys = hf['col_keys'][:].astype(str) predictions["keys"] = keys[predictions["indices"]] elif isinstance(key, str): if key not in self.unique_keys: raise Exception("Key '%s' not found." % key) idx = bisect_left(self.keys, key) with h5py.File(self.data_path, 'r') as hf: predictions["indices"] = hf['indices'][idx] predictions["distances"] = hf['distances'][idx] keys = hf['col_keys'][:].astype(str) predictions["keys"] = keys[predictions["indices"]] elif isinstance(key, int): with h5py.File(self.data_path, 'r') as hf: predictions["indices"] = hf['indices'][key] predictions["distances"] = hf['distances'][key] keys = hf['col_keys'][:].astype(str) predictions["keys"] = keys[predictions["indices"]] else: raise Exception("Key type %s not recognized." % type(key)) return predictions
[docs] def get_vectors(self, keys, include_nan=False, dataset_name='indices'): """Get vectors for a list of keys, sorted by default. Args: keys(list): a List of string, only the overlapping subset to the signature keys is considered. include_nan(bool): whether to include requested but absent molecule signatures as NaNs. dataset_name(str): return any dataset in the h5 which is organized by sorted keys. """ self.__log.debug("Fetching Neig %s rows from dataset %s" % (len(keys), dataset_name)) valid_keys = list(set(self.row_keys) & set(keys)) idxs = np.argwhere( np.isin(list(self.row_keys), list(valid_keys), assume_unique=True)) inks, signs = list(), list() with h5py.File(self.data_path, 'r') as hf: dset = hf[dataset_name] col_keys = hf['col_keys'][:].astype(str) dset_shape = dset.shape for idx in sorted(idxs.flatten()): inks.append(self.row_keys[idx].astype(str)) if dataset_name == 'indices': signs.append(col_keys[dset[idx]]) else: signs.append(dset[idx]) missed_inks = set(keys) - set(inks) # if missing signatures are requested add NaNs if include_nan: inks.extend(list(missed_inks)) dimensions = (len(missed_inks), dset_shape[1]) nan_matrix = np.zeros(dimensions) * np.nan signs.append(nan_matrix) if missed_inks: self.__log.info("NaN for %s requested keys as are not available.", len(missed_inks)) elif missed_inks: self.__log.warn("Following %s requested keys are not available:", len(missed_inks)) self.__log.warn(" ".join(list(missed_inks)[:10]) + "...") if len(inks) == 0: self.__log.warn("No requested keys available!") return None, None inks, signs = np.stack(inks), np.vstack(signs) sort_idx = np.argsort(inks) return inks[sort_idx], signs[sort_idx]
[docs] def get_kth_nearest(self, signatures, k=1000, distances=True, keys=True): """Return up to the k-th nearest neighbor. This function returns the k-th closest neighbor. A k>1 is useful when we expect and want to exclude a perfect match, i.e. when the signature we query for are the same that have been used to generate the neighbors. Args: signatures(array): Matrix or list of signatures for which we want to find neighbors. k(int): Amount of neigbors to find, if None return the maximum possible. Returns: dict with keys: 1. 'indices' the indices of neighbors 2. 'keys' the inchikey of neighbors 3. 'distances' the cosine distances. """ try: import faiss except ImportError: raise ImportError("requires faiss " + "https://github.com/facebookresearch/faiss") self.__log.info("Reading index file") with h5py.File(self.data_path, "r") as hw: metric_orig = hw["metric"][0] if type(hw["metric"][0]) != str: metric_orig = metric_orig.decode() # open faiss model faiss.omp_set_num_threads(self.cpu) index = faiss.read_index(self.index_filename) # decide K max_k = index.ntotal if k is None: k = max_k if k > max_k: self.__log.warning("Maximum k is %s.", max_k) k = max_k # convert signatures to float32 as faiss is very picky data = np.array(signatures, dtype=np.float32) self.__log.info("Searching %s neighbors" % k) # get neighbors idx and distances if "cosine" in metric_orig: normst = LA.norm(data, axis=1) dists, idx = index.search(data / normst[:, None], k) else: dists, idx = index.search(data, k) predictions = dict() predictions["indices"] = idx if keys: with h5py.File(self.data_path, 'r') as hf: keys = hf['col_keys'][:] predictions["keys"] = keys[idx].astype(str) if distances: predictions["distances"] = dists if metric_orig == "cosine": predictions["distances"] = np.maximum( 0.0, 1.0 - predictions["distances"]) return predictions
def check_distances(self, n_sign=5, n_neig=10): sign = self.get_sign('sign' + self.cctype[-1]) dist_fn = eval(self.metric) for ink1 in tqdm(sign.keys[:n_sign], desc='Checking distances'): s1 = sign[ink1] nn = self[ink1] inks = nn['keys'][:n_neig] dists = nn['distances'][:n_neig] for ink2, dist in zip(inks, dists): ink2 = ink2.decode() s2 = sign[ink2] comp_d = dist_fn(s1, s2) if not np.allclose(dist, comp_d, atol=1e-05): self.__log.error('%s %s %.6f %.6f', ink1, ink2, dist, comp_d) np.testing.assert_allclose(dist, comp_d, atol=1e-05)
[docs] @staticmethod def jaccard_similarity(n1, n2): """Compute Jaccard similarity. Args: n1(np.array): First set of neighbors, row are molecule each column the idx of a neighbor n1(np.array): Second set of neighbors, row are molecule each column the idx of a neighbor """ res = list() for r1, r2 in zip(n1, n2): s1 = set(r1) s2 = set(r2) inter = len(set.intersection(s1, s2)) uni = len(set.union(s1, s2)) res.append(inter / float(uni)) return np.array(res)
[docs] def __iter__(self): """Iterate on neighbours indeces and distances.""" if not os.path.isfile(self.data_path): raise Exception("Data file %s not available." % self.data_path) with h5py.File(self.data_path, 'r') as hf: for i in range(self.shape[0]): yield hf['indices'][i], hf['distances'][i]
@property def shape(self): """Get the V matrix sizes.""" if not os.path.isfile(self.data_path): raise Exception("Data file %s not available." % self.data_path) with h5py.File(self.data_path, 'r') as hf: if 'shape' not in hf.keys(): self.__log.warn("HDF5 file has no 'shape' dataset.") return hf['distances'].shape return hf['shape'][:]