Source code for chemicalchecker.core.clus

"""Cluster Signature.

Performs K-means clustering.
"""
import os
import csv
import glob
import h5py
import json
import bisect
import shelve
import joblib
import tempfile
import datetime
import numpy as np
from csvsort import csvsort
from sklearn.preprocessing import Normalizer
from scipy.spatial.distance import euclidean, pdist

from .signature_base import BaseSignature
from .signature_data import DataSignature

from chemicalchecker.util import logged
from chemicalchecker.util import Config
from chemicalchecker.util.plot import Plot


[docs]@logged class clus(BaseSignature, DataSignature): """Cluster Signature class.""" def __init__(self, signature_path, dataset, **params): """Initialize a Signature. Args: signature_path(str): the path to the signature directory. dataset(object): The dataset object with all info related type(str): The type of clustering between kmeans and hdbscan. (default:kmeans) metric(str): The metric used in the KNN algorithm: euclidean or cosine (default: cosine) k_neig(int): The number of k neighbours to search for (default:None) cpu(int): The number of cores to use (default:1) min_members(int): Minimum number of points per cluster (hdbscan) (default:5) num_subdim(int): Splitting of the PQ encoder (kmeans) (default:8) min_k(int): Minimum number of clusters (kmeans)(default:1) max_k(int): Maximum number of clusters (kmeans) (default: None) n_points(int): Number of points to calculate (kmeans) (default:100) balance(float): If 1, all clusters are of equal size. Greater values are increasingly more imbalanced (kmeans) (default:None) significance(float): Distance significance cutoff (kmeans) (default:0.05) """ # Calling init on the base class to trigger file existance checks BaseSignature.__init__(self, signature_path, dataset, **params) self.__log.debug('signature path is: %s', signature_path) self.data_path = os.path.join(signature_path, "clus.h5") DataSignature.__init__(self, self.data_path) self.__log.debug('data_path: %s', self.data_path) self.clustencoder_file = "clustencoder.h5" self.clustcentroids_file = "clustcentroids.h5" self.clust_info_file = "clust_stats.json" self.clust_output = 'clust.h5' self.bg_pq_euclideans_file = "bg_pq_euclideans.h5" self.hdbscan_file = "hdbscan.pkl" self.type = "kmeans" self.cpu = 1 self.k_neig = None self.min_members = 5 self.num_subdim = 8 self.min_k = 1 self.max_k = None self.n_points = 100 self.balance = None self.significance = 0.05 self.metric = "euclidean" for param, value in params.items(): self.__log.debug('parameter %s : %s', param, value) if "metric" in params: self.metric = params["metric"] if "cpu" in params: self.cpu = params["cpu"] if "k_neig" in params: self.k_neig = params["k_neig"] if "min_members" in params: self.min_members = params["min_members"] if "num_subdim" in params: self.num_subdim = params["num_subdim"] if "min_k" in params: self.min_k = params["min_k"] if "max_k" in params: self.max_k = params["max_k"] if "n_points" in params: self.n_points = params["n_points"] if "balance" in params: self.balance = params["balance"] if "significance" in params: self.significance = params["significance"] if "type" in params: self.type = params["type"]
[docs] def fit(self, sign=None, validations=True): """Fit cluster model given a signature.""" try: import faiss except ImportError: raise ImportError("requires faiss " + "https://github.com/facebookresearch/faiss") try: import hdbscan except ImportError: raise ImportError("requires hdbscan " + "https://hdbscan.readthedocs.io/en/latest/") BaseSignature.fit(self) plot = Plot(self.dataset, self.stats_path) mappings = None if sign is None: sign = self.get_sign( 'sign' + self.cctype[-1]).get_molset("reference") if os.path.isfile(sign.data_path): self.data = sign.data.astype(np.float32) self.data_type = self.data.dtype self.keys = sign.keys mappings = sign.mappings else: raise Exception("The file " + sign.data_path + " does not exist") tmp_dir = tempfile.mkdtemp( prefix='clus_' + self.dataset + "_", dir=Config().PATH.CC_TMP) self.__log.debug("Temporary files saved in " + tmp_dir) if self.type == "hdbscan": self.__log.info("Calculating HDBSCAN clusters") clusterer = hdbscan.HDBSCAN(min_cluster_size=int( np.max([2, self.min_members])), prediction_data=True).fit(self.data) self.__log.info("Saving the model") joblib.dump(clusterer, os.path.join( self.model_path, self.hdbscan_file)) self.__log.info("Predicting...") labels, strengths = hdbscan.approximate_predict( clusterer, self.data) # Save self.__log.info("Saving matrix...") with h5py.File(self.data_path, "w") as hf: hf.create_dataset("labels", data=labels) hf.create_dataset("keys", data=np.array( self.keys, DataSignature.string_dtype())) hf.create_dataset("strengths", data=strengths) if validations: self.__log.info("Doing validations") if mappings is not None: inchikey_mappings = dict(mappings) else: inchikey_mappings = None inchikey_clust = shelve.open( os.path.join(tmp_dir, "clus.dict"), "n") for i in range(len(self.keys)): lab = labels[i] if lab == -1: continue inchikey_clust[str(self.keys[i])] = lab odds_moa, pval_moa = plot.label_validation( inchikey_clust, "clus", prefix="moa", inchikey_mappings=inchikey_mappings) odds_atc, pval_atc = plot.label_validation( inchikey_clust, "clus", prefix="atc", inchikey_mappings=inchikey_mappings) inchikey_clust.close() self.__log.info("Cleaning") for filename in glob.glob(os.path.join(tmp_dir, "clus.dict*")): os.remove(filename) os.rmdir(tmp_dir) self.metric = "hdbscan" if self.type == "kmeans": faiss.omp_set_num_threads(self.cpu) with h5py.File(sign.data_path, 'r') as dh5: if "elbow" not in dh5.keys(): Vn, Vm = self.data.shape[0], self.data.shape[1] / 2 else: Vn, Vm = self.data.shape[0], dh5["elbow"][0] if self.metric == "cosine": self.data = self._normalizer(self.data, False) if self.data.shape[1] < self.num_subdim: self.data = np.hstack( (self.data, np.zeros((self.data.shape[0], self.num_subdim - self.data.shape[1])))) self.data = self.data.astype(np.float32) self.__log.info("Calculating k...") # Do reference distributions for the gap statistic if not self.max_k: self.max_k = int(np.sqrt(self.data.shape[0])) if self.k_neig is None: cluster_range = np.arange(self.min_k, self.max_k, step=np.max( [int((self.max_k - self.min_k) / self.n_points), 1])) inertias = [] disps = [] bg_distances = sign.background_distances(self.metric) pvals = bg_distances["pvalue"] distance = bg_distances["distance"] sig_dist = distance[ bisect.bisect_left(pvals, self.significance)] for k in cluster_range: niter = 20 d = self.data.shape[1] kmeans = faiss.Kmeans(int(d), int(k), niter=niter) kmeans.train(self.data) D, labels = kmeans.index.search(self.data, 1) inertias += [self._inertia(self.data, labels, kmeans.centroids)] disps += [self._dispersion(kmeans.centroids, sig_dist, self.metric)] disps[0] = disps[1] # Smooting, monotonizing, and combining the scores Ncs = np.arange(self.min_k, self.max_k) D = self._minmaxscaler(np.interp(Ncs, cluster_range, self._smooth( self._monotonize(np.array(disps), True), self.max_k))) I = self._minmaxscaler(np.interp(Ncs, cluster_range, self._smooth( self._monotonize(np.array(inertias), False), self.max_k))) alpha = Vm / (Vm + np.sqrt(Vn / 2.)) S = np.abs((I**(1 - alpha)) - (D**(alpha))) S = self._minmaxscaler(-self._smooth(S, self.max_k)) k = plot.clustering_plot(Ncs, I, D, S) else: k = self.k_neig self.__log.info("Clustering with k = %d" % k) niter = 20 d = self.data.shape[1] kmeans = faiss.Kmeans(int(d), int(k), niter=niter) kmeans.train(self.data) D, labels = kmeans.index.search(self.data, 1) centroids = kmeans.centroids with h5py.File(os.path.join(self.model_path, self.clustcentroids_file), "w") as hf: hf.create_dataset("centroids", data=centroids) self.__log.info("Balancing...") labels = self._get_balance(self.data, centroids, labels, self.balance, k, tmp_dir) self.__log.info("Saving matrix...") with h5py.File(self.data_path, "w") as hf: hf.create_dataset("labels", data=labels) hf.create_dataset("keys", data=np.array( self.keys, DataSignature.string_dtype())) hf.create_dataset("V", data=self.data) if validations: # MOA validation self.__log.info("Doing validations") if mappings is not None: inchikey_mappings = dict(mappings) else: inchikey_mappings = None inchikey_clust = shelve.open( os.path.join(tmp_dir, "clus.dict"), "n") for i in range(len(self.keys)): inchikey_clust[str(self.keys[i])] = labels[i] odds_moa, pval_moa = plot.label_validation( inchikey_clust, "clus", prefix="moa", inchikey_mappings=inchikey_mappings) odds_atc, pval_atc = plot.label_validation( inchikey_clust, "clus", prefix="atc", inchikey_mappings=inchikey_mappings) inchikey_clust.close() self.__log.info("Cleaning") for filename in glob.glob(os.path.join(tmp_dir, "clus.dict*")): os.remove(filename) os.rmdir(tmp_dir) faiss.write_index(kmeans.index, os.path.join( self.model_path, "kmeans.index")) if self.k_neig is None and validations: self.__log.info("Saving info") INFO = { "k": int(k), "odds_moa": odds_moa, "pval_moa": pval_moa, "odds_atc": odds_atc, "pval_atc": pval_atc } with open(os.path.join(self.model_path, self.clust_info_file), 'w') as fp: json.dump(INFO, fp) with h5py.File(self.data_path, "a") as hf: name = str(self.dataset) + "_clus" hf.create_dataset( "name", data=[name.encode(encoding='UTF-8', errors='strict')]) hf.create_dataset( "date", data=[datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S").encode(encoding='UTF-8', errors='strict')]) hf.create_dataset("metric", data=[self.metric.encode( encoding='UTF-8', errors='strict')]) hf.create_dataset("normed", data=[False]) hf.create_dataset("integerized", data=[False]) hf.create_dataset("principal_components", data=[False]) if mappings is not None: hf.create_dataset("mappings", data=np.array( mappings, DataSignature.string_dtype())) # also predict for full if available sign_full = self.get_sign('sign' + self.cctype[-1]).get_molset("full") if os.path.isfile(sign_full.data_path): self.predict(sign_full, self.get_molset("full").data_path) self.mark_ready()
[docs] def predict(self, sign, destination=None, validations=False): """Use the fitted models to go from input to output.""" try: import faiss except ImportError: raise ImportError("requires faiss " + "https://github.com/facebookresearch/faiss") try: import hdbscan except ImportError: raise ImportError("requires hdbscan " + "https://hdbscan.readthedocs.io/en/latest/") plot = Plot(self.dataset, self.stats_path) mappings = None if os.path.isfile(sign.data_path): self.data = sign.data.astype(np.float32) self.data_type = self.data.dtype self.keys = sign.keys mappings = sign.mappings else: raise Exception("The file " + sign.data_path + " does not exist") if destination is None: raise Exception( "Predict method requires a destination file to output results") tmp_dir = tempfile.mkdtemp( prefix='sign_' + self.dataset + "_", dir=Config().PATH.CC_TMP) self.__log.debug("Temporary files saved in " + tmp_dir) if self.type == "hdbscan": self.__log.info("Reading HDBSCAN clusters") clusterer = joblib.load(os.path.join( self.model_path, self.hdbscan_file)) self.__log.info("Predicting...") labels, strengths = hdbscan.approximate_predict( clusterer, self.data) # Save self.__log.info("Saving matrix...") with h5py.File(destination, "w") as hf: hf.create_dataset("labels", data=labels) hf.create_dataset("keys", data=np.array( self.keys, DataSignature.string_dtype())) hf.create_dataset("strengths", data=strengths) if validations: self.__log.info("Doing validations") if mappings is not None: inchikey_mappings = dict(mappings) else: inchikey_mappings = None inchikey_clust = shelve.open( os.path.join(tmp_dir, "clus.dict"), "n") for i in range(len(self.keys)): lab = labels[i] if lab == -1: continue inchikey_clust[str(self.keys[i])] = lab odds_moa, pval_moa = plot.label_validation( inchikey_clust, "clus", prefix="moa", inchikey_mappings=inchikey_mappings) odds_atc, pval_atc = plot.label_validation( inchikey_clust, "clus", prefix="atc", inchikey_mappings=inchikey_mappings) inchikey_clust.close() self.__log.info("Cleaning") for filename in glob.glob(os.path.join(tmp_dir, "clus.dict*")): os.remove(filename) os.rmdir(tmp_dir) with h5py.File(destination, "a") as hf: name = str(self.dataset) + "_clus" hf.create_dataset( "name", data=[name.encode(encoding='UTF-8', errors='strict')]) hf.create_dataset( "date", data=[datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S").encode(encoding='UTF-8', errors='strict')]) hf.create_dataset("metric", data=["hdbscan"]) hf.create_dataset("normed", data=[False]) hf.create_dataset("integerized", data=[False]) hf.create_dataset("principal_components", data=[False]) if mappings is not None: hf.create_dataset("mappings", data=np.array( mappings, DataSignature.string_dtype())) if self.type == "kmeans": faiss.omp_set_num_threads(self.cpu) if not os.path.isfile(os.path.join(self.model_path, "kmeans.index")): raise Exception( "There is not cluster info. Please run fit method.") if self.metric == "cosine": self.data = self._normalizer(self.data, True) if self.data.shape[1] < self.num_subdim: self.data = np.hstack( (self.data, np.zeros((self.data.shape[0], self.num_subdim - self.data.shape[1])))) self.data = self.data.astype(np.float32) index = faiss.read_index(os.path.join( self.model_path, "kmeans.index")) D, labels = index.search(self.data, 1) self.__log.info("Saving matrix...") with h5py.File(destination, "w") as hf: hf.create_dataset("labels", data=labels) hf.create_dataset("keys", data=np.array( self.keys, DataSignature.string_dtype())) hf.create_dataset("V", data=self.data) if validations: # MOA validation self.__log.info("Doing validations") if mappings is not None: inchikey_mappings = dict(mappings) else: inchikey_mappings = None inchikey_clust = shelve.open( os.path.join(tmp_dir, "clus.dict"), "n") for i in range(len(self.keys)): inchikey_clust[str(self.keys[i])] = labels[i] odds_moa, pval_moa = plot.label_validation( inchikey_clust, "clus", prefix="moa", inchikey_mappings=inchikey_mappings) odds_atc, pval_atc = plot.label_validation( inchikey_clust, "clus", prefix="atc", inchikey_mappings=inchikey_mappings) inchikey_clust.close() self.__log.info("Cleaning") for filename in glob.glob(os.path.join(tmp_dir, "clus.dict*")): os.remove(filename) os.rmdir(tmp_dir) with h5py.File(destination, "a") as hf: name = str(self.dataset) + "_clus" hf.create_dataset( "name", data=[name.encode(encoding='UTF-8', errors='strict')]) hf.create_dataset( "date", data=[datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S").encode(encoding='UTF-8', errors='strict')]) hf.create_dataset("metric", data=[self.metric.encode( encoding='UTF-8', errors='strict')]) hf.create_dataset("normed", data=[False]) hf.create_dataset("integerized", data=[False]) hf.create_dataset("principal_components", data=[False]) if mappings is not None: hf.create_dataset("mappings", data=np.array( mappings, DataSignature.string_dtype()))
def _smooth(self, x, max_k, window_len=None, window='hanning'): if window_len is None: window_len = int(max_k / 10) + 1 if window_len % 2 == 0: window_len += 1 if x.size <= window_len: self.__log.warning( "Input vector was smaller or equal than window size.") window_len = x.size - 1 if window_len % 2 == 0: window_len += 1 if window_len < 3: return x if window not in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']: raise ValueError( "Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'") s = np.r_[x[window_len - 1:0:-1], x, x[-2:-window_len - 1:-1]] if window == 'flat': # moving average w = np.ones(window_len, 'd') else: w = eval('np.' + window + '(window_len)') n = int((window_len - 1) / 2) y = np.convolve(w / w.sum(), s, mode='valid') return y[n:-n] def _inertia(self, V_pqcode, labels, centroids): ines = 0 for i in range(V_pqcode.shape[0]): ines += euclidean(V_pqcode[i], centroids[labels[i]]) return ines def _dispersion(self, centroids, sig_dist, metric): if len(centroids) == 1: return None return np.sum(pdist(centroids, metric=metric) < sig_dist) def _monotonize(self, v, up=True): if up: return np.mean(np.array([np.maximum.accumulate(v), np.minimum.accumulate(v[::-1])[::-1]]), axis=0) else: return np.mean(np.array([np.minimum.accumulate(v), np.maximum.accumulate(v[::-1])[::-1]]), axis=0) def _minmaxscaler(self, v): v = np.array(v) Min = np.min(v) Max = np.max(v) return (v - Min) / (Max - Min) def _get_balance(self, V_pqcode, centroids, labels, balance, k, tmp): if balance is None: return labels if balance < 1: self.__log.info( "Balance is smaller than 1. I don't understand. Anyway, I just don't balance.") return labels S = np.ceil((V_pqcode.shape[0] / k) * balance) clusts = [None] * V_pqcode.shape[0] counts = [0] * k tmpfile = os.path.join(tmp, "clus_dists.csv") with open(tmpfile, "w") as f: for i, v in enumerate(V_pqcode): for j, c in enumerate(centroids): d = euclidean(c, v) f.write("%d,%d,%010d\n" % (i, j, d)) csvsort(tmpfile, [2], has_header=False) with open(tmpfile, "r") as f: for r in csv.reader(f): item_id = int(r[0]) cluster_id = int(r[1]) if counts[cluster_id] >= S: continue if clusts[item_id] is None: clusts[item_id] = cluster_id counts[cluster_id] += 1 os.remove(tmpfile) return clusts def _normalizer(self, V, recycle): FILE = self.model_path + "/normalizer.pkl" if not recycle or not os.path.exists(FILE): nlz = Normalizer(copy=True, norm="l2") V = nlz.fit_transform(V) joblib.dump(nlz, FILE) else: nlz = joblib.load(FILE) V = nlz.transform(V) return V.astype(np.float32)