Shortcuts

model_checkpoint

Classes

ModelCheckpoint

Save the model after every epoch by monitoring a quantity.

Model Checkpointing

Automatically save model checkpoints during training.

class pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint(filepath=None, monitor=None, verbose=False, save_last=None, save_top_k=None, save_weights_only=False, mode='auto', period=1, prefix='', dirpath=None, filename=None)[source]

Bases: pytorch_lightning.callbacks.base.Callback

Save the model after every epoch by monitoring a quantity.

After training finishes, use best_model_path to retrieve the path to the best checkpoint file and best_model_score to retrieve its score.

Parameters
  • filepath (Optional[str]) –

    path to save the model file.

    Warning

    Deprecated since version 1.0.

    Use dirpath + filename instead. Will be removed in v1.2

  • monitor (Optional[str]) – quantity to monitor. By default it is None which saves a checkpoint only for the last epoch.

  • verbose (bool) – verbosity mode. Default: False.

  • save_last (Optional[bool]) – When True, always saves the model at the end of the epoch to a file last.ckpt. Default: None.

  • save_top_k (Optional[int]) – if save_top_k == k, the best k models according to the quantity monitored will be saved. if save_top_k == 0, no models are saved. if save_top_k == -1, all models are saved. Please note that the monitors are checked every period epochs. if save_top_k >= 2 and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with v0.

  • mode (str) – one of {auto, min, max}. If save_top_k != 0, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.

  • save_weights_only (bool) – if True, then only the model’s weights will be saved (model.save_weights(filepath)), else the full model is saved (model.save(filepath)).

  • period (int) – Interval (number of epochs) between checkpoints.

  • dirpath (Union[str, Path, None]) –

    directory to save the model file.

    Example:

    # custom path
    # saves a file like: my/path/epoch=0.ckpt
    >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
    

    By default, dirpath is None and will be set at runtime to the location specified by Trainer’s default_root_dir or weights_save_path arguments, and if the Trainer uses a logger, the path will also contain logger name and version.

  • filename (Optional[str]) –

    checkpoint filename. Can contain named formatting options to be auto-filled.

    Example:

    # save any arbitrary metrics like `val_loss`, etc. in name
    # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
    >>> checkpoint_callback = ModelCheckpoint(
    ...     dirpath='my/path',
    ...     filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
    ... )
    

    By default, filename is None and will be set to '{epoch}'.

Example:

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelCheckpoint

# saves checkpoints to 'my/path/' at every epoch
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
>>> trainer = Trainer(callbacks=[checkpoint_callback])

# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val_loss',
...     dirpath='my/path/',
...     filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
... )

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path
format_checkpoint_name(epoch, metrics, ver=None)[source]

Generate a filename according to the defined template.

Example:

>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'missing=0.ckpt'
>>> ckpt = ModelCheckpoint(filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
Return type

str

on_load_checkpoint(checkpointed_state)[source]

Called when loading a model checkpoint, use to reload state.

on_pretrain_routine_start(trainer, pl_module)[source]

When pretrain routine starts we build the ckpt dir on the fly

on_save_checkpoint(trainer, pl_module)[source]

Called when saving a model checkpoint, use to persist state.

Return type

Dict[str, Any]

on_validation_end(trainer, pl_module)[source]

checkpoints can be saved at the end of the val loop

save_checkpoint(trainer, pl_module)[source]

Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of self.save_function to handle correct behaviour in distributed training, i.e., saving only on rank 0.

to_yaml(filepath=None)[source]

Saves the best_k_models dict containing the checkpoint paths with the corresponding scores to a YAML file.