"""CCFit task.
This task allow submitting HPC jobs which will call the `fit` method
for a specific signature type (e.g. 'sign1').
It allows submitting jobs for all spaces of the CC at once and passing
parameters specific for each of them.
We should avoid adding too much logic in here, and simply pass `args` and
`kwargs` to signatures. The specific signature classes should check that
the parameters are OK.
"""
import os
import shutil
import tempfile
from chemicalchecker.database import Dataset
from chemicalchecker.core import ChemicalChecker
from chemicalchecker.util.pipeline import BaseTask
from chemicalchecker.util import logged, HPC
VALID_TYPES = ['sign', 'neig', 'clus', 'proj']
FIT_SCRIPT = """
import sys
import os
import pickle
import logging
import chemicalchecker
from chemicalchecker import ChemicalChecker, Config
logging.log(logging.DEBUG, 'chemicalchecker: {{}}'.format(
chemicalchecker.__path__))
logging.log(logging.DEBUG, 'CWD: {{}}'.format(os.getcwd()))
config = Config()
task_id = sys.argv[1] # <TASK_ID>
filename = sys.argv[2] # <FILE>
inputs = pickle.load(open(filename, 'rb')) # load pickled data
sign_args = inputs[task_id][0][0]
sign_kwargs = inputs[task_id][0][1]
fit_args = inputs[task_id][0][2]
fit_kwargs = inputs[task_id][0][3]
cc = ChemicalChecker('{cc_root}')
sign = cc.get_signature(*sign_args, **sign_kwargs)
sign.fit(*fit_args, **fit_kwargs)
print('JOB DONE')
"""
[docs]@logged
class CCFit(BaseTask):
def __init__(self, cc_root, cctype, molset, **params):
"""Initialize CC fit task.
Args:
cc_root (str): The CC root path (Required)
cctype (str): The CC type where the fit is applied (Required)
molset (str): The signature molset (e.g. `full` or `reference`)
on which the `fit` method will be called.
name (str): The name of the task (default:cctype)
datasets (list): The list of dataset codes to apply the fit
(Optional, by default 'essential' which includes all essential
CC datasets)
sign_args (dict): A dictionary where key is dataset code and
value is a list with all dataset specific parameters for
initializing the signature. (Optional)
sign_kwargs (dict): A dictionary where key is dataset code and
value is a dictionary with all dataset specific key-worded
parameters for initializing the signature. (Optional)
fit_args (dict): A dictionary where key is dataset code and
value is a list with all dataset specific parameters for
calling the signature `fit` method. (Optional)
fit_kwargs (dict): A dictionary where key is dataset code and
value is a dictionary with all dataset specific key-worded
calling the signature `fit` method. (Optional)
hpc_kwargs (dict): A dictionary where key is dataset code and
value is a dictionary with key-worded parameters for the
`HPC` module. (Optional)
ref_datasets (list): List of reference datasets for fitting sign3.
(specific for `sign3`)
"""
if not any([cctype.startswith(t) for t in VALID_TYPES]):
raise Exception("cctype '%s' is not recognized.")
self.name = params.get('name', cctype)
BaseTask.__init__(self, self.name)
self.cctype = cctype
self.cc_root = cc_root
self.molset = molset
self.datasets = params.get('datasets', 'essential')
if self.datasets == 'essential':
self.datasets = [ds.code for ds in Dataset.get(essential=True)]
self.sign_args = params.get('sign_args', {})
self.sign_kwargs = params.get('sign_kwargs', {})
self.fit_args = params.get('fit_args', {})
self.fit_kwargs = params.get('fit_kwargs', {})
self.hpc_kwargs = params.get('hpc_kwargs', {})
def_ref_datasets = [ds.code for ds in Dataset.get(exemplary=True)]
self.ref_datasets = params.get('reference_datasets', def_ref_datasets)
[docs] def run(self):
"""Run the task."""
# exclude dataset that have been already fitted
cc = ChemicalChecker(self.cc_root)
dataset_codes = list()
for ds in self.datasets:
sign = cc.get_signature(self.cctype, self.molset, ds)
if sign.is_fit():
continue
dataset_codes.append(ds)
if len(dataset_codes) == 0:
self.__log.warning('All dataset are already fitted.')
self.mark_ready()
return
# Preparing dataset_params
# for each dataset we want to define a set of signature parameters
# (i.e. sign_pars, used when loading the signature) and a set of
# parameters used when calling the 'fit' method (i.e. fit_pars)
# FIXME can be further harmonized fixing individual signature classes
dataset_params = list()
for ds_code in dataset_codes:
sign_args = list()
fit_args = list()
sign_args.extend(self.sign_args.get(ds_code, list()))
fit_args.extend(self.fit_args.get(ds_code, list()))
sign_kwargs = dict()
fit_kwargs = dict()
sign_kwargs.update(self.sign_kwargs.get(ds_code, dict()))
fit_kwargs.update(self.fit_kwargs.get(ds_code, dict()))
# we add arguments which are used by CCFit but are also needed
# by a signature fit
sign_args.insert(0, self.cctype)
sign_args.insert(1, self.molset)
sign_args.insert(2, ds_code)
# prepare it as tuple that will be serialized
dataset_params.append(
(sign_args, sign_kwargs, fit_args, fit_kwargs))
self.__log.info('%s sign_args: %s', ds_code, str(sign_args))
self.__log.info('%s sign_kwargs: %s', ds_code, str(sign_kwargs))
self.__log.info('%s fit_args: %s', ds_code, str(fit_args))
self.__log.info('%s fit_kwargs: %s', ds_code, str(fit_kwargs))
# Create script file that will launch signx fit for each dataset
job_path = tempfile.mkdtemp(
prefix='jobs_%s_' % self.cctype, dir=self.tmpdir)
script_name = os.path.join(job_path, self.cctype + '_script.py')
script_content = FIT_SCRIPT.format(cc_root=self.cc_root)
with open(script_name, 'w') as fh:
fh.write(script_content)
# HPC job parameters
params = {}
params["num_jobs"] = len(dataset_codes)
params["jobdir"] = job_path
params["job_name"] = "CC_" + self.cctype.upper()
params["elements"] = dataset_params
params["wait"] = True
params.update(self.hpc_kwargs)
# prepare job command and submit job
cc_config_path = self.config.config_path
cc_package = os.path.join(self.config.PATH.CC_REPO, 'package')
singularity_image = self.config.PATH.SINGULARITY_IMAGE
command = ("SINGULARITYENV_PYTHONPATH={} SINGULARITYENV_CC_CONFIG={} "
"singularity exec {} python {} <TASK_ID> <FILE>").format(
cc_package, cc_config_path, singularity_image, script_name)
self.__log.debug('CMD CCFIT: %s', command)
# submit jobs
cluster = HPC.from_config(self.config)
jobs = cluster.submitMultiJob(command, **params)
self.__log.info("Job with jobid '%s' ended.", str(jobs))
# Check if signatures are indeed fitted
dataset_not_done = []
for ds_code in dataset_codes:
sign = cc.get_signature(
self.cctype, self.molset, ds_code)
if sign.is_fit():
continue
dataset_not_done.append(ds_code)
self.__log.warning(
self.cctype + " fit failed for dataset code: " + ds_code)
if len(dataset_not_done) == 0:
self.mark_ready()
if not self.keep_jobs:
self.__log.info("Deleting job path: %s", job_path)
shutil.rmtree(job_path, ignore_errors=True)
else:
if not self.custom_ready():
raise Exception("Not all dataset fits are done")
[docs] def execute(self, context):
"""Same as run but for Airflow."""
self.tmpdir = context['params']['tmpdir']
self.run()