training

Utilities for managing the training process

class pydrobert.torch.training.TrainingStateController(params, state_csv_path=None, state_dir=None, warn=True, reduce_op=None)[source]

Controls the state of training a model

This class is used to help both control and persist experiment information like the current epoch, the model parameters, and model error. It assumes that the values stored in params have not changed when resuming a run. It is also used to control learning rates and early stopping.

Parameters:
  • params (TrainingStateParams) –

  • state_csv_path (Optional[str]) –

    A path to where training state information is stored. It stores in comma-separated-values format the following information. Note that stored values represent the state after updates due to epoch results, such as the learning rate. That way, an experiment can be resumed without worrying about updating the loaded results.

    1. ”epoch”: the epoch associated with this row of information

    2. ”es_resume_cd”: the number of epochs left before the early stopping criterion begins/resumes

    3. es_patience_cd: the number of epochs left that must pass without much improvement before training halts due to early stopping

    4. ”rlr_resume_cd”: the number of epochs left before the criterion for reducing the learning rate begins/resumes

    5. ”rlr_patience_cd”: the number of epochs left that must pass without much improvement before the learning rate is reduced

    6. ”lr”: the learning rate of the optimizer after any updates

    7. ”train_met”: mean training metric in exponent format. The metric is assumed to be lower is better

    8. ”val_met”: mean validation metric in exponent format. The metric is assumed to be lower is better

    9. Any additional entries added through add_entry()

    If unset, the history will not be stored/loaded.

  • state_dir (Optional[str]) – A path to a directory to store/load model and optimizer states. If unset, the information will not be stored/loaded.

  • warn (bool) – Whether to warn using warnings module when a format string does not contain the “epoch” field.

  • reduce_op (Optional[ReduceOp]) – The op to combine metrics and other reducable ops in a distributed environment. See the note below for more details.

Examples

>>> params = TrainingStateParams(num_epochs=5)
>>> model = torch.nn.Linear(10, 1)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> controller = TrainingStateController(
...    params,
...    state_csv_path='log.csv',
...    state_dir='states')
>>> # load previous
>>> controller.load_model_and_optimizer_for_epoch(model, optimizer)
>>> for epoch in range(params.num_epochs):
>>>     # do training loop for epoch
>>>     train_loss, val_loss = 0.1, 0.01
>>>     if not controller.update_for_epoch(
...             model, optimizer, train_loss, val_loss):
>>>         break  # early stopping

Warning

Prior to v0.4.0, the cache of history was updated automatically (reading from state_csv) whenever get_last_epoch(), get_best_epoch(), add_entry(), or get_info() (when the info was missing from the cache) was called. Now, the cache is only updated automatically on initialization and with calls to add_entry(). The cache may still be updated manually via update_cache(). There was no good reason to continuously update the cache as any updates to the underlying file by other processes could ruin the control flow anyways.

Notes

TrainingStateController has rudimentary support for distributed training via torch.nn.parallel.DistributedDataParallel. Please read the tutorial to understand the basics of the environment before continuing.

Simple training loops involving a TrainingStateController, like in the example above, should work with only the usual distributed boilerplate (spawning the pool, initializing process group, and wrapping the model with DistributedDataParallel). The controller should be created and update_for_epoch() called in each worker.

The only values which require coordinated over workers by default are the training and validation metrics; the rest – early stopping, learning rate reduction, current epoch, and so on – will stay in sync across workers. If a custom entry in the state history needs to be coordinated (i.e. it depends directly on the data seen over the epoch, not on the training or validation metrics), the reduce flag of add_entry() can be set to True and that value will likewise be coordinated on the call to update_for_epoch(). reduce_op determines how the relevant values are coordinated. An average is taken by default by first dividing each copy of the value by the world size and then summing the copies together via torch.distributed.ReduceOp.SUM. See torch.distributed.ReduceOp for other options.

add_entry(name, typ=<class 'str'>, fmt='{}', reduce=False)[source]

Add an entry to to be stored and retrieved at every epoch

This method is useful when training loops need specialized, persistent information on every epoch. Prior to the first time any information is saved via update_for_epoch(), this method can be called with an entry name and optional typ. The user is then expected to provide a keyword argument with that name every time update_for_epoch() is called. The values of those entries can be retrieved via get_info(), cast to typ, for any saved epoch

Parameters:
  • name (str) – The name/key of the entry.

  • typ (type) – Should be a type that is serialized from a string via typ(str_obj) and serialized to a string via fmt.format(obj).

  • fmt (str) – The format string used to serialize the objects into strings.

  • reduce (bool) – If True and in a distributed environment, the value will be synchronized across workers via a reduction op on each call to update_for_epoch() see the notes in the class documentation for more information.

Examples

>>> params = TrainingStateParams()
>>> controller = TrainingStateController(params)
>>> model = torch.nn.Linear(10, 1)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> controller.add_entry('important_value', int)
>>> controller.update_for_epoch(
...     model, optimizer, 0.1, 0.1, important_value=3)
>>> controller.update_for_epoch(
...     model, optimizer, 0.2, 0.01, important_value=4)
>>> assert controller[1]['important_value'] == 3
>>> assert controller[2]['important_value'] == 4

Notes

add_entry() must be called prior to update_for_epoch() or save_info_to_hist(), or it may corrupt the experiment history. However, the controller can safely ignore additional entries when loading history from a CSV. Thus, there is no need to call add_entry() if no new training is to be done (unless those entries are needed outside of training).

continue_training(epoch=None)[source]

Return a boolean on whether to continue training

Useful when resuming training. Will check the training history at the target epoch and determine whether training should continue from that point, based on the total number of epochs and the early stopping criterion.

Parameters:

epoch (Optional[int]) – The epoch to check the history of. If unset, the last epoch will be inferred.

Returns:

cont (bool) – True if training should continue

delete_model_and_optimizer_for_epoch(epoch)[source]

Delete state dicts for model and epoch off of disk, if they exist

This method does nothing if the epoch records or the files do not exist.

Parameters:

epoch (int) –

get_best_epoch(train_met=False)[source]

Get the epoch that has lead to the best validation metric val so far

The “best” is the lowest recorded validation metric. In the case of ties, the earlier epoch is chosen.

Parameters:

train_met (bool) – If True look for the best training metric value instead

Returns:

epoch (int) – The corresponding ‘best’ epoch, or 0 if no epochs have run

Notes

Negligible differences between epochs are determined by TrainingStateController.METRIC_PRECISION, which is relative to the metrics base 10. This is in contrast to early stopping criteria and learning rate annealing, whose thresholds are absolute.

get_info(epoch, *default)[source]

Get history entries for a specific epoch

If there’s an entry present for epoch, return it. The value is a dictionary with the keys “epoch”, “es_resume_cd”, “es_patience_cd”, “rlr_resume_cd”, “rlr_patience_cd”, “lr”, “train_met”, and “val_met”, as well as any additional entries specified through add_entry().

If there’s no entry for epoch, and no additional arguments were passed to this method, it raises a KeyError. If an additional argument was passed to this method, return it.

get_last_epoch()[source]

Return the last finished epoch from training, or 0 if no history

load_model_and_optimizer_for_epoch(model, optimizer, epoch=None, strict=True)[source]

Load up model and optimizer states, or initialize them

Parameters:
  • model (Module) – Model state will be loaded into this.

  • optimizer (Optimizer) – Optimizer state will be loaded into this.

  • epoch (Optional[int]) – The epoch from which the states should be loaded. We look for the appropriately named files in self.state_dir. If epoch is unset, the last epoch in recorded history will be loaded. If it’s 0, the model and optimizer are initialized with states for the beginning of the experiment.

  • strict (bool) – Whether to strictly enforce that the keys in model.state_dict() match those that were saved.

load_model_for_epoch(model, epoch=None, strict=True)[source]

Load up just the model, or initialize it

Parameters:
  • model (Module) – Model state will be loaded into this.

  • epoch (Optional[int]) – The epoch from which the states should be loaded. We look for the appropriately named files in self.state_dir. If epoch is None, the best epoch in recorded history will be loaded. If it’s 0, the model is initialized with states from the beginning of the experiment.

  • strict (bool) – Whether to strictly enforce that the keys in model.state_dict() match those that were saved.

save_info_to_hist(info)[source]

Append history entries to the history csv

This is called automatically during update_for_epoch(). Does not save if there is no file to save to (i.e. self.state_csv_path is None). Values are appended to the end of the csv file - no checking is performed for mantaining a valid history.

Parameters:

info (dict) –

save_model_and_optimizer_with_info(model, optimizer, info)[source]

Save model and optimizer state dictionaries to files given epoch info

This is called automatically during update_for_epoch(). Does not save if there is no directory to save to (i.e. self.state_dir is None). Format strings from self.params are formatted with the values from info to construct the base names of each file

Parameters:
  • model (Module) – The model whose state dictionary will be saved.

  • optimizer (Optimizer) – The optimizer whose state dictionary will be saved.

  • info (dict) – The history dictionary. Entries can be used in the state dict’s path’s format strings.

update_cache()[source]

Update the cache with history stored in state_csv_path

update_for_epoch(model, optimizer, train_met, val_met, epoch=None, best_is_train=False, **kwargs)[source]

Update history, model, and optimizer after latest epoch results

Parameters:
  • model (Module) – The model after the epoch that just finished.

  • optimizer (Optimizer) – The optimizer after the epoch that just finished.

  • train_met (float) – Mean value of metric on training set for epoch.

  • val_met (float) – Mean value of metric on validation set for epoch.

  • epoch (Optional[int]) – The epoch that just finished. If unset, it is inferred to be one after the last epoch in the history.

  • best_is_train (bool) – Whether to just the best model in record by training set (False is validation)

  • **kwargs – Additional keyword arguments can be used to specify the values of entries specified via add_entry().

Returns:

cont (bool) – Whether to continue training. This can be set to False either by hitting the max number of epochs or by early stopping.

class pydrobert.torch.training.TrainingStateParams(*, early_stopping_burnin, early_stopping_patience, early_stopping_threshold, keep_last_and_best_only, log10_learning_rate, num_epochs, reduce_lr_burnin, reduce_lr_cooldown, reduce_lr_factor, reduce_lr_log10_epsilon, reduce_lr_patience, reduce_lr_threshold, saved_model_fmt, saved_optimizer_fmt, seed, name)[source]

Parameters controlling a TrainingStateController

This class implements the pydrobert.param.optuna.TunableParameterized interface

classmethod get_tunable()[source]

Returns a set of tunable parameters

classmethod suggest_params(trial, base=None, only=None, prefix='')[source]

Populate a parameterized instance with values from trial