Callback¶
A callback is a self-contained program that can be reused across projects.
Lightning has a callback system to execute callbacks when needed. Callbacks should capture NON-ESSENTIAL
logic that is NOT required for your LightningModule
to run.
Here’s the flow of how the callback hooks are executed:
An overall Lightning system should have:
Trainer for all engineering
LightningModule for all research code.
Callbacks for non-essential code.
Example:
class MyPrintingCallback(Callback):
def on_init_start(self, trainer):
print('Starting to init trainer!')
def on_init_end(self, trainer):
print('trainer is init now')
def on_train_end(self, trainer, pl_module):
print('do something when training ends')
trainer = Trainer(callbacks=[MyPrintingCallback()])
Starting to init trainer!
trainer is init now
We successfully extended functionality without polluting our super clean
LightningModule
research code.
Examples¶
You can do pretty much anything with callbacks.
Callback Hooks¶
Subclass this class and override any of the relevant hooks
-
class
pytorch_lightning.callbacks.base.
Callback
[source] Bases:
abc.ABC
Abstract base class used to build new callbacks.
-
on_batch_end
(trainer, pl_module)[source] Called when the training batch ends.
-
on_batch_start
(trainer, pl_module)[source] Called when the training batch begins.
-
on_epoch_end
(trainer, pl_module)[source] Called when the epoch ends.
-
on_epoch_start
(trainer, pl_module)[source] Called when the epoch begins.
-
on_fit_end
(trainer, pl_module)[source] Called when fit ends
-
on_fit_start
(trainer, pl_module)[source] Called when fit begins
-
on_init_end
(trainer)[source] Called when the trainer initialization ends, model has not yet been set.
-
on_init_start
(trainer)[source] Called when the trainer initialization begins, model has not yet been set.
-
on_keyboard_interrupt
(trainer, pl_module)[source] Called when the training is interrupted by KeyboardInterrupt.
-
on_pretrain_routine_end
(trainer, pl_module)[source] Called when the pretrain routine ends.
-
on_pretrain_routine_start
(trainer, pl_module)[source] Called when the pretrain routine begins.
-
on_sanity_check_end
(trainer, pl_module)[source] Called when the validation sanity check ends.
-
on_sanity_check_start
(trainer, pl_module)[source] Called when the validation sanity check starts.
-
on_test_batch_end
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the test batch ends.
-
on_test_batch_start
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the test batch begins.
-
on_test_end
(trainer, pl_module)[source] Called when the test ends.
-
on_test_epoch_end
(trainer, pl_module)[source] Called when the test epoch ends.
-
on_test_epoch_start
(trainer, pl_module)[source] Called when the test epoch begins.
-
on_test_start
(trainer, pl_module)[source] Called when the test begins.
-
on_train_batch_end
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the train batch ends.
-
on_train_batch_start
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the train batch begins.
-
on_train_end
(trainer, pl_module)[source] Called when the train ends.
-
on_train_epoch_end
(trainer, pl_module)[source] Called when the train epoch ends.
-
on_train_epoch_start
(trainer, pl_module)[source] Called when the train epoch begins.
-
on_train_start
(trainer, pl_module)[source] Called when the train begins.
-
on_validation_batch_end
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the validation batch ends.
-
on_validation_batch_start
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the validation batch begins.
-
on_validation_end
(trainer, pl_module)[source] Called when the validation loop ends.
-
on_validation_epoch_end
(trainer, pl_module)[source] Called when the val epoch ends.
-
on_validation_epoch_start
(trainer, pl_module)[source] Called when the val epoch begins.
-
on_validation_start
(trainer, pl_module)[source] Called when the validation loop begins.
-
setup
(trainer, pl_module, stage)[source] Called when fit or test begins
-
teardown
(trainer, pl_module, stage)[source] Called when fit or test ends
-
Built-in Callbacks¶
Lightning has a few built-in callbacks.
Note
For a richer collection of callbacks, check out our bolts library.
Early Stopping¶
Monitor a validation metric and stop training when it stops improving.
-
class
pytorch_lightning.callbacks.early_stopping.
EarlyStopping
(monitor='val_loss', min_delta=0.0, patience=3, verbose=False, mode='auto', strict=True)[source] Bases:
pytorch_lightning.callbacks.base.Callback
- Parameters
monitor¶ (
str
) – quantity to be monitored. Default:'val_loss'
. .. note:: Has no effect when using EvalResult or TrainResultmin_delta¶ (
float
) – minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. Default:0.0
.patience¶ (
int
) – number of validation epochs with no improvement after which training will be stopped. Default:3
.mode¶ (
str
) – one of {auto, min, max}. In min mode, training will stop when the quantity monitored has stopped decreasing; in max mode it will stop when the quantity monitored has stopped increasing; in auto mode, the direction is automatically inferred from the name of the monitored quantity. Default:'auto'
.strict¶ (
bool
) – whether to crash the training if monitor is not found in the validation metrics. Default:True
.
Example:
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import EarlyStopping >>> early_stopping = EarlyStopping('val_loss') >>> trainer = Trainer(early_stop_callback=early_stopping)
-
on_train_end
(trainer, pl_module)[source] Called when the train ends.
-
on_train_epoch_end
(trainer, pl_module)[source] Called when the train epoch ends.
-
on_validation_end
(trainer, pl_module)[source] Called when the validation loop ends.
-
on_validation_epoch_end
(trainer, pl_module)[source] Called when the val epoch ends.
GPU Usage Logger¶
Log GPU memory and GPU usage during training
-
class
pytorch_lightning.callbacks.gpu_usage_logger.
GpuUsageLogger
(memory_utilisation=True, gpu_utilisation=True, intra_step_time=False, inter_step_time=False, fan_speed=False, temperature=False)[source] Bases:
pytorch_lightning.callbacks.base.Callback
Automatically logs GPU memory and GPU usage during training stage.
- Parameters
memory_utilisation¶ (
bool
) – Set toTrue
to log used, free and percentage of memory utilisation at starts and ends of each step. Default:True
. From nvidia-smi –help-query-gpu memory.used =`Total memory allocated by active contexts.`
memory.free =`Total free memory.`
gpu_utilisation¶ (
bool
) – Set toTrue
to log percentage of GPU utilisation. at starts and ends of each step. Default:True
.intra_step_time¶ (
bool
) – Set toTrue
to log the time of each step. Default:False
inter_step_time¶ (
bool
) – Set toTrue
to log the time between the end of one step and the start of the next. Default:False
fan_speed¶ (
bool
) – Set toTrue
to log percentage of fan speed. Default:False
.temperature¶ (
bool
) – Set toTrue
to log the memory and gpu temperature in degrees C. Default:False
Example:
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import GpuUsageLogger >>> gpu_usage = GpuUsageLogger() >>> trainer = Trainer(callbacks=[gpu_usage])
Gpu usage is mainly based on nvidia-smi –query-gpu command. The description of the queries used here as appears in in
nvidia-smi --help-query-gpu
:“fan.speed”
`The fan speed value is the percent of maximum speed that the device's fan is currently intended to run at. It ranges from 0 to 100 %. Note: The reported speed is the intended fan speed. If the fan is physically blocked and unable to spin, this output will not match the actual fan speed. Many parts do not report fan speeds because they rely on cooling via fans in the surrounding enclosure.`
“memory.used”`Total memory allocated by active contexts.`
“memory.free”`Total free memory.`
“utilization.gpu”`Percent of time over the past sample period during which one or more kernels was executing on the GPU. The sample period may be between 1 second and 1/6 second depending on the product.`
“utilization.memory”`Percent of time over the past sample period during which global (device) memory was being read or written. The sample period may be between 1 second and 1/6 second depending on the product.`
“temperature.gpu”`Core GPU temperature. in degrees C.`
“temperature.memory”`HBM memory temperature. in degrees C.`
-
on_batch_end
(trainer, pl_module)[source] Called when the training batch ends.
-
on_batch_start
(trainer, pl_module)[source] Called when the training batch begins.
-
on_train_epoch_start
(trainer, pl_module)[source] Called when the train epoch begins.
Gradient Accumulator¶
Change gradient accumulation factor according to scheduling.
Trainer also calls optimizer.step()
for the last indivisible step number.
-
class
pytorch_lightning.callbacks.gradient_accumulation_scheduler.
GradientAccumulationScheduler
(scheduling)[source] Bases:
pytorch_lightning.callbacks.base.Callback
Change gradient accumulation factor according to scheduling.
Example:
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import GradientAccumulationScheduler # at epoch 5 start accumulating every 2 batches >>> accumulator = GradientAccumulationScheduler(scheduling={5: 2}) >>> trainer = Trainer(callbacks=[accumulator]) # alternatively, pass the scheduling dict directly to the Trainer >>> trainer = Trainer(accumulate_grad_batches={5: 2})
-
on_epoch_start
(trainer, pl_module)[source] Called when the epoch begins.
-
Learning Rate Logger¶
Log learning rate for lr schedulers during training
-
class
pytorch_lightning.callbacks.lr_logger.
LearningRateLogger
(logging_interval=None)[source] Bases:
pytorch_lightning.callbacks.base.Callback
Automatically logs learning rate for learning rate schedulers during training.
- Parameters
logging_interval¶ (
Optional
[str
]) – set to epoch or step to log lr of all optimizers at the same interval, set to None to log at individual interval according to the interval key of each scheduler. Defaults toNone
.
Example:
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import LearningRateLogger >>> lr_logger = LearningRateLogger(logging_interval='step') >>> trainer = Trainer(callbacks=[lr_logger])
Logging names are automatically determined based on optimizer class name. In case of multiple optimizers of same type, they will be named Adam, Adam-1 etc. If a optimizer has multiple parameter groups they will be named Adam/pg1, Adam/pg2 etc. To control naming, pass in a name keyword in the construction of the learning rate schdulers
Example:
def configure_optimizer(self): optimizer = torch.optim.Adam(...) lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...) 'name': 'my_logging_name'} return [optimizer], [lr_scheduler]
-
on_batch_start
(trainer, pl_module)[source] Called when the training batch begins.
-
on_epoch_start
(trainer, pl_module)[source] Called when the epoch begins.
-
on_train_start
(trainer, pl_module)[source] Called before training, determines unique names for all lr schedulers in the case of multiple of the same type or in the case of multiple parameter groups
Model Checkpointing¶
Automatically save model checkpoints during training.
-
class
pytorch_lightning.callbacks.model_checkpoint.
ModelCheckpoint
(filepath=None, monitor='val_loss', verbose=False, save_last=False, save_top_k=1, save_weights_only=False, mode='auto', period=1, prefix='')[source] Bases:
pytorch_lightning.callbacks.base.Callback
Save the model after every epoch if it improves.
After training finishes, use
best_model_path
to retrieve the path to the best checkpoint file andbest_model_score
to retrieve its score.- Parameters
path to save the model file. Can contain named formatting options to be auto-filled.
Example:
# custom path # saves a file like: my/path/epoch_0.ckpt >>> checkpoint_callback = ModelCheckpoint('my/path/') # save any arbitrary metrics like `val_loss`, etc. in name # saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt >>> checkpoint_callback = ModelCheckpoint( ... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... )
By default, filepath is None and will be set at runtime to the location specified by
Trainer
’sdefault_root_dir
orweights_save_path
arguments, and if the Trainer uses a logger, the path will also contain logger name and version.save_last¶ (
bool
) – always saves the model at the end of the epoch. Default:False
.save_top_k¶ (
int
) – ifsave_top_k == k
, the best k models according to the quantity monitored will be saved. ifsave_top_k == 0
, no models are saved. ifsave_top_k == -1
, all models are saved. Please note that the monitors are checked every period epochs. ifsave_top_k >= 2
and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with v0.mode¶ (
str
) – one of {auto, min, max}. Ifsave_top_k != 0
, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.save_weights_only¶ (
bool
) – ifTrue
, then only the model’s weights will be saved (model.save_weights(filepath)
), else the full model is saved (model.save(filepath)
).period¶ (
int
) – Interval (number of epochs) between checkpoints.
Example:
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import ModelCheckpoint # saves checkpoints to 'my/path/' whenever 'val_loss' has a new min >>> checkpoint_callback = ModelCheckpoint(filepath='my/path/') >>> trainer = Trainer(checkpoint_callback=checkpoint_callback) # save epoch and val_loss in name # saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt >>> checkpoint_callback = ModelCheckpoint( ... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}' ... ) # retrieve the best checkpoint after training checkpoint_callback = ModelCheckpoint(filepath='my/path/') trainer = Trainer(checkpoint_callback=checkpoint_callback) model = ... trainer.fit(model) checkpoint_callback.best_model_path
-
format_checkpoint_name
(epoch, metrics, ver=None)[source] Generate a filename according to the defined template.
Example:
>>> tmpdir = os.path.dirname(__file__) >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}')) >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) 'epoch=0.ckpt' >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}')) >>> os.path.basename(ckpt.format_checkpoint_name(5, {})) 'epoch=005.ckpt' >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}')) >>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456))) 'epoch=2-val_loss=0.12.ckpt' >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}')) >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) 'missing=0.ckpt'
-
on_train_start
(trainer, pl_module)[source] Determines model checkpoint save directory at runtime. References attributes from the trainer’s logger to determine where to save checkpoints. The base path for saving weights is set in this priority:
Checkpoint callback’s path (if passed in)
The default_root_dir from trainer if trainer has no logger
The weights_save_path from trainer, if user provides it
User provided weights_saved_path
The base path gets extended with logger name and version (if these are available) and subfolder “checkpoints”.
-
on_validation_end
(trainer, pl_module)[source] Called when the validation loop ends.
Progress Bars¶
Use or override one of the progress bar callbacks.
-
class
pytorch_lightning.callbacks.progress.
ProgressBar
(refresh_rate=1, process_position=0)[source] Bases:
pytorch_lightning.callbacks.progress.ProgressBarBase
This is the default progress bar used by Lightning. It prints to stdout using the
tqdm
package and shows up to four different bars:sanity check progress: the progress during the sanity check run
main progress: shows training + validation progress combined. It also accounts for multiple validation runs during training when
val_check_interval
is used.validation progress: only visible during validation; shows total progress over all validation datasets.
test progress: only active when testing; shows total progress over all test datasets.
For infinite datasets, the progress bar never ends.
If you want to customize the default
tqdm
progress bars used by Lightning, you can override specific methods of the callback class and pass your custom implementation to theTrainer
:Example:
class LitProgressBar(ProgressBar): def init_validation_tqdm(self): bar = super().init_validation_tqdm() bar.set_description('running validation ...') return bar bar = LitProgressBar() trainer = Trainer(callbacks=[bar])
- Parameters
refresh_rate¶ (
int
) – Determines at which rate (in number of batches) the progress bars get updated. Set it to0
to disable the display. By default, theTrainer
uses this implementation of the progress bar and sets the refresh rate to the value provided to theprogress_bar_refresh_rate
argument in theTrainer
.process_position¶ (
int
) – Set this to a value greater than0
to offset the progress bars by this many lines. This is useful when you have progress bars defined elsewhere and want to show all of them together. This corresponds toprocess_position
in theTrainer
.
-
disable
()[source] You should provide a way to disable the progress bar. The
Trainer
will call this to disable the output on processes that have a rank different from 0, e.g., in multi-node training.- Return type
-
enable
()[source] You should provide a way to enable the progress bar. The
Trainer
will call this in e.g. pre-training routines like the learning rate finder to temporarily enable and disable the main progress bar.- Return type
-
init_sanity_tqdm
()[source] Override this to customize the tqdm bar for the validation sanity run.
- Return type
tqdm
-
init_test_tqdm
()[source] Override this to customize the tqdm bar for testing.
- Return type
tqdm
-
init_train_tqdm
()[source] Override this to customize the tqdm bar for training.
- Return type
tqdm
-
init_validation_tqdm
()[source] Override this to customize the tqdm bar for validation.
- Return type
tqdm
-
on_epoch_start
(trainer, pl_module)[source] Called when the epoch begins.
-
on_sanity_check_end
(trainer, pl_module)[source] Called when the validation sanity check ends.
-
on_sanity_check_start
(trainer, pl_module)[source] Called when the validation sanity check starts.
-
on_test_batch_end
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the test batch ends.
-
on_test_end
(trainer, pl_module)[source] Called when the test ends.
-
on_test_start
(trainer, pl_module)[source] Called when the test begins.
-
on_train_batch_end
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the train batch ends.
-
on_train_end
(trainer, pl_module)[source] Called when the train ends.
-
on_train_start
(trainer, pl_module)[source] Called when the train begins.
-
on_validation_batch_end
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the validation batch ends.
-
on_validation_end
(trainer, pl_module)[source] Called when the validation loop ends.
-
on_validation_start
(trainer, pl_module)[source] Called when the validation loop begins.
-
class
pytorch_lightning.callbacks.progress.
ProgressBarBase
[source] Bases:
pytorch_lightning.callbacks.base.Callback
The base class for progress bars in Lightning. It is a
Callback
that keeps track of the batch progress in theTrainer
. You should implement your highly custom progress bars with this as the base class.Example:
class LitProgressBar(ProgressBarBase): def __init__(self): super().__init__() # don't forget this :) self.enable = True def disable(self): self.enable = False def on_train_batch_end(self, trainer, pl_module): super().on_train_batch_end(trainer, pl_module) # don't forget this :) percent = (self.train_batch_idx / self.total_train_batches) * 100 sys.stdout.flush() sys.stdout.write(f'{percent:.01f} percent complete \r') bar = LitProgressBar() trainer = Trainer(callbacks=[bar])
-
disable
()[source] You should provide a way to disable the progress bar. The
Trainer
will call this to disable the output on processes that have a rank different from 0, e.g., in multi-node training.
-
enable
()[source] You should provide a way to enable the progress bar. The
Trainer
will call this in e.g. pre-training routines like the learning rate finder to temporarily enable and disable the main progress bar.
-
on_epoch_start
(trainer, pl_module)[source] Called when the epoch begins.
-
on_init_end
(trainer)[source] Called when the trainer initialization ends, model has not yet been set.
-
on_test_batch_end
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the test batch ends.
-
on_test_start
(trainer, pl_module)[source] Called when the test begins.
-
on_train_batch_end
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the train batch ends.
-
on_train_start
(trainer, pl_module)[source] Called when the train begins.
-
on_validation_batch_end
(trainer, pl_module, batch, batch_idx, dataloader_idx)[source] Called when the validation batch ends.
-
on_validation_start
(trainer, pl_module)[source] Called when the validation loop begins.
-
property
test_batch_idx
[source] The current batch index being processed during testing. Use this to update your progress bar.
- Return type
-
property
total_test_batches
[source] The total number of training batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return
inf
if the test dataloader is of infinite size.- Return type
-
property
total_train_batches
[source] The total number of training batches during training, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return
inf
if the training dataloader is of infinite size.- Return type
-
property
total_val_batches
[source] The total number of training batches during validation, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return
inf
if the validation dataloader is of infinite size.- Return type
-
property
train_batch_idx
[source] The current batch index being processed during training. Use this to update your progress bar.
- Return type
-
-
pytorch_lightning.callbacks.progress.
convert_inf
(x)[source] The tqdm doesn’t support inf values. We have to convert it to None.
Best Practices¶
The following are best practices when using/designing callbacks.
Callbacks should be isolated in their functionality.
Your callback should not rely on the behavior of other callbacks in order to work properly.
Do not manually call methods from the callback.
Directly calling methods (eg. on_validation_end) is strongly discouraged.
Whenever possible, your callbacks should not depend on the order in which they are executed.