"""Signature type 0.
A sufficiently-processed version of the raw data. Each bioactive space has
a peculiar format which might be categorical, discrete or continuous.
They usually show explicit knowledge, which enables connectivity and
interpretation.
"""
import os
import h5py
import datetime
import collections
import numpy as np
from .signature_data import DataSignature
from .signature_base import BaseSignature
from .preprocess import Preprocess
from chemicalchecker.util import logged
from chemicalchecker.util.sanitize import Sanitizer
from chemicalchecker.util.aggregate import Aggregate
from chemicalchecker.util.decorator import cached_property
from chemicalchecker.util.sampler.triplets import TripletSampler
[docs]@logged
class sign0(BaseSignature, DataSignature):
"""Signature type 0 class."""
def __init__(self, signature_path, dataset, **params):
"""Initialize a Signature.
Args:
signature_path (str): the path to the signature directory.
dataset (str): NS ex A1.001, here only serves as the 'name' record
of the h5 file.
"""
BaseSignature.__init__(self, signature_path, dataset, **params)
self.data_path = os.path.join(self.signature_path, "sign0.h5")
DataSignature.__init__(self, self.data_path, **params)
[docs] def process_keys(self, keys, key_type, sort=False):
"""Given keys, and key type validate them.
If None is specified, then all keys are kept, and no validation is
performed.
Returns:
keys(list): the processed InChIKeys
ray_keys(list): raw input keys
indices (list): index of valid keys
"""
if key_type is None:
return np.array(keys), np.array(keys), \
np.array([i for i in range(0, len(keys))])
keys_ = list()
keys_raw = list()
idxs = list()
if key_type.lower() == "inchikey":
self.__log.debug("Validating InChIKeys, only valid ones are kept.")
for i, k in enumerate(keys):
if isinstance(k, bytes):
k = k.decode()
if len(k) == 27:
if k[14] == "-" and k[25] == "-":
keys_.append(k)
keys_raw.append(k)
idxs.append(i)
else:
self.__log.debug(
"skipping key '%s' for format (line %s)" % (k, i))
else:
self.__log.debug(
"skipping key '%s' for format (line %s)" % (k, i))
elif key_type.lower() == "smiles":
self.__log.debug(
"Validating SMILES, only valid ones are kept.")
from chemicalchecker.util.parser import Converter
conv = Converter()
for i, k in enumerate(keys):
if isinstance(k, bytes):
k = k.decode()
try:
keys_.append(conv.smiles_to_inchi(k)[0])
keys_raw.append(k)
idxs.append(i)
except Exception as ex:
self.__log.warning('Problem in conversion: %s' % str(ex))
continue
else:
raise "key_type must be 'inchikey' or 'smiles'"
self.__log.info("Initial keys: %d / Final keys: %d" %
(len(keys), len(keys_)))
# perform sorting
keys_ = np.array(keys_)
keys_raw = np.array(keys_raw)
idxs = np.array(idxs)
if sort:
order = np.argsort(keys_)
keys_ = keys_[order]
keys_raw = keys_raw[order]
idxs = idxs[order]
return keys_, keys_raw, idxs
[docs] def process_features(self, features, n):
"""Define feature names.
Process features. Give an arbitrary name to features if not provided.
Returns the feature names as a numpy array of strings.
"""
if features is None:
self.__log.debug(
"No features were provided, giving arbitrary names")
digits = int(np.log10(n)) + 1
features = []
for i in range(0, n):
s = "%d" % i
s = s.zfill(digits)
features += ["feature_%s" % s]
return np.array(features).astype(str)
[docs] def get_data(self, pairs, X, keys, features, data_file, key_type,
agg_method):
"""Get data in the right format.
Input data for 'fit' or 'predict' can come in 2 main different format:
as matrix or as pairs. If a 'X' matrix is passed we also expect the row
identifier ('keys') and optionally column identifier ('features').
If 'pairs' (dense representation) are passed we expect a combination of
key and feature that can be associated with a value or not.
The information can be bundled in a H5 file or provided as argument.
Basic check are performed to ensure consistency of 'keys' and
'features'.
Args:
pairs(list): list of pair (key, feature) or (key, feature, value)
X(array): 2D matrix, rows corresponds to molecules and columns
corresponds to features
keys(list): list of string identifier for molecules
features(list): list of string identifier for features
data_file(str): path to a input file, at least must contain the
datasets: 'pairs' or 'X' and 'keys'
key_type(str): the type of molecule identifier used
agg_method(str): the aggregation method to use
"""
# load data from the data file
if data_file is not None:
if not os.path.isfile(data_file):
raise Exception("File not found: %s" % data_file)
dh5 = h5py.File(data_file, 'r')
# get pairs and values if available
if "pairs" in dh5.keys():
pairs = dh5["pairs"][:]
if "values" in dh5.keys():
pairs = zip(pairs, dh5["values"][:])
# get matrix
if "X" in dh5.keys():
X = dh5["X"][:]
elif "V" in dh5.keys():
X = dh5["V"][:]
# get keys and features
if "keys" in dh5.keys():
keys = dh5["keys"][:]
if "features" in dh5.keys():
features = dh5["features"][:]
if features is None and pairs is None:
features = self.process_features(features, X.shape[1])
dh5.close()
if pairs is None and X is None:
raise Exception("H5 file %s must contain datasets "
"'pairs' or 'X'" % data_file)
# handle pairs case
if pairs is not None:
self.__log.info("Input data are pairs")
if X is not None:
raise Exception(
"If you input pairs, X should not be specified!")
has_values = len(pairs[0]) != 2
self.__log.debug("Processing keys and features")
if hasattr(pairs[0][0], 'decode'):
keys = list(set([x[0].decode() for x in pairs]))
else:
keys = list(set([x[0] for x in pairs]))
if hasattr(pairs[0][1], 'decode'):
features = list(set([x[1].decode() for x in pairs]))
else:
features = list(set([x[1] for x in pairs]))
self.__log.debug("Before processing:")
self.__log.debug("KEYS example: {}".format(keys[:10]))
self.__log.debug("key_type: {}".format(key_type))
keys, keys_raw, _ = self.process_keys(keys, key_type)
features = self.process_features(features, len(features))
keys_dict = dict((k, i) for i, k in enumerate(keys_raw))
features_dict = dict((k, i) for i, k in enumerate(features))
self.__log.debug("Iterating over pairs and doing matrix")
pairs_ = collections.defaultdict(list)
if not has_values:
self.__log.debug("Binary pairs")
for p in pairs:
if not isinstance(p[0], str):
p[0] = p[0].decode()
if not isinstance(p[1], str):
p[1] = p[1].decode()
if p[0] not in keys_dict or p[1] not in features_dict:
continue
pairs_[(keys_dict[p[0]], features_dict[p[1]])] += [1]
else:
self.__log.debug("Valued pairs")
for p in pairs:
if p[0] not in keys_dict or p[1] not in features_dict:
continue
pairs_[(keys_dict[p[1]], features_dict[p[1]])] += [p[2]]
X = np.zeros((len(keys), len(features)))
self.__log.debug("Aggregating duplicates")
if agg_method == "average":
def do_agg(v):
return np.mean(v)
elif agg_method == "first":
def do_agg(v):
return v[0]
elif agg_method == "last":
def do_agg(v):
return v[-1]
for k, v in pairs_.items():
X[k[0], k[1]] = do_agg(v)
self.__log.debug("Setting input type")
input_type = "pairs"
else:
if X is None:
raise Exception(
"No data were provided! "
"X cannot be None if pairs or data_file aren't provided")
if keys is None:
raise Exception("keys cannot be None")
if features is None:
raise Exception("features cannot be None")
if X.shape[0] != len(keys):
raise Exception(
"number of rows of X must equal length of keys")
if X.shape[1] != len(features):
raise Exception(
"number of columns of X must equal length of features")
if len(features) != len(set(features)):
raise Exception("features must be unique")
self.__log.debug("Processing keys")
keys, keys_raw, idxs = self.process_keys(keys, key_type)
self.__log.debug("Processing features")
features = self.process_features(features, X.shape[1])
self.__log.debug("Only keeping idxs of relevance")
X = X[idxs]
self.__log.debug("Setting input type")
input_type = "matrix"
if X.shape[0] != len(keys):
raise Exception(
"after processing, number of rows does not equal "
"number of columns")
X, keys, keys_raw, features = self.sort(X, keys, keys_raw, features)
results = {
"X": X,
"keys": keys,
"keys_raw": keys_raw,
"features": features,
"input_type": input_type
}
return results
@cached_property
def agg_method(self):
"""Get the agg method of the signature."""
if not os.path.isfile(self.data_path):
raise Exception("Data file %s not available." % self.data_path)
with h5py.File(self.data_path, 'r') as hf:
if "agg_method" not in hf.keys():
self.__log.warn("No agg_method available for this signature!")
return None
if hasattr(hf["agg_method"][0], 'decode'):
return [k.decode() for k in hf["agg_method"][:]][0]
else:
return hf["agg_method"][0]
@cached_property
def input_type(self):
"""Get the input type done at fit time."""
if not os.path.isfile(self.data_path):
raise Exception("Data file %s not available." % self.data_path)
with h5py.File(self.data_path, 'r') as hf:
if "input_type" not in hf.keys():
self.__log.warn("No input_type available for this signature!")
return None
if hasattr(hf["input_type"][0], 'decode'):
return [k.decode() for k in hf["input_type"][:]][0]
else:
return hf["input_type"][0]
@cached_property
def key_type(self):
"""Get the key type done at fit time."""
if not os.path.isfile(self.data_path):
raise Exception("Data file %s not available." % self.data_path)
with h5py.File(self.data_path, 'r') as hf:
if "key_type" not in hf.keys():
self.__log.warn("No key_type available for this signature!")
return None
if hasattr(hf["key_type"][0], 'decode'):
return [k.decode() for k in hf["key_type"][:]][0]
else:
return hf["key_type"][0]
def refesh(self):
DataSignature.refesh()
self._refresh("key_type")
self._refresh("input_type")
self._refresh("agg_method")
def sort(self, X, keys, keys_raw, features):
self.__log.debug("Sorting")
key_idxs = np.argsort(keys)
feature_idxs = np.argsort(features)
# sort all data
X = X[key_idxs]
for i in range(0, X.shape[0], 2000):
chunk = slice(i, i + 2000)
X[chunk] = X[chunk, feature_idxs]
# sort keys
keys = keys[key_idxs]
keys_raw = keys_raw[key_idxs]
# sort features
features = features[feature_idxs]
return X, keys, keys_raw, features
[docs] def fit(self, cc_root=None, pairs=None, X=None, keys=None, features=None,
data_file=None, key_type="inchikey", agg_method="average",
do_triplets=False, sanitize=True, sanitizer_kwargs={}, **kwargs):
"""Process the input data.
We produce a sign0 (full) and a sign0 (reference).
Data are sorted (keys and features).
Args:
cc_root(str): Path to a CC instance. This is important to produce
the triplets. If None specified, the same CC where the
signature is present will be used (default=None).
pairs(array of tuples or file): Data. If file it needs to H5 file
with dataset called 'pairs'.
X(matrix or file): Data. If file it needs to H5 file with datasets
called 'X', 'keys' and maybe 'features'.
keys(array): Row names.
key_type(str): Type of key. May be inchikey or smiles
(default='inchikey').
features(array): Column names (default=None).
data_file(str): Input data file in the form of H5 file and it
should contain the required data in datasets.
do_triplets(boolean): Draw triplets from the CC (default=True).
"""
BaseSignature.fit(self, **kwargs)
self.clear()
self.update_status("Getting data")
if pairs is None and X is None and data_file is None:
self.__log.debug("Runnning preprocess")
data_file = Preprocess.preprocess(self, **kwargs)
self.__log.debug("data_file is {}".format(data_file))
res = self.get_data(pairs=pairs, X=X, keys=keys, features=features,
data_file=data_file, key_type=key_type,
agg_method=agg_method)
X = res["X"]
keys = res["keys"]
keys_raw = res["keys_raw"]
features = res["features"]
input_type = res["input_type"]
if sanitize:
self.update_status("Sanitizing")
san = Sanitizer(**sanitizer_kwargs)
X, keys, keys_raw, features = san.transform(
V=X, keys=keys, keys_raw=keys_raw, features=features,
sign=None)
self.update_status("Aggregating")
agg = Aggregate(method=agg_method, input_type=input_type)
X, keys, keys_raw = agg.transform(V=X, keys=keys, keys_raw=keys_raw)
self.update_status("Saving H5")
with h5py.File(self.data_path, "w") as hf:
hf.create_dataset("name", data=np.array(
[str(self.dataset) + "sig"], DataSignature.string_dtype()))
sdate = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
hf.create_dataset(
"date", data=np.array([sdate], DataSignature.string_dtype()))
hf.create_dataset("V", data=X)
hf.create_dataset("keys", data=np.array(
keys, DataSignature.string_dtype()))
hf.create_dataset("features", data=np.array(
features, DataSignature.string_dtype()))
hf.create_dataset("keys_raw", data=np.array(
keys_raw, DataSignature.string_dtype()))
hf.create_dataset("agg_method", data=np.array(
[str(agg_method)], DataSignature.string_dtype()))
hf.create_dataset("input_type", data=np.array(
[str(input_type)], DataSignature.string_dtype()))
self.refresh()
# save reference
self.save_reference(overwrite=True)
# Making triplets
if do_triplets:
self.update_status("Sampling triplets")
cc = self.get_cc(cc_root)
sampler = TripletSampler(cc, self, save=True)
sampler.sample(**kwargs)
# finalize signature
BaseSignature.fit_end(self, **kwargs)
[docs] def predict(self, pairs=None, X=None, keys=None, features=None,
data_file=None, key_type=None, merge=False, merge_method="new",
destination=None, chunk_size=10000):
"""Given data, produce a sign0.
Args:
pairs(array of tuples or file): Data. If file it needs to H5 file
with dataset called 'pairs'.
X(matrix or file): Data. If file it needs to H5 file with datasets
called 'X', 'keys' and maybe 'features'.
keys(array): Row names.
key_type(str): Type of key. May be inchikey or smiles. If None
specified, no filtering is applied (default=None).
features(array): Column names (default=None).
merge(bool): Merge queried data with the currently existing one.
merge_method(str): Merging method to be applied when a repeated key
is found. Can be 'average', 'old' or 'new' (default=new).
destination(str): Path to the H5 file. If none specified, a
(V, keys, features) tuple is returned.
"""
self.__log.info("Predict START")
if merge:
self.__log.info("Merging. Loading existing signature.")
V_ = self[:]
keys_ = self.keys
keys_raw_ = self.keys_raw
if merge_method is not None:
if merge_method not in ["average", "new", "old"]:
raise Exception(
"merge_method must be None, 'average', 'new' or 'old'")
else:
self.__log.debug(
"Not merging, only predicting signature for the input data.")
V_ = None
keys_ = None
keys_raw_ = None
features_ = self.features
features_idx = dict((k, i) for i, k in enumerate(features_))
self.__log.debug("Preparing input data")
res = self.get_data(pairs=pairs, X=X, keys=keys, features=features,
data_file=data_file, key_type=key_type,
agg_method=self.agg_method)
X = res["X"]
keys = res["keys"]
keys_raw = res["keys_raw"]
input_type = res["input_type"]
features = res["features"]
if input_type != self.input_type:
raise Exception("Input type must be %s" % self.input_type)
self.__log.debug(
"Use same features arrangement as fitted signature.")
if len(set(features_) & set(features)) == 0:
raise Exception("No overlap between provided features and "
"expected ones. Check your feature names.")
if len(set(features_) & set(features)) < len(set(features_)):
self.__log.warning("Not all original features are covered, "
"Missing columns will be set to 0.")
W = np.full((len(keys), len(features_)), 0)
for i in range(0, X.shape[0]):
for j in range(0, X.shape[1]):
feat = features[j]
if feat not in features_idx:
continue
W[i, features_idx[feat]] = X[i, j]
X = W
self.refresh()
self.__log.debug("Aggregating as fitted signature.")
agg = Aggregate(method=self.agg_method, input_type=input_type)
X, keys, keys_raw = agg.transform(V=X, keys=keys, keys_raw=keys_raw)
features = res["features"]
features = features_
if V_ is None:
V = X
else:
self.__log.debug("Stacking")
V = np.vstack((V_, X))
keys = np.append(keys_, keys)
keys_raw = np.append(keys_raw_, keys_raw)
self.__log.debug("Aggregating (merging) again")
if merge_method is None:
agg_method = self.agg_method
if merge_method == 'new':
agg_method = 'first'
if merge_method == 'old':
agg_method = 'last'
if merge_method == 'average':
agg_method = merge_method
agg = Aggregate(method=agg_method, input_type=input_type)
V, keys, keys_raw = agg.transform(
V=V, keys=keys, keys_raw=keys_raw)
if destination is None:
results = {
"V": V,
"keys": keys,
"features": features,
"keys_raw": keys_raw
}
return results
else:
if isinstance(destination, BaseSignature):
destination = destination.data_path
self.__log.debug("Saving H5 file in %s" % destination)
with h5py.File(destination, "w") as hf:
hf.create_dataset(
"name", data=np.array([str(self.dataset) + "sig"],
DataSignature.string_dtype()))
sdate = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
hf.create_dataset(
"date",
data=np.array([sdate], DataSignature.string_dtype()))
hf.create_dataset("V", data=V)
hf.create_dataset("keys", data=np.array(
keys, DataSignature.string_dtype()))
hf.create_dataset("features", data=np.array(
features, DataSignature.string_dtype()))
hf.create_dataset("keys_raw", data=np.array(
keys_raw, DataSignature.string_dtype()))
self.__log.debug("Predict DONE")
[docs] def restrict_to_universe(self):
"""Restricts the keys contained in the universe."""
cc = self.get_cc()
universe = cc.universe # list of inchikeys belonging to the universe
self.__log.debug(
"--> getting the vectors from s0 corresponding to our "
"(restricted) universe")
# get the vectors from s0 corresponding to our (restricted) universe
inchk_univ, _ = self.get_vectors(keys=universe)
# obtain a mask for sign0 in order to obtain a filtered h5 file
# Strangely, putting lists greatly improves the performances of np.isin
self.__log.debug("--> Obtaining a mask")
mask = np.isin(list(self.keys), list(inchk_univ))
del inchk_univ # avoiding consuming too much memory
filtered_h5 = os.path.join(
os.path.dirname(self.data_path), 'sign0_univ.h5')
print("Creating", filtered_h5)
self.__log.debug("--> Creating file {}".format(filtered_h5))
self.make_filtered_copy(filtered_h5, mask)
# After that check that your file is ok and move it to sign0.h5
self.__log.debug("Done")
def export_features(self, destination=None):
features = self.features
destination = self.model_path if destination is None else destination
fn = os.path.join(destination, "features_sign0_%s.h5" % self.dataset)
with h5py.File(fn, 'w') as hf_out:
hf_out.create_dataset("features", data=np.array(
features, DataSignature.string_dtype()))