chemicalchecker.tool.adanet.adanet_wrap.AdaNetWrapper
- class AdaNetWrapper(traintest_file, **kwargs)[source]
Bases:
object
Wrapper class adapted from scripted examples on AdaNet’s github.
https://github.com/tensorflow/adanet/blob/master/adanet/ examples/tutorials/adanet_objective.ipynb
Methods
Extract the ensemble architecture from evaluation results.
Return the weigths of the trained neural network.
Generate an input function for the Estimator.
Load model and return predictions.
Load model and return the predict function.
Predict on given testset without killing the memory.
Print out the NAS network architechture structure.
Print out the NAS network architechture structure.
Save stats and make plots.
Train and evaluate AdaNet.
- static get_trainable_variables(model_dir)[source]
Return the weigths of the trained neural network.
- Parameters:
model_dir (str) – path where of the saved model.
- input_fn(split, training, augmentation=False)[source]
Generate an input function for the Estimator.
- Parameters:
split (str) – the split to use within the traintest file.
training (bool) – whether we are training or evaluating.
augmentation (func) – a function to aument data, False if no aumentation is desired.
- static predict(features, predict_fn=None, mask_fn=None, probs=False, samples=10, model_dir=None, consensus=False)[source]
Load model and return predictions.
- Parameters:
model_dir (str) – path where to save the model.
features (matrix) – a numpy matrix of Xs.
predict_fn (func) – the predict function returned by predict_fn.
probs (bool) – if this is a classifier return the probabilities.
consensus (bool) – return also a sampling for consensus calculation. (regression only).
- static predict_fn(model_dir)[source]
Load model and return the predict function.
- Parameters:
model_dir (str) – path where to save the model.
- static predict_online(h5_file, split, predict_fn=None, mask_fn=None, batch_size=1000, limit=None, probs=False, n_classes=None, model_dir=None)[source]
Predict on given testset without killing the memory.
- Parameters:
model_dir (str) – path where to save the model.
h5_file (str) – path to h5 file compatible with Traintest.
split (str) – the split to use within the h5_file.
predict_fn (func) – the predict function returned by predict_fn.
mask_fn (func) – a function masking part of the input.
batch_size (int) – batch size for Traintest file.
limit (int) – maximum number of predictions.
probs (bool) – if this is a classifier return the probabilities.
- static print_model_architechture(model_dir)[source]
Print out the NAS network architechture structure.
- Parameters:
model_dir (str) – path where of the saved model.
- save_model(model_dir)[source]
Print out the NAS network architechture structure.
- Parameters:
model_dir (str) – path where to save the model.