Callback¶
A callback is a self-contained program that can be reused across projects.
Lightning has a callback system to execute callbacks when needed. Callbacks should capture NON-ESSENTIAL logic that is NOT required for your LightningModule to run.
Here’s the flow of how the callback hooks are executed:
An overall Lightning system should have:
Trainer for all engineering
LightningModule for all research code.
Callbacks for non-essential code.
Example:
from pytorch_lightning.callbacks import Callback
class MyPrintingCallback(Callback):
def on_init_start(self, trainer):
print('Starting to init trainer!')
def on_init_end(self, trainer):
print('trainer is init now')
def on_train_end(self, trainer, pl_module):
print('do something when training ends')
trainer = Trainer(callbacks=[MyPrintingCallback()])
Starting to init trainer!
trainer is init now
We successfully extended functionality without polluting our super clean LightningModule research code.
Examples¶
You can do pretty much anything with callbacks.
Built-in Callbacks¶
Lightning has a few built-in callbacks.
Note
For a richer collection of callbacks, check out our bolts library.
Abstract base class used to build new callbacks. |
|
Monitor a validation metric and stop training when it stops improving. |
|
Automatically monitors and logs GPU stats during training stage. |
|
Change gradient accumulation factor according to scheduling. |
|
Automatically monitor and logs learning rate for learning rate schedulers during training. |
|
Save the model after every epoch by monitoring a quantity. |
|
This is the default progress bar used by Lightning. |
|
The base class for progress bars in Lightning. |
Persisting State¶
Some callbacks require internal state in order to function properly. You can optionally
choose to persist your callback’s state as part of model checkpoint files using the callback hooks
on_save_checkpoint()
and on_load_checkpoint()
.
However, you must follow two constraints:
Your returned state must be able to be pickled.
You can only use one instance of that class in the Trainer callbacks list. We don’t support persisting state for multiple callbacks of the same class.
Best Practices¶
The following are best practices when using/designing callbacks.
Callbacks should be isolated in their functionality.
Your callback should not rely on the behavior of other callbacks in order to work properly.
Do not manually call methods from the callback.
Directly calling methods (eg. on_validation_end) is strongly discouraged.
Whenever possible, your callbacks should not depend on the order in which they are executed.
Available Callback hooks¶
on_init_start¶
-
Callback.
on_init_start
(trainer)[source] Called when the trainer initialization begins, model has not yet been set.
on_init_end¶
-
Callback.
on_init_end
(trainer)[source] Called when the trainer initialization ends, model has not yet been set.
on_fit_start¶
-
Callback.
on_save_checkpoint
(trainer, pl_module)[source] Called when saving a model checkpoint, use to persist state.
on_sanity_check_start¶
-
Callback.
on_sanity_check_start
(trainer, pl_module)[source] Called when the validation sanity check starts.
on_sanity_check_end¶
-
Callback.
on_sanity_check_end
(trainer, pl_module)[source] Called when the validation sanity check ends.
on_train_batch_start¶
-
Callback.
on_train_batch_start
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the train batch begins.
on_train_batch_end¶
-
Callback.
on_train_batch_end
(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source] Called when the train batch ends.
on_train_epoch_start¶
-
Callback.
on_train_epoch_start
(trainer, pl_module)[source] Called when the train epoch begins.
on_train_epoch_end¶
-
Callback.
on_train_epoch_end
(trainer, pl_module, outputs)[source] Called when the train epoch ends.
on_validation_epoch_start¶
-
Callback.
on_validation_epoch_start
(trainer, pl_module)[source] Called when the val epoch begins.
on_validation_epoch_end¶
-
Callback.
on_validation_epoch_end
(trainer, pl_module)[source] Called when the val epoch ends.
on_test_epoch_start¶
-
Callback.
on_test_epoch_start
(trainer, pl_module)[source] Called when the test epoch begins.
on_test_epoch_end¶
-
Callback.
on_test_epoch_end
(trainer, pl_module)[source] Called when the test epoch ends.
on_batch_start¶
-
Callback.
on_batch_start
(trainer, pl_module)[source] Called when the training batch begins.
on_validation_batch_start¶
-
Callback.
on_validation_batch_start
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the validation batch begins.
on_validation_batch_end¶
-
Callback.
on_validation_batch_end
(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source] Called when the validation batch ends.
on_test_batch_start¶
-
Callback.
on_test_batch_start
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the test batch begins.
on_test_batch_end¶
-
Callback.
on_test_batch_end
(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source] Called when the test batch ends.
on_batch_end¶
-
Callback.
on_batch_end
(trainer, pl_module)[source] Called when the training batch ends.
on_pretrain_routine_start¶
-
Callback.
on_pretrain_routine_start
(trainer, pl_module)[source] Called when the pretrain routine begins.
on_pretrain_routine_end¶
-
Callback.
on_pretrain_routine_end
(trainer, pl_module)[source] Called when the pretrain routine ends.
on_validation_start¶
-
Callback.
on_validation_start
(trainer, pl_module)[source] Called when the validation loop begins.
on_validation_end¶
-
Callback.
on_validation_end
(trainer, pl_module)[source] Called when the validation loop ends.
on_keyboard_interrupt¶
-
Callback.
on_keyboard_interrupt
(trainer, pl_module)[source] Called when the training is interrupted by KeyboardInterrupt.