model_checkpoint¶
Classes
Save the model after every epoch by monitoring a quantity. |
Model Checkpointing¶
Automatically save model checkpoints during training.
-
class
pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=None, save_weights_only=False, mode='auto', period=1, prefix='')[source]¶ Bases:
pytorch_lightning.callbacks.base.CallbackSave the model after every epoch by monitoring a quantity.
After training finishes, use
best_model_pathto retrieve the path to the best checkpoint file andbest_model_scoreto retrieve its score.- Parameters
dirpath¶ (
Union[str,Path,None]) –directory to save the model file.
Example:
# custom path # saves a file like: my/path/epoch=0-step=10.ckpt >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
By default, dirpath is
Noneand will be set at runtime to the location specified byTrainer’sdefault_root_dirorweights_save_patharguments, and if the Trainer uses a logger, the path will also contain logger name and version.checkpoint filename. Can contain named formatting options to be auto-filled.
Example:
# save any arbitrary metrics like `val_loss`, etc. in name # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt >>> checkpoint_callback = ModelCheckpoint( ... dirpath='my/path', ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... )
By default, filename is
Noneand will be set to'{epoch}-{step}'.monitor¶ (
Optional[str]) – quantity to monitor. By default it isNonewhich saves a checkpoint only for the last epoch.save_last¶ (
Optional[bool]) – WhenTrue, always saves the model at the end of the epoch to a file last.ckpt. Default:None.save_top_k¶ (
Optional[int]) – ifsave_top_k == k, the best k models according to the quantity monitored will be saved. ifsave_top_k == 0, no models are saved. ifsave_top_k == -1, all models are saved. Please note that the monitors are checked everyperiodepochs. ifsave_top_k >= 2and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting withv1.one of {auto, min, max}. If
save_top_k != 0, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity.Warning
Setting
mode='auto'has been deprecated in v1.1 and will be removed in v1.3.save_weights_only¶ (
bool) – ifTrue, then only the model’s weights will be saved (model.save_weights(filepath)), else the full model is saved (model.save(filepath)).period¶ (
int) – Interval (number of epochs) between checkpoints.A string to put at the beginning of checkpoint filename.
Warning
This argument has been deprecated in v1.1 and will be removed in v1.3
Note
For extra customization, ModelCheckpoint includes the following attributes:
CHECKPOINT_JOIN_CHAR = "-"CHECKPOINT_NAME_LAST = "last"FILE_EXTENSION = ".ckpt"STARTING_VERSION = 1
For example, you can change the default last checkpoint name by doing
checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"- Raises
MisconfigurationException – If
save_top_kis neitherNonenor more than or equal to-1, ifmonitorisNoneandsave_top_kis none ofNone,-1, and0, or ifmodeis none of"min","max", and"auto".ValueError – If
trainer.save_checkpointisNone.
Example:
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import ModelCheckpoint # saves checkpoints to 'my/path/' at every epoch >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') >>> trainer = Trainer(callbacks=[checkpoint_callback]) # save epoch and val_loss in name # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt >>> checkpoint_callback = ModelCheckpoint( ... monitor='val_loss', ... dirpath='my/path/', ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' ... ) # retrieve the best checkpoint after training checkpoint_callback = ModelCheckpoint(dirpath='my/path/') trainer = Trainer(callbacks=[checkpoint_callback]) model = ... trainer.fit(model) checkpoint_callback.best_model_path
-
file_exists(filepath, trainer)[source]¶ Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.
- Return type
-
format_checkpoint_name(epoch, step, metrics, ver=None)[source]¶ Generate a filename according to the defined template.
Example:
>>> tmpdir = os.path.dirname(__file__) >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}') >>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={})) 'epoch=0.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}') >>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={})) 'epoch=005.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}') >>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456))) 'epoch=2-val_loss=0.12.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}') >>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={})) 'missing=0.ckpt' >>> ckpt = ModelCheckpoint(filename='{step}') >>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {})) 'step=0.ckpt'
- Return type
-
on_load_checkpoint(callback_state)[source]¶ Called when loading a model checkpoint, use to reload state.
-
on_pretrain_routine_start(trainer, pl_module)[source]¶ When pretrain routine starts we build the ckpt dir on the fly
-
on_save_checkpoint(trainer, pl_module, checkpoint)[source]¶ Called when saving a model checkpoint, use to persist state.