Callback
A callback is a self-contained program that can be reused across projects.
Lightning has a callback system to execute callbacks when needed. Callbacks should capture NON-ESSENTIAL
logic that is NOT required for your lightning module to run.
Here’s the flow of how the callback hooks are executed:
An overall Lightning system should have:
Trainer for all engineering
LightningModule for all research code.
Callbacks for non-essential code.
Example:
from pytorch_lightning.callbacks import Callback
class MyPrintingCallback(Callback):
def on_init_start(self, trainer):
print('Starting to init trainer!')
def on_init_end(self, trainer):
print('trainer is init now')
def on_train_end(self, trainer, pl_module):
print('do something when training ends')
trainer = Trainer(callbacks=[MyPrintingCallback()])
Starting to init trainer!
trainer is init now
We successfully extended functionality without polluting our super clean
lightning module research code.
Examples
You can do pretty much anything with callbacks.
Built-in Callbacks
Lightning has a few built-in callbacks.
Note
For a richer collection of callbacks, check out our
bolts library.
BackboneFinetuning
|
Finetune a backbone model based on a learning rate user-defined scheduling. |
BaseFinetuning
|
This class implements the base logic for writing your own Finetuning Callback. |
Callback
|
Abstract base class used to build new callbacks. |
EarlyStopping
|
Monitor a metric and stop training when it stops improving. |
GPUStatsMonitor
|
Automatically monitors and logs GPU stats during training stage. |
GradientAccumulationScheduler
|
Change gradient accumulation factor according to scheduling. |
LambdaCallback
|
Create a simple callback on the fly using lambda functions. |
LearningRateMonitor
|
Automatically monitor and logs learning rate for learning rate schedulers during training. |
ModelCheckpoint
|
Save the model after every epoch by monitoring a quantity. |
ModelPruning
|
Model pruning Callback, using PyTorch’s prune utilities. |
ProgressBar
|
This is the default progress bar used by Lightning. |
ProgressBarBase
|
The base class for progress bars in Lightning. |
QuantizationAwareTraining
|
Quantization allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision. |
StochasticWeightAveraging
|
Implements the Stochastic Weight Averaging (SWA) Callback to average a model. |
Persisting State
Some callbacks require internal state in order to function properly. You can optionally
choose to persist your callback’s state as part of model checkpoint files using the callback hooks
on_save_checkpoint()
and on_load_checkpoint()
.
However, you must follow two constraints:
Your returned state must be able to be pickled.
You can only use one instance of that class in the Trainer callbacks list. We don’t support persisting state for multiple callbacks of the same class.
Best Practices
The following are best practices when using/designing callbacks.
Callbacks should be isolated in their functionality.
Your callback should not rely on the behavior of other callbacks in order to work properly.
Do not manually call methods from the callback.
Directly calling methods (eg. on_validation_end) is strongly discouraged.
Whenever possible, your callbacks should not depend on the order in which they are executed.
Available Callback hooks
setup
-
Callback.
setup
(trainer, pl_module, stage)[source]
Called when fit or test begins
- Return type
None
teardown
-
Callback.
teardown
(trainer, pl_module, stage)[source]
Called when fit or test ends
- Return type
None
on_init_start
-
Callback.
on_init_start
(trainer)[source]
Called when the trainer initialization begins, model has not yet been set.
- Return type
None
on_init_end
-
Callback.
on_init_end
(trainer)[source]
Called when the trainer initialization ends, model has not yet been set.
- Return type
None
on_fit_start
-
Callback.
on_save_checkpoint
(trainer, pl_module, checkpoint)[source]
Called when saving a model checkpoint, use to persist state.
- Parameters
trainer – the current Trainer instance.
pl_module (LightningModule
) – the current LightningModule instance.
checkpoint (Dict
[str
, Any
]) – the checkpoint dictionary that will be saved.
- Return type
dict
- Returns
The callback state.
on_fit_end
-
Callback.
on_fit_end
(trainer, pl_module)[source]
Called when fit ends
- Return type
None
on_sanity_check_start
-
Callback.
on_sanity_check_start
(trainer, pl_module)[source]
Called when the validation sanity check starts.
- Return type
None
on_sanity_check_end
-
Callback.
on_sanity_check_end
(trainer, pl_module)[source]
Called when the validation sanity check ends.
- Return type
None
on_train_batch_start
-
Callback.
on_train_batch_start
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]
Called when the train batch begins.
- Return type
None
on_train_batch_end
-
Callback.
on_train_batch_end
(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]
Called when the train batch ends.
- Return type
None
on_train_epoch_start
-
Callback.
on_train_epoch_start
(trainer, pl_module)[source]
Called when the train epoch begins.
- Return type
None
on_train_epoch_end
-
Callback.
on_train_epoch_end
(trainer, pl_module, outputs)[source]
Called when the train epoch ends.
- Return type
None
on_validation_epoch_start
-
Callback.
on_validation_epoch_start
(trainer, pl_module)[source]
Called when the val epoch begins.
- Return type
None
on_validation_epoch_end
-
Callback.
on_validation_epoch_end
(trainer, pl_module)[source]
Called when the val epoch ends.
- Return type
None
on_test_epoch_start
-
Callback.
on_test_epoch_start
(trainer, pl_module)[source]
Called when the test epoch begins.
- Return type
None
on_test_epoch_end
-
Callback.
on_test_epoch_end
(trainer, pl_module)[source]
Called when the test epoch ends.
- Return type
None
on_epoch_start
-
Callback.
on_epoch_start
(trainer, pl_module)[source]
Called when either of train/val/test epoch begins.
- Return type
None
on_epoch_end
-
Callback.
on_epoch_end
(trainer, pl_module)[source]
Called when either of train/val/test epoch ends.
- Return type
None
on_batch_start
-
Callback.
on_batch_start
(trainer, pl_module)[source]
Called when the training batch begins.
- Return type
None
on_validation_batch_start
-
Callback.
on_validation_batch_start
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]
Called when the validation batch begins.
- Return type
None
on_validation_batch_end
-
Callback.
on_validation_batch_end
(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]
Called when the validation batch ends.
- Return type
None
on_test_batch_start
-
Callback.
on_test_batch_start
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]
Called when the test batch begins.
- Return type
None
on_test_batch_end
-
Callback.
on_test_batch_end
(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]
Called when the test batch ends.
- Return type
None
on_batch_end
-
Callback.
on_batch_end
(trainer, pl_module)[source]
Called when the training batch ends.
- Return type
None
on_train_start
-
Callback.
on_train_start
(trainer, pl_module)[source]
Called when the train begins.
- Return type
None
on_train_end
-
Callback.
on_train_end
(trainer, pl_module)[source]
Called when the train ends.
- Return type
None
on_pretrain_routine_start
-
Callback.
on_pretrain_routine_start
(trainer, pl_module)[source]
Called when the pretrain routine begins.
- Return type
None
on_pretrain_routine_end
-
Callback.
on_pretrain_routine_end
(trainer, pl_module)[source]
Called when the pretrain routine ends.
- Return type
None
on_validation_start
-
Callback.
on_validation_start
(trainer, pl_module)[source]
Called when the validation loop begins.
- Return type
None
on_validation_end
-
Callback.
on_validation_end
(trainer, pl_module)[source]
Called when the validation loop ends.
- Return type
None
on_test_start
-
Callback.
on_test_start
(trainer, pl_module)[source]
Called when the test begins.
- Return type
None
on_test_end
-
Callback.
on_test_end
(trainer, pl_module)[source]
Called when the test ends.
- Return type
None
on_keyboard_interrupt
-
Callback.
on_keyboard_interrupt
(trainer, pl_module)[source]
Called when the training is interrupted by KeyboardInterrupt
.
- Return type
None
on_save_checkpoint
-
Callback.
on_save_checkpoint
(trainer, pl_module, checkpoint)[source]
Called when saving a model checkpoint, use to persist state.
- Parameters
trainer – the current Trainer instance.
pl_module (LightningModule
) – the current LightningModule instance.
checkpoint (Dict
[str
, Any
]) – the checkpoint dictionary that will be saved.
- Return type
dict
- Returns
The callback state.
on_load_checkpoint
-
Callback.
on_load_checkpoint
(callback_state)[source]
Called when loading a model checkpoint, use to reload state.
- Parameters
callback_state (Dict
[str
, Any
]) – the callback state returned by on_save_checkpoint
.
- Return type
None
on_after_backward
-
Callback.
on_after_backward
(trainer, pl_module)[source]
Called after loss.backward()
and before optimizers do anything.
- Return type
None
on_before_zero_grad
-
Callback.
on_before_zero_grad
(trainer, pl_module, optimizer)[source]
Called after optimizer.step()
and before optimizer.zero_grad()
.
- Return type
None