Train Module

class visualize_training.train.ModelManager(model, train_dataloader, test_dataloader, config, full_dataloader=None)

The ModelManager is one of the core modules of Visualization Training. Its the main class which handles all the model training, data collection, metrics calculation, hookpoint configurations and metrics management.

attach_hooks(layers)

Method for configuring different hookpoints on specified layers.

Parameters:

layers (List) – List of layers where hookpoints needs to be configured

clear_metrics_cache()

Clear all the metrics cache.

clear_weights_biases_cache()

Clear the the current weights and biases cache.

compute_metrics()

Computing the required metrics from the weights and biases cache. For the complete list of metrics, refer source/metrics

remove_hooks()

Remove all the hooks currently configured in the model.

save_metrics()

Saving the metrics in a proper file directory structure.

train()

Primary Training function which consists of the training loop and some part of metrics collection

train_and_save_metrics()

Wrapper function to train and save metrics in one go.