training
Utilities for managing the training process
- class pydrobert.torch.training.TrainingStateController(params, state_csv_path=None, state_dir=None, warn=True, reduce_op=None)[source]
Controls the state of training a model
This class is used to help both control and persist experiment information like the current epoch, the model parameters, and model error. It assumes that the values stored in params have not changed when resuming a run. It is also used to control learning rates and early stopping.
- Parameters:
params (
TrainingStateParams
) –state_csv_path (
Optional
[str
]) –A path to where training state information is stored. It stores in comma-separated-values format the following information. Note that stored values represent the state after updates due to epoch results, such as the learning rate. That way, an experiment can be resumed without worrying about updating the loaded results.
”epoch”: the epoch associated with this row of information
”es_resume_cd”: the number of epochs left before the early stopping criterion begins/resumes
es_patience_cd: the number of epochs left that must pass without much improvement before training halts due to early stopping
”rlr_resume_cd”: the number of epochs left before the criterion for reducing the learning rate begins/resumes
”rlr_patience_cd”: the number of epochs left that must pass without much improvement before the learning rate is reduced
”lr”: the learning rate of the optimizer after any updates
”train_met”: mean training metric in exponent format. The metric is assumed to be lower is better
”val_met”: mean validation metric in exponent format. The metric is assumed to be lower is better
Any additional entries added through
add_entry()
If unset, the history will not be stored/loaded.
state_dir (
Optional
[str
]) – A path to a directory to store/load model and optimizer states. If unset, the information will not be stored/loaded.warn (
bool
) – Whether to warn usingwarnings
module when a format string does not contain the “epoch” field.reduce_op (
Optional
[ReduceOp
]) – The op to combine metrics and other reducable ops in a distributed environment. See the note below for more details.
Examples
>>> params = TrainingStateParams(num_epochs=5) >>> model = torch.nn.Linear(10, 1) >>> optimizer = torch.optim.Adam(model.parameters()) >>> controller = TrainingStateController( ... params, ... state_csv_path='log.csv', ... state_dir='states') >>> # load previous >>> controller.load_model_and_optimizer_for_epoch(model, optimizer) >>> for epoch in range(params.num_epochs): >>> # do training loop for epoch >>> train_loss, val_loss = 0.1, 0.01 >>> if not controller.update_for_epoch( ... model, optimizer, train_loss, val_loss): >>> break # early stopping
Warning
Prior to v0.4.0, the cache of history was updated automatically (reading from state_csv) whenever
get_last_epoch()
,get_best_epoch()
,add_entry()
, orget_info()
(when the info was missing from the cache) was called. Now, the cache is only updated automatically on initialization and with calls toadd_entry()
. The cache may still be updated manually viaupdate_cache()
. There was no good reason to continuously update the cache as any updates to the underlying file by other processes could ruin the control flow anyways.Notes
TrainingStateController
has rudimentary support for distributed training viatorch.nn.parallel.DistributedDataParallel
. Please read the tutorial to understand the basics of the environment before continuing.Simple training loops involving a
TrainingStateController
, like in the example above, should work with only the usual distributed boilerplate (spawning the pool, initializing process group, and wrapping the model withDistributedDataParallel
). The controller should be created andupdate_for_epoch()
called in each worker.The only values which require coordinated over workers by default are the training and validation metrics; the rest – early stopping, learning rate reduction, current epoch, and so on – will stay in sync across workers. If a custom entry in the state history needs to be coordinated (i.e. it depends directly on the data seen over the epoch, not on the training or validation metrics), the reduce flag of
add_entry()
can be set toTrue
and that value will likewise be coordinated on the call toupdate_for_epoch()
. reduce_op determines how the relevant values are coordinated. An average is taken by default by first dividing each copy of the value by the world size and then summing the copies together viatorch.distributed.ReduceOp.SUM
. Seetorch.distributed.ReduceOp
for other options.- add_entry(name, typ=<class 'str'>, fmt='{}', reduce=False)[source]
Add an entry to to be stored and retrieved at every epoch
This method is useful when training loops need specialized, persistent information on every epoch. Prior to the first time any information is saved via
update_for_epoch()
, this method can be called with an entry name and optional typ. The user is then expected to provide a keyword argument with that name every timeupdate_for_epoch()
is called. The values of those entries can be retrieved viaget_info()
, cast to typ, for any saved epoch- Parameters:
name (
str
) – The name/key of the entry.typ (
type
) – Should be a type that is serialized from a string viatyp(str_obj)
and serialized to a string viafmt.format(obj)
.fmt (
str
) – The format string used to serialize the objects into strings.reduce (
bool
) – IfTrue
and in a distributed environment, the value will be synchronized across workers via a reduction op on each call toupdate_for_epoch()
see the notes in the class documentation for more information.
Examples
>>> params = TrainingStateParams() >>> controller = TrainingStateController(params) >>> model = torch.nn.Linear(10, 1) >>> optimizer = torch.optim.Adam(model.parameters()) >>> controller.add_entry('important_value', int) >>> controller.update_for_epoch( ... model, optimizer, 0.1, 0.1, important_value=3) >>> controller.update_for_epoch( ... model, optimizer, 0.2, 0.01, important_value=4) >>> assert controller[1]['important_value'] == 3 >>> assert controller[2]['important_value'] == 4
Notes
add_entry()
must be called prior toupdate_for_epoch()
orsave_info_to_hist()
, or it may corrupt the experiment history. However, the controller can safely ignore additional entries when loading history from a CSV. Thus, there is no need to calladd_entry()
if no new training is to be done (unless those entries are needed outside of training).
- continue_training(epoch=None)[source]
Return a boolean on whether to continue training
Useful when resuming training. Will check the training history at the target epoch and determine whether training should continue from that point, based on the total number of epochs and the early stopping criterion.
- delete_model_and_optimizer_for_epoch(epoch)[source]
Delete state dicts for model and epoch off of disk, if they exist
This method does nothing if the epoch records or the files do not exist.
- Parameters:
epoch (
int
) –
- get_best_epoch(train_met=False)[source]
Get the epoch that has lead to the best validation metric val so far
The “best” is the lowest recorded validation metric. In the case of ties, the earlier epoch is chosen.
- Parameters:
train_met (
bool
) – IfTrue
look for the best training metric value instead- Returns:
epoch (
int
) – The corresponding ‘best’ epoch, or0
if no epochs have run
Notes
Negligible differences between epochs are determined by
TrainingStateController.METRIC_PRECISION
, which is relative to the metrics base 10. This is in contrast to early stopping criteria and learning rate annealing, whose thresholds are absolute.
- get_info(epoch, *default)[source]
Get history entries for a specific epoch
If there’s an entry present for epoch, return it. The value is a dictionary with the keys “epoch”, “es_resume_cd”, “es_patience_cd”, “rlr_resume_cd”, “rlr_patience_cd”, “lr”, “train_met”, and “val_met”, as well as any additional entries specified through
add_entry()
.If there’s no entry for epoch, and no additional arguments were passed to this method, it raises a
KeyError
. If an additional argument was passed to this method, return it.
- load_model_and_optimizer_for_epoch(model, optimizer, epoch=None, strict=True)[source]
Load up model and optimizer states, or initialize them
- Parameters:
model (
Module
) – Model state will be loaded into this.optimizer (
Optimizer
) – Optimizer state will be loaded into this.epoch (
Optional
[int
]) – The epoch from which the states should be loaded. We look for the appropriately named files inself.state_dir
. If epoch is unset, the last epoch in recorded history will be loaded. If it’s 0, the model and optimizer are initialized with states for the beginning of the experiment.strict (
bool
) – Whether to strictly enforce that the keys inmodel.state_dict()
match those that were saved.
- load_model_for_epoch(model, epoch=None, strict=True)[source]
Load up just the model, or initialize it
- Parameters:
model (
Module
) – Model state will be loaded into this.epoch (
Optional
[int
]) – The epoch from which the states should be loaded. We look for the appropriately named files inself.state_dir
. If epoch isNone
, the best epoch in recorded history will be loaded. If it’s 0, the model is initialized with states from the beginning of the experiment.strict (
bool
) – Whether to strictly enforce that the keys inmodel.state_dict()
match those that were saved.
- save_info_to_hist(info)[source]
Append history entries to the history csv
This is called automatically during
update_for_epoch()
. Does not save if there is no file to save to (i.e.self.state_csv_path is None
). Values are appended to the end of the csv file - no checking is performed for mantaining a valid history.- Parameters:
info (
dict
) –
- save_model_and_optimizer_with_info(model, optimizer, info)[source]
Save model and optimizer state dictionaries to files given epoch info
This is called automatically during
update_for_epoch()
. Does not save if there is no directory to save to (i.e.self.state_dir is None
). Format strings fromself.params
are formatted with the values from info to construct the base names of each file
- update_for_epoch(model, optimizer, train_met, val_met, epoch=None, best_is_train=False, **kwargs)[source]
Update history, model, and optimizer after latest epoch results
- Parameters:
model (
Module
) – The model after the epoch that just finished.optimizer (
Optimizer
) – The optimizer after the epoch that just finished.train_met (
float
) – Mean value of metric on training set for epoch.val_met (
float
) – Mean value of metric on validation set for epoch.epoch (
Optional
[int
]) – The epoch that just finished. If unset, it is inferred to be one after the last epoch in the history.best_is_train (
bool
) – Whether to just the best model in record by training set (False
is validation)**kwargs – Additional keyword arguments can be used to specify the values of entries specified via
add_entry()
.
- Returns:
cont (
bool
) – Whether to continue training. This can be set toFalse
either by hitting the max number of epochs or by early stopping.
- class pydrobert.torch.training.TrainingStateParams(*, early_stopping_burnin, early_stopping_patience, early_stopping_threshold, keep_last_and_best_only, log10_learning_rate, num_epochs, reduce_lr_burnin, reduce_lr_cooldown, reduce_lr_factor, reduce_lr_log10_epsilon, reduce_lr_patience, reduce_lr_threshold, saved_model_fmt, saved_optimizer_fmt, seed, name)[source]
Parameters controlling a TrainingStateController
This class implements the
pydrobert.param.optuna.TunableParameterized
interface