Saving and loading weights¶
Lightning can automate saving and loading checkpoints.
Checkpoint saving¶
A Lightning checkpoint has everything needed to restore a training session including:
16-bit scaling factor (apex)
Current epoch
Global step
Model state_dict
State of all optimizers
State of all learningRate schedulers
State of all callbacks
The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
Automatic saving¶
Checkpointing is enabled by default to the current working directory. To change the checkpoint path pass in:
trainer = Trainer(default_root_dir='/your/path/to/save/checkpoints')
To modify the behavior of checkpointing pass in your own callback.
from pytorch_lightning.callbacks import ModelCheckpoint
# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(),
save_top_k=1,
verbose=True,
monitor='checkpoint_on',
mode='min',
prefix=''
)
trainer = Trainer(checkpoint_callback=checkpoint_callback)
Or disable it by passing
trainer = Trainer(checkpoint_callback=False)
The Lightning checkpoint also saves the arguments passed into the LightningModule init under the module_arguments key in the checkpoint.
class MyLightningModule(LightningModule):
def __init__(self, learning_rate, *args, **kwargs):
super().__init__()
# all init args were saved to the checkpoint
checkpoint = torch.load(CKPT_PATH)
print(checkpoint['module_arguments'])
# {'learning_rate': the_value}
Manual saving¶
You can manually save checkpoints and restore your model from the checkpointed state.
model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")
Checkpoint Loading¶
To load a model along with its weights, biases and module_arguments use following method.
model = MyLightingModule.load_from_checkpoint(PATH)
print(model.learning_rate)
# prints the learning_rate you used in this checkpoint
model.eval()
y_hat = model(x)
But if you don’t want to use the values saved in the checkpoint, pass in your own here
class LitModel(LightningModule):
def __init__(self, in_dim, out_dim):
super().__init__()
self.save_hyperparameters()
self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)
you can restore the model like this
# if you train and save the model like this it will use these values when loading
# the weights. But you can overwrite this
LitModel(in_dim=32, out_dim=10)
# uses in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)
# uses in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
Restoring Training State¶
If you don’t just want to load weights, but instead restore the full training, do the following:
model = LitModel()
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
# automatically restores model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model)