Source code for chemicalchecker.util.pipeline.tasks.task_cc_predict

"""CCPredict task.

This task allow submitting HPC jobs which will call the `predict` method
for a specific signature type (e.g. 'sign1').
It allows submitting jobs for all spaces of the CC at once and passing
parameters specific for each of them.
We should avoid adding too much logic in here, and simply pass `args` and
`kwargs` to signatures. The specific signature classes should check that
the parameters are OK.
"""
import os
import shutil
import tempfile

from chemicalchecker.database import Dataset
from chemicalchecker.core import ChemicalChecker
from chemicalchecker.util.pipeline import BaseTask
from chemicalchecker.util import logged, HPC

VALID_TYPES = ['sign', 'neig', 'clus', 'proj']

predict_SCRIPT = """
import sys
import os
import pickle
import logging
import chemicalchecker
from chemicalchecker import ChemicalChecker, Config
logging.log(logging.DEBUG, 'chemicalchecker: {{}}'.format(
    chemicalchecker.__path__))
logging.log(logging.DEBUG, 'CWD: {{}}'.format(os.getcwd()))
config = Config()
task_id = sys.argv[1]  # <TASK_ID>
filename = sys.argv[2]  # <FILE>
inputs = pickle.load(open(filename, 'rb'))  # load pickled data
sign_args = inputs[task_id][0][0]
sign_kwargs = inputs[task_id][0][1]
predict_args = inputs[task_id][0][2]
predict_kwargs = inputs[task_id][0][3]
cc = ChemicalChecker('{cc_root}')
sign = cc.get_signature(*sign_args, **sign_kwargs)
sign.{predict_fn}(*predict_args, **predict_kwargs)
print('JOB DONE')
"""


[docs]@logged class CCPredict(BaseTask): def __init__(self, cc_root, cctype, molset, **params): """Initialize CC predict task. Args: cc_root (str): The CC root path (Required) cctype (str): The CC type where the predict is applied (Required) molset (str): The signature molset (e.g. `full` or `reference`) on which the `predict` method will be called. name (str): The name of the task (default:cctype) datasets (list): The list of dataset codes to apply the predict (Optional, by default 'essential' which includes all essential CC datasets) sign_args (dict): A dictionary where key is dataset code and value is a list with all dataset specific parameters for initializing the signature. (Optional) sign_kwargs (dict): A dictionary where key is dataset code and value is a dictionary with all dataset specific key-worded parameters for initializing the signature. (Optional) predict_fn (str): The name of the predict function to call. By default si `predict`. predict_args (dict): A dictionary where key is dataset code and value is a list with all dataset specific parameters for calling the signature `predict` method. (Optional) predict_kwargs (dict): A dictionary where key is dataset code and value is a dictionary with all dataset specific key-worded calling the signature `predict` method. (Optional) hpc_kwargs (dict): A dictionary where key is dataset code and value is a dictionary with key-worded parameters for the `HPC` module. (Optional) """ if not any([cctype.startswith(t) for t in VALID_TYPES]): raise Exception("cctype '%s' is not recognized.") self.name = params.get('name', cctype) BaseTask.__init__(self, self.name) self.cctype = cctype self.cc_root = cc_root self.molset = molset self.datasets = params.get('datasets', 'essential') if self.datasets == 'essential': self.datasets = [ds.code for ds in Dataset.get(essential=True)] self.sign_args = params.get('sign_args', {}) self.sign_kwargs = params.get('sign_kwargs', {}) self.predict_fn = params.get('predict_fn', 'predict') self.predict_args = params.get('predict_args', {}) self.predict_kwargs = params.get('predict_kwargs', {}) self.hpc_kwargs = params.get('hpc_kwargs', {})
[docs] def run(self): """Run the task.""" # exclude dataset that have not been fitted cc = ChemicalChecker(self.cc_root) dataset_codes = list() for ds in self.datasets: sign = cc.get_signature(self.cctype, self.molset, ds) if not sign.is_fit(): self.__log.warning('Dataset %s should be fitted first.' % ds) continue dataset_codes.append(ds) if len(dataset_codes) == 0: self.__log.warning('All dataset should be fitted first.') self.mark_ready() return # Preparing dataset_params # for each dataset we want to define a set of signature parameters # (i.e. sign_pars, used when loading the signature) and a set of # parameters used when calling the 'predict' method (i.e. predict_pars) # FIXME can be further harmonized fixing individual signature classes dataset_params = list() for ds_code in dataset_codes: sign_args = list() predict_args = list() sign_args.extend(self.sign_args.get(ds_code, list())) predict_args.extend(self.predict_args.get(ds_code, list())) sign_kwargs = dict() predict_kwargs = dict() sign_kwargs.update(self.sign_kwargs.get(ds_code, dict())) predict_kwargs.update(self.predict_kwargs.get(ds_code, dict())) # we add arguments which are used by CCPredict but are also needed # by a signature sign_args.insert(0, self.cctype) sign_args.insert(1, self.molset) sign_args.insert(2, ds_code) # prepare it as tuple that will be serialized dataset_params.append( (sign_args, sign_kwargs, predict_args, predict_kwargs)) self.__log.info('%s sign_args: %s', ds_code, str(sign_args)) self.__log.info('%s sign_kwargs: %s', ds_code, str(sign_kwargs)) #self.__log.info('%s predict_args: %s', ds_code, str(predict_args)) #self.__log.info('%s predict_kwargs: %s', # ds_code, str(predict_kwargs)) # Create script file that will launch signx predict for each dataset job_path = tempfile.mkdtemp( prefix='jobs_%s_' % self.cctype, dir=self.tmpdir) script_name = os.path.join(job_path, self.cctype + '_script.py') script_content = predict_SCRIPT.format(cc_root=self.cc_root, predict_fn=self.predict_fn) with open(script_name, 'w') as fh: fh.write(script_content) # HPC job parameters params = {} params["num_jobs"] = len(dataset_codes) params["jobdir"] = job_path params["job_name"] = "CC_" + self.cctype.upper() params["elements"] = dataset_params params["wait"] = True params.update(self.hpc_kwargs) # prepare job command and submit job cc_config_path = self.config.config_path cc_package = os.path.join(self.config.PATH.CC_REPO, 'package') singularity_image = self.config.PATH.SINGULARITY_IMAGE command = ("SINGULARITYENV_PYTHONPATH={} SINGULARITYENV_CC_CONFIG={} " "singularity exec {} python {} <TASK_ID> <FILE>").format( cc_package, cc_config_path, singularity_image, script_name) self.__log.debug('CMD CCPREDICT: %s', command) # submit jobs cluster = HPC.from_config(self.config) jobs = cluster.submitMultiJob(command, **params) self.__log.info("Job with jobid '%s' ended.", str(jobs)) self.mark_ready() if not self.keep_jobs: self.__log.info("Deleting job path: %s", job_path) shutil.rmtree(job_path, ignore_errors=True)
[docs] def execute(self, context): """Same as run but for Airflow.""" self.tmpdir = context['params']['tmpdir'] self.run()