Shortcuts

Callback


A callback is a self-contained program that can be reused across projects.

Lightning has a callback system to execute callbacks when needed. Callbacks should capture NON-ESSENTIAL logic that is NOT required for your LightningModule to run.

Here’s the flow of how the callback hooks are executed:

An overall Lightning system should have:

  1. Trainer for all engineering

  2. LightningModule for all research code.

  3. Callbacks for non-essential code.


Example:

from pytorch_lightning.callbacks import Callback

class MyPrintingCallback(Callback):

    def on_init_start(self, trainer):
        print('Starting to init trainer!')

    def on_init_end(self, trainer):
        print('trainer is init now')

    def on_train_end(self, trainer, pl_module):
        print('do something when training ends')

trainer = Trainer(callbacks=[MyPrintingCallback()])
Starting to init trainer!
trainer is init now

We successfully extended functionality without polluting our super clean LightningModule research code.



Built-in Callbacks

Lightning has a few built-in callbacks.

Note

For a richer collection of callbacks, check out our bolts library.

Callback

Abstract base class used to build new callbacks.

EarlyStopping

Monitor a validation metric and stop training when it stops improving.

GPUStatsMonitor

Automatically monitors and logs GPU stats during training stage.

GradientAccumulationScheduler

Change gradient accumulation factor according to scheduling.

LearningRateMonitor

Automatically monitor and logs learning rate for learning rate schedulers during training.

ModelCheckpoint

Save the model after every epoch by monitoring a quantity.

ProgressBar

This is the default progress bar used by Lightning.

ProgressBarBase

The base class for progress bars in Lightning.


Persisting State

Some callbacks require internal state in order to function properly. You can optionally choose to persist your callback’s state as part of model checkpoint files using the callback hooks on_save_checkpoint() and on_load_checkpoint(). However, you must follow two constraints:

  1. Your returned state must be able to be pickled.

  2. You can only use one instance of that class in the Trainer callbacks list. We don’t support persisting state for multiple callbacks of the same class.

Best Practices

The following are best practices when using/designing callbacks.

  1. Callbacks should be isolated in their functionality.

  2. Your callback should not rely on the behavior of other callbacks in order to work properly.

  3. Do not manually call methods from the callback.

  4. Directly calling methods (eg. on_validation_end) is strongly discouraged.

  5. Whenever possible, your callbacks should not depend on the order in which they are executed.


Available Callback hooks

setup

Callback.setup(trainer, pl_module, stage)[source]

Called when fit or test begins

teardown

Callback.teardown(trainer, pl_module, stage)[source]

Called when fit or test ends

on_init_start

Callback.on_init_start(trainer)[source]

Called when the trainer initialization begins, model has not yet been set.

on_init_end

Callback.on_init_end(trainer)[source]

Called when the trainer initialization ends, model has not yet been set.

on_fit_start

Callback.on_save_checkpoint(trainer, pl_module)[source]

Called when saving a model checkpoint, use to persist state.

on_fit_end

Callback.on_fit_end(trainer, pl_module)[source]

Called when fit ends

on_sanity_check_start

Callback.on_sanity_check_start(trainer, pl_module)[source]

Called when the validation sanity check starts.

on_sanity_check_end

Callback.on_sanity_check_end(trainer, pl_module)[source]

Called when the validation sanity check ends.

on_train_batch_start

Callback.on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]

Called when the train batch begins.

on_train_batch_end

Callback.on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the train batch ends.

on_train_epoch_start

Callback.on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

on_train_epoch_end

Callback.on_train_epoch_end(trainer, pl_module, outputs)[source]

Called when the train epoch ends.

on_validation_epoch_start

Callback.on_validation_epoch_start(trainer, pl_module)[source]

Called when the val epoch begins.

on_validation_epoch_end

Callback.on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

on_test_epoch_start

Callback.on_test_epoch_start(trainer, pl_module)[source]

Called when the test epoch begins.

on_test_epoch_end

Callback.on_test_epoch_end(trainer, pl_module)[source]

Called when the test epoch ends.

on_epoch_start

Callback.on_epoch_start(trainer, pl_module)[source]

Called when the epoch begins.

on_epoch_end

Callback.on_epoch_end(trainer, pl_module)[source]

Called when the epoch ends.

on_batch_start

Callback.on_batch_start(trainer, pl_module)[source]

Called when the training batch begins.

on_validation_batch_start

Callback.on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]

Called when the validation batch begins.

on_validation_batch_end

Callback.on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the validation batch ends.

on_test_batch_start

Callback.on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]

Called when the test batch begins.

on_test_batch_end

Callback.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the test batch ends.

on_batch_end

Callback.on_batch_end(trainer, pl_module)[source]

Called when the training batch ends.

on_train_start

Callback.on_train_start(trainer, pl_module)[source]

Called when the train begins.

on_train_end

Callback.on_train_end(trainer, pl_module)[source]

Called when the train ends.

on_pretrain_routine_start

Callback.on_pretrain_routine_start(trainer, pl_module)[source]

Called when the pretrain routine begins.

on_pretrain_routine_end

Callback.on_pretrain_routine_end(trainer, pl_module)[source]

Called when the pretrain routine ends.

on_validation_start

Callback.on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

on_validation_end

Callback.on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

on_test_start

Callback.on_test_start(trainer, pl_module)[source]

Called when the test begins.

on_test_end

Callback.on_test_end(trainer, pl_module)[source]

Called when the test ends.

on_keyboard_interrupt

Callback.on_keyboard_interrupt(trainer, pl_module)[source]

Called when the training is interrupted by KeyboardInterrupt.

on_save_checkpoint

Callback.on_save_checkpoint(trainer, pl_module)[source]

Called when saving a model checkpoint, use to persist state.

on_load_checkpoint

Callback.on_load_checkpoint(checkpointed_state)[source]

Called when loading a model checkpoint, use to reload state.