pytorch_lightning.callbacks.model_checkpoint module¶
Model Checkpointing¶
Automatically save model checkpoints during training.
-
class
pytorch_lightning.callbacks.model_checkpoint.
ModelCheckpoint
(filepath=None, monitor='val_loss', verbose=False, save_last=False, save_top_k=1, save_weights_only=False, mode='auto', period=1, prefix='')[source]¶ Bases:
pytorch_lightning.callbacks.base.Callback
Save the model after every epoch if it improves.
After training finishes, use
best_model_path
to retrieve the path to the best checkpoint file andbest_model_score
to retrieve its score.- Parameters
path to save the model file. Can contain named formatting options to be auto-filled.
Example:
# custom path # saves a file like: my/path/epoch_0.ckpt >>> checkpoint_callback = ModelCheckpoint('my/path/') # save any arbitrary metrics like `val_loss`, etc. in name # saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt >>> checkpoint_callback = ModelCheckpoint( ... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... )
By default, filepath is None and will be set at runtime to the location specified by
Trainer
’sdefault_root_dir
orweights_save_path
arguments, and if the Trainer uses a logger, the path will also contain logger name and version.save_last¶ (
bool
) – always saves the model at the end of the epoch. Default:False
.save_top_k¶ (
int
) – ifsave_top_k == k
, the best k models according to the quantity monitored will be saved. ifsave_top_k == 0
, no models are saved. ifsave_top_k == -1
, all models are saved. Please note that the monitors are checked every period epochs. ifsave_top_k >= 2
and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with v0.mode¶ (
str
) – one of {auto, min, max}. Ifsave_top_k != 0
, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.save_weights_only¶ (
bool
) – ifTrue
, then only the model’s weights will be saved (model.save_weights(filepath)
), else the full model is saved (model.save(filepath)
).period¶ (
int
) – Interval (number of epochs) between checkpoints.
Example:
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import ModelCheckpoint # saves checkpoints to 'my/path/' whenever 'val_loss' has a new min >>> checkpoint_callback = ModelCheckpoint(filepath='my/path/') >>> trainer = Trainer(checkpoint_callback=checkpoint_callback) # save epoch and val_loss in name # saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt >>> checkpoint_callback = ModelCheckpoint( ... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}' ... ) # retrieve the best checkpoint after training checkpoint_callback = ModelCheckpoint(filepath='my/path/') trainer = Trainer(checkpoint_callback=checkpoint_callback) model = ... trainer.fit(model) checkpoint_callback.best_model_path
-
format_checkpoint_name
(epoch, metrics, ver=None)[source]¶ Generate a filename according to the defined template.
Example:
>>> tmpdir = os.path.dirname(__file__) >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}')) >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) 'epoch=0.ckpt' >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}')) >>> os.path.basename(ckpt.format_checkpoint_name(5, {})) 'epoch=005.ckpt' >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}')) >>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456))) 'epoch=2-val_loss=0.12.ckpt' >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}')) >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) 'missing=0.ckpt'
-
on_train_start
(trainer, pl_module)[source]¶ Determines model checkpoint save directory at runtime. References attributes from the trainer’s logger to determine where to save checkpoints. The base path for saving weights is set in this priority:
Checkpoint callback’s path (if passed in)
The default_root_dir from trainer if trainer has no logger
The weights_save_path from trainer, if user provides it
User provided weights_saved_path
The base path gets extended with logger name and version (if these are available) and subfolder “checkpoints”.