"""Plot information on multiple Chemical Checker datasets."""
import os
import math
import h5py
import json
import pickle
import itertools
import numpy as np
import pandas as pd
import collections
from tqdm import tqdm
from scipy import interpolate
from scipy import stats
from functools import partial
from scipy.stats import gaussian_kde
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import r2_score
from matplotlib.patches import Polygon
from sklearn.preprocessing import robust_scale
from sklearn.metrics.cluster import contingency_matrix
from sklearn.metrics import matthews_corrcoef
from sklearn.manifold import MDS
from scipy.spatial.distance import cosine, euclidean
import matplotlib
matplotlib.use('Agg')
import seaborn as sns
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.collections import LineCollection
from chemicalchecker.util.parser import Converter
from chemicalchecker.util import logged
from chemicalchecker.util.decomposition import dataset_correlation
from chemicalchecker.util.plot.diagnosticsplot import DiagnosisPlot
[docs]@logged
class MultiPlot():
"""MultiPlot class.
Produce Chemical Checker plots using multiple datasets.
"""
def __init__(self, chemchecker, plot_path, limit_dataset=None, svg=False,
dpi=200, grey=False, style=None):
"""Initialize a MultiPlot instance.
Produce plots integrating data from multiple datasets.
Args:
chemchecker (str): A Chemical Checker instance.
plot_path (str): Destination folder for plot images.
"""
if not os.path.isdir(plot_path):
raise Exception("Folder to save plots does not exist")
self.__log.debug('Plots will be saved to %s', plot_path)
self.plot_path = plot_path
self.cc = chemchecker
if not limit_dataset:
self.datasets = list(self.cc.datasets)
else:
if not isinstance(limit_dataset, list):
self.datasets = list(limit_dataset)
else:
self.datasets = limit_dataset
self.svg = svg
self.dpi = dpi
self.grey = grey
if style is None:
self.style = ('ticks', {
'font.family': ' sans-serif',
'font.serif': ['Arial'],
'font.size': 16,
'axes.grid': True})
else:
self.style = style
sns.set_style(*self.style)
def _rgb2hex(self, r, g, b):
return '#%02x%02x%02x' % (r, g, b)
[docs] def cc_palette(self, coords):
"""Return a list of colors a.k.a. a palette."""
def rgb2hex(r, g, b):
return '#%02x%02x%02x' % (r, g, b)
colors = list()
for coord in coords:
if "A" in coord:
colors.append(rgb2hex(250, 100, 80))
elif "B" in coord:
colors.append(rgb2hex(200, 100, 225))
elif "C" in coord:
colors.append(rgb2hex(80, 120, 220))
elif "D" in coord:
colors.append(rgb2hex(120, 180, 60))
elif "E" in coord:
colors.append(rgb2hex(250, 150, 50))
return colors
def cc_colors(self, coord, lighness=0):
colors = {
'A': ['#EA5A49', '#EE7B6D', '#F7BDB6'],
'B': ['#B16BA8', '#C189B9', '#D0A6CB'],
'C': ['#5A72B5', '#7B8EC4', '#9CAAD3'],
'D': ['#7CAF2A', '#96BF55', '#B0CF7F'],
'E': ['#F39426', '#F5A951', '#F8BF7D'],
'G': ['#333333', '#666666', '#999999']}
if not self.grey:
return colors[coord[:1]][lighness]
else:
return colors['G'][lighness]
[docs] def cmap_discretize(cmap, N):
"""Return a discrete colormap from the continuous colormap cmap.
cmap: colormap instance, eg. cm.jet.
N: number of colors.
Example
x = resize(arange(100), (5,100))
djet = cmap_discretize(cm.jet, 5)
imshow(x, cmap=djet)
"""
if type(cmap) == str:
cmap = plt.get_cmap(cmap)
colors_i = np.concatenate((np.linspace(0, 1., N), (0., 0., 0., 0.)))
colors_rgba = cmap(colors_i)
indices = np.linspace(0, 1., N + 1)
cdict = {}
for ki, key in enumerate(('red', 'green', 'blue')):
cdict[key] = [(indices[i], colors_rgba[i - 1, ki],
colors_rgba[i, ki]) for i in range(N + 1)]
# Return colormap object.
return matplotlib.colors.LinearSegmentedColormap(cmap.name + "_%d" % N, cdict, 1024)
def sign_adanet_stats(self, ctype, metric=None, compare=None):
# read stats fields
sign = self.cc.get_signature(ctype, 'full', 'E5.001')
stat_file = os.path.join(
sign.model_path, 'adanet_eval', 'stats_eval.pkl')
df = pd.read_pickle(stat_file)
# merge all stats to pandas
df = pd.DataFrame(columns=['coordinate'] + list(df.columns))
for ds in tqdm(self.datasets):
sign = self.cc.get_signature(ctype, 'full', ds)
stat_file = os.path.join(
sign.model_path, 'adanet_eval', 'stats_eval.pkl')
if not os.path.isfile(stat_file):
continue
tmpdf = pd.read_pickle(stat_file)
tmpdf['coordinate'] = ds
df = df.append(tmpdf, ignore_index=True)
df = df.infer_objects()
outfile_csv = os.path.join(self.plot_path, 'sign2_adanet_stats.csv')
df.to_csv(outfile_csv)
outfile_pkl = os.path.join(self.plot_path, 'sign2_adanet_stats.pkl')
df.to_pickle(outfile_pkl)
if compare:
cdf = pd.read_pickle(compare)
cdf = cdf[cdf.algo == 'AdaNet'].copy()
cdf['algo'] = cdf.algo.apply(lambda x: x + '_STACK')
df = df.append(cdf, ignore_index=True)
if metric:
all_metrics = [metric]
else:
all_metrics = ['mse', 'r2', 'explained_variance', 'pearson_std',
'pearson_avg', 'time', 'nn_layers', 'nr_variables']
for metric in all_metrics:
# sns.set_style("whitegrid")
g = sns.catplot(data=df, kind='point', x='dataset', y=metric,
hue="algo", col="coordinate", col_wrap=5,
col_order=self.datasets,
aspect=.8, height=3, dodge=True,
order=['train', 'test', 'validation'],
palette=['darkgreen', 'orange', 'darkgrey'])
if metric == 'r2':
for ax in g.axes.flatten():
ax.set_ylim(0, 1)
if metric == 'mse':
for ax in g.axes.flatten():
ax.set_ylim(0, 0.02)
if metric == 'explained_variance':
for ax in g.axes.flatten():
ax.set_ylim(0, 1)
if compare:
metric += '_CMP'
outfile = os.path.join(
self.plot_path, 'sign2_adanet_stats_%s.png' % metric)
plt.savefig(outfile, dpi=self.dpi)
plt.close('all')
[docs] def sign2_node2vec_stats(self):
"""Plot the stats for sign2."""
# plot selected stats
stats = [
"nodes",
"edges",
#"zeroNodes",
"zeroInNodes",
#"zeroOutNodes",
#"nonZIODegNodes",
"Connected Components",
"Degree",
"Weights",
"AUC-ROC",
"MCC",
"Sign Range"]
# the following are aggregated
conncompo = [
"SccSz",
"WccSz",
]
degrees = [
#"Degree_min",
"Degree_25",
"Degree_50",
"Degree_75",
#"Degree_max"
]
weights = [
#"Weight_min",
"Weight_25",
"Weight_50",
"Weight_75",
#"Weight_max"
]
# move stats to pandas
df = pd.DataFrame(columns=['dataset'] + stats)
for ds in tqdm(self.datasets):
# get sign2 and stats file
sign2 = self.cc.get_signature('sign2', 'reference', ds)
graph_file = os.path.join(sign2.stats_path, "graph_stats.json")
if not os.path.isfile(graph_file):
self.__log.warn('Graph stats %s not found', graph_file)
continue
graph_stat = json.load(open(graph_file, 'r'))
linkpred_file = os.path.join(sign2.stats_path, "linkpred.json")
skip_linkpred = False
if not os.path.isfile(linkpred_file):
self.__log.warn('Node2vec stats %s not found', linkpred_file)
skip_linkpred = True
pass
if not skip_linkpred:
liknpred_perf = json.load(open(linkpred_file, 'r'))
liknpred_perf = {k: float(v) for k, v in liknpred_perf.items()}
# prepare row
for deg in degrees:
row = dict()
row.update(graph_stat)
if not skip_linkpred:
row.update(liknpred_perf)
row.update({"dataset": ds})
row.update({"Degree": graph_stat[deg]})
df.loc[len(df)] = pd.Series(row)
for conn in conncompo:
row = dict()
row.update(graph_stat)
if not skip_linkpred:
row.update(liknpred_perf)
row.update({"dataset": ds})
row.update({"Connected Components": graph_stat[conn]})
df.loc[len(df)] = pd.Series(row)
for wei in weights:
row = dict()
row.update(graph_stat)
if not skip_linkpred:
row.update(liknpred_perf)
row.update({"dataset": ds})
row.update({"Weights": graph_stat[wei]})
df.loc[len(df)] = pd.Series(row)
maxss = list()
minss = list()
for s in sign2.chunker(size=10000):
curr = sign2[s]
maxss.append(np.percentile(curr, 99))
minss.append(np.percentile(curr, 1))
row = {"dataset": ds, "Sign Range": np.mean(maxss)}
df.loc[len(df)] = pd.Series(row)
row = {"dataset": ds, "Sign Range": np.mean(minss)}
df.loc[len(df)] = pd.Series(row)
df = df.infer_objects()
sns.set(style="ticks")
sns.set_context("talk", font_scale=1.)
g = sns.PairGrid(df.sort_values("dataset", ascending=True),
x_vars=stats, y_vars=["dataset"],
height=10, aspect=.3)
g.map(sns.stripplot, size=10, dodge=False, jitter=False, # marker="|",
palette=self.cc_palette(self.datasets),
orient="h", linewidth=1, edgecolor="w")
for ax in g.axes.flat:
# Make the grid horizontal instead of vertical
ax.xaxis.grid(True, color='#e3e3e3')
ax.yaxis.grid(True)
g.axes.flat[0].set_xscale("log")
g.axes.flat[0].set_xlim(1e3, 3 * 1e6)
g.axes.flat[0].set_xlabel("Nodes")
g.axes.flat[1].set_xscale("log")
g.axes.flat[1].set_xlim(1e4, 1e8)
g.axes.flat[1].set_xlabel("Edges")
g.axes.flat[2].set_xlim(0, 5000)
g.axes.flat[2].set_xlabel("0 In Nodes")
g.axes.flat[3].set_xlim(0, 1.0)
g.axes.flat[4].set_xscale("log")
g.axes.flat[4].set_xlim(10, 1e3)
g.axes.flat[4].set_xlabel("Degree %tiles")
g.axes.flat[5].set_xlim(0, 1)
g.axes.flat[6].set_xlim(0.9, 1)
g.axes.flat[7].set_xlim(0.5, 1)
g.axes.flat[8].set_xlim(-1, 1)
# g.axes.flat[-1].set_xlim(1e1,1e3)
sns.despine(left=True, bottom=True)
outfile = os.path.join(self.plot_path, 'sign2_node2vec_stats.png')
plt.savefig(outfile, dpi=self.dpi)
plt.close('all')
def sign_feature_distribution_plot(self, cctype, molset, block_size=1000,
block_nr=10, sort=False):
sample_size = block_size * block_nr
fig, axes = plt.subplots(25, 1, sharey=True, sharex=True,
figsize=(10, 40), dpi=self.dpi)
for ds, ax in tqdm(zip(self.datasets, axes.flatten())):
sign = self.cc.get_signature(cctype, molset, ds)
if not os.path.isfile(sign.data_path):
continue
if sign.shape[0] > sample_size:
blocks = np.random.choice(
int(np.ceil(sample_size / block_size)) + 1,
block_nr, replace=False)
block_mat = list()
for block in tqdm(blocks):
chunk = slice(block * block_size,
(block * block_size) + block_size)
block_mat.append(sign[chunk])
matrix = np.vstack(block_mat)
else:
matrix = sign[:]
df = pd.DataFrame(matrix).melt()
all_df = df.copy()
all_df['variable'] = 130
df = df.append(all_df, ignore_index=True)
if not sort:
order = [130, -1] + range(matrix.shape[1])
else:
order = [130, -1] + \
list(np.argsort(np.mean(matrix, axis=0))[::-1])
sns.pointplot(x='variable', y='value', data=df, order=order,
ax=ax, ci='sd', join=False, markers='.',
color=self.cc_palette([ds])[0])
ax.set_ylim(-1, 1)
ax.set_xlim(-2, 130)
ax.set_xticks([])
ax.set_xlabel('')
ax.set_ylabel(ds)
min_mean = np.min(np.mean(matrix, axis=0))
max_mean = np.max(np.mean(matrix, axis=0))
ax.fill_between([-2, 130], [max_mean, max_mean],
[min_mean, min_mean],
facecolor=self.cc_palette([ds])[0], alpha=0.3,
zorder=0)
max_std = max(np.std(matrix, axis=0))
ax.fill_between([-2, 130],
[max_mean + max_std, max_mean + max_std],
[min_mean - max_std, min_mean - max_std],
facecolor=self.cc_palette([ds])[0], alpha=0.2,
zorder=0)
sns.despine(bottom=True)
plt.tight_layout()
if not sort:
filename = os.path.join(
self.plot_path,
"%s_%s_feat_distrib.png" % (cctype, molset))
else:
filename = os.path.join(
self.plot_path,
"%s_%s_feat_distrib_sort.png" % (cctype, molset))
plt.savefig(filename, dpi=self.dpi)
plt.close()
def plot_adanet_subnetwork_layer_size(self, shapes=None, func=None):
if not shapes:
shapes = list()
for ds in self.cc.datasets:
sign1 = self.cc.get_signature('sign1', 'reference', ds)
x, y = sign1.shape
shapes.append((ds, x, y))
def layer_size(nr_samples, nr_features, nr_out=128, s_fact=7.):
heu_layer_size = (
1 / s_fact) * (np.sqrt(nr_samples) / .3 + ((nr_features + nr_out) / 5.))
heu_layer_size = np.power(2, np.ceil(np.log2(heu_layer_size)))
heu_layer_size = np.maximum(heu_layer_size, 32)
return heu_layer_size
if not func:
func = layer_size
x = np.logspace(2, 6, 500)
y = np.linspace(5, 5000, 500)
X, Y = np.meshgrid(x, y) # grid of point
Z = func(X, Y) # evaluation of the function on the grid
# sns.set_style("whitegrid")
fig, ax = plt.subplots(figsize=(7, 5), dpi=self.dpi)
norm = matplotlib.colors.BoundaryNorm(
boundaries=[2**i for i in range(5, 11)], ncolors=256)
# drawing the function
im = ax.pcolormesh(X, Y, Z, norm=norm, cmap=plt.cm.Blues)
plt.xscale('log')
# adding the Contour lines with labels
# cset = ax.contour(Z, [2**i for i in range(2, 11)],linewidths = 2, cmap = plt.cm.Set2)
# plt.clabel(cset, inline=True, fmt='%1.1f', fontsize=10)
plt.colorbar(im, label='Neurons') # adding the colobar on the right
plt.ylim(5, 5000)
ax.set_xlabel("Molecules")
ax.set_ylabel("Features")
plt.tight_layout()
for ds, x, y in shapes:
plt.scatter(x, y, color=self.cc_palette([ds])[0], alpha=.3)
plt.text(x, y, "%s" % (ds[:2]),
ha="center", va="center",
bbox={"boxstyle": "circle", "color": self.cc_palette([ds])[
0]},
color='k', fontsize=10)
filename = os.path.join(self.plot_path, "layer_size.png")
plt.savefig(filename, dpi=self.dpi)
plt.close()
def sign2_grid_search_plot(self, grid_postfix=None):
grid_roots = list()
for ds in self.cc.datasets:
sign2 = self.cc.get_signature('sign2', 'reference', ds)
grid_roots.append(os.path.join(sign2.model_path,
'grid_search_%s' % grid_postfix))
file_names = list()
for grid_root in grid_roots:
file_names.extend([os.path.join(grid_root, name, 'stats.pkl') for name in os.listdir(
grid_root) if os.path.isfile(os.path.join(grid_root, name, 'stats.pkl'))])
cols = list(pd.read_pickle(file_names[0]).columns)
params = {n.rsplit("_", 1)[0]: n.rsplit("_", 1)[1]
for n in file_names[0].split('/')[-2].split("-")}
df = pd.DataFrame(columns=set(cols) | set(params.keys()))
for tmpdf_file in file_names:
tmpdf = pd.read_pickle(tmpdf_file)
params = {n.rsplit("_", 1)[0]: n.rsplit("_", 1)[1]
for n in tmpdf_file.split('/')[-2].split("-")}
for k, v in params.items():
tmpdf[k] = pd.Series([v] * len(tmpdf))
coordinate = tmpdf_file.split('/')[-6]
tmpdf['coordinate'] = pd.Series([coordinate] * len(tmpdf))
if 'Ext' in params["subnetwork_generator"]:
tmpdf = tmpdf[tmpdf.algo == 'AdaNet']
else:
tmpdf["subnetwork_generator"] = tmpdf.algo.map(
{"AdaNet": "StackDNNGenerator", "LinearRegression": "LinearRegression"})
df = df.append(tmpdf, ignore_index=True)
# df['layer_size'] = df['layer_size'].astype(int)
# df['adanet_iterations'] = df['adanet_iterations'].astype(int)
# df['adanet_lambda'] = df['adanet_lambda'].astype(float)
df = df.infer_objects()
sns.set_context("talk")
netdf = pd.DataFrame(columns=list(df.columns) + ['layer', 'neurons'])
for index, row in df.iterrows():
for layer, size in enumerate(row.architecture[:-1]):
new_row = row.to_dict()
new_row['layer'] = layer + 1
new_row['neurons'] = size
netdf.loc[len(netdf)] = pd.Series(new_row)
sns.set_context("notebook")
# sns.set_style("whitegrid")
hue_order = ["StackDNNGenerator", "ExtendDNNGenerator"]
g = sns.catplot(data=netdf, kind='bar', x='layer', y='neurons',
hue="subnetwork_generator", col="coordinate", col_wrap=5,
col_order=self.datasets, hue_order=hue_order,
aspect=1.2, height=3, dodge=True,
palette=['forestgreen', 'orange'])
for ax in g.axes.flatten():
ax.set_yscale('log', basey=2)
ax.set_title("")
filename = os.path.join(
self.plot_path, "sign2_%s_grid_search_NN.png" % (grid_postfix))
plt.savefig(filename, dpi=self.dpi)
plt.close()
hue_order = ["StackDNNGenerator", "ExtendDNNGenerator"]
g = sns.catplot(data=netdf[netdf.subnetwork_generator == 'StackDNNGenerator'], kind='bar', x='layer', y='neurons',
hue="subnetwork_generator", col="coordinate", col_wrap=5,
col_order=self.datasets, hue_order=[
"StackDNNGenerator"],
aspect=1.2, height=3, dodge=True,
palette=['forestgreen', 'orange'])
for ax in g.axes.flatten():
ax.set_yscale('log', basey=2)
ax.set_title("")
filename = os.path.join(
self.plot_path, "sign2_%s_grid_search_NN_stackonly.png" % (grid_postfix))
plt.savefig(filename, dpi=self.dpi)
plt.close()
for metric in ['pearson_avg', 'time', 'r2', 'pearson_std', 'explained_variance']:
# sns.set_style("whitegrid")
hue_order = ["StackDNNGenerator",
"ExtendDNNGenerator", "LinearRegression"]
if metric == 'time':
sharey = False
else:
sharey = True
g = sns.catplot(data=df, kind='point', x='dataset', y=metric,
hue="subnetwork_generator", col="coordinate", col_wrap=5,
col_order=self.datasets, hue_order=hue_order, sharey=sharey,
order=['train', 'test', 'validation'], aspect=1.2, height=3,
palette=['forestgreen', 'orange', 'darkgrey'])
for ax in g.axes.flatten():
if metric == 'pearson_avg':
ax.set_ylim(0.5, 1)
ax.set_title("")
filename = os.path.join(
self.plot_path, "sign2_%s_grid_search_%s.png" % (grid_postfix, metric))
plt.savefig(filename, dpi=self.dpi)
plt.close()
g = sns.catplot(data=df[df.subnetwork_generator != 'ExtendDNNGenerator'], kind='point', x='dataset', y=metric,
hue="subnetwork_generator", col="coordinate", col_wrap=5,
col_order=self.datasets, hue_order=[
"StackDNNGenerator", "LinearRegression"],
aspect=1.2, height=3, dodge=True, sharey=sharey,
order=['train', 'test', 'validation'],
palette=['forestgreen', 'darkgrey'])
for ax in g.axes.flatten():
if metric == 'pearson_avg':
ax.set_ylim(0.5, 1)
ax.set_title("")
filename = os.path.join(
self.plot_path, "sign2_%s_grid_search_%s_stackonly.png" % (grid_postfix, metric))
plt.savefig(filename, dpi=self.dpi)
plt.close()
def sign2_grid_search_node2vec_plot(self, grid_postfix=None):
grid_roots = list()
for ds in self.cc.datasets:
sign2 = self.cc.get_signature('sign2', 'reference', ds)
grid_roots.append(os.path.join(sign2.model_path,
'grid_search_%s' % grid_postfix))
file_names = list()
for grid_root in grid_roots:
file_names.extend([os.path.join(grid_root, name, 'linkpred.test.json') for name in os.listdir(
grid_root) if os.path.isfile(os.path.join(grid_root, name, 'linkpred.test.json'))])
file_names.extend([os.path.join(grid_root, name, 'linkpred.train.json') for name in os.listdir(
grid_root) if os.path.isfile(os.path.join(grid_root, name, 'linkpred.train.json'))])
cols = json.load(open(file_names[0], 'r')).keys()
params = {n.rsplit("_", 1)[0]: n.rsplit("_", 1)[1]
for n in file_names[0].split('/')[-2].split("-")}
columns = list(set(cols) | set(params.keys()))
df = pd.DataFrame(columns=columns + ['coordinate', 'dataset'])
for tmpdf_file in file_names:
row = json.load(open(tmpdf_file, 'r'))
row = {k: float(v) for k, v in row.items()}
row['coordinate'] = tmpdf_file.split('/')[-6]
if 'train' in tmpdf_file:
row['dataset'] = 'train'
else:
row['dataset'] = 'test'
params = {n.rsplit("_", 1)[0]: n.rsplit("_", 1)[1]
for n in tmpdf_file.split('/')[-2].split("-")}
row.update(params)
df.loc[len(df)] = pd.Series(row)
df['d'] = df['d'].astype(int)
df = df.infer_objects()
sns.set_context("talk")
# sns.set_style("ticks")
g = sns.relplot(data=df, kind='line', x='d', y='AUC-ROC',
hue="coordinate", col="coordinate", col_wrap=5,
col_order=self.datasets, style="dataset",
palette=self.cc_palette(self.datasets),
aspect=1, height=2.475, legend=False, lw=3)
g.fig.set_size_inches(16.5, 16.5)
g.set_titles("")
coords = {0: "$\\bf{A}$", 5: "$\\bf{B}$",
10: "$\\bf{C}$", 15: "$\\bf{D}$", 20: "$\\bf{E}$"}
for idx, ax in enumerate(g.axes.flatten()):
ax.set_xscale('log', basex=2)
ax.set_xticks([2, 16, 128, 1024])
ax.set_yticks([.8, .9, 1.0])
ax.set_ylim([.78, 1.02])
if not idx % 5:
ax.set_ylabel(coords[idx])
if idx >= 20:
ax.set_xlabel("$\\bf{%s}$" % ((idx % 5) + 1))
if idx == 24:
lines = [matplotlib.lines.Line2D(
[0], [0], color=".15",
linewidth=2, linestyle=ls) for ls in ['--', '-']]
labels = ['Train', 'Test']
ax.legend(lines, labels, frameon=False)
sns.despine(top=False, right=False, left=False, bottom=False)
filename = os.path.join(
self.plot_path, "sign2_%s_grid_search_node2vec.png" % (grid_postfix))
plt.tight_layout()
plt.savefig(filename, dpi=self.dpi)
plt.close()
def sign3_grid_search_plot(self, grid_roots):
file_names = list()
for grid_root in grid_roots:
file_names.extend([os.path.join(grid_root, name, 'adanet', 'stats.pkl') for name in os.listdir(
grid_root) if os.path.isfile(os.path.join(grid_root, name, 'adanet', 'stats.pkl'))])
cols = list(pd.read_pickle(file_names[0]).columns)
df = pd.DataFrame(columns=list(set(cols)) + ['subnetwork_generator'])
for tmpdf_file in file_names:
tmpdf = pd.read_pickle(tmpdf_file)
coordinate = tmpdf_file.split('/')[-3]
tmpdf['coordinate'] = pd.Series(
[coordinate.split("_")[0]] * len(tmpdf))
if 'STACK' in tmpdf_file:
tmpdf = tmpdf[tmpdf.algo == 'AdaNet']
tmpdf["subnetwork_generator"] = tmpdf.algo.map(
{"AdaNet": "StackDNNGenerator", "LinearRegression": "LinearRegression"})
else:
tmpdf["subnetwork_generator"] = tmpdf.algo.map(
{"AdaNet": "ExtendDNNGenerator", "LinearRegression": "LinearRegression"})
df = df.append(tmpdf, ignore_index=True)
# df['layer_size'] = df['layer_size'].astype(int)
# df['adanet_iterations'] = df['adanet_iterations'].astype(int)
# df['adanet_lambda'] = df['adanet_lambda'].astype(float)
df = df.infer_objects()
sns.set_context("talk")
netdf = pd.DataFrame(columns=list(df.columns) + ['layer', 'neurons'])
for index, row in df.iterrows():
for layer, size in enumerate(row.architecture[:-1]):
new_row = row.to_dict()
new_row['layer'] = layer + 1
new_row['neurons'] = size
netdf.loc[len(netdf)] = pd.Series(new_row)
sns.set_context("notebook")
# sns.set_style("whitegrid")
hue_order = ["StackDNNGenerator", "ExtendDNNGenerator"]
g = sns.catplot(data=netdf, kind='bar', x='layer', y='neurons',
hue="subnetwork_generator", col="coordinate", col_wrap=5,
col_order=self.datasets, hue_order=hue_order,
aspect=1.2, height=3, dodge=True,
palette=['forestgreen', 'orange'])
for ax in g.axes.flatten():
ax.set_yscale('log', basey=2)
ax.set_title("")
filename = os.path.join(
self.plot_path, "sign3_crossfit_NN.png")
plt.savefig(filename, dpi=self.dpi)
plt.close()
hue_order = ["StackDNNGenerator", "ExtendDNNGenerator"]
g = sns.catplot(data=netdf[netdf.subnetwork_generator == 'StackDNNGenerator'], kind='bar', x='layer', y='neurons',
hue="subnetwork_generator", col="coordinate", col_wrap=5,
col_order=self.datasets, hue_order=[
"StackDNNGenerator"],
aspect=1.2, height=3, dodge=True,
palette=['forestgreen', 'orange'])
for ax in g.axes.flatten():
ax.set_yscale('log', basey=2)
ax.set_title("")
filename = os.path.join(
self.plot_path, "sign3_crossfit_NN_stackonly.png")
plt.savefig(filename, dpi=self.dpi)
plt.close()
for metric in ['pearson_avg', 'time', 'r2', 'pearson_std', 'explained_variance']:
# sns.set_style("whitegrid")
hue_order = ["StackDNNGenerator",
"ExtendDNNGenerator", "LinearRegression"]
if metric == 'time':
sharey = False
else:
sharey = True
g = sns.catplot(data=df, kind='point', x='dataset', y=metric,
hue="subnetwork_generator", col="coordinate", col_wrap=5,
col_order=self.datasets, hue_order=hue_order, sharey=sharey,
order=['train', 'test', 'validation'], aspect=1.2, height=3,
palette=['forestgreen', 'orange', 'darkgrey'])
for ax in g.axes.flatten():
ax.set_title("")
filename = os.path.join(
self.plot_path, "sign3_crossfit_%s.png" % (metric))
plt.savefig(filename, dpi=self.dpi)
plt.close()
g = sns.catplot(data=df[df.subnetwork_generator != 'ExtendDNNGenerator'], kind='point', x='dataset', y=metric,
hue="subnetwork_generator", col="coordinate", col_wrap=5,
col_order=self.datasets, hue_order=[
"StackDNNGenerator", "LinearRegression"],
aspect=1.2, height=3, dodge=True, sharey=sharey,
order=['train', 'test', 'validation'],
palette=['forestgreen', 'darkgrey'])
for ax in g.axes.flatten():
ax.set_title("")
filename = os.path.join(
self.plot_path, "sign3_crossfit_%s_stackonly.png" % (metric))
plt.savefig(filename, dpi=self.dpi)
plt.close()
def sign3_all_crossfit_plot(self, crossfit_dir):
file_names = list()
for name in os.listdir(crossfit_dir):
filename = os.path.join(crossfit_dir, name, 'adanet', 'stats.pkl')
if not os.path.isfile(filename):
print("File not found: %s", filename)
continue
file_names.append(filename)
cols = list(pd.read_pickle(file_names[0]).columns)
dfs = list()
for tmpdf_file in file_names:
tmpdf = pd.read_pickle(tmpdf_file)
coordinate = tmpdf_file.split('/')[-3]
tmpdf['coordinate_from'] = pd.Series(
[coordinate.split("_")[0]] * len(tmpdf))
tmpdf['coordinate_to'] = pd.Series(
[coordinate.split("_")[1]] * len(tmpdf))
# also find dataset size
traintest_file = os.path.join(os.path.split(tmpdf_file)[
0], '..', 'traintest.h5')
with h5py.File(traintest_file, 'r') as fh:
train_size = fh['x_train'].shape[0]
tmpdf['train_size'] = pd.Series([train_size] * len(tmpdf))
dfs.append(tmpdf)
df = pd.DataFrame(columns=list(set(cols)) +
['coordinate_from', 'coordinate_to', 'train_size'])
df = df.append(dfs, ignore_index=True)
df = df.infer_objects()
adanet_test = df[(df.algo == 'AdaNet') & (df.dataset == 'test')]
adanet_train = df[(df.algo == 'AdaNet') & (df.dataset == 'train')]
metrics = ['pearson_avg', 'time', 'r2', 'train_size',
'pearson_std', 'explained_variance', 'nr_variables']
for idx, met in enumerate(metrics):
piv = adanet_test.pivot(
index='coordinate_from', columns='coordinate_to', values=met)
if met == 'train_size':
piv = np.log10(piv)
met = 'log_train_size'
if met == 'nr_variables':
piv = np.log10(piv)
met = 'log_nr_variables'
ax = plt.axes()
col_start = np.linspace(0, 3, len(metrics))[idx]
col_rot = 1. / len(metrics)
cubehelix = sns.cubehelix_palette(
start=col_start, rot=col_rot, as_cmap=True)
cubehelix.set_under(".9")
cmap = self.cmap_discretize(cubehelix, 5)
cmap.set_under(".9")
if met == 'pearson_avg':
sns.heatmap(piv, cmap=cmap,
linecolor='grey',
square=True, vmin=0., vmax=1., ax=ax)
else:
sns.heatmap(piv, cmap=cmap,
linecolor='grey',
square=True, ax=ax)
"""
elif met in ['F1', 'AUC-ROC', 'AUC-PR']:
sns.heatmap(piv, cmap=cmap,
linecolor='grey',
square=True, vmin=0.5, vmax=1., ax=ax)
elif met in ['validation_neg_median', 'validation_pos_median']:
sns.heatmap(piv, cmap=cmap_discretize(cubehelix, 10),
linecolor='grey',
square=True, vmin=-1., vmax=1., ax=ax)
"""
for grid in range(0, 26, 5):
ax.axhline(y=grid, color='grey', linewidth=0.5)
ax.axvline(x=grid, color='grey', linewidth=0.5)
ax.set_title(met)
plt.tight_layout()
filename = os.path.join(
self.plot_path, "sign3_all_crossfit_train_delta_%s.png" % (met))
plt.savefig(filename, dpi=self.dpi)
plt.close()
piv_test = adanet_test.pivot(
index='coordinate_from', columns='coordinate_to', values='pearson_avg')
piv_train = adanet_train.pivot(
index='coordinate_from', columns='coordinate_to', values='pearson_avg')
piv = piv_train - piv_test
overfit = piv.stack().reset_index()
overfit = overfit.rename(index=str, columns={0: 'overfit_pearson_avg'})
odf = pd.merge(adanet_test, overfit, how='left',
left_on=['coordinate_from', 'coordinate_to'],
right_on=['coordinate_from', 'coordinate_to'])
odf = odf[(odf.coordinate_from != odf.coordinate_to)]
odf['pair'] = odf['coordinate_from'].apply(
lambda x: x[:2]) + "_" + odf['coordinate_to'].apply(lambda x: x[:2])
# odf['capped_train_size'] = np.minimum(odf.train_size,20000)
odf['log10_train_size'] = np.log(odf.train_size)
odf['pearson_avg_train'] = odf.overfit_pearson_avg + odf.pearson_avg
# sns.set_style("whitegrid")
sns.set_context("talk")
order = sorted(odf.coordinate_to.unique())
sns.relplot(x="pearson_avg_train", y="pearson_avg",
hue='coordinate_from', hue_order=order,
palette=self.cc_palette(order),
col='coordinate_to', col_wrap=5, col_order=order,
size="log10_train_size", sizes=(5, 100), data=odf,
facet_kws={'xlim': (0., 1.), 'ylim': (0., 1.)})
filename = os.path.join(
self.plot_path, "sign3_overfit_vs_trainsize.png")
plt.savefig(filename, dpi=self.dpi)
plt.close()
@staticmethod
def spy_sign2_universe_matrix(self, universe_h5, datasets,
chunk_size=1000):
# the matrix is too big for any plotting attempt
# we go by density bins
bins = list()
with h5py.File(universe_h5, 'r') as hf:
for i in tqdm(range(0, hf['x_test'].shape[0], chunk_size)):
chunk = slice(i, i + chunk_size)
matrix = hf['x_test'][chunk]
presence = (~np.isnan(matrix[:, 0::128])).astype(int)
curr_bin = np.sum(presence, axis=0) / float(presence.shape[0])
bins.append(curr_bin)
binned = np.vstack(bins)
# do some column-wise smoothing
def smooth_fn(r):
s = interpolate.interp1d(np.arange(len(r)), r)
xnew = np.arange(0, len(r) - 1, .1)
return s(xnew)
smooth = np.vstack([smooth_fn(c) for c in binned.T]).T
# plot
sns.set_context('talk')
# sns.set_style("white")
fig, ax = plt.subplots(figsize=(14, 12))
cmap = matplotlib.cm.viridis
cmap.set_bad(cmap.colors[0])
im = ax.imshow(smooth * 100, aspect='auto',
norm=matplotlib.colors.LogNorm(),
cmap=cmap)
# ax.yaxis.set_xticklabels(
# np.arange(0, binned.shape[0] * chunk_size, chunk_size))
ax.set_xlabel("Bioativity Spaces")
ax.xaxis.set_label_position('top')
ax.set_xticks(np.arange(binned.shape[1]))
ax.set_xticklabels([ds[:2] for ds in datasets])
ax.tick_params(labelbottom='off', labeltop='on')
ax.xaxis.labelpad = 20
ax.set_yticklabels([])
thousands = binned.shape[0] * chunk_size / 1000.
ax.set_ylabel("%ik Molecules" % thousands)
# colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.ax.set_ylabel('Coverage')
cbar.ax.yaxis.set_label_position('left')
cbar.set_ticks([0.01, 0.1, 1., 10, 100])
cbar.set_ticklabels(['0.01%', '0.1%', '1%', '10%', '100%'])
'''
# also mark dataset medians
ds_avg = np.median(binned, axis=0) * 100
cbar2 = plt.colorbar(im, ax=ax)
cbar2.ax.set_ylabel('Coverage')
cbar2.ax.yaxis.set_label_position('left')
cbar2.set_ticks(ds_avg[np.argsort(ds_avg)])
cbar2.set_ticklabels(
np.array([ds[:2] for ds in datasets])[np.argsort(ds_avg)])
'''
# save
plt.tight_layout()
filename = os.path.join("sign2_universe.png")
plt.savefig(filename, dpi=self.dpi)
plt.close()
def spy_augment(self, matrix, augment_fn, epochs=1):
nr_samples, nr_features = matrix.shape
fig, axes = plt.subplots(nrows=epochs // 5, ncols=5, figsize=(15, 15))
for idx in range(epochs):
ax = axes.flatten()[idx]
aug_mat, _ = augment_fn(matrix, True)
ax.spy(aug_mat)
ax.set_yticklabels([])
ax.set_xticklabels([])
filename = os.path.join(self.plot_path, "spy.png")
plt.savefig(filename, dpi=self.dpi)
plt.close()
def sign3_adanet_performance_all_plot(self, metric="pearson", suffix=None,
stat_filename="stats_eval.pkl"):
# sns.set_style("whitegrid")
# sns.set_style({'font.family': 'sans-serif', 'font.serif': ['Arial']})
fig, axes = plt.subplots(25, 1, sharey=True, sharex=False,
figsize=(20, 70), dpi=self.dpi)
adanet_dir = 'adanet_eval'
if suffix is not None:
adanet_dir = 'adanet_%s' % suffix
for ds, ax in tqdm(zip(self.datasets, axes.flatten())):
s3 = self.cc.get_signature('sign3', 'full', ds)
perf_file = os.path.join(
s3.model_path, adanet_dir, stat_filename)
if not os.path.isfile(perf_file):
continue
df = pd.read_pickle(perf_file)
sns.barplot(x='from', y=metric, data=df, hue="split",
hue_order=['train', 'test'],
ax=ax, color=self.cc_palette([ds])[0])
ax.set_ylim(0, 1)
sns.stripplot(x='from', y='coverage', data=df, hue="split",
hue_order=['train', 'test'],
ax=ax, jitter=False, palette=['pink', 'crimson'],
alpha=.9)
ax.get_legend().remove()
ax.set_xlabel('')
ax.set_ylabel(ds)
for idx, p in enumerate(ax.patches):
if "%.2f" % p.get_height() == 'nan':
continue
val = "%.2f" % p.get_height()
ax.annotate(val[1:],
(p.get_x() + p.get_width() / 2., 0),
ha='center', va='center', fontsize=11,
color='k', rotation=90, xytext=(0, 20),
textcoords='offset points')
ax.grid(axis='y', linestyle="-",
color=self.cc_palette([ds])[0], lw=0.3)
ax.spines["bottom"].set_color(self.cc_palette([ds])[0])
ax.spines["top"].set_color(self.cc_palette([ds])[0])
ax.spines["right"].set_color(self.cc_palette([ds])[0])
ax.spines["left"].set_color(self.cc_palette([ds])[0])
plt.tight_layout()
filename = os.path.join(self.plot_path, "adanet_performance_all.png")
plt.savefig(filename, dpi=self.dpi)
plt.close('all')
def sign3_adanet_performance_overall(self, metric="pearson", suffix=None,
not_self=True):
# sns.set_style("whitegrid")
sns.set_context("talk")
fig, axes = plt.subplots(5, 5, sharey=True, sharex=False,
figsize=(10, 10), dpi=self.dpi)
adanet_dir = 'adanet_eval'
if suffix is not None:
adanet_dir = 'adanet_%s' % suffix
for ds, ax in tqdm(zip(self.datasets, axes.flatten())):
s3 = self.cc.get_signature('sign3', 'full', ds)
perf_file = os.path.join(
s3.model_path, adanet_dir, 'stats.pkl')
if not os.path.isfile(perf_file):
continue
df = pd.read_pickle(perf_file)
if not_self:
if ds in ['B4.001', 'C3.001', 'C4.001', 'C5.001']:
df = df[df['from'] == 'not-BX|CX']
else:
df = df[df['from'] == 'not-%s' % ds]
sns.barplot(x='from', y=metric, data=df, hue="split",
hue_order=['train', 'test'], alpha=.8,
ax=ax, color=self.cc_palette([ds])[0])
ax.set_ylim(0, 1)
ax.get_legend().remove()
ax.set_xlabel('')
ax.set_ylabel('')
# ax.set_xticklabels([ds])
for idx, p in enumerate(ax.patches):
if "%.2f" % p.get_height() == 'nan':
continue
val = "%.2f" % p.get_height()
ax.annotate(val[1:],
(p.get_x() + p.get_width() / 2., 0),
ha='center', va='center', fontsize=11,
color='k', rotation=90, xytext=(0, 20),
textcoords='offset points')
ax.grid(axis='y', linestyle="-",
color=self.cc_palette([ds])[0], lw=0.3)
ax.spines["bottom"].set_color(self.cc_palette([ds])[0])
ax.spines["top"].set_color(self.cc_palette([ds])[0])
ax.spines["right"].set_color(self.cc_palette([ds])[0])
ax.spines["left"].set_color(self.cc_palette([ds])[0])
ax.legend(loc='upper right', fontsize='small')
plt.tight_layout()
if suffix is not None:
filename = os.path.join(
self.plot_path, "adanet_performance_%s.png" % suffix)
else:
filename = os.path.join(
self.plot_path, "adanet_performance_overall.png")
plt.savefig(filename, dpi=self.dpi)
plt.close('all')
def sign3_adanet_performance_overall_heatmap(self, metric="pearson",
split='test',
suffix=None, not_self=True):
adanet_dir = 'adanet_eval'
if suffix is not None:
adanet_dir = 'adanet_%s' % suffix
df = pd.DataFrame()
for ds in tqdm(self.datasets):
s3 = self.cc.get_signature('sign3', 'full', ds)
perf_file = os.path.join(s3.model_path, adanet_dir,
'stats_eval.pkl')
if not os.path.isfile(perf_file):
continue
sdf = pd.read_pickle(perf_file)
sel = sdf[(sdf['split'] == split)].groupby(
'from', as_index=False)[metric].mean()
sel['to'] = ds[:2]
df = df.append(sel, ignore_index=True)
df['from'] = df['from'].map({ds: ds[:2] for ds in self.cc.datasets})
df = df.dropna()
fig, ax = plt.subplots(1, 1, figsize=(6, 5), dpi=self.dpi)
cmap = plt.cm.get_cmap('plasma_r', 5)
sns.heatmap(df.pivot('from', 'to', metric), vmin=0, vmax=1,
linewidths=.5, square=True, cmap=cmap)
plt.title('set: %s, metric: %s' % (split, metric))
plt.tight_layout()
filename = os.path.join(
self.plot_path, "adanet_perf_heatmap_%s_%s.png" % (split, metric))
plt.savefig(filename, dpi=self.dpi)
plt.close('all')
def sign3_coverage_heatmap(self, sign2_coverage):
cov = sign2_coverage.get_h5_dataset('V')
df = pd.DataFrame(columns=['from', 'to', 'coverage'])
for ds_from, ds_to in tqdm(itertools.product(self.datasets, self.datasets)):
idx_from = self.datasets.index(ds_from)
idx_to = self.datasets.index(ds_to)
mask_to = cov[:, idx_to].astype(bool)
tot_to = np.count_nonzero(mask_to)
having_from = np.count_nonzero(cov[mask_to, idx_from])
coverage = having_from / float(tot_to)
df.loc[len(df)] = pd.Series({
'from': ds_from[:2],
'to': ds_to[:2],
'coverage': coverage})
fig, ax = plt.subplots(1, 1, figsize=(6, 5), dpi=self.dpi)
cmap = plt.cm.get_cmap('plasma_r', 5)
sns.heatmap(df.pivot('from', 'to', 'coverage'), vmin=0, vmax=1,
linewidths=.5, square=True, cmap=cmap)
plt.title('Coverage')
plt.tight_layout()
filename = os.path.join(
self.plot_path, "sign3_coverage_heatmap.png")
plt.savefig(filename, dpi=self.dpi)
plt.close('all')
def sign3_coverage_barplot(self, sign2_coverage):
# sns.set_style(*self.style)
cov = sign2_coverage.get_h5_dataset('V')
df = pd.DataFrame(columns=['from', 'to', 'coverage'])
for ds_from, ds_to in tqdm(itertools.product(self.datasets, self.datasets)):
idx_from = self.datasets.index(ds_from)
idx_to = self.datasets.index(ds_to)
mask_to = cov[:, idx_to].astype(bool)
tot_to = np.count_nonzero(mask_to)
having_from = np.count_nonzero(cov[mask_to, idx_from])
coverage = having_from / float(tot_to)
df.loc[len(df)] = pd.Series({
'from': ds_from[:2],
'to': ds_to[:2],
'coverage': coverage})
cross_cov = df.pivot('from', 'to', 'coverage').values
fracs = (cross_cov / np.sum(cross_cov, axis=0)).T
totals = dict()
for ds in self.datasets:
idx = self.datasets.index(ds)
coverage_col = cov[:, idx].astype(bool)
totals[ds] = np.count_nonzero(coverage_col)
columns = ['dataset'] + self.datasets
df2 = pd.DataFrame(columns=columns)
for ds, frac in list(zip(self.datasets, fracs))[::-1]:
row = zip(columns, [ds[:2]] +
(frac * np.log10(totals[ds])).tolist())
df2.loc[len(df2)] = pd.Series(dict(row))
df2.set_index('dataset', inplace=True)
fig, ax = plt.subplots(1, 1, figsize=(3, 10), dpi=self.dpi)
colors = [
'#EA5A49', '#EE7B6D', '#EA5A49', '#EE7B6D', '#EA5A49',
'#C189B9', '#B16BA8', '#C189B9', '#B16BA8', '#C189B9',
'#5A72B5', '#7B8EC4', '#5A72B5', '#7B8EC4', '#5A72B5',
'#96BF55', '#7CAF2A', '#96BF55', '#7CAF2A', '#96BF55',
'#F39426', '#F5A951', '#F39426', '#F5A951', '#F39426']
df2.plot.barh(stacked=True, ax=ax, legend=False, color=colors, lw=0)
sns.despine(left=True, trim=True)
plt.tick_params(left=False)
ax.set_ylabel('')
ax.set_xlabel('Molecule Coverage')
# ax.xaxis.tick_top()
ax.set_xticks([0, 2, 4, 6])
ax.set_xticklabels(
[r'$10^{0}$', r'$10^{2}$', r'$10^{4}$', r'$10^{6}$', ])
ax.tick_params(labelsize=14)
ax.grid(False)
plt.tight_layout()
filename = os.path.join(
self.plot_path, "sign3_coverage_barplot.png")
plt.savefig(filename, dpi=self.dpi)
if self.svg:
plt.savefig(filename.replace('.png', '.svg'), dpi=self.dpi)
plt.close('all')
fig = plt.figure(figsize=(3, 10))
plt.subplots_adjust(left=0.05, right=0.95, bottom=0.02,
top=0.95, hspace=0.1)
gs = fig.add_gridspec(2, 1)
gs.set_height_ratios((50, 1))
gs_ds = gs[0].subgridspec(1, 2, wspace=0.1, hspace=0.0)
ax_cov = fig.add_subplot(gs_ds[0])
ax_xcov = fig.add_subplot(gs_ds[1])
cdf = pd.DataFrame([(x[:2], totals[x]) for x in sorted(totals.keys(), reverse=True)],
columns=['dataset', 'coverage'])
ax_cov.barh(range(25), cdf['coverage'],
color=list(reversed(colors)), lw=0)
for y, name in enumerate(cdf['dataset'].tolist()):
ax_cov.text(2, y - 0.06, name, size=12,
va='center', ha='left',
color='white', fontweight='bold')
ax_cov.set_yticks(range(1, 26))
ax_cov.set_yticklabels([])
# ax_cov.set_xlim(0, 110)
# ax_cov.xaxis.set_label_position('top')
# ax_cov.xaxis.tick_top()
plt.tick_params(left=False)
ax_cov.set_ylabel('')
ax_cov.set_xlabel('Molecules', fontsize=16)
ax_cov.set_xscale('log')
ax_cov.set_xlim(1, 1e6)
ax_cov.set_ylim(-1, 25)
# ax.xaxis.tick_top()
ax_cov.set_xticks([1e1, 1e3, 1e5])
# ax_cov.set_xticklabels(
# [r'$\mathregular{10^{1}}$', r'$\mathregular{10^{3}}$',
# r'$\mathregular{10^{5}}$', ])
ax_cov.grid(False)
# ax_cov.tick_params(labelsize=14)
ax_cov.tick_params(left=False, labelsize=14, pad=0, direction='inout')
sns.despine(ax=ax_cov, left=True, bottom=True, top=False, trim=True)
ax_cov.xaxis.set_label_position('top')
ax_cov.xaxis.tick_top()
ax_xcov.tick_params(left=False, bottom=False, top=True,
labelbottom=False, labeltop=True)
cmap_tmp = plt.get_cmap("magma", 14)
cmap = matplotlib.colors.ListedColormap(
cmap_tmp(np.linspace(0, 1, 14))[4:], 10)
for i in range(25):
ax_xcov.barh(25 - i, 1, left=range(25),
color=[cmap(1 - cross_cov[i, x]) for x in range(25)],
lw=0)
ax_xcov.set_ylim(0, 26)
ax_xcov.set_yticks([])
ax_xcov.set_yticklabels([])
ax_xcov.set_xticks([])
ax_xcov.set_xticklabels([])
ax_xcov.set_yticks(range(1, 26))
for y, color in enumerate(colors):
rect = matplotlib.patches.Rectangle(
(y, 25.75), 1, 0.5, lw=0, edgecolor='w', facecolor=color,
clip_on=False)
ax_xcov.add_patch(rect)
main_colors = ['#EA5A49', '#B16BA8', '#5A72B5', '#7CAF2A', '#F39426']
for y, col in enumerate(main_colors):
ax_xcov.set_yticks(range(1, 26))
rect = matplotlib.patches.Rectangle(
(y * 5, 26), 5, 1, lw=0.1, edgecolor='w', facecolor=col,
clip_on=False)
ax_xcov.add_patch(rect)
rx, ry = rect.get_xy()
cx = rx + rect.get_width() / 2.0
cy = ry + rect.get_height() / 2.0
ax_xcov.text(cx, cy, 'ABCDE'[y], weight='bold',
size=12, ha='center', va='center',
color='white', clip_on=False)
sns.despine(ax=ax_xcov, left=True, bottom=True)
ax_xcov.tick_params(left=False, bottom=False, top=False,
labelbottom=False, labeltop=False)
ax_xcov.grid(False)
ax_cbar = fig.add_subplot(gs[1])
cbar = matplotlib.colorbar.ColorbarBase(
ax_cbar, cmap=cmap, orientation='horizontal',
ticklocation='top')
cbar.ax.set_xlabel('Overlap', fontsize=16)
cbar.ax.tick_params(labelsize=12)
cbar.set_ticks([1, .8, .6, .4, .2, .0])
cbar.set_ticklabels(['0', '', '', '', '', '1'])
poly = Polygon([(0.05, 2.0), (0.05, 1.4), (1, 1.4)],
transform=ax_cbar.get_xaxis_transform(),
clip_on=False,
facecolor='black')
ax_cbar.add_patch(poly)
outfile = os.path.join(self.plot_path, 'sign3_coverage_barplot.png')
plt.savefig(outfile, dpi=self.dpi)
if self.svg:
plt.savefig(outfile.replace('.png', '.svg'), dpi=self.dpi)
plt.close('all')
def cctype_CCA(self, cctype1='sign4', cctype2='sign1', limit=10000):
cca_file = os.path.join(
self.plot_path, "%s_%s_CCA.pkl" % (cctype1, cctype2))
if not os.path.isfile(cca_file):
df = pd.DataFrame(columns=['from', 'to', cctype2, cctype1])
for i in range(len(self.datasets)):
ds_from = self.datasets[i]
s2_from = self.cc.get_signature(
cctype2, 'full', ds_from)
s3_from = self.cc.get_signature(
cctype1, 'full', ds_from)[:limit]
for j in range(i + 1):
ds_to = self.datasets[j]
if ds_to == ds_from:
df.loc[len(df)] = pd.Series({
'from': ds_from[:2],
'to': ds_to[:2],
cctype1: 1.0,
cctype2: 1.0})
continue
s2_to = self.cc.get_signature(
cctype2, 'full', ds_to)
s3_to = self.cc.get_signature(
cctype1, 'full', ds_to)[:limit]
s3_res = dataset_correlation(s3_from, s3_to)
# shared keys
shared_inks = s2_from.unique_keys & s2_to.unique_keys
shared_inks = sorted(list(shared_inks))[:limit]
mask_from = np.isin(list(s2_from.keys), list(shared_inks))
mask_to = np.isin(list(s2_to.keys), list(shared_inks))
ss2_from = s2_from[
:limit * 10][mask_from[:limit * 10]][:limit]
ss2_to = s2_to[:limit * 10][mask_to[:limit * 10]][:limit]
min_size = min(len(ss2_to), len(ss2_from))
print(ds_from, ds_to, 'S2 min_size', min_size)
if min_size < 10:
df.loc[len(df)] = pd.Series({
'from': ds_from[:2],
'to': ds_to[:2],
cctype1: s3_res[0],
cctype2: 0.0})
df.loc[len(df)] = pd.Series({
'from': ds_to[:2],
'to': ds_from[:2],
cctype1: s3_res[3],
cctype2: 0.0})
continue
s2_res = dataset_correlation(
ss2_from[:min_size], ss2_to[:min_size])
df.loc[len(df)] = pd.Series({
'from': ds_from[:2],
'to': ds_to[:2],
cctype1: s3_res[0],
cctype2: s2_res[0]})
df.loc[len(df)] = pd.Series({
'from': ds_to[:2],
'to': ds_from[:2],
cctype1: s3_res[3],
cctype2: s2_res[3]})
print(ds_from, ds_to, 's3 %.2f' %
s3_res[0], 's2 %.2f' % s2_res[0])
df.to_pickle(cca_file)
df = pd.read_pickle(cca_file)
# sns.set_style("ticks")
# sns.set_style({'font.family': 'sans-serif', 'font.serif': ['Arial']})
# CCA heatmaps
for cca_id in [cctype2, cctype1]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5), dpi=self.dpi)
cmap = plt.cm.get_cmap('plasma_r', 5)
sns.heatmap(df.pivot('from', 'to', cca_id), vmin=0, vmax=1,
linewidths=.2, square=True, cmap=cmap, ax=ax)
ax.set_xlabel('')
ax.set_ylabel('')
bottom, top = ax.get_ylim()
ax.set_ylim(bottom + 0.5, top - 0.5)
plt.title('Canonical Correlation Analysis')
plt.tight_layout()
filename = os.path.join(
self.plot_path, "%s_%s_heatmap.png" % (cctype1, cca_id))
plt.savefig(filename, dpi=self.dpi)
plt.close('all')
# combined heatmap
cca3 = df.pivot('from', 'to', cctype1).values
cca3_avg = np.zeros_like(cca3)
for i, j in itertools.product(range(25), range(25)):
cca3_avg[i, j] = np.mean((cca3[i, j], cca3[j, i]))
cca2 = df.pivot('from', 'to', cctype2).values
cca2_avg = np.zeros_like(cca2)
for i, j in itertools.product(range(25), range(25)):
cca2_avg[i, j] = np.mean((cca2[i, j], cca2[j, i]))
combined = np.tril(cca3_avg) + np.triu(cca2_avg)
np.fill_diagonal(combined, np.nan)
fig = plt.figure(figsize=(6, 6))
plt.subplots_adjust(left=0.05, right=0.90, bottom=0.08, top=0.9,
wspace=.02, hspace=.02)
gs = fig.add_gridspec(2, 2)
gs.set_height_ratios((1, 10))
gs.set_width_ratios((20, 1))
ax_dist = fig.add_subplot(gs[0, 0])
cca2_dist = cca2_avg[np.triu_indices_from(cca2_avg, 1)]
cca3_dist = cca3_avg[np.tril_indices_from(cca3_avg, -1)]
sns.kdeplot(cca2_dist, ax=ax_dist, shade=True, bw=.2, color='blue',
clip=(min(cca2_dist), max(cca2_dist)))
sns.kdeplot(cca3_dist, ax=ax_dist, shade=True, bw=.2, color='red',
clip=(min(cca3_dist), max(cca3_dist)))
ax_dist.text(0.05, 0.35, "Observed", transform=ax_dist.transAxes,
name='Arial', size=14, color='blue')
ax_dist.text(0.05, 0.8, "Predicted", transform=ax_dist.transAxes,
name='Arial', size=14, color='red')
fig.text(0.5, 1.7, 'Canonical Correlation Analysis',
ha='center', va='center', transform=ax_dist.transAxes,
name='Arial', size=16)
# sns.despine(ax=ax_dist, left=True, bottom=True)
ax_dist.set_axis_off()
ax_hm = fig.add_subplot(gs[1, 0])
ax_cbar = fig.add_subplot(gs[1, 1])
cmap = plt.cm.get_cmap('plasma_r', 5)
sns.heatmap(combined, vmin=0, vmax=1,
linewidths=.2, square=True, cmap=cmap,
ax=ax_hm, cbar_ax=ax_cbar)
ax_hm.set_xlabel('')
ax_hm.set_ylabel('')
ax_hm.set_xticklabels([x[:2] for x in self.datasets], rotation=90)
ax_hm.set_yticklabels([x[:2] for x in self.datasets], rotation=0)
bottom, top = ax_hm.get_ylim()
ax_hm.set_ylim(bottom + 0.5, top - 0.5)
filename = os.path.join(
self.plot_path, "%s_%s_CCA_heatmap.png" % (cctype1, cctype2))
plt.savefig(filename, dpi=self.dpi)
plt.close('all')
# scatter plot comparing sign1 and sign4
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=self.dpi)
sns.scatterplot(x=cctype2, y=cctype1, data=df, ax=ax)
filename = os.path.join(
self.plot_path, "%s_%s_CCA_scatter.png" % (cctype1, cctype2))
plt.savefig(filename, dpi=self.dpi)
plt.close('all')
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=self.dpi)
for i in range(len(self.datasets)):
for j in range(len(self.datasets)):
dsx = self.datasets[i][:2]
dsy = self.datasets[j][:2]
if dsx == dsy:
continue
c1 = self.cc_colors(dsx)
c2 = self.cc_colors(dsy)
marker_style = dict(color=c1, markerfacecoloralt=c2,
markeredgecolor='w', markeredgewidth=0,
markersize=8, marker='o')
x = df[(df['from'] == dsx) & (df['to'] == dsy)][
cctype2].tolist()[0]
y = df[(df['from'] == dsx) & (df['to'] == dsy)][
cctype1].tolist()[0]
ax.plot(x, y, alpha=0.9, fillstyle='right', **marker_style)
ax.set_xlabel('%s CCA' % cctype2,
fontdict=dict(name='Arial', size=16))
ax.set_ylabel('%s CCA' % cctype1,
fontdict=dict(name='Arial', size=16))
ax.tick_params(labelsize=14)
sns.despine(ax=ax, trim=True)
filename = os.path.join(
self.plot_path, "%s_%s_CCA_comparison.png" % (cctype1, cctype2))
plt.savefig(filename, dpi=self.dpi)
plt.close('all')
# also plot as MDS projection
cca = df.pivot('from', 'to', cctype1).values
# make average of upper and lower triangular matrix
cca_avg = np.zeros_like(cca)
for i, j in itertools.product(range(25), range(25)):
cca_avg[i, j] = np.mean((cca[i, j], cca[j, i]))
proj = MDS(dissimilarity='precomputed', random_state=0)
coords = proj.fit_transform(1 - cca_avg)
def anno_dist(ax, idx1, idx2, coords):
p1 = coords[idx1]
p2 = coords[idx2]
coords_dist = dist = np.linalg.norm(p1 - p2)
dist = 1 - cca_avg[idx1, idx2]
angle = math.degrees(math.atan2(p1[1] - p2[1], p1[0] - p2[0]))
if angle < 0:
angle += 5
else:
angle -= 5
ax.annotate(
'', xy=p1, xycoords='data',
xytext=p2, textcoords='data',
zorder=1,
arrowprops=dict(
arrowstyle="-", color="0.2",
shrinkA=10, shrinkB=10,
patchA=None, patchB=None,
connectionstyle="bar,angle=%.2f,fraction=-%.2f" % (
angle, (1 - np.power(coords_dist, 1 / 3))),
),)
# ax.plot([p1[0],p2[0]], [p1[1],p2[1]])
midpoint = (p1[0] + p2[0]) / 2, (p1[1] + p2[1]) / 2
ax.annotate('%.2f' % dist, xy=midpoint,
xytext=midpoint, textcoords='data', xycoords='data',
rotation=angle, va='center', ha='center')
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=self.dpi)
colors = [self.cc_colors(ds) for ds in self.datasets]
markers = ["$\u2776$", "$\u2777$", "$\u2778$", "$\u2779$", "$\u277a$"]
for i, color in enumerate(colors):
ax.scatter(coords[i, 0], coords[i, 1], c=color, s=400,
edgecolor='', marker=markers[i % 5], zorder=2)
anno_dist(ax, 19, 15, coords)
# anno_dist(ax, 20, 22,coords)
anno_dist(ax, 1, 2, coords)
anno_dist(ax, 9, 20, coords)
filename = os.path.join(
self.plot_path, "%s_CCA_MDS.svg" % cctype1)
plt.axis('off')
plt.savefig(filename, dpi=self.dpi)
plt.close('all')
def all_sign_validations(self, sign_types=None, molsets=None, valset='moa'):
if sign_types is None:
sign_types = ['sign1', 'sign2', 'sign3']
if molsets is None:
molsets = ['reference', 'full']
#['atc_auc', 'atc_cov', 'atc_ks_d', 'atc_ks_p',
# 'moa_auc', 'moa_cov', 'moa_ks_d', 'moa_ks_p']
df = pd.DataFrame(
columns=['sign_molset', 'dataset', 'metric', 'value'])
for ds in self.datasets:
for molset in molsets:
for sign_type in sign_types:
try:
sign = self.cc.get_signature(sign_type, molset, ds)
except Exception as err:
self.__log.warning(
"Skippin %s: %s", str(sign), str(err))
continue
stat_file = os.path.join(
sign.stats_path, 'validation_stats.json')
if not os.path.isfile(stat_file):
continue
stats = json.load(open(stat_file, 'r'))
for k, v in stats.items():
row = {
'sign_molset': '_'.join([sign_type, molset]),
'dataset': ds,
'metric': k,
'value': float(v),
}
if 'cov' in k:
row['value'] /= 100.
df.loc[len(df)] = pd.Series(row)
print(df)
# sns.set_style("whitegrid")
fig, axes = plt.subplots(5, 5, sharey=True, sharex=True,
figsize=(15, 15), dpi=self.dpi)
for ds, ax in tqdm(zip(self.datasets, axes.flatten())):
ds_color = self.cc_palette([ds])[0]
sns.barplot(x='sign_molset', y='value',
data=df[(df.dataset == ds) & (
df.metric == '%s_auc' % valset)],
ax=ax, alpha=1,
color=ds_color)
sns.stripplot(x='sign_molset', y='value',
data=df[(df.dataset == ds) & (
df.metric == '%s_cov' % valset)],
size=10, marker="o", edgecolor='k', linewidth=2,
ax=ax, jitter=False, alpha=1, color='w')
ax.set_xlabel('')
ax.set_ylabel('')
# ax.set_xticklabels([ds])
for idx, p in enumerate(ax.patches):
if "%.2f" % p.get_height() == 'nan':
continue
val = p.get_height()
if val > 1.0:
val = "%.1f" % p.get_height()
else:
val = "%.2f" % p.get_height()
ax.annotate(val,
(p.get_x() + p.get_width() / 2., 0),
ha='center', va='center', fontsize=11,
color='k', rotation=90, xytext=(0, 20),
textcoords='offset points')
if ds.startswith('E'):
for label in ax.get_xticklabels():
label.set_ha("right")
label.set_rotation(45)
ax.set_ylim(0, 1)
ax.grid(axis='y', linestyle="-",
color=ds_color, lw=0.3)
ax.spines["bottom"].set_color(ds_color)
ax.spines["top"].set_color(ds_color)
ax.spines["right"].set_color(ds_color)
ax.spines["left"].set_color(ds_color)
plt.tight_layout()
filename = os.path.join(self.plot_path,
"sign_validation_%s_%s_%s.png"
% (valset, '_'.join(sign_types),
'_'.join(molsets)))
plt.savefig(filename, dpi=self.dpi)
plt.close('all')
"""
for metric in metrics:
# sns.set_style("whitegrid")
fig, axes = plt.subplots(5, 5, sharey=True, sharex=False,
figsize=(10, 10), dpi=self.dpi)
for ds, ax in tqdm(zip(self.cc.datasets, axes.flatten())):
cdf = df[df.dataset == ds][metric]
s1v = cdf[df.sign_type == 'sign1'].iloc[0]
s2v = cdf[df.sign_type == 'sign2'].iloc[0]
s3v = cdf[df.sign_type == 'sign3'].iloc[0]
sns.barplot(x=['s2-s1', 's3-s1'], y=[100 * (s2v - s1v) / s1v, 100 * (s3v - s1v) / s1v],
ax=ax, alpha=.8, color=self.cc_palette([ds])[0])
ax.set_xlabel('')
ax.set_ylabel('')
# ax.set_xticklabels([ds])
for idx, p in enumerate(ax.patches):
if "%.2f" % p.get_height() == 'nan':
continue
val = p.get_height()
if val > 1.0:
val = "%.1f" % p.get_height()
else:
val = "%.2f" % p.get_height()
ax.annotate(val,
(p.get_x() + p.get_width() / 2., 0),
ha='center', va='center', fontsize=11,
color='k', rotation=90, xytext=(0, 20),
textcoords='offset points')
ax.grid(axis='y', linestyle="-",
color=self.cc_palette([ds])[0], lw=0.3)
ax.spines["bottom"].set_color(self.cc_palette([ds])[0])
ax.spines["top"].set_color(self.cc_palette([ds])[0])
ax.spines["right"].set_color(self.cc_palette([ds])[0])
ax.spines["left"].set_color(self.cc_palette([ds])[0])
plt.tight_layout()
filename = os.path.join(self.plot_path,
"sign_validation_deltas_%s.png" % metric)
plt.savefig(filename, dpi=self.dpi)
plt.close('all')
"""
def cctype_similarity_search(self, cctype='sign4', cctype_ref='sign1',
limit=10000, limit_neig=50000, sign_cap=200000):
from chemicalchecker.core.signature_data import DataSignature
import faiss
def mask_exclude(idxs, x1_data):
x1_data_transf = np.copy(x1_data)
for idx in idxs:
# set current space to nan
col_slice = slice(idx * 128, (idx + 1) * 128)
x1_data_transf[:, col_slice] = np.nan
# drop rows that only contain NaNs
return x1_data_transf
def mask_keep(idxs, x1_data):
# we will fill an array of NaN with values we want to keep
x1_data_transf = np.zeros_like(x1_data, dtype=np.float32) * np.nan
for idx in idxs:
# copy column from original data
col_slice = slice(idx * 128, (idx + 1) * 128)
x1_data_transf[:, col_slice] = x1_data[:, col_slice]
# keep rows containing at least one not-NaN value
return x1_data_transf
def background_distances(matrix, metric, sample_pairs=100000, unflat=True,
memory_safe=False):
PVALRANGES = np.array([0, 0.001, 0.01, 0.1] +
list(np.arange(1, 100)) + [100]) / 100.
metric_fn = eval(metric)
if matrix.shape[0]**2 < sample_pairs:
print("Requested more pairs then possible combinations")
sample_pairs = matrix.shape[0]**2 - matrix.shape[0]
bg = list()
done = set()
tries = 1e6
tr = 0
while len(bg) < sample_pairs and tr < tries:
tr += 1
i = np.random.randint(0, matrix.shape[0] - 1)
j = np.random.randint(i + 1, matrix.shape[0])
if (i, j) not in done:
dist = metric_fn(matrix[i], matrix[j])
bg.append(dist)
done.add((i, j))
# pavalues as percentiles
i = 0
PVALS = [(0, 0., i)] # DISTANCE, RANK, INTEGER
i += 1
percs = PVALRANGES[1:-1] * 100
for perc in percs:
PVALS += [(np.percentile(bg, perc), perc / 100., i)]
i += 1
PVALS += [(np.max(bg), 1., i)]
# prepare returned dictionary
bg_distances = dict()
if not unflat:
bg_distances["distance"] = np.array([p[0] for p in PVALS])
bg_distances["pvalue"] = np.array([p[1] for p in PVALS])
else:
# Remove flat regions whenever we observe them
dists = [p[0] for p in PVALS]
pvals = np.array([p[1] for p in PVALS])
top_pval = np.min(
[1. / sample_pairs, np.min(pvals[pvals > 0]) / 10.])
pvals[pvals == 0] = top_pval
pvals = np.log10(pvals)
dists_ = sorted(set(dists))
pvals_ = [pvals[dists.index(d)] for d in dists_]
dists = np.interp(pvals, pvals_, dists_)
thrs = [(dists[t], PVALS[t][1], PVALS[t][2])
for t in range(len(PVALS))]
bg_distances["distance"] = np.array([p[0] for p in thrs])
bg_distances["pvalue"] = np.array([p[1] for p in thrs])
return bg_distances
def jaccard_similarity(n1, n2):
"""Compute Jaccard similarity."""
s1 = set(n1)
s2 = set(n2)
inter = len(set.intersection(s1, s2))
uni = len(set.union(s1, s2))
return inter / float(uni)
def overlap(n1, n2):
"""Compute Overlap."""
s1 = set(n1)
s2 = set(n2)
uni = len(set.intersection(s1, s2))
return float(uni) / len(s1)
outfile = os.path.join(
self.plot_path, '%s_simsearch.pkl' % cctype)
if not os.path.isfile(outfile):
df = pd.DataFrame(
columns=['dataset', 'nthr', 'cthr', 'dthr',
'log-odds-ratio', 'logodds_err'])
# [(5,'B1.001'),(15,'D1.001')]:
for ds_idx, ds in list(enumerate(self.datasets)):
sign = self.cc.get_signature(cctype, 'full', ds)
signref = self.cc.get_signature(cctype_ref, 'full', ds)
# get siamese train/test inks
traintest_file = os.path.join(
sign.model_path, 'traintest_eval.h5')
tt = DataSignature(traintest_file)
train_inks = tt.get_h5_dataset('keys_train')[:limit_neig]
train_inks = np.sort(train_inks)
train_mask = np.isin(
list(signref.keys), list(train_inks), assume_unique=True)
test_inks = tt.get_h5_dataset('keys_test')[:limit]
test_inks = np.sort(test_inks)
test_mask = np.isin(
list(signref.keys), list(test_inks), assume_unique=True)
# get train/test sign1
slice_cap = slice(0, sign_cap)
signref_V = signref.get_h5_dataset('V', mask=slice_cap)
train_signref = signref_V[train_mask[slice_cap]]
test_signref = signref_V[test_mask[slice_cap]]
# predict train/test sign4
print('REFERENCE SIGNATURE:', cctype_ref)
print('train_signref', train_signref.shape)
print('test_signref', test_signref.shape)
# predict train/test sign4
input_file = DataSignature(
os.path.join(sign.model_path, 'train.h5'))
input_x = input_file.get_h5_dataset('x', mask=slice_cap)
train_input = input_x[train_mask[slice_cap]]
test_input = input_x[test_mask[slice_cap]]
print('PREDICTION INPUT:', traintest_file)
print('train_input', train_input.shape)
print('test_input', test_input.shape)
# laod eval siamese predictor
predict_fn = sign.get_predict_fn(
smiles=False, model='siamese_eval')
train_sign = predict_fn(train_input)
test_sign = predict_fn(mask_exclude([ds_idx], test_input))
print('PREDICTION OUTPUT:', traintest_file)
print('train_sign', train_sign.shape)
print('test_sign', test_sign.shape)
# get confidence for test sign4
conf_mask = np.isin(
list(sign.keys), list(test_inks), assume_unique=True)
test_confidence = sign.get_h5_dataset(
'confidence')[conf_mask][:len(test_sign)]
print('test_confidence', test_confidence.shape)
# make train sign1 neig
train_signref_neig = faiss.IndexFlatL2(train_signref.shape[1])
train_signref_neig.add(train_signref.astype(np.float32))
# make train sign4 neig
train_sign_neig = faiss.IndexFlatL2(train_sign.shape[1])
train_sign_neig.add(train_sign.astype(np.float32))
# find test sign1 neighbors
signref_neig_dist, signref_neig_idx = train_signref_neig.search(
test_signref.astype(np.float32), 100)
signref_neig_dist = np.sqrt(signref_neig_dist)
# find test sign4 neighbors
sign_neig_dist, sign_neig_idx = train_sign_neig.search(
test_sign.astype(np.float32), 100)
sign_neig_dist = np.sqrt(sign_neig_dist)
# check various thresholds
# get sign ref background distances thresholds
back = background_distances(train_signref, 'euclidean')
dthrs = list()
dthrs.append((back['distance'][1], back['pvalue'][1]))
dthrs.append((back['distance'][5], back['pvalue'][5]))
dthrs.append((back['distance'][-1], back['pvalue'][-1]))
nthrs = [10] # top neighbors
cthrs = [0, .5, .8] # confidence
for nthr, cthr, dthr in itertools.product(nthrs, cthrs, dthrs):
hits = 0
rnd_hits = collections.defaultdict(int)
for row in tqdm(range(signref_neig_dist.shape[0])):
# limit original space neighbors by distance
d_mask = signref_neig_dist < dthr[0]
# limit signature molecule by confidence
c_mask = test_confidence > cthr
if not c_mask[row]:
continue
# select top n valid neighbors
ref_neig = signref_neig_idx[row][d_mask[row]][:nthr]
# if no neighbors we skip the molecule
if len(ref_neig) == 0:
continue
# compare to sign neighbors
sign_neig = sign_neig_idx[row][:nthr]
hits += len(set(ref_neig).intersection(sign_neig))
# compute random background
rnd_idxs = np.arange(train_signref_neig.ntotal)
for i in range(1000):
rnd_neig = np.random.choice(rnd_idxs, nthr)
rnd_hits[
i] += len(set(ref_neig).intersection(rnd_neig))
rnd_hits = [v for k, v in rnd_hits.items()]
rnd_mean = np.mean(rnd_hits)
rnd_std = np.std(rnd_hits)
logodds = np.log2(hits / rnd_mean)
logodds_std = np.log2(hits / (rnd_mean + rnd_std))
logodds_err = abs(logodds - logodds_std)
print(nthr, cthr, dthr, 'log-odds-ratio', logodds)
df.loc[len(df)] = pd.Series({
'dataset': ds,
'nthr': nthr,
'cthr': cthr,
'dthr': dthr[1],
'log-odds-ratio': logodds,
'logodds_err': logodds_err,
})
df.to_pickle(outfile)
df = pd.read_pickle(outfile)
nthrs = df['nthr'].unique()
dthrs = df['dthr'].unique()
max_odds = df['log-odds-ratio'].describe()['75%'] * 1.5
min_odds = 0
for nthr, dthr in itertools.product(nthrs, dthrs):
fdf = df[(df.nthr == nthr) & (df.dthr == dthr)]
if len(fdf) == 0:
continue
fig = plt.figure(constrained_layout=True, figsize=(5, 10))
gs = fig.add_gridspec(5, 5, wspace=0.1, hspace=0.1)
plt.subplots_adjust(left=0.14, right=.92, bottom=0.06, top=.95)
axes = list()
for row, col in itertools.product(range(5), range(5)):
axes.append(fig.add_subplot(gs[row, col]))
fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top=False,
bottom=False, left=False, right=False)
plt.grid(False)
plt.xlabel("Confidence", size=18)
plt.ylabel("Log-Odds Ratio", size=18)
for ds, ax in zip(self.datasets[:], axes):
dsdf = fdf[fdf.dataset == ds]
if len(dsdf) == 0:
continue
ax.errorbar([1, 2, 3], dsdf['log-odds-ratio'],
yerr=dsdf['logodds_err'], fmt='-',
color=self.cc_colors(ds, 0),
ecolor=self.cc_colors(ds, 1),
elinewidth=2, capsize=0)
ax.set_ylabel('')
ax.set_xlabel('')
ax.set_ylim(min_odds, max_odds)
ax.set_xlim(0.5, 3.5)
x = [str(i) for i in dsdf['cthr'].unique()]
ax.xaxis.set_ticklabels(x)
# axis ticks
if ds[:2] == 'E1':
# set the alignment for outer ticklabels
ticklabels = ax.get_yticklabels()
ticklabels[0].set_va("bottom")
ticklabels[-1].set_va("top")
elif ds[1] == '1':
# set the alignment for outer ticklabels
ticklabels = ax.get_yticklabels()
ticklabels[0].set_va("bottom")
ticklabels[-1].set_va("top")
ax.xaxis.set_ticklabels([])
elif ds[0] == 'E':
ax.yaxis.set_ticklabels([])
else:
ax.xaxis.set_ticklabels([])
ax.yaxis.set_ticklabels([])
# axis labels
if ds[0] == 'A':
ax.set_xlabel(ds[1], fontsize=18, labelpad=15)
ax.xaxis.set_label_position('top')
if ds[1] == '5':
ax.set_ylabel(ds[0], fontsize=18, rotation=0, va='center',
labelpad=15)
ax.yaxis.set_label_position('right')
outfile = os.path.join(
self.plot_path, 'simsearch_%s_%s.png' % (nthr, dthr))
print(outfile)
plt.savefig(outfile, dpi=self.dpi)
plt.close('all')
def diagnosis_projections(self, cctype):
fig = plt.figure(constrained_layout=True, figsize=(8, 8))
gs = fig.add_gridspec(5, 5, wspace=0.1, hspace=0.1)
plt.subplots_adjust(left=0.08, right=.95, bottom=0.08, top=.95)
axes = list()
for row, col in itertools.product(range(5), range(5)):
axes.append(fig.add_subplot(gs[row, col]))
fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top=False,
bottom=False, left=False, right=False)
plt.grid(False)
plt.xlabel("Dim 1", size=18)
plt.ylabel("Dim 2", size=18)
for ds, ax in zip(self.datasets[:], axes):
sign = self.cc.get_signature(cctype, 'full', ds)
diag = DiagnosisPlot(self.cc, sign)
diag.projection(ax=ax)
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_title('')
# axis ticks
if ds[:2] == 'E1':
# set the alignment for outer ticklabels
ticklabels = ax.get_yticklabels()
ticklabels[0].set_va("bottom")
ticklabels[-1].set_va("top")
elif ds[1] == '1':
# set the alignment for outer ticklabels
ticklabels = ax.get_yticklabels()
ticklabels[0].set_va("bottom")
ticklabels[-1].set_va("top")
ax.xaxis.set_ticklabels([])
elif ds[0] == 'E':
ax.yaxis.set_ticklabels([])
else:
ax.xaxis.set_ticklabels([])
ax.yaxis.set_ticklabels([])
# axis labels
if ds[0] == 'A':
ax.set_xlabel(ds[1], fontsize=18, labelpad=15)
ax.xaxis.set_label_position('top')
if ds[1] == '5':
ax.set_ylabel(ds[0], fontsize=18, rotation=0, va='center',
labelpad=15)
ax.yaxis.set_label_position('right')
outfile = os.path.join(
self.plot_path, 'diagnosis_projections.png')
print(outfile)
plt.savefig(outfile, dpi=self.dpi)
plt.close('all')
def diagnosis_confidences_projection(self, cctype):
fig = plt.figure(constrained_layout=True, figsize=(7, 7))
gs = fig.add_gridspec(5, 5, wspace=0.1, hspace=0.1)
plt.subplots_adjust(left=0.08, right=.95, bottom=0.08, top=.95)
axes = list()
for row, col in itertools.product(range(5), range(5)):
axes.append(fig.add_subplot(gs[row, col]))
fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top=False,
bottom=False, left=False, right=False)
plt.grid(False)
plt.xlabel("Dim 1", size=18)
plt.ylabel("Dim 2", size=18)
for ds, ax in zip(self.datasets[:], axes):
sign = self.cc.get_signature(cctype, 'full', ds)
diag = DiagnosisPlot(self.cc, sign)
diag.confidences_projection(ax=ax)
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_title('')
# axis ticks
if ds[:2] == 'E1':
# set the alignment for outer ticklabels
ticklabels = ax.get_yticklabels()
ticklabels[0].set_va("bottom")
ticklabels[-1].set_va("top")
elif ds[1] == '1':
# set the alignment for outer ticklabels
ticklabels = ax.get_yticklabels()
ticklabels[0].set_va("bottom")
ticklabels[-1].set_va("top")
ax.xaxis.set_ticklabels([])
elif ds[0] == 'E':
ax.yaxis.set_ticklabels([])
else:
ax.xaxis.set_ticklabels([])
ax.yaxis.set_ticklabels([])
# axis labels
if ds[0] == 'A':
ax.set_xlabel(ds[1], fontsize=18, labelpad=15)
ax.xaxis.set_label_position('top')
if ds[1] == '5':
ax.set_ylabel(ds[0], fontsize=18, rotation=0, va='center',
labelpad=15)
ax.yaxis.set_label_position('right')
outfile = os.path.join(
self.plot_path, 'diagnosis_confidences_projection.png')
print(outfile)
plt.savefig(outfile, dpi=self.dpi)
plt.close('all')
def diagnosis_euclidean_distances(self, cctype):
fig = plt.figure(constrained_layout=True, figsize=(10, 10))
gs = fig.add_gridspec(5, 5, wspace=0.1, hspace=0.1)
plt.subplots_adjust(left=0.08, right=.95, bottom=0.08, top=.95)
axes = list()
for row, col in itertools.product(range(5), range(5)):
axes.append(fig.add_subplot(gs[row, col]))
fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top=False,
bottom=False, left=False, right=False)
plt.grid(False)
plt.xlabel("Euclidean Distance", size=18)
plt.ylabel("Density", size=18)
for ds, ax in zip(self.datasets[:], axes):
sign = self.cc.get_signature(cctype, 'full', ds)
diag = DiagnosisPlot(self.cc, sign)
diag.euclidean_distances(ax=ax)
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_title('')
ax.set_xlim(0, 2)
ax.set_ylim(0, 5)
# axis ticks
if ds[:2] == 'E1':
# set the alignment for outer ticklabels
ticklabels = ax.get_yticklabels()
ticklabels[0].set_va("bottom")
ticklabels[-1].set_va("top")
elif ds[1] == '1':
# set the alignment for outer ticklabels
ticklabels = ax.get_yticklabels()
ticklabels[0].set_va("bottom")
ticklabels[-1].set_va("top")
ax.xaxis.set_ticklabels([])
elif ds[0] == 'E':
ax.yaxis.set_ticklabels([])
else:
ax.xaxis.set_ticklabels([])
ax.yaxis.set_ticklabels([])
# axis labels
if ds[0] == 'A':
ax.set_xlabel(ds[1], fontsize=18, labelpad=15)
ax.xaxis.set_label_position('top')
if ds[1] == '5':
ax.set_ylabel(ds[0], fontsize=18, rotation=0, va='center',
labelpad=15)
ax.yaxis.set_label_position('right')
outfile = os.path.join(
self.plot_path, 'diagnosis_moa.png')
print(outfile)
plt.savefig(outfile, dpi=self.dpi)
plt.close('all')
def sign3_neig2_jaccard(self, limit=2000):
df = pd.DataFrame(columns=['dataset', 'confidence', 'jaccard'])
for ds in self.datasets[:]:
s2 = self.cc.get_signature('sign2', 'full', ds)
n2 = self.cc.get_signature('neig2', 'reference', ds)
s3 = self.cc.get_signature('sign3', 'full', ds)
# decide sample molecules
inks = s2.keys
s3_conf = s3.get_h5_dataset('confidence')
s3_intensity = s3.get_h5_dataset('intensity_norm')
s2_mask = np.isin(list(s3.keys), list(s2.keys), assume_unique=True)
s2_conf = s3_conf[s2_mask]
s2_intensity = s3_intensity[s2_mask]
high_conf = s2_conf > .9
inks_high = inks[high_conf][:limit]
# get sign2 and sign3
_, s2_data = s2.get_vectors(inks_high)
_, s3_data = s3.get_vectors(inks_high)
_, s3_data_conf = s3.get_vectors(
inks_high, dataset_name='confidence')
s3_data_conf = s3_data_conf.flatten()
# get idxs of nearest neighbors of s2 and s3
k = 10
n2_s2 = n2.get_kth_nearest(
list(s2_data), k=k, distances=False, keys=False)
n2_s3 = n2.get_kth_nearest(
list(s3_data), k=k, distances=False, keys=False)
jacc = n2.jaccard_similarity(n2_s2['indices'], n2_s3['indices'])
df = df.append(pd.DataFrame(
{'dataset': ds, 'confidence': 'high', 'jaccard': jacc}), ignore_index=True)
print('***** HIGH', len(jacc), np.mean(jacc),
stats.spearmanr(jacc, s3_data_conf))
all_conf = np.ones_like(s2_conf).astype(bool)
if len(jacc) < limit:
new_limit = len(jacc)
else:
new_limit = limit
inks_all = inks[all_conf][:new_limit]
# get sign2 and sign3
_, s2_data = s2.get_vectors(inks_all)
_, s3_data = s3.get_vectors(inks_all)
_, s3_data_conf = s3.get_vectors(
inks_all, dataset_name='confidence')
s3_data_conf = s3_data_conf.flatten()
# get idxs of nearest neighbors of s2 and s3
k = 10
n2_s2 = n2.get_kth_nearest(
list(s2_data), k=k, distances=False, keys=False)
n2_s3 = n2.get_kth_nearest(
list(s3_data), k=k, distances=False, keys=False)
jacc = n2.jaccard_similarity(n2_s2['indices'], n2_s3['indices'])
df = df.append(pd.DataFrame(
{'dataset': ds, 'confidence': 'all', 'jaccard': jacc}), ignore_index=True)
print('***** ALL', len(jacc), np.mean(jacc),
stats.spearmanr(jacc, s3_data_conf))
# sns.set_style("whitegrid")
f, axes = plt.subplots(5, 5, figsize=(4, 6), sharex=True, sharey='row')
plt.subplots_adjust(left=0.16, right=0.99, bottom=0.12, top=0.99,
wspace=.08, hspace=.1)
for ds, ax in zip(self.datasets[:], axes.flat):
sns.barplot(data=df[df.dataset == ds], y='jaccard', x='confidence',
order=['all', 'high'],
ax=ax,
palette=[self.cc_colors(ds, 2), self.cc_colors(ds, 0)])
ax.set_ylabel('')
ax.set_xlabel('')
ax.set_ylim(0, 1)
if ds[:2] == 'E1':
sns.despine(ax=ax, offset=3, trim=True)
ax.set_yticks([0, 1])
ax.set_yticklabels(['0', '1'])
elif ds[1] == '1':
sns.despine(ax=ax, bottom=True, offset=3, trim=True)
ax.tick_params(bottom=False)
ax.set_yticks([0, 1])
ax.set_yticklabels(['0', '1'])
elif ds[0] == 'E':
sns.despine(ax=ax, left=True, offset=3, trim=True)
ax.tick_params(left=False)
ax.set_xticks([0, 1])
ax.set_xticklabels(['All', 'High'])
else:
sns.despine(ax=ax, bottom=True, left=True, offset=3, trim=True)
ax.tick_params(bottom=False, left=False)
f.text(0.5, 0.04, 'Confidence', ha='center', va='center')
f.text(0.06, 0.5, 'Jaccard Similarity', ha='center',
va='center', rotation='vertical')
outfile = os.path.join(
self.plot_path, 'sign3_neig2_jaccard.png')
plt.savefig(outfile, dpi=self.dpi)
plt.close('all')
def sign_property_distribution(self, cctype, molset, prop, xlim=None,
ylim=None, known_delta=True):
# sns.set_style("whitegrid")
# sns.set_style({'font.family': 'sans-serif', 'font.serif': ['Arial']})
fig = plt.figure(constrained_layout=True, figsize=(6, 6))
gs = fig.add_gridspec(5, 5, wspace=0.1, hspace=0.1)
plt.subplots_adjust(left=0.1, right=.95, bottom=0.1, top=.95)
axes = list()
for row, col in itertools.product(range(5), range(5)):
axes.append(fig.add_subplot(gs[row, col]))
fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top=False,
bottom=False, left=False, right=False)
plt.grid(False)
plt.xlabel(prop.capitalize(), size=18)
plt.ylabel("Molecules", size=18)
for ds, ax in zip(self.datasets, axes):
try:
sign = self.cc.get_signature(cctype, molset, ds)
except Exception:
continue
if not os.path.isfile(sign.data_path):
continue
# decide sample molecules
prop_data = sign.get_h5_dataset(prop)
if known_delta:
known_mask = sign.get_h5_dataset('known')
plot_data = [prop_data[known_mask], prop_data[~known_mask]]
colors = [self.cc_colors(ds, 0), self.cc_colors(ds, 2)]
print(ds, prop, np.min(plot_data[0]), np.max(plot_data[0]))
print(ds, prop, np.min(plot_data[1]), np.max(plot_data[1]))
ax.hist(prop_data, color=self.cc_colors(ds, 2),
histtype='step', fill=True,
density=False, bins=10, log=True,
range=xlim, alpha=.9, stacked=True)
ax.hist(plot_data[0], color=self.cc_colors(ds, 0),
histtype='step', fill=True,
density=False, bins=10, log=True,
range=xlim, alpha=.9, stacked=True)
else:
plot_data = [prop_data]
colors = [self.cc_colors(ds, 0)]
print(ds, prop, np.min(plot_data[0]), np.max(plot_data[0]))
ax.hist(plot_data, color=colors,
histtype='step',
density=False, bins=20, log=True,
range=xlim, alpha=.9, stacked=True)
# set limits
if xlim:
ax.set_xlim(xlim)
if ylim:
ax.set_ylim(ylim)
# ax.set_yscale('log')
ax.set_yticks([1e2, 1e4, 1e6])
xmin, xmax = xlim
ax.set_xticks(np.linspace(xmin, xmax, 5))
ticks_str = ['%.1f' % x for x in np.linspace(xmin, xmax, 5)]
ticks_str[0] = '%i' % xmin
ticks_str[-1] = '%i' % xmax
ax.set_xticklabels(ticks_str)
# axis ticks
if ds[:2] == 'E1':
# set the alignment for outer ticklabels
ticklabels = ax.get_yticklabels()
ticklabels[0].set_va("bottom")
ticklabels[-1].set_va("top")
elif ds[1] == '1':
# set the alignment for outer ticklabels
ticklabels = ax.get_yticklabels()
ticklabels[0].set_va("bottom")
ticklabels[-1].set_va("top")
ax.xaxis.set_ticklabels([])
elif ds[0] == 'E':
ax.yaxis.set_ticklabels([])
else:
ax.xaxis.set_ticklabels([])
ax.yaxis.set_ticklabels([])
# axis labels
if ds[0] == 'A':
ax.set_xlabel(ds[1], fontsize=16, labelpad=8)
ax.xaxis.set_label_position('top')
if ds[1] == '5':
ax.set_ylabel(ds[0], fontsize=16, rotation=0, va='center',
labelpad=12)
ax.yaxis.set_label_position('right')
# plt.minorticks_off()
outfile = os.path.join(
self.plot_path, '%s.png' % '_'.join([cctype, molset, prop]))
plt.savefig(outfile, dpi=self.dpi)
if self.svg:
plt.savefig(outfile.replace('.png', '.svg'), dpi=self.dpi)
plt.close('all')
def sign3_error_predictors(self, sign2_universe_presence):
from chemicalchecker.tool.adanet import AdaNet
df = pd.DataFrame(columns=['dataset', 'sign2_count', 'algo', 'mse',
'pearson', 'mae'])
errors = dict()
for ds in self.cc.datasets:
# get real data
s3 = self.cc.get_signature('sign3', 'full', ds)
s2 = self.cc.get_signature('sign2', 'full', ds)
s2_idxs = np.argwhere(
np.isin(list(s3.keys), list(s2.keys), assume_unique=True)).flatten()
ss2 = s2[:100000]
s2_idxs = s2_idxs[:100000]
ss3 = s3[:][s2_idxs]
with h5py.File(sign2_universe_presence, 'r') as fh:
x_real = fh['V'][:][s2_idxs]
y_real = np.log10(np.expand_dims(
np.mean(((ss2 - ss3)**2), axis=1), 1))
row = {
'dataset': ds,
'sign2_count': len(x_real)
}
# load predictors
eval_err_path = os.path.join(s3.model_path, 'adanet_error_eval')
error_pred_fn = AdaNet.predict_fn(
os.path.join(eval_err_path, 'savedmodel'))
lr = pickle.load(
open(os.path.join(eval_err_path, 'LinearRegression.pkl')))
rf = pickle.load(
open(os.path.join(eval_err_path, 'RandomForest.pkl')))
predictions = {
'NeuralNetwork': AdaNet.predict(x_real, error_pred_fn)[:, 0],
'LinearRegression': lr.predict(x_real),
'RandomForest': rf.predict(x_real)
}
y_flat = y_real[:, 0]
errors[ds] = list()
for algo, pred in predictions.items():
errors[ds].append((algo, y_flat - pred))
row['algo'] = algo
row['mse'] = np.mean((y_flat - pred)**2)
row['mae'] = np.mean(y_flat - pred)
row['pearson'] = np.corrcoef(y_flat, pred)[0][1]
df.loc[len(df)] = pd.Series(row)
print(row)
df['algo'] = df.algo.map(
{'LinearRegression': 'LR', 'RandomForest': 'RF', 'NeuralNetwork': 'NN'})
# sns.set_style("whitegrid")
sns.set_context("talk")
fig, axes = plt.subplots(5, 5, sharey=True, sharex=True,
figsize=(10, 15), dpi=self.dpi)
for ds, ax in tqdm(zip(self.datasets, axes.flatten())):
sns.barplot(x=[x[0] for x in errors[ds]], y=[x[1]
for x in errors[ds]], ax=ax)
# ax.set_ylim(-0.15, .15)
# ax.get_legend().remove()
ax.set_xlabel('')
ax.set_ylabel('')
# ax.set_xticklabels()
ax.grid(axis='y', linestyle="-",
color=self.cc_palette([ds])[0], lw=0.3)
ax.spines["bottom"].set_color(self.cc_palette([ds])[0])
ax.spines["top"].set_color(self.cc_palette([ds])[0])
ax.spines["right"].set_color(self.cc_palette([ds])[0])
ax.spines["left"].set_color(self.cc_palette([ds])[0])
# plt.tight_layout()
filename = os.path.join(self.plot_path, "sign3_error_predictors.png")
plt.savefig(filename, dpi=self.dpi)
plt.close('all')
return df
def sign3_confidence_distribution(self):
# sns.set_style("whitegrid")
fig, axes = plt.subplots(5, 5, sharey=True, sharex=True,
figsize=(10, 10), dpi=self.dpi)
for ds, ax in zip(self.datasets, axes.flat):
sign3 = self.cc.get_signature('sign3', 'full', ds)
error_file = os.path.join(sign3.model_path, 'error.h5')
with h5py.File(error_file, "r") as hf:
keys = hf['keys'][:]
train_log_mse = hf['log_mse_consensus'][:]
train_log_mse_real = hf['log_mse'][:]
# test is anything that wasn't in the confidence distribution
test_keys = list(sign3.unique_keys - set(keys))
test_idxs = np.where(
np.isin(list(sign3.keys), test_keys, assume_unique=True))[0]
train_idxs = np.where(
~np.isin(list(sign3.keys), test_keys, assume_unique=True))[0]
# decide sample molecules
s3_stddev = sign3.get_h5_dataset('stddev_norm')
s3_intensity = sign3.get_h5_dataset('intensity_norm')
s3_experr = sign3.get_h5_dataset('exp_error_norm')
s3_conf = (s3_intensity * (1 - s3_stddev))**(1 / 2.)
s3_conf_new = (s3_intensity * (1 - s3_stddev)
* (1 - s3_experr))**(1 / 3.)
# s3_conf = s3.get_h5_dataset('confidence')
pc_inte = abs(stats.pearsonr(
s3_intensity[train_idxs], train_log_mse)[0])
pc_stddev = abs(stats.pearsonr(
s3_stddev[train_idxs], train_log_mse)[0])
pc_experr = abs(stats.pearsonr(
s3_experr[train_idxs], train_log_mse)[0])
s3_conf_new_w = np.average(
[s3_intensity, (1 - s3_stddev), (1 - s3_experr)], axis=0, weights=[1, 1, pc_experr])
df = pd.DataFrame({'train': True, 'confidence': 0[
train_idxs], 'kind': 'old'})
df = df.append(pd.DataFrame({'train': False, 'confidence': s3_conf[
test_idxs], 'kind': 'old'}), ignore_index=True)
df = df.append(pd.DataFrame({'train': True, 'confidence': s3_conf_new[
train_idxs], 'kind': 'new'}), ignore_index=True)
df = df.append(pd.DataFrame({'train': False, 'confidence': s3_conf_new[
test_idxs], 'kind': 'new'}), ignore_index=True)
df = df.append(pd.DataFrame({'train': True, 'confidence': s3_conf_new_w[
train_idxs], 'kind': 'test'}), ignore_index=True)
df = df.append(pd.DataFrame({'train': False, 'confidence': s3_conf_new_w[
test_idxs], 'kind': 'test'}), ignore_index=True)
# get idx of nearest neighbors of s2
sns.boxplot(data=df, y='confidence', x='kind', hue='train',
order=['old', 'new', 'test'],
hue_order=[True, False],
color=self.cc_palette([ds])[0], ax=ax,)
# ax.set_yscale('log')
ax.get_legend().remove()
ax.set_ylim(0, 1)
ax.grid(axis='y', linestyle="-",
color=self.cc_palette([ds])[0], lw=0.3)
ax.spines["bottom"].set_color(self.cc_palette([ds])[0])
ax.spines["top"].set_color(self.cc_palette([ds])[0])
ax.spines["right"].set_color(self.cc_palette([ds])[0])
ax.spines["left"].set_color(self.cc_palette([ds])[0])
outfile = os.path.join(
self.plot_path, 'confidence_distribution_new.png')
plt.savefig(outfile, dpi=self.dpi)
if self.svg:
plt.savefig(outfile.replace('.png', '.svg'), dpi=self.dpi)
plt.close('all')
def sign3_test_error_distribution(self):
from chemicalchecker.tool.adanet import AdaNet
from chemicalchecker.util.splitter import Traintest
from chemicalchecker.core.sign3 import subsample_x_only
def row_wise_correlation(X, Y):
var1 = (X.T - np.mean(X, axis=1)).T
var2 = (Y.T - np.mean(Y, axis=1)).T
cov = np.mean(var1 * var2, axis=1)
return cov / (np.std(X, axis=1) * np.std(Y, axis=1))
def mask_exclude(idxs, x_data, y_data):
x_data_transf = np.copy(x_data)
for idx in idxs:
# set current space to nan
col_slice = slice(idx * 128, (idx + 1) * 128)
x_data_transf[:, col_slice] = np.nan
# drop rows that only contain NaNs
not_nan = np.isfinite(x_data_transf).any(axis=1)
x_data_transf = x_data_transf[not_nan]
y_data_transf = y_data[not_nan]
return x_data_transf, y_data_transf
# sns.set_style("whitegrid")
fig, axes = plt.subplots(5, 5, sharey=True, sharex=True,
figsize=(10, 10), dpi=self.dpi)
all_dss = list(self.datasets)
for ds, ax in zip(all_dss, axes.flat):
s3 = self.cc.get_signature('sign3', 'full', ds)
# filter most correlated spaces
ds_corr = s3.get_h5_dataset('datasets_correlation')
corr_spaces = np.array(list(self.cc.datasets))[
ds_corr > .9].tolist()
self.__log.info('masking %s' % str(corr_spaces))
if ds in corr_spaces:
dss = corr_spaces
else:
dss = [ds]
idxs = [all_dss.index(d) for d in dss]
mask_fn = partial(mask_exclude, idxs)
# load DNN
predict_fn = AdaNet.predict_fn(os.path.join(
s3.model_path, 'adanet_eval', 'savedmodel'))
# load X Y data
traintest_file = os.path.join(s3.model_path, 'traintest.h5')
traintest = Traintest(traintest_file, 'test')
traintest.open()
x_test, y_test = traintest.get_xy(0, 1000)
y_pred_nomask = AdaNet.predict(x_test, predict_fn)
x_test, y_test = mask_fn(x_test, y_test)
traintest.close()
# get the predictions and consensus
self.__log.info('prediction consensus 5')
y_pred, samples = AdaNet.predict(x_test, predict_fn,
subsample_x_only,
consensus=True,
samples=5)
y_pred_consensus = np.mean(samples, axis=1)
self.__log.info('prediction consensus 10')
y_pred, samples = AdaNet.predict(x_test, predict_fn,
subsample_x_only,
consensus=True,
samples=10)
y_pred_consensus_10 = np.mean(samples, axis=1)
self.__log.info('prediction consensus 20')
y_pred, samples = AdaNet.predict(x_test, predict_fn,
subsample_x_only,
consensus=True,
samples=20)
y_pred_consensus_20 = np.mean(samples, axis=1)
self.__log.info('plotting')
mse = np.mean((y_pred - y_test)**2, axis=1)
mse_nomask = np.mean((y_pred_nomask - y_test)**2, axis=1)
mse_consensus = np.mean((y_pred_consensus - y_test)**2, axis=1)
mse_consensus_10 = np.mean(
(y_pred_consensus_10 - y_test)**2, axis=1)
mse_consensus_20 = np.mean(
(y_pred_consensus_20 - y_test)**2, axis=1)
sns.distplot(np.log10(mse_nomask), ax=ax,
color='orange', label='cons. 1 nomask')
sns.distplot(np.log10(mse), ax=ax,
color='red', label='cons. 1')
sns.distplot(np.log10(mse_consensus), ax=ax,
color='green', label='cons. 5')
sns.distplot(np.log10(mse_consensus_10), ax=ax,
color='blue', label='cons. 10')
sns.distplot(np.log10(mse_consensus_20), ax=ax,
color='purple', label='cons. 20')
'''
corr_test = row_wise_correlation(y_pred_test, y_true_test)
sns.distplot(corr_test, ax=ax, bins=20, hist_kws={'range': (0, 1)},
color='grey', label='%s mols.' % y_pred_test.shape[0])
corr_test_comp = row_wise_correlation(y_pred_test.T, y_true_test.T)
sns.distplot(corr_test_comp, ax=ax,
color=self.cc_palette([ds])[0], label='128 comp.')
'''
# err_test = np.mean((y_pred_test - y_true_test)**2, axis=1)
# pc_corr_err = stats.pearsonr(corr_test, err_test)[0]
# ax.text(0.05, 0.85, "p: {:.2f}".format(pc_corr_err),
# transform=ax.transAxes, size=10)
# ax.set_xlim(0, 1)
ax.legend(prop={'size': 3})
ax.grid(axis='y', linestyle="-",
color=self.cc_palette([ds])[0], lw=0.3)
ax.spines["bottom"].set_color(self.cc_palette([ds])[0])
ax.spines["top"].set_color(self.cc_palette([ds])[0])
ax.spines["right"].set_color(self.cc_palette([ds])[0])
ax.spines["left"].set_color(self.cc_palette([ds])[0])
outfile = os.path.join(
self.plot_path, 'sign3_test_error_distribution.png')
plt.savefig(outfile, dpi=self.dpi)
plt.close('all')
def sign3_correlation_distribution(self):
from chemicalchecker.tool.adanet import AdaNet
from chemicalchecker.util.splitter import Traintest
def row_wise_correlation(X, Y):
var1 = (X.T - np.mean(X, axis=1)).T
var2 = (Y.T - np.mean(Y, axis=1)).T
cov = np.mean(var1 * var2, axis=1)
return cov / (np.std(X, axis=1) * np.std(Y, axis=1))
def mask_exclude(idxs, x_data, y_data):
x_data_transf = np.copy(x_data)
for idx in idxs:
# set current space to nan
col_slice = slice(idx * 128, (idx + 1) * 128)
x_data_transf[:, col_slice] = np.nan
# drop rows that only contain NaNs
not_nan = np.isfinite(x_data_transf).any(axis=1)
x_data_transf = x_data_transf[not_nan]
y_data_transf = y_data[not_nan]
return x_data_transf, y_data_transf
# sns.set_style("whitegrid")
fig, axes = plt.subplots(5, 5, sharey=False, sharex=True,
figsize=(10, 10), dpi=self.dpi)
all_dss = list(self.datasets)
for ds, ax in zip(all_dss, axes.flat):
s3 = self.cc.get_signature('sign3', 'full', ds)
# filter most correlated spaces
ds_corr = s3.get_h5_dataset('datasets_correlation')
self.__log.info(str(zip(list(self.datasets), list(ds_corr))))
# load X Y data
traintest_file = os.path.join(s3.model_path, 'traintest.h5')
traintest = Traintest(traintest_file, 'test')
traintest.open()
x_test, y_test = traintest.get_xy(0, 1000)
traintest.close()
# load DNN
predict_fn = AdaNet.predict_fn(os.path.join(
s3.model_path, 'adanet_eval', 'savedmodel'))
# check various correlations thresholds
colors = ['firebrick', 'gold', 'forestgreen']
for corr_thr, color in zip([.7, .9, 1.0], colors):
corr_spaces = np.array(list(self.datasets))[
ds_corr > corr_thr].tolist()
self.__log.info('masking %s' % str(corr_spaces))
idxs = [all_dss.index(d) for d in corr_spaces]
x_thr, y_true = mask_exclude(idxs, x_test, y_test)
y_pred = AdaNet.predict(x_thr, predict_fn)
corr_test_comp = row_wise_correlation(y_pred.T, y_true.T)
self.__log.info('%.2f N(%.2f,%.2f)' % (
corr_thr, np.mean(corr_test_comp), np.std(corr_test_comp)))
sns.distplot(corr_test_comp, ax=ax,
color=color, label='%.2f' % corr_thr)
ax.legend(prop={'size': 6})
ax.grid(axis='y', linestyle="-",
color=self.cc_palette([ds])[0], lw=0.3)
ax.spines["bottom"].set_color(self.cc_palette([ds])[0])
ax.spines["top"].set_color(self.cc_palette([ds])[0])
ax.spines["right"].set_color(self.cc_palette([ds])[0])
ax.spines["left"].set_color(self.cc_palette([ds])[0])
outfile = os.path.join(
self.plot_path, 'sign3_correlation_distribution.png')
plt.savefig(outfile, dpi=self.dpi)
if self.svg:
plt.savefig(outfile.replace('.png', '.svg'), dpi=self.dpi)
plt.close('all')
def sign3_test_distribution(self, cctype='sign4', limit=10000,
options=[['Pearson'], [True], [False]]):
def row_wise_correlation(X, Y):
var1 = (X.T - np.mean(X, axis=1)).T
var2 = (Y.T - np.mean(Y, axis=1)).T
cov = np.mean(var1 * var2, axis=1)
return cov / (np.std(X, axis=1) * np.std(Y, axis=1))
df = pd.DataFrame(
columns=['dataset', 'scaled', 'comp_wise', 'metric', 'value'])
all_dss = list(self.datasets)
all_dfs = list()
for ds in all_dss:
sign = self.cc.get_signature(cctype, 'full', ds)
pred_file = os.path.join(
sign.model_path, 'siamese_eval', 'plot_preds.pkl')
if not os.path.isfile(pred_file):
self.__log.warning('%s not found!' % pred_file)
continue
if options is None:
options = [
['log10MSE', 'R2', 'Pearson', 'MCC'],
[True, False],
[True, False]
]
preds = pickle.load(open(pred_file, 'rb'))
for metric, scaled, comp_wise in itertools.product(*options):
y_true = preds['test']['ONLY-SELF']
y_pred = preds['test']['NOT-SELF']
if comp_wise:
y_true = y_true.T
y_pred = y_pred.T
if scaled:
y_true = robust_scale(y_true)
y_pred = robust_scale(y_pred)
if metric == 'log10MSE':
values = np.log10(np.mean((y_true - y_pred)**2, axis=1))
elif metric == 'R2':
values = r2_score(y_true, y_pred, multioutput='raw_values')
elif metric == 'Pearson':
values = row_wise_correlation(y_true, y_pred)
elif metric == 'MCC':
y_true = y_true > 0
y_pred = y_pred > 0
values = [matthews_corrcoef(
y_true[i], y_pred[i]) for i in range(len(y_true))]
_df = pd.DataFrame(
dict(dataset=ds, scaled=scaled, comp_wise=comp_wise,
metric=metric, value=values))
all_dfs.append(_df)
df = pd.concat(all_dfs)
# sns.set_style("ticks")
# sns.set_style({'font.family': 'sans-serif', 'font.serif': ['Arial']})
if options is None:
options = [
['log10MSE', 'R2', 'Pearson', 'MCC'],
[True, False],
[True, False]
]
for metric, scaled, comp_wise in itertools.product(*options):
odf = df[(df.scaled == scaled) & (df.comp_wise == comp_wise) & (
df.metric == metric)]
xmin = np.floor(np.percentile(odf.value, 5))
xmax = np.ceil(np.percentile(odf.value, 95))
fig, axes = plt.subplots(
26, 1, sharex=True, figsize=(3, 10), dpi=self.dpi)
fig.subplots_adjust(left=0.05, right=.95, bottom=0.08,
top=1, wspace=0, hspace=-.3)
for idx, (ds, ax) in enumerate(zip(all_dss, axes.flat)):
color = self.cc_colors(ds, idx % 2)
color2 = self.cc_colors(ds, (idx % 2) + 1)
values = odf[(odf.dataset == ds)].value.tolist()
sns.kdeplot(values, ax=ax, clip_on=True, shade=True,
alpha=1, lw=0, bw=.15, color=color)
sns.kdeplot(values, ax=ax, clip_on=True,
color=color2, lw=2, bw=.15)
ax.axhline(y=0, lw=2, clip_on=True, color=color)
ax.set_xlim(xmin, xmax)
ax.tick_params(axis='x', colors=color)
ax.set_yticks([])
ax.set_xticks([])
ax.patch.set_alpha(0)
sns.despine(ax=ax, bottom=True, left=True, trim=True)
ax.grid(False)
ax = axes.flat[-1]
ax.set_yticks([])
ax.set_xlim(xmin, xmax)
ax.set_xticks(np.linspace(xmin, xmax, 5))
ax.set_xticklabels(
['%.1f' % x for x in np.linspace(xmin, xmax, 5)])
ax.grid(False)
# ax.set_xticklabels(['0','0.5','1'])
xlabel = metric
if comp_wise:
xlabel += ' Comp.'
else:
xlabel += ' Mol.'
if scaled:
xlabel += ' scaled'
ax.set_xlabel(xlabel,
fontdict=dict(name='Arial', size=16))
ax.tick_params(labelsize=14)
ax.patch.set_alpha(0)
sns.despine(ax=ax, bottom=False, left=True, trim=True)
fname = 'sign3_test_dist_%s' % metric
if comp_wise:
fname += '_comp'
if scaled:
fname += '_scaled'
print(fname)
print(odf.value.describe())
outfile = os.path.join(self.plot_path, fname + '.png')
plt.savefig(outfile, dpi=self.dpi)
if self.svg:
plt.savefig(outfile.replace('.png', '.svg'), dpi=self.dpi)
plt.close('all')
def sign3_mfp_predictor(self):
# sns.set_style("whitegrid")
f, axes = plt.subplots(5, 5, figsize=(9, 9), sharex=True, sharey=True)
for ds, ax in zip(self.datasets, axes.flat):
try:
s3 = self.cc.get_signature('sign3', 'full', ds)
except Exception:
continue
if not os.path.isfile(s3.data_path):
continue
stat_file = os.path.join(s3.model_path,
'adanet_sign0_A1.001_eval',
'stats_sign0_A1.001_eval.pkl')
df = pd.read_pickle(stat_file)
df['component_cat'] = pd.cut(
df.component,
bins=[-1, 127, 128, 129, 130, 131, 132],
labels=['signature', 'stddev', 'intensity', 'exp_error',
'novelty', 'confidence'])
# get idx of nearest neighbors of s2
sns.barplot(x='component_cat', y='mse', data=df, hue='split',
hue_order=['train', 'test'],
ax=ax, color=self.cc_palette([ds])[0])
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.legend(prop={'size': 6})
ax.get_legend().remove()
ax.grid(axis='y', linestyle="-",
color=self.cc_palette([ds])[0], lw=0.3)
ax.spines["bottom"].set_color(self.cc_palette([ds])[0])
ax.spines["top"].set_color(self.cc_palette([ds])[0])
ax.spines["right"].set_color(self.cc_palette([ds])[0])
ax.spines["left"].set_color(self.cc_palette([ds])[0])
f.text(0.06, 0.5, 'MSE', ha='center',
va='center', rotation='vertical')
outfile = os.path.join(
self.plot_path, 'sign3_mfp_predictor.png')
plt.savefig(outfile, dpi=self.dpi)
plt.close('all')
def sign3_mfp_confidence_predictor(self):
# sns.set_style("whitegrid")
f, axes = plt.subplots(5, 5, figsize=(9, 9), sharex=True, sharey=True)
for ds, ax in zip(self.datasets, axes.flat):
try:
s3 = self.cc.get_signature('sign3', 'full', ds)
except Exception:
continue
if not os.path.isfile(s3.data_path):
continue
stat_file = os.path.join(s3.model_path,
'adanet_sign0_A1.001_conf_eval',
'stats_sign0_A1.001_conf_eval.pkl')
df = pd.read_pickle(stat_file)
df['component_cat'] = df.component.astype('category')
df['component_cat'] = df['component_cat'].cat.rename_categories(
['stddev', 'intensity', 'exp_error', 'novelty', 'confidence'])
# get idx of nearest neighbors of s2
sns.barplot(x='component_cat', y='mse', data=df, hue='split',
hue_order=['train', 'test'],
ax=ax, color=self.cc_palette([ds])[0])
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.legend(prop={'size': 6})
ax.get_legend().remove()
ax.grid(axis='y', linestyle="-",
color=self.cc_palette([ds])[0], lw=0.3)
ax.spines["bottom"].set_color(self.cc_palette([ds])[0])
ax.spines["top"].set_color(self.cc_palette([ds])[0])
ax.spines["right"].set_color(self.cc_palette([ds])[0])
ax.spines["left"].set_color(self.cc_palette([ds])[0])
f.text(0.06, 0.5, 'MSE', ha='center',
va='center', rotation='vertical')
outfile = os.path.join(
self.plot_path, 'sign3_mfp_confidence_predictor.png')
plt.savefig(outfile, dpi=self.dpi)
plt.close('all')
@staticmethod
def quick_gaussian_kde(x, y, limit=1000):
xl = x[:limit]
yl = y[:limit]
xy = np.vstack([xl, yl])
try:
c = gaussian_kde(xy)(xy)
except Exception as ex:
MultiPlot.__log.warning('Could not compute KDE: %s' % str(ex))
c = np.arange(len(xy))
order = c.argsort()
return xl, yl, c, order
def sign3_confidence_summary(self, limit=5000):
from chemicalchecker.core.signature_data import DataSignature
fig = plt.figure(constrained_layout=True, figsize=(12, 12))
gs = fig.add_gridspec(5, 5, wspace=0.1, hspace=0.1)
plt.subplots_adjust(left=0.08, right=.95, bottom=0.08, top=.95)
axes = list()
for row, col in itertools.product(range(5), range(5)):
axes.append(fig.add_subplot(gs[row, col]))
fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top=False,
bottom=False, left=False, right=False)
plt.grid(False)
plt.xlabel("Correlation", labelpad=25, size=18)
plt.ylabel("Applicability", labelpad=25, size=18)
for ds, ax in zip(self.datasets, axes):
sign = self.cc.get_signature('sign3', 'full', ds)
# get data
confidence_path = os.path.join(sign.model_path, 'confidence_eval')
known_dist = DataSignature(os.path.join(
confidence_path, 'data.h5'))
correlation = known_dist.get_h5_dataset('y_test')
conf_feats = known_dist.get_h5_dataset('x_test')
applicability = conf_feats[:, 0]
robustness = conf_feats[:, 1]
prior = conf_feats[:, 2]
prior_sign = conf_feats[:, 3]
intensity = conf_feats[:, 4]
# get confidence
confidence_file = os.path.join(confidence_path, 'confidence.pkl')
calibration_file = os.path.join(confidence_path, 'calibration.pkl')
conf_mdl = (pickle.load(open(confidence_file, 'rb')),
pickle.load(open(calibration_file, 'rb')))
# and estimate confidence
conf_feats = np.vstack(
[applicability, robustness, prior, prior_sign, intensity]).T
conf_estimate = conf_mdl[0].predict(conf_feats)
confidence = conf_mdl[1].predict(np.expand_dims(conf_estimate, 1))
# compute pearson rho
rhos = dict()
rho_confidence = stats.pearsonr(correlation, confidence)[0]
rhos['d'] = abs(stats.pearsonr(correlation, applicability)[0])
rhos['r'] = abs(stats.pearsonr(correlation, robustness)[0])
rhos['p'] = abs(stats.pearsonr(correlation, prior)[0])
rhos['s'] = abs(stats.pearsonr(correlation, prior_sign)[0])
rhos['i'] = abs(stats.pearsonr(correlation, intensity)[0])
# scatter gaussian
x, y, c, order = self.quick_gaussian_kde(
correlation, confidence, limit)
colors = [self.cc_colors(ds, 2), self.cc_colors(ds, 0)]
cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
'', colors)
ax.scatter(x[order], y[order], c=c[order],
cmap=cmap, s=15, edgecolor='', alpha=.9)
ax.text(0.05, 0.85, r"$\rho$: {:.2f}".format(rho_confidence),
transform=ax.transAxes, name='Arial', size=14,
bbox=dict(facecolor='white', alpha=0.8))
ax.set_ylim(-1.0, 1.0)
ax.set_xlim(-1.0, 1.0)
ax.set_ylabel('Applicability')
ax.set_ylabel('')
ax.set_xlabel('Correlation')
ax.set_xlabel('')
ax.plot(ax.get_xlim(), ax.get_ylim(), ls="--", c=".9")
ax.set_yticks([-1.0, -0.5, 0.0, 0.5, 1.0])
ax.set_yticklabels(['-1', '-0.5', '0', '0.5', '1'])
ax.set_xticks([-1.0, -0.5, 0.0, 0.5, 1.0])
ax.set_xticklabels(['-1', '-0.5', '0', '0.5', '1'])
ax.tick_params(labelsize=14, direction='inout')
# axis ticks
if ds[:2] == 'E1':
# set the alignment for outer ticklabels
ticklabels = ax.get_yticklabels()
ticklabels[0].set_va("bottom")
ticklabels[-1].set_va("top")
elif ds[1] == '1':
# set the alignment for outer ticklabels
ticklabels = ax.get_yticklabels()
ticklabels[0].set_va("bottom")
ticklabels[-1].set_va("top")
ax.xaxis.set_ticklabels([])
elif ds[0] == 'E':
ax.yaxis.set_ticklabels([])
else:
ax.xaxis.set_ticklabels([])
ax.yaxis.set_ticklabels([])
# axis labels
if ds[0] == 'A':
ax.set_xlabel(ds[1], fontsize=18, labelpad=15)
ax.xaxis.set_label_position('top')
if ds[1] == '5':
ax.set_ylabel(ds[0], fontsize=18, rotation=0, va='center',
labelpad=15)
ax.yaxis.set_label_position('right')
# pies
wp = {'linewidth': 0, 'antialiased': True}
colors = [self.cc_colors(ds, 1), 'lightgrey']
i = 0
for name, rho in rhos.items():
# [x0, y0, width, height]
bounds = [0.05 + (0.18 * i), 0.05, 0.18, 0.18]
i += 1
inset_ax = ax.inset_axes(bounds)
inset_ax.pie([abs(rho), 1 - abs(rho)], wedgeprops=wp,
counterclock=False, startangle=90, colors=colors)
inset_ax.pie([1.0], radius=0.5, colors=[
'white'], wedgeprops=wp)
inset_ax.text(0.5, 1.2, name, ha='center', va='center',
transform=inset_ax.transAxes, name='Arial',
style='italic', size=12)
outfile = os.path.join(self.plot_path,
'sign3_confidence_summary.png')
plt.savefig(outfile, dpi=self.dpi)
if self.svg:
plt.savefig(outfile.replace('.png', '.svg'), dpi=self.dpi)
plt.savefig(outfile, dpi=self.dpi)
if self.svg:
plt.savefig(outfile.replace('.png', '.svg'), dpi=self.dpi)
plt.close('all')
def sign3_novel_confidence_distribution(self):
# sns.set_style("whitegrid")
f, axes = plt.subplots(5, 5, figsize=(9, 9), sharex=True, sharey=True)
for ds, ax in zip(self.datasets, axes.flat):
try:
s3 = self.cc.get_signature('sign3', 'full', ds)
except Exception:
continue
if not os.path.isfile(s3.data_path):
continue
# get novelty and confidence
# nov = s3.get_h5_dataset('novelty')
out = s3.get_h5_dataset('outlier')
conf = s3.get_h5_dataset('confidence')
# get really novel molecules
# min_known = min(nov[out == 0])
nov_confs = conf[out == -1]
print(ds, len(nov_confs))
sns.distplot(nov_confs, color=self.cc_palette([ds])[0],
kde=False, norm_hist=False, ax=ax, bins=20,
hist_kws={'range': (0, 1)})
ax.set_xlim(0, 1)
ax.set_yscale('log')
ax.grid(axis='y', linestyle="-",
color=self.cc_palette([ds])[0], lw=0.3)
ax.grid(axis='x', linestyle="-",
color=self.cc_palette([ds])[0], lw=0.3)
ax.spines["bottom"].set_color(self.cc_palette([ds])[0])
ax.spines["top"].set_color(self.cc_palette([ds])[0])
ax.spines["right"].set_color(self.cc_palette([ds])[0])
ax.spines["left"].set_color(self.cc_palette([ds])[0])
ax.text(0.05, 0.85, "%i" % len(nov_confs),
transform=ax.transAxes, size=8)
f.text(0.5, 0.04, 'Confidence', ha='center', va='center')
f.text(0.06, 0.5, 'Novel Molecules', ha='center',
va='center', rotation='vertical')
outfile = os.path.join(
self.plot_path, 'sign3_novel_confidence_distribution.png')
plt.savefig(outfile, dpi=self.dpi)
plt.close('all')
def sign3_examplary_test_correlation(self, limit=1000,
molecules=[
'MBUVEWMHONZEQD-UHFFFAOYSA-N'],
examplary_ds=['B1.001', 'D1.001', 'E4.001']):
from chemicalchecker.core import DataSignature
def mask_keep(idxs, x1_data):
# we will fill an array of NaN with values we want to keep
x1_data_transf = np.zeros_like(x1_data, dtype=np.float32) * np.nan
for idx in idxs:
# copy column from original data
col_slice = slice(idx * 128, (idx + 1) * 128)
x1_data_transf[:, col_slice] = x1_data[:, col_slice]
# keep rows containing at least one not-NaN value
return x1_data_transf
def mask_exclude(idxs, x1_data):
x1_data_transf = np.copy(x1_data)
for idx in idxs:
# set current space to nan
col_slice = slice(idx * 128, (idx + 1) * 128)
x1_data_transf[:, col_slice] = np.nan
# drop rows that only contain NaNs
return x1_data_transf
def row_wise_correlation(X, Y):
var1 = (X.T - np.mean(X, axis=1)).T
var2 = (Y.T - np.mean(Y, axis=1)).T
cov = np.mean(var1 * var2, axis=1)
return cov / (np.std(X, axis=1) * np.std(Y, axis=1))
def plot_molecule(ax, smiles, size=6):
from rdkit import Chem
from rdkit.Chem import Draw
figure = Draw.MolToMPL(Chem.MolFromSmiles(smiles))
# rdkit only plot to new figure, so copy over to my axq
for child in figure.axes[0].get_children():
if isinstance(child, matplotlib.lines.Line2D):
ax.plot(*child.get_data(), c=child.get_color())
if isinstance(child, matplotlib.text.Annotation):
ax.text(*child.get_position(), s=child.get_text(),
color=child.get_color(), ha=child.get_ha(),
va=child.get_va(), family=child.get_family(),
size=size,
bbox={'facecolor': 'white',
'boxstyle': 'circle,pad=0.2'})
plt.close(figure)
correlations = dict()
true_pred = dict()
for ds in examplary_ds:
ds_idx = np.argwhere(np.isin(self.datasets, ds)).flatten()
s2 = self.cc.get_signature('sign2', 'full', ds)
s4 = self.cc.get_signature('sign3', 'full', ds)
traintest_file = os.path.join(s4.model_path, 'traintest_eval.h5')
traintest = DataSignature(traintest_file)
inks = traintest.get_h5_dataset('keys_train')[:100000]
inks = np.sort(inks)
test_mask = np.isin(list(s2.keys), list(inks),
assume_unique=True)
sign2_matrix = os.path.join(s4.model_path, 'train.h5')
X = DataSignature(sign2_matrix)
feat = X.get_h5_dataset('x', mask=test_mask)
predict_fn = s4.get_predict_fn(smiles=False, model='siamese_eval')
y_true = predict_fn(mask_keep(ds_idx, feat))
y_pred = predict_fn(mask_exclude(ds_idx, feat))
true_pred[ds] = dict()
for i, ink in enumerate(inks):
tp = (y_true[i], y_pred[i])
true_pred[ds][ink] = tp
ink_corrs = list(zip(inks, row_wise_correlation(y_true, y_pred)))
correlations[ds] = dict(ink_corrs)
if molecules is None:
mols = list()
for k, v in true_pred.items():
mols.append(set(v.keys()))
shared_inks = set.intersection(*mols)
print(shared_inks)
else:
shared_inks = molecules
for mol in tqdm(shared_inks):
# sns.set_style("whitegrid")
# sns.set_style({'font.family': 'sans-serif', 'font.serif': ['Arial']})
fig = plt.figure(constrained_layout=True, figsize=(4, 4))
fig.set_constrained_layout_pads(w_pad=0., h_pad=0.,
hspace=0., wspace=0.)
gs = fig.add_gridspec(2, 1)
gs.set_height_ratios((1, 1))
# gs.set_height_ratios((1, 2, 2, 2))
'''
fig.text(0.5, 0.02, 'Actual Signature',
ha='center', va='center',
name='Arial', size=16)
fig.text(0.04, 0.45, 'Predicted', ha='center',
va='center', rotation='vertical',
name='Arial', size=16)
'''
# plot molecule
ax_mol = fig.add_subplot(gs[0])
# get smiles
converter = Converter()
smiles = converter.inchi_to_smiles(
converter.inchikey_to_inchi(mol)[0]['standardinchi'])
plot_molecule(ax_mol, smiles, size=8)
# fix placement
ax_mol.set_axis_off()
l, b, w, h = ax_mol.get_position().bounds
new_w = w + 0.05
new_h = h + 0.05
ax_mol.set_position([0.5 - (new_w / 2.), b + 0.05, new_w, new_h])
ax_mol.axis('equal')
gss_ds = gs[1].subgridspec(1, len(examplary_ds))
mccs = list()
for idx, (ds, sub) in enumerate(zip(examplary_ds, gss_ds)):
gs_ds = sub.subgridspec(2, 2, wspace=0.0, hspace=0.0)
gs_ds.set_height_ratios((1, 5))
gs_ds.set_width_ratios((5, 1))
ax_main = fig.add_subplot(gs_ds[1, 0])
ax_top = fig.add_subplot(gs_ds[0, 0], sharex=ax_main)
ax_top.text(-0.2, 0.2, "%s" % ds[:2],
color='black',
transform=ax_top.transAxes,
name='Arial', size=14, weight='bold')
ax_top.set_axis_off()
ax_right = fig.add_subplot(gs_ds[1, 1], sharey=ax_main)
ax_right.set_axis_off()
true, pred = true_pred[ds][mol]
T = true > 0
P = pred > 0
from scipy.signal import find_peaks
kde_range = np.linspace(-1, 1, 1000)
peaks_true_i, _ = find_peaks(gaussian_kde(true)(kde_range))
peaks_pred_i, _ = find_peaks(gaussian_kde(pred)(kde_range))
peaks_true = kde_range[peaks_true_i]
peaks_pred = kde_range[peaks_pred_i]
coords = np.array(
[[max(peaks_true), max(peaks_pred)],
[min(peaks_true), max(peaks_pred)],
[min(peaks_true), min(peaks_pred)],
[max(peaks_true), min(peaks_pred)]])
x_range = (min(peaks_true) - 0.1, max(peaks_true) + 0.1)
y_range = (min(peaks_pred) - 0.1, max(peaks_pred) + 0.1)
ax_main.set_xlim(x_range)
ax_main.set_ylim(y_range)
s = np.array(
[sum(T & P), sum(~T & P), sum(~T & ~P), sum(T & ~P)])
ax_main.scatter(coords[:, 0], coords[:, 1], s=s * 10,
color=self.cc_colors(ds, 1),
edgecolor="black", lw=0.8)
# sns.despine(ax=ax_main, offset=3, trim=True)
ax_main.set_xlabel('Actual', size=14)
ax_main.set_ylabel('Inferred', size=14)
if idx != 0:
ax_main.set_ylabel(' ')
ax_main.set_yticks([0])
ax_main.set_yticklabels([])
ax_main.set_xticks([0])
ax_main.set_xticklabels([])
ax_main.tick_params(labelsize=14, direction='inout')
# ax_main.set_aspect('equal')
mcc = matthews_corrcoef(T, P)
mccs.append(mcc)
ax_main.text(0.5, 0.5, r"$MCC$: {:.2f}".format(mcc),
transform=ax_main.transAxes, name='Arial',
ha='center', va='center',
size=12, bbox=dict(facecolor='white', alpha=0.8))
sns.distplot(true, ax=ax_top, kde=True,
hist=False, kde_kws=dict(shade=True, bw=.2),
color=self.cc_colors(ds, 1))
sns.distplot(pred, ax=ax_right, vertical=True, kde=True,
hist=False, kde_kws=dict(shade=True, bw=.2),
color=self.cc_colors(ds, 1))
'''
if idx == 0:
ax_main.set_xlabel(' ')
if idx == 1:
ax_main.set_ylabel(' ')
if idx == 2:
ax_main.set_ylabel(' ')
ax_main.set_xlabel(' ')
'''
# plt.tight_layout()
spaces = '_'.join(['%s-%.1f' % (ds[:2], m)
for ds, m in zip(examplary_ds, mccs)])
outfile = os.path.join(
self.plot_path, 'sign3_%s_%s.png' % (spaces, mol))
plt.savefig(outfile, dpi=self.dpi)
if self.svg:
plt.savefig(outfile.replace('.png', '.svg'), dpi=self.dpi)
plt.close('all')
def cctype_validation_comparison(self, cctype1='sign4', cctype2='sign2', valtype='moa'):
pklfile = os.path.join(
self.plot_path, '%s_%s_%s_comparison.pkl' % (cctype1, cctype2, valtype))
if not os.path.isfile(pklfile):
data = dict()
for ds in self.datasets:
if 'dataset' not in data:
data['dataset'] = list()
data['dataset'].append(ds[:2])
# sign2
s2 = self.cc.get_signature(cctype2, 'full', ds)
stat_file = os.path.join(
s2.stats_path, 'validation_stats.json')
if not os.path.isfile(stat_file):
s2.validate()
stats = json.load(open(stat_file, 'r'))
for k, v in stats.items():
if k + '_%s' % cctype2 not in data:
data[k + '_%s' % cctype2] = list()
data[k + '_%s' % cctype2].append(v)
# sign4
s3 = self.cc.get_signature(cctype1, 'full', ds)
stat_file = os.path.join(
s3.stats_path, 'validation_stats.json')
if not os.path.isfile(stat_file):
s3.validate()
stats = json.load(open(stat_file, 'r'))
for k, v in stats.items():
if k + '_0.0' not in data:
data[k + '_0.0'] = list()
data[k + '_0.0'].append(v)
# other confidences thresholds
for thr in np.arange(0.1, 0.9, 0.1):
s3_conf = self.cc.get_signature(
cctype1, 'conf%.1f' % thr, ds)
if not os.path.isfile(s3_conf.data_path):
conf_mask = s3.get_h5_dataset('confidence') > thr
s3.make_filtered_copy(s3_conf.data_path, conf_mask)
stat_file = os.path.join(s3_conf.stats_path,
'validation_stats.json')
if not os.path.isfile(stat_file):
s3_conf.validate()
stats = json.load(open(stat_file, 'r'))
for k, v in stats.items():
if k + '_%.1f' % thr not in data:
data[k + '_%.1f' % thr] = list()
data[k + '_%.1f' % thr].append(v)
df = pd.DataFrame(data)
df = df.infer_objects()
df.sort_values("dataset", ascending=False, inplace=True)
df.to_pickle(pklfile)
def gradient_arrow(ax, start, end, xs=None, cmap="plasma", head=None, n=50, lw=3):
# Arrow shaft: LineCollection
if xs is None:
x = np.linspace(start[0], end[0], n)
else:
x = xs
n = len(xs)
cmap = plt.get_cmap(cmap, n)
y = np.linspace(start[1], end[1], n)
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
lc = LineCollection(segments, cmap=cmap, linewidth=lw)
lc.set_array(np.linspace(0, 1, n))
ax.add_collection(lc)
# Arrow head: Triangle
tricoords = [(0, -0.4), (0.5, 0), (0, 0.4), (0, -0.4)]
angle = np.arctan2(end[1] - start[1], end[0] - start[0])
rot = matplotlib.transforms.Affine2D().rotate(angle)
tricoords2 = rot.transform(tricoords)
tri = matplotlib.path.Path(tricoords2, closed=True)
if head is None:
if xs is None:
head = cmap(n)
else:
head = cmap(xs.index(xs[-1]))
ax.scatter(end[0], end[1], c=head, s=(2 * lw)**2,
marker=tri, cmap=cmap, vmin=0)
ax.autoscale_view()
df = pd.read_pickle(pklfile)
# sns.set_style("ticks")
# sns.set_style({'font.family': 'sans-serif', 'font.serif': ['Arial']})
fig = plt.figure(figsize=(3, 10))
plt.subplots_adjust(left=0.14, right=0.96, bottom=0.01,
top=0.92, hspace=0.1)
gs = fig.add_gridspec(2, 1)
gs.set_height_ratios((30, 1))
gs_ds = gs[0].subgridspec(1, 2, wspace=0.1, hspace=0.0)
ax_cov = fig.add_subplot(gs_ds[0])
ax_roc = fig.add_subplot(gs_ds[1])
for ds in self.datasets:
y = 25 - self.datasets.index(ds) - 1
start = df[df.dataset == ds[:2]]['%s_cov_%s' % (valtype, cctype2)].tolist()[
0]
end = df[df.dataset == ds[:2]]['%s_cov_0.0' % valtype].tolist()[0]
js = [cctype2] + ['%.1f' %
f for f in reversed(np.arange(0.0, 0.9, 0.1))]
# js = ['sign2','0.5','0.0']
covs = [df[df.dataset == ds[:2]]['%s_cov_%s' % (valtype, j)].tolist()[0]
for j in js]
cmap = plt.get_cmap("plasma", len(covs))
gradient_arrow(ax_cov, (start, y), (end, y), xs=covs, lw=7)
start = df[df.dataset == ds[:2]]['%s_auc_%s' % (valtype, cctype2)].tolist()[
0]
aucs = [df[df.dataset == ds[:2]]['%s_auc_%s' % (valtype, j)].tolist()[0]
for j in js]
end = aucs[np.argmax(covs)]
cmap = plt.get_cmap("plasma", len(covs))
ax_roc.scatter(start, y, color=cmap(0), s=60)
ax_roc.scatter(end, y, color=cmap(np.argmax(covs)), s=60)
ax_cov.grid(False)
ax_cov.set_yticks(range(0, 25))
ax_cov.set_yticklabels(df['dataset'])
ax_cov.set_xlim(0, 110)
ax_cov.set_xticks([0, 100])
ax_cov.set_xticklabels(['0', '100'])
ax_cov.set_xlabel('Coverage', fontdict=dict(name='Arial', size=14))
ax_cov.xaxis.set_label_position('top')
# ax_cov.xaxis.tick_top()
sns.despine(ax=ax_cov, left=True, bottom=True, top=False, trim=True)
ax_cov.tick_params(left=False, labelsize=14, direction='inout',
bottom=False, top=True, labelbottom=False, labeltop=True)
# set the alignment for outer ticklabels
ticklabels = ax_cov.get_xticklabels()
ticklabels[0].set_ha("left")
ticklabels[-1].set_ha("right")
ax_roc.grid(False)
ax_roc.set_yticks([])
ax_roc.set_yticklabels([])
ax_roc.set_xlim(0.5, 1)
ax_roc.set_xticks([0.5, 1])
ax_roc.set_xticklabels(['0.5', '1'])
ax_roc.set_xlabel('AUROC', fontdict=dict(name='Arial', size=14))
ax_roc.xaxis.set_label_position('top')
sns.despine(ax=ax_roc, left=True, bottom=True, top=False, trim=True)
ax_roc.tick_params(left=False, labelsize=14, direction='inout',
bottom=False, top=True, labelbottom=False, labeltop=True)
# set the alignment for outer ticklabels
ticklabels = ax_roc.get_xticklabels()
ticklabels[0].set_ha("left")
ticklabels[-1].set_ha("right")
ax_cbar = fig.add_subplot(gs[1])
cbar = matplotlib.colorbar.ColorbarBase(
ax_cbar, cmap=cmap, orientation='horizontal',
ticklocation='top')
cbar.ax.set_xlabel('Confidence filter',
fontdict=dict(name='Arial', size=14))
cbar.ax.tick_params(labelsize=14, )
cbar.set_ticks([1, .8, .6, .4, .2, .0])
cbar.set_ticklabels(
list(reversed(['1', '0.8', '0.6', '0.4', '0.2', '0'])))
outfile = os.path.join(
self.plot_path, '%s_%s_%s_comparison.png' % (cctype1, cctype2, valtype))
plt.savefig(outfile, dpi=self.dpi)
plt.close('all')