Shortcuts

pytorch_lightning.trainer.training_io module

Lightning can automate saving and loading checkpoints

Checkpointing is enabled by default to the current working directory. To change the checkpoint path pass in:

Trainer(default_root_dir='/your/path/to/save/checkpoints')

To modify the behavior of checkpointing pass in your own callback.

from pytorch_lightning.callbacks import ModelCheckpoint

# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
    filepath=os.getcwd(),
    save_top_k=1,
    verbose=True,
    monitor='val_loss',
    mode='min',
    prefix=''
)

trainer = Trainer(checkpoint_callback=checkpoint_callback)

Restoring training session

You might want to not only load a model but also continue training it. Use this method to restore the trainer state as well. This will continue from the epoch and global step you last left off. However, the dataloaders will start from the first batch again (if you shuffled it shouldn’t matter).

Lightning will restore the session if you pass a logger with the same version and there’s a saved checkpoint.

from pytorch_lightning import Trainer

trainer = Trainer(
    resume_from_checkpoint=PATH
)

# this fit call loads model weights and trainer state
# the trainer continues seamlessly from where you left off
# without having to do anything else.
trainer.fit(model)

The trainer restores:

  • global_step

  • current_epoch

  • All optimizers

  • All lr_schedulers

  • Model weights

You can even change the logic of your model as long as the weights and “architecture” of the system isn’t different. If you add a layer, for instance, it might not work.

At a rough level, here’s what happens inside Trainer pytorch_lightning.base_module.saving.py:

self.global_step = checkpoint['global_step']
self.current_epoch = checkpoint['epoch']

# restore the optimizers
optimizer_states = checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
    optimizer.load_state_dict(opt_state)

# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
    scheduler['scheduler'].load_state_dict(lrs_state)

# uses the model you passed into trainer
model.load_state_dict(checkpoint['state_dict'])
class pytorch_lightning.trainer.training_io.TrainerIOMixin[source]

Bases: abc.ABC

_atomic_save(checkpoint, filepath)[source]

Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.

This will create a temporary checkpoint with a suffix of .part, then copy it to the final location once saving is finished.

Parameters
  • checkpoint – The object to save. Built to be used with the dump_checkpoint method, but can deal with anything which torch.save accepts.

  • filepath (str) – The path to which the checkpoint will be saved. This points to the file that the checkpoint will be stored in.

dump_checkpoint(weights_only=False)[source]

Creating model checkpoint.

Parameters

weights_only (bool) – saving model weights only

Return type

dict

Returns

structured dictionary

get_model()[source]
hpc_load(folderpath, on_gpu)[source]
hpc_save(folderpath, logger)[source]
max_ckpt_in_folder(path, name_key='ckpt_')[source]
register_slurm_signal_handlers()[source]
restore(checkpoint_path, on_gpu)[source]

Restore training state from checkpoint. Also restores all training state like: - epoch - callbacks - schedulers - optimizer

restore_hpc_weights_if_needed(model)[source]

If there is a set of hpc weights, use as signal to restore model.

restore_training_state(checkpoint)[source]

Restore trainer state. Model will get its change to update :param _sphinx_paramlinks_pytorch_lightning.trainer.training_io.TrainerIOMixin.restore_training_state.checkpoint: :return:

restore_weights(model)[source]

We attempt to restore weights in this order: 1. HPC weights. 2. if no HPC weights restore checkpoint_path weights 3. otherwise don’t restore weights

save_checkpoint(filepath, weights_only=False)[source]
sig_handler(signum, frame)[source]
term_handler(signum, frame)[source]
accumulate_grad_batches: int = None[source]
checkpoint_callback: ... = None[source]
early_stop_callback: ... = None[source]
global_rank: int = None[source]
logger: LightningLoggerBase = None[source]
lr_schedulers: ... = None[source]
model: LightningModule = None[source]
num_training_batches: int = None[source]
on_gpu: bool = None[source]
on_tpu: bool = None[source]
optimizers: ... = None[source]
resume_from_checkpoint: ... = None[source]
root_gpu: ... = None[source]
scaler: ... = None[source]
use_amp: bool = None[source]
use_ddp: bool = None[source]
use_ddp2: bool = None[source]
use_horovod: bool = None[source]
weights_save_path: str = None[source]