StochasticWeightAveraging(swa_epoch_start=0.8, swa_lrs=None, annealing_epochs=10, annealing_strategy='cos', avg_fn=None, device=torch.device)¶
Implements the Stochastic Weight Averaging (SWA) Callback to average a model.
Stochastic Weight Averaging was proposed in
Averaging Weights Leads to Wider Optima and Better Generalizationby Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018).
This documentation is highly inspired by PyTorch’s work on SWA. The callback arguments follow the scheme defined in PyTorch’s
For a SWA explanation, please take a look here.
StochasticWeightAveragingis in beta and subject to change.
StochasticWeightAveragingis currently not supported for multiple optimizers/schedulers.
StochasticWeightAveragingis currently only supported on every epoch.
SWA can easily be activated directly from the Trainer as follow:
float]) – If provided as int, the procedure will start from the
swa_epoch_start-th epoch. If provided as float between 0 and 1, the procedure will start from
int(swa_epoch_start * max_epochs)epoch
Specifies the annealing strategy (default: “cos”):
"cos". For cosine annealing.
"linear"For linear annealing
FloatTensor]]) – the averaging function used to update the parameters; the function must take in the current value of the
AveragedModelparameter, the current value of
modelparameter and the number of models already averaged; if None, equally weighted average is used (default:
avg_fn(averaged_model_parameter, model_parameter, num_averaged)¶
- Return type
Called before accelerator is being setup
Called when fit begins
Called when the train ends.
Called when the train epoch ends.
Called when the train epoch begins.
update_parameters(average_model, model, n_averaged, avg_fn)¶