model_checkpoint¶
Classes
Save the model after every epoch by monitoring a quantity. |
Model Checkpointing¶
Automatically save model checkpoints during training.
-
class
pytorch_lightning.callbacks.model_checkpoint.
ModelCheckpoint
(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=None, save_weights_only=False, mode='auto', period=1, prefix='')[source]¶ Bases:
pytorch_lightning.callbacks.base.Callback
Save the model after every epoch by monitoring a quantity.
After training finishes, use
best_model_path
to retrieve the path to the best checkpoint file andbest_model_score
to retrieve its score.- Parameters
dirpath¶ (
Union
[str
,Path
,None
]) –directory to save the model file.
Example:
# custom path # saves a file like: my/path/epoch=0-step=10.ckpt >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
By default, dirpath is
None
and will be set at runtime to the location specified byTrainer
’sdefault_root_dir
orweights_save_path
arguments, and if the Trainer uses a logger, the path will also contain logger name and version.checkpoint filename. Can contain named formatting options to be auto-filled.
Example:
# save any arbitrary metrics like `val_loss`, etc. in name # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt >>> checkpoint_callback = ModelCheckpoint( ... dirpath='my/path', ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... )
By default, filename is
None
and will be set to'{epoch}-{step}'
.monitor¶ (
Optional
[str
]) – quantity to monitor. By default it isNone
which saves a checkpoint only for the last epoch.save_last¶ (
Optional
[bool
]) – WhenTrue
, always saves the model at the end of the epoch to a file last.ckpt. Default:None
.save_top_k¶ (
Optional
[int
]) – ifsave_top_k == k
, the best k models according to the quantity monitored will be saved. ifsave_top_k == 0
, no models are saved. ifsave_top_k == -1
, all models are saved. Please note that the monitors are checked everyperiod
epochs. ifsave_top_k >= 2
and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting withv1
.one of {auto, min, max}. If
save_top_k != 0
, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.Warning
Setting
mode='auto'
has been deprecated in v1.1 and will be removed in v1.3.save_weights_only¶ (
bool
) – ifTrue
, then only the model’s weights will be saved (model.save_weights(filepath)
), else the full model is saved (model.save(filepath)
).period¶ (
int
) – Interval (number of epochs) between checkpoints.A string to put at the beginning of checkpoint filename.
Warning
This argument has been deprecated in v1.1 and will be removed in v1.3
Note
For extra customization, ModelCheckpoint includes the following attributes:
CHECKPOINT_JOIN_CHAR = "-"
CHECKPOINT_NAME_LAST = "last"
FILE_EXTENSION = ".ckpt"
STARTING_VERSION = 1
For example, you can change the default last checkpoint name by doing
checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"
- Raises
MisconfigurationException – If
save_top_k
is neitherNone
nor more than or equal to-1
, ifmonitor
isNone
andsave_top_k
is none ofNone
,-1
, and0
, or ifmode
is none of"min"
,"max"
, and"auto"
.ValueError – If
trainer.save_checkpoint
isNone
.
Example:
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import ModelCheckpoint # saves checkpoints to 'my/path/' at every epoch >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') >>> trainer = Trainer(callbacks=[checkpoint_callback]) # save epoch and val_loss in name # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt >>> checkpoint_callback = ModelCheckpoint( ... monitor='val_loss', ... dirpath='my/path/', ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' ... ) # retrieve the best checkpoint after training checkpoint_callback = ModelCheckpoint(dirpath='my/path/') trainer = Trainer(callbacks=[checkpoint_callback]) model = ... trainer.fit(model) checkpoint_callback.best_model_path
-
file_exists
(filepath, trainer)[source]¶ Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.
- Return type
-
format_checkpoint_name
(epoch, step, metrics, ver=None)[source]¶ Generate a filename according to the defined template.
Example:
>>> tmpdir = os.path.dirname(__file__) >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}') >>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={})) 'epoch=0.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}') >>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={})) 'epoch=005.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}') >>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456))) 'epoch=2-val_loss=0.12.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}') >>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={})) 'missing=0.ckpt' >>> ckpt = ModelCheckpoint(filename='{step}') >>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {})) 'step=0.ckpt'
- Return type
-
on_load_checkpoint
(callback_state)[source]¶ Called when loading a model checkpoint, use to reload state.
-
on_pretrain_routine_start
(trainer, pl_module)[source]¶ When pretrain routine starts we build the ckpt dir on the fly
-
on_save_checkpoint
(trainer, pl_module, checkpoint)[source]¶ Called when saving a model checkpoint, use to persist state.