Shortcuts

Callbacks

Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL logic that is NOT required for your LightningModule to run.

An overall Lightning system should have:

  1. Trainer for all engineering

  2. LightningModule for all research code.

  3. Callbacks for non-essential code.

Example:

class MyPrintingCallback(Callback):

    def on_init_start(self, trainer):
        print('Starting to init trainer!')

    def on_init_end(self, trainer):
        print('trainer is init now')

    def on_train_end(self, trainer, pl_module):
        print('do something when training ends')

trainer = Trainer(callbacks=[MyPrintingCallback()])
Starting to init trainer!
trainer is init now

We successfully extended functionality without polluting our super clean LightningModule research code.


Callback Base

Abstract base class used to build new callbacks.

class pytorch_lightning.callbacks.base.Callback[source]

Bases: abc.ABC

Abstract base class used to build new callbacks.

on_batch_end(trainer, pl_module)[source]

Called when the training batch ends.

on_batch_start(trainer, pl_module)[source]

Called when the training batch begins.

on_epoch_end(trainer, pl_module)[source]

Called when the epoch ends.

on_epoch_start(trainer, pl_module)[source]

Called when the epoch begins.

on_init_end(trainer)[source]

Called when the trainer initialization ends, model has not yet been set.

on_init_start(trainer)[source]

Called when the trainer initialization begins, model has not yet been set.

on_sanity_check_end(trainer, pl_module)[source]

Called when the validation sanity check ends.

on_sanity_check_start(trainer, pl_module)[source]

Called when the validation sanity check starts.

on_test_batch_end(trainer, pl_module)[source]

Called when the test batch ends.

on_test_batch_start(trainer, pl_module)[source]

Called when the test batch begins.

on_test_end(trainer, pl_module)[source]

Called when the test ends.

on_test_start(trainer, pl_module)[source]

Called when the test begins.

on_train_end(trainer, pl_module)[source]

Called when the train ends.

on_train_start(trainer, pl_module)[source]

Called when the train begins.

on_validation_batch_end(trainer, pl_module)[source]

Called when the validation batch ends.

on_validation_batch_start(trainer, pl_module)[source]

Called when the validation batch begins.

on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.


Early Stopping

Stop training when a monitored quantity has stopped improving.

class pytorch_lightning.callbacks.early_stopping.EarlyStopping(monitor='val_loss', min_delta=0.0, patience=3, verbose=False, mode='auto', strict=True)[source]

Bases: pytorch_lightning.callbacks.base.Callback

Parameters
  • monitor (str) – quantity to be monitored. Default: 'val_loss'.

  • min_delta (float) – minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. Default: 0.

  • patience (int) – number of epochs with no improvement after which training will be stopped. Default: 0.

  • verbose (bool) – verbosity mode. Default: False.

  • mode (str) – one of {auto, min, max}. In min mode, training will stop when the quantity monitored has stopped decreasing; in max mode it will stop when the quantity monitored has stopped increasing; in auto mode, the direction is automatically inferred from the name of the monitored quantity. Default: 'auto'.

  • strict (bool) – whether to crash the training if monitor is not found in the metrics. Default: True.

Example:

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import EarlyStopping
>>> early_stopping = EarlyStopping('val_loss')
>>> trainer = Trainer(early_stop_callback=early_stopping)
_validate_condition_metric(logs)[source]

Checks that the condition metric for early stopping is good :param _sphinx_paramlinks_pytorch_lightning.callbacks.early_stopping.EarlyStopping._validate_condition_metric.logs: :return:

on_epoch_end(trainer, pl_module)[source]

Called when the epoch ends.

on_train_end(trainer, pl_module)[source]

Called when the train ends.

on_train_start(trainer, pl_module)[source]

Called when the train begins.


Model Checkpointing

Automatically save model checkpoints during training.

class pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint(filepath=None, monitor='val_loss', verbose=False, save_top_k=1, save_weights_only=False, mode='auto', period=1, prefix='')[source]

Bases: pytorch_lightning.callbacks.base.Callback

Save the model after every epoch.

Parameters
  • filepath (Optional[str]) –

    path to save the model file. Can contain named formatting options to be auto-filled.

    Example:

    # custom path
    # saves a file like: my/path/epoch_0.ckpt
    >>> checkpoint_callback = ModelCheckpoint('my/path/')
    
    # save any arbitrary metrics like `val_loss`, etc. in name
    # saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
    >>> checkpoint_callback = ModelCheckpoint(
    ...     filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
    ... )
    

    Can also be set to None, then it will be set to default location during trainer construction.

  • monitor (str) – quantity to monitor.

  • verbose (bool) – verbosity mode. Default: False.

  • save_top_k (int) – if save_top_k == k, the best k models according to the quantity monitored will be saved. if save_top_k == 0, no models are saved. if save_top_k == -1, all models are saved. Please note that the monitors are checked every period epochs. if save_top_k >= 2 and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with v0.

  • mode (str) – one of {auto, min, max}. If save_top_k != 0, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.

  • save_weights_only (bool) – if True, then only the model’s weights will be saved (model.save_weights(filepath)), else the full model is saved (model.save(filepath)).

  • period (int) – Interval (number of epochs) between checkpoints.

Example:

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelCheckpoint

# saves checkpoints to 'my/path/' whenever 'val_loss' has a new min
>>> checkpoint_callback = ModelCheckpoint(filepath='my/path/')
>>> trainer = Trainer(checkpoint_callback=checkpoint_callback)

# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}'
... )
format_checkpoint_name(epoch, metrics, ver=None)[source]

Generate a filename according to the defined template.

Example:

>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'missing=0.ckpt'
on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.


Gradient Accumulator

Change gradient accumulation factor according to scheduling.

class pytorch_lightning.callbacks.gradient_accumulation_scheduler.GradientAccumulationScheduler(scheduling)[source]

Bases: pytorch_lightning.callbacks.base.Callback

Change gradient accumulation factor according to scheduling.

Parameters

scheduling (dict) –

scheduling in format {epoch: accumulation_factor}

Warning

Epochs indexing starts from “1” until v0.6.x, but will start from “0” in v0.8.0.

Example:

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import GradientAccumulationScheduler

# at epoch 5 start accumulating every 2 batches
>>> accumulator = GradientAccumulationScheduler(scheduling={5: 2})
>>> trainer = Trainer(callbacks=[accumulator])

# alternatively, pass the scheduling dict directly to the Trainer
>>> trainer = Trainer(accumulate_grad_batches={5: 2})
on_epoch_start(trainer, pl_module)[source]

Called when the epoch begins.


Progress Bars

Use or override one of the progress bar callbacks.

class pytorch_lightning.callbacks.progress.ProgressBar(refresh_rate=1, process_position=0)[source]

Bases: pytorch_lightning.callbacks.progress.ProgressBarBase

This is the default progress bar used by Lightning. It prints to stdout using the tqdm package and shows up to four different bars:

  • sanity check progress: the progress during the sanity check run

  • main progress: shows training + validation progress combined. It also accounts for multiple validation runs during training when val_check_interval is used.

  • validation progress: only visible during validation; shows total progress over all validation datasets.

  • test progress: only active when testing; shows total progress over all test datasets.

For infinite datasets, the progress bar never ends.

If you want to customize the default tqdm progress bars used by Lightning, you can override specific methods of the callback class and pass your custom implementation to the Trainer:

Example:

class LitProgressBar(ProgressBar):

    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        bar.set_description('running validation ...')
        return bar

bar = LitProgressBar()
trainer = Trainer(callbacks=[bar])
Parameters
  • refresh_rate (int) – Determines at which rate (in number of batches) the progress bars get updated. Set it to 0 to disable the display. By default, the Trainer uses this implementation of the progress bar and sets the refresh rate to the value provided to the progress_bar_refresh_rate argument in the Trainer.

  • process_position (int) – Set this to a value greater than 0 to offset the progress bars by this many lines. This is useful when you have progress bars defined elsewhere and want to show all of them together. This corresponds to process_position in the Trainer.

disable()[source]

You should provide a way to disable the progress bar. The Trainer will call this to disable the output on processes that have a rank different from 0, e.g., in multi-node training.

Return type

None

enable()[source]

You should provide a way to enable the progress bar. The Trainer will call this in e.g. pre-training routines like the learning rate finder to temporarily enable and disable the main progress bar.

Return type

None

init_sanity_tqdm()[source]

Override this to customize the tqdm bar for the validation sanity run.

Return type

tqdm

init_test_tqdm()[source]

Override this to customize the tqdm bar for testing.

Return type

tqdm

init_train_tqdm()[source]

Override this to customize the tqdm bar for training.

Return type

tqdm

init_validation_tqdm()[source]

Override this to customize the tqdm bar for validation.

Return type

tqdm

on_batch_end(trainer, pl_module)[source]

Called when the training batch ends.

on_epoch_start(trainer, pl_module)[source]

Called when the epoch begins.

on_sanity_check_end(trainer, pl_module)[source]

Called when the validation sanity check ends.

on_sanity_check_start(trainer, pl_module)[source]

Called when the validation sanity check starts.

on_test_batch_end(trainer, pl_module)[source]

Called when the test batch ends.

on_test_end(trainer, pl_module)[source]

Called when the test ends.

on_test_start(trainer, pl_module)[source]

Called when the test begins.

on_train_end(trainer, pl_module)[source]

Called when the train ends.

on_train_start(trainer, pl_module)[source]

Called when the train begins.

on_validation_batch_end(trainer, pl_module)[source]

Called when the validation batch ends.

on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

class pytorch_lightning.callbacks.progress.ProgressBarBase[source]

Bases: pytorch_lightning.callbacks.base.Callback

The base class for progress bars in Lightning. It is a Callback that keeps track of the batch progress in the Trainer. You should implement your highly custom progress bars with this as the base class.

Example:

class LitProgressBar(ProgressBarBase):

    def __init__(self):
        super().__init__()  # don't forget this :)
        self.enable = True

    def disable(self):
        self.enable = False

    def on_batch_end(self, trainer, pl_module):
        super().on_batch_end(trainer, pl_module)  # don't forget this :)
        percent = (self.train_batch_idx / self.total_train_batches) * 100
        sys.stdout.flush()
        sys.stdout.write(f'{percent:.01f} percent complete \r')

bar = LitProgressBar()
trainer = Trainer(callbacks=[bar])
disable()[source]

You should provide a way to disable the progress bar. The Trainer will call this to disable the output on processes that have a rank different from 0, e.g., in multi-node training.

enable()[source]

You should provide a way to enable the progress bar. The Trainer will call this in e.g. pre-training routines like the learning rate finder to temporarily enable and disable the main progress bar.

on_batch_end(trainer, pl_module)[source]

Called when the training batch ends.

on_epoch_start(trainer, pl_module)[source]

Called when the epoch begins.

on_init_end(trainer)[source]

Called when the trainer initialization ends, model has not yet been set.

on_test_batch_end(trainer, pl_module)[source]

Called when the test batch ends.

on_test_start(trainer, pl_module)[source]

Called when the test begins.

on_train_start(trainer, pl_module)[source]

Called when the train begins.

on_validation_batch_end(trainer, pl_module)[source]

Called when the validation batch ends.

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

property test_batch_idx[source]

The current batch index being processed during testing. Use this to update your progress bar.

Return type

int

property total_test_batches[source]

The total number of training batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return inf if the test dataloader is of infinite size.

Return type

int

property total_train_batches[source]

The total number of training batches during training, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return inf if the training dataloader is of infinite size.

Return type

int

property total_val_batches[source]

The total number of training batches during validation, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return inf if the validation dataloader is of infinite size.

Return type

int

property train_batch_idx[source]

The current batch index being processed during training. Use this to update your progress bar.

Return type

int

property val_batch_idx[source]

The current batch index being processed during validation. Use this to update your progress bar.

Return type

int

pytorch_lightning.callbacks.progress.convert_inf(x)[source]

The tqdm doesn’t support inf values. We have to convert it to None.


Logging of learning rates

Log learning rate for lr schedulers during training

class pytorch_lightning.callbacks.lr_logger.LearningRateLogger[source]

Bases: pytorch_lightning.callbacks.base.Callback

Automatically logs learning rate for learning rate schedulers during training.

Example:

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import LearningRateLogger
>>> lr_logger = LearningRateLogger()
>>> trainer = Trainer(callbacks=[lr_logger])

Logging names are automatically determined based on optimizer class name. In case of multiple optimizers of same type, they will be named Adam, Adam-1 etc. If a optimizer has multiple parameter groups they will be named Adam/pg1, Adam/pg2 etc. To control naming, pass in a name keyword in the construction of the learning rate schdulers

Example:

def configure_optimizer(self):
    optimizer = torch.optim.Adam(...)
    lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...)
                    'name': 'my_logging_name'}
    return [optimizer], [lr_scheduler]
_extract_lr(trainer, interval)[source]

Extracts learning rates for lr schedulers and saves information into dict structure.

on_batch_start(trainer, pl_module)[source]

Called when the training batch begins.

on_epoch_start(trainer, pl_module)[source]

Called when the epoch begins.

on_train_start(trainer, pl_module)[source]

Called before training, determines unique names for all lr schedulers in the case of multiple of the same type or in the case of multiple parameter groups