Shortcuts

pytorch_lightning.callbacks.model_checkpoint module

Model Checkpointing

Automatically save model checkpoints during training.

class pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint(filepath=None, monitor='val_loss', verbose=False, save_last=False, save_top_k=1, save_weights_only=False, mode='auto', period=1, prefix='')[source]

Bases: pytorch_lightning.callbacks.base.Callback

Save the model after every epoch if it improves.

After training finishes, use best_model_path to retrieve the path to the best checkpoint file and best_model_score to retrieve its score.

Parameters
  • filepath (Optional[str]) –

    path to save the model file. Can contain named formatting options to be auto-filled.

    Example:

    # custom path
    # saves a file like: my/path/epoch_0.ckpt
    >>> checkpoint_callback = ModelCheckpoint('my/path/')
    
    # save any arbitrary metrics like `val_loss`, etc. in name
    # saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
    >>> checkpoint_callback = ModelCheckpoint(
    ...     filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
    ... )
    

    Can also be set to None, then it will be set to default location during trainer construction.

  • monitor (str) – quantity to monitor.

  • verbose (bool) – verbosity mode. Default: False.

  • save_last (bool) – always saves the model at the end of the epoch. Default: False.

  • save_top_k (int) – if save_top_k == k, the best k models according to the quantity monitored will be saved. if save_top_k == 0, no models are saved. if save_top_k == -1, all models are saved. Please note that the monitors are checked every period epochs. if save_top_k >= 2 and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with v0.

  • mode (str) – one of {auto, min, max}. If save_top_k != 0, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.

  • save_weights_only (bool) – if True, then only the model’s weights will be saved (model.save_weights(filepath)), else the full model is saved (model.save(filepath)).

  • period (int) – Interval (number of epochs) between checkpoints.

Example:

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelCheckpoint

# saves checkpoints to 'my/path/' whenever 'val_loss' has a new min
>>> checkpoint_callback = ModelCheckpoint(filepath='my/path/')
>>> trainer = Trainer(checkpoint_callback=checkpoint_callback)

# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}'
... )

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(filepath='my/path/')
trainer = Trainer(checkpoint_callback=checkpoint_callback)
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path
_del_model(filepath)[source]
_do_check_save(filepath, current, epoch)[source]
_save_model(filepath)[source]
check_monitor_top_k(current)[source]
format_checkpoint_name(epoch, metrics, ver=None)[source]

Generate a filename according to the defined template.

Example:

>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'missing=0.ckpt'
on_train_start(trainer, pl_module)[source]

Determine model checkpoint save directory at runtime. References attributes from the Trainer’s logger to determine where to save checkpoints.

on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

property best[source]
property kth_best_model[source]