chemicalchecker.tool.adanet.adanet_wrap.AdaNetWrapper

class AdaNetWrapper(traintest_file, **kwargs)[source]

Bases: object

Wrapper class adapted from scripted examples on AdaNet’s github.

https://github.com/tensorflow/adanet/blob/master/adanet/ examples/tutorials/adanet_objective.ipynb

Methods

architecture

Extract the ensemble architecture from evaluation results.

get_trainable_variables

Return the weigths of the trained neural network.

input_fn

Generate an input function for the Estimator.

predict

Load model and return predictions.

predict_fn

Load model and return the predict function.

predict_online

Predict on given testset without killing the memory.

print_model_architechture

Print out the NAS network architechture structure.

save_model

Print out the NAS network architechture structure.

save_performances

Save stats and make plots.

train_and_evaluate

Train and evaluate AdaNet.

architecture()[source]

Extract the ensemble architecture from evaluation results.

static get_trainable_variables(model_dir)[source]

Return the weigths of the trained neural network.

Parameters:

model_dir (str) – path where of the saved model.

input_fn(split, training, augmentation=False)[source]

Generate an input function for the Estimator.

Parameters:
  • split (str) – the split to use within the traintest file.

  • training (bool) – whether we are training or evaluating.

  • augmentation (func) – a function to aument data, False if no aumentation is desired.

static predict(features, predict_fn=None, mask_fn=None, probs=False, samples=10, model_dir=None, consensus=False)[source]

Load model and return predictions.

Parameters:
  • model_dir (str) – path where to save the model.

  • features (matrix) – a numpy matrix of Xs.

  • predict_fn (func) – the predict function returned by predict_fn.

  • probs (bool) – if this is a classifier return the probabilities.

  • consensus (bool) – return also a sampling for consensus calculation. (regression only).

static predict_fn(model_dir)[source]

Load model and return the predict function.

Parameters:

model_dir (str) – path where to save the model.

static predict_online(h5_file, split, predict_fn=None, mask_fn=None, batch_size=1000, limit=None, probs=False, n_classes=None, model_dir=None)[source]

Predict on given testset without killing the memory.

Parameters:
  • model_dir (str) – path where to save the model.

  • h5_file (str) – path to h5 file compatible with Traintest.

  • split (str) – the split to use within the h5_file.

  • predict_fn (func) – the predict function returned by predict_fn.

  • mask_fn (func) – a function masking part of the input.

  • batch_size (int) – batch size for Traintest file.

  • limit (int) – maximum number of predictions.

  • probs (bool) – if this is a classifier return the probabilities.

static print_model_architechture(model_dir)[source]

Print out the NAS network architechture structure.

Parameters:

model_dir (str) – path where of the saved model.

save_model(model_dir)[source]

Print out the NAS network architechture structure.

Parameters:

model_dir (str) – path where to save the model.

save_performances(output_dir, plot, suffix=None, extra_predictors=None, do_plot=True)[source]

Save stats and make plots.

train_and_evaluate(evaluate=True)[source]

Train and evaluate AdaNet.