- class pytorch_lightning.callbacks.StochasticWeightAveraging(swa_epoch_start=0.8, swa_lrs=None, annealing_epochs=10, annealing_strategy='cos', avg_fn=None, device=torch.device)¶
Implements the Stochastic Weight Averaging (SWA) Callback to average a model.
Stochastic Weight Averaging was proposed in
Averaging Weights Leads to Wider Optima and Better Generalizationby Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018).
This documentation is highly inspired by PyTorch’s work on SWA. The callback arguments follow the scheme defined in PyTorch’s
For a SWA explanation, please take a look here.
StochasticWeightAveragingis in beta and subject to change.
StochasticWeightAveragingis currently not supported for multiple optimizers/schedulers.
StochasticWeightAveragingis currently only supported on every epoch.
SWA can easily be activated directly from the Trainer as follow:
float]) – If provided as int, the procedure will start from the
swa_epoch_start-th epoch. If provided as float between 0 and 1, the procedure will start from
int(swa_epoch_start * max_epochs)epoch
Specifies the annealing strategy (default: “cos”):
"cos". For cosine annealing.
"linear"For linear annealing
FloatTensor]]) – the averaging function used to update the parameters; the function must take in the current value of the
AveragedModelparameter, the current value of
modelparameter and the number of models already averaged; if None, equally weighted average is used (default:
- static avg_fn(averaged_model_parameter, model_parameter, num_averaged)¶
- Return type
- on_before_accelerator_backend_setup(trainer, pl_module)¶
Called before accelerator is being setup
- on_fit_start(trainer, pl_module)¶
Called when fit begins
- on_train_end(trainer, pl_module)¶
Called when the train ends.
- on_train_epoch_end(trainer, *args)¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, either:
Implement training_epoch_end in the LightningModule and access outputs via the module OR
Cache data across train batch hooks inside the callback implementation to post-process in this hook.
- on_train_epoch_start(trainer, pl_module)¶
Called when the train epoch begins.
- static update_parameters(average_model, model, n_averaged, avg_fn)¶