Shortcuts

ModelCheckpoint

class pytorch_lightning.callbacks.ModelCheckpoint(filepath=None, monitor=None, verbose=False, save_last=None, save_top_k=None, save_weights_only=False, mode='auto', period=1, prefix='')[source]

Bases: pytorch_lightning.callbacks.base.Callback

Save the model after every epoch by monitoring a quantity.

After training finishes, use best_model_path to retrieve the path to the best checkpoint file and best_model_score to retrieve its score.

Parameters
  • filepath (Optional[str]) –

    path to save the model file. Can contain named formatting options to be auto-filled.

    Example:

    # custom path
    # saves a file like: my/path/epoch=0.ckpt
    >>> checkpoint_callback = ModelCheckpoint('my/path/')
    
    # save any arbitrary metrics like `val_loss`, etc. in name
    # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
    >>> checkpoint_callback = ModelCheckpoint(
    ...     filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
    ... )
    

    By default, filepath is None and will be set at runtime to the location specified by Trainer’s default_root_dir or weights_save_path arguments, and if the Trainer uses a logger, the path will also contain logger name and version.

  • monitor (Optional[str]) – quantity to monitor. By default it is None which saves a checkpoint only for the last epoch

  • verbose (bool) – verbosity mode. Default: False.

  • save_last (Optional[bool]) – When True, always saves the model at the end of the epoch to a file last.ckpt. Default: None.

  • save_top_k (Optional[int]) – if save_top_k == k, the best k models according to the quantity monitored will be saved. if save_top_k == 0, no models are saved. if save_top_k == -1, all models are saved. Please note that the monitors are checked every period epochs. if save_top_k >= 2 and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with v0.

  • mode (str) – one of {auto, min, max}. If save_top_k != 0, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.

  • save_weights_only (bool) – if True, then only the model’s weights will be saved (model.save_weights(filepath)), else the full model is saved (model.save(filepath)).

  • period (int) – Interval (number of epochs) between checkpoints.

Example:

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelCheckpoint

# saves checkpoints to 'my/path/' at every epoch
>>> checkpoint_callback = ModelCheckpoint(filepath='my/path/')
>>> trainer = Trainer(checkpoint_callback=checkpoint_callback)

# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(monitor='val_loss',
...     filepath='my/path/sample-mnist-{epoch:02d}-{val_loss:.2f}'
... )

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(filepath='my/path/')
trainer = Trainer(checkpoint_callback=checkpoint_callback)
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path
format_checkpoint_name(epoch, metrics, ver=None)[source]

Generate a filename according to the defined template.

Example:

>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'missing=0.ckpt'
Return type

str

on_load_checkpoint(checkpointed_state)[source]

Called when loading a model checkpoint, use to reload state.

on_pretrain_routine_start(trainer, pl_module)[source]

When pretrain routine starts we build the ckpt dir on the fly

on_save_checkpoint(trainer, pl_module)[source]

Called when saving a model checkpoint, use to persist state.

Return type

Dict[str, Any]

on_validation_end(trainer, pl_module)[source]

checkpoints can be saved at the end of the val loop

save_checkpoint(trainer, pl_module)[source]

Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of self.save_function to handle correct behaviour in distributed training, i.e., saving only on rank 0.

to_yaml(filepath=None)[source]

Saves the best_k_models dict containing the checkpoint paths with the corresponding scores to a YAML file.

Read the Docs v: stable
Versions
latest
stable
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3.2
0.5.3
0.4.9
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.