Shortcuts

model_checkpoint

Classes

ModelCheckpoint

Save the model after every epoch by monitoring a quantity.

Model Checkpointing

Automatically save model checkpoints during training.

class pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=None, save_weights_only=False, mode='auto', period=1, prefix='')[source]

Bases: pytorch_lightning.callbacks.base.Callback

Save the model after every epoch by monitoring a quantity.

After training finishes, use best_model_path to retrieve the path to the best checkpoint file and best_model_score to retrieve its score.

Parameters
  • dirpath (Union[str, Path, None]) –

    directory to save the model file.

    Example:

    # custom path
    # saves a file like: my/path/epoch=0-step=10.ckpt
    >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
    

    By default, dirpath is None and will be set at runtime to the location specified by Trainer’s default_root_dir or weights_save_path arguments, and if the Trainer uses a logger, the path will also contain logger name and version.

  • filename (Optional[str]) –

    checkpoint filename. Can contain named formatting options to be auto-filled.

    Example:

    # save any arbitrary metrics like `val_loss`, etc. in name
    # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
    >>> checkpoint_callback = ModelCheckpoint(
    ...     dirpath='my/path',
    ...     filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
    ... )
    

    By default, filename is None and will be set to '{epoch}-{step}'.

  • monitor (Optional[str]) – quantity to monitor. By default it is None which saves a checkpoint only for the last epoch.

  • verbose (bool) – verbosity mode. Default: False.

  • save_last (Optional[bool]) – When True, always saves the model at the end of the epoch to a file last.ckpt. Default: None.

  • save_top_k (Optional[int]) – if save_top_k == k, the best k models according to the quantity monitored will be saved. if save_top_k == 0, no models are saved. if save_top_k == -1, all models are saved. Please note that the monitors are checked every period epochs. if save_top_k >= 2 and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with v1.

  • mode (str) –

    one of {auto, min, max}. If save_top_k != 0, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.

    Warning

    Setting mode='auto' has been deprecated in v1.1 and will be removed in v1.3.

  • save_weights_only (bool) – if True, then only the model’s weights will be saved (model.save_weights(filepath)), else the full model is saved (model.save(filepath)).

  • period (int) – Interval (number of epochs) between checkpoints.

  • prefix (str) –

    A string to put at the beginning of checkpoint filename.

    Warning

    This argument has been deprecated in v1.1 and will be removed in v1.3

Note

For extra customization, ModelCheckpoint includes the following attributes:

  • CHECKPOINT_JOIN_CHAR = "-"

  • CHECKPOINT_NAME_LAST = "last"

  • FILE_EXTENSION = ".ckpt"

  • STARTING_VERSION = 1

For example, you can change the default last checkpoint name by doing checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"

Raises
  • MisconfigurationException – If save_top_k is neither None nor more than or equal to -1, if monitor is None and save_top_k is none of None, -1, and 0, or if mode is none of "min", "max", and "auto".

  • ValueError – If trainer.save_checkpoint is None.

Example:

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelCheckpoint

# saves checkpoints to 'my/path/' at every epoch
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
>>> trainer = Trainer(callbacks=[checkpoint_callback])

# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val_loss',
...     dirpath='my/path/',
...     filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
... )

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path
file_exists(filepath, trainer)[source]

Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.

Return type

bool

format_checkpoint_name(epoch, step, metrics, ver=None)[source]

Generate a filename according to the defined template.

Example:

>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
>>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
'missing=0.ckpt'
>>> ckpt = ModelCheckpoint(filename='{step}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {}))
'step=0.ckpt'
Return type

str

on_load_checkpoint(callback_state)[source]

Called when loading a model checkpoint, use to reload state.

Parameters

callback_state (Dict[str, Any]) – the callback state returned by on_save_checkpoint.

on_pretrain_routine_start(trainer, pl_module)[source]

When pretrain routine starts we build the ckpt dir on the fly

on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Called when saving a model checkpoint, use to persist state.

Parameters
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint (Dict[str, Any]) – the checkpoint dictionary that will be saved.

Return type

Dict[str, Any]

Returns

The callback state.

on_validation_end(trainer, pl_module)[source]

checkpoints can be saved at the end of the val loop

save_checkpoint(trainer, pl_module)[source]

Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of self.save_function to handle correct behaviour in distributed training, i.e., saving only on rank 0.

to_yaml(filepath=None)[source]

Saves the best_k_models dict containing the checkpoint paths with the corresponding scores to a YAML file.