Shortcuts

Saving and loading weights

Lightning can automate saving and loading checkpoints.

Checkpoint saving

A Lightning checkpoint has everything needed to restore a training session including:

  • 16-bit scaling factor (apex)

  • Current epoch

  • Global step

  • Model state_dict

  • State of all optimizers

  • State of all learningRate schedulers

  • State of all callbacks

  • The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)

Automatic saving

Checkpointing is enabled by default to the current working directory. To change the checkpoint path pass in:

trainer = Trainer(default_save_path='/your/path/to/save/checkpoints')

To modify the behavior of checkpointing pass in your own callback.

from pytorch_lightning.callbacks import ModelCheckpoint

# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
    filepath=os.getcwd(),
    save_top_k=True,
    verbose=True,
    monitor='val_loss',
    mode='min',
    prefix=''
)

trainer = Trainer(checkpoint_callback=checkpoint_callback)

Or disable it by passing

trainer = Trainer(checkpoint_callback=False)

The Lightning checkpoint also saves the hparams (hyperparams) passed into the LightningModule init.

Note

hparams is a Namespace.

from argparse import Namespace

# usually these come from command line args
args = Namespace(learning_rate=0.001)

# define you module to have hparams as the first arg
# this means your checkpoint will have everything that went into making
# this model (in this case, learning rate)
class MyLightningModule(LightningModule):

    def __init__(self, hparams, *args, **kwargs):
        self.hparams = hparams

Manual saving

You can manually save checkpoints and restore your model from the checkpointed state.

model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")

Checkpoint Loading

To load a model along with its weights, biases and hyperparameters use following method.

model = MyLightingModule.load_from_checkpoint(PATH)
model.eval()
y_hat = model(x)

The above only works if you used hparams in your model definition

class LitModel(LightningModule):

    def __init__(self, hparams):
        self.hparams = hparams
        self.l1 = nn.Linear(hparams.in_dim, hparams.out_dim)

But if you don’t and instead pass individual parameters

class LitModel(LightningModule):

    def __init__(self, in_dim, out_dim):
        self.l1 = nn.Linear(in_dim, out_dim)

you can restore the model like this

model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)

Restoring Training State

If you don’t just want to load weights, but instead restore the full training, do the following:

model = LitModel()
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')

# automatically restores model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model)