Cloud-based checkpoints (advanced)

Cloud checkpoints

Lightning is integrated with the major remote file systems including local filesystems and several cloud storage providers such as S3 on AWS, GCS on Google Cloud, or ADL on Azure.

PyTorch Lightning uses fsspec internally to handle all filesystem operations.


Save a cloud checkpoint

To save to a remote filesystem, prepend a protocol like “s3:/” to the root_dir used for writing and reading model data.

# `default_root_dir` is the default path used for logs and checkpoints
trainer = Trainer(default_root_dir="s3://my_bucket/data/")
trainer.fit(model)

Resume training from a cloud checkpoint

To resume training from a cloud checkpoint use a cloud url.

trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
trainer.fit(model, ckpt_path="s3://my_bucket/ckpts/classifier.ckpt")

PyTorch Lightning uses fsspec internally to handle all filesystem operations.


Modularize your checkpoints

Checkpoints can also save the state of datamodules and callbacks.


Modify a checkpoint anywhere

When you need to change the components of a checkpoint before saving or loading, use the on_save_checkpoint() and on_load_checkpoint() of your LightningModule.

class LitModel(L.LightningModule):
    def on_save_checkpoint(self, checkpoint):
        checkpoint["something_cool_i_want_to_save"] = my_cool_pickable_object

    def on_load_checkpoint(self, checkpoint):
        my_cool_pickable_object = checkpoint["something_cool_i_want_to_save"]

Use the above approach when you need to couple this behavior to your LightningModule for reproducibility reasons. Otherwise, Callbacks also have the on_save_checkpoint() and on_load_checkpoint() which you should use instead:

import lightning as L


class LitCallback(L.Callback):
    def on_save_checkpoint(self, checkpoint):
        checkpoint["something_cool_i_want_to_save"] = my_cool_pickable_object

    def on_load_checkpoint(self, checkpoint):
        my_cool_pickable_object = checkpoint["something_cool_i_want_to_save"]

Resume from a partial checkpoint

Loading a checkpoint is normally “strict”, meaning parameter names in the checkpoint must match the parameter names in the model or otherwise PyTorch will raise an error. In use cases where you want to load only a partial checkpoint, you can disable strict loading by setting self.strict_loading = False in the LightningModule to avoid errors. A common use case is when you have a pretrained feature extractor or encoder that you don’t update during training, and you don’t want it included in the checkpoint:

import lightning as L

class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()

        # This model only trains the decoder, we don't save the encoder
        self.encoder = from_pretrained(...).requires_grad_(False)
        self.decoder = Decoder()

        # Set to False because we only care about the decoder
        self.strict_loading = False

    def state_dict(self):
        # Don't save the encoder, it is not being trained
        return {k: v for k, v in super().state_dict().items() if "encoder" not in k}

Since strict_loading is set to False, you won’t get any key errors when resuming the checkpoint with the Trainer:

trainer = Trainer()
model = LitModel()

# Will load weights with `.load_state_dict(strict=model.strict_loading)`
trainer.fit(model, ckpt_path="path/to/checkpoint")