Saving and loading weights¶
Lightning can automate saving and loading checkpoints.
Checkpoint saving¶
A Lightning checkpoint has everything needed to restore a training session including:
16-bit scaling factor (apex)
Current epoch
Global step
Model state_dict
State of all optimizers
State of all learningRate schedulers
State of all callbacks
The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
Automatic saving¶
Checkpointing is enabled by default to the current working directory. To change the checkpoint path pass in:
trainer = Trainer(default_save_path='/your/path/to/save/checkpoints')
To modify the behavior of checkpointing pass in your own callback.
from pytorch_lightning.callbacks import ModelCheckpoint
# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(),
save_top_k=True,
verbose=True,
monitor='val_loss',
mode='min',
prefix=''
)
trainer = Trainer(checkpoint_callback=checkpoint_callback)
Or disable it by passing
trainer = Trainer(checkpoint_callback=False)
The Lightning checkpoint also saves the hparams (hyperparams) passed into the LightningModule init.
Note
hparams is a Namespace.
from argparse import Namespace
# usually these come from command line args
args = Namespace(learning_rate=0.001)
# define you module to have hparams as the first arg
# this means your checkpoint will have everything that went into making
# this model (in this case, learning rate)
class MyLightningModule(LightningModule):
def __init__(self, hparams, *args, **kwargs):
self.hparams = hparams
Manual saving¶
You can manually save checkpoints and restore your model from the checkpointed state.
model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")
Checkpoint Loading¶
To load a model along with its weights, biases and hyperparameters use following method.
model = MyLightingModule.load_from_checkpoint(PATH)
model.eval()
y_hat = model(x)
The above only works if you used hparams in your model definition
class LitModel(LightningModule):
def __init__(self, hparams):
self.hparams = hparams
self.l1 = nn.Linear(hparams.in_dim, hparams.out_dim)
But if you don’t and instead pass individual parameters
class LitModel(LightningModule):
def __init__(self, in_dim, out_dim):
self.l1 = nn.Linear(in_dim, out_dim)
you can restore the model like this
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
Restoring Training State¶
If you don’t just want to load weights, but instead restore the full training, do the following:
model = LitModel()
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
# automatically restores model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model)