.. _checkpointing_advanced:
##################################
Cloud-based checkpoints (advanced)
##################################
*****************
Cloud checkpoints
*****************
Lightning is integrated with the major remote file systems including local filesystems and several cloud storage providers such as
`S3 `_ on `AWS `_, `GCS `_ on `Google Cloud `_,
or `ADL `_ on `Azure `_.
PyTorch Lightning uses `fsspec `_ internally to handle all filesystem operations.
----
Save a cloud checkpoint
=======================
To save to a remote filesystem, prepend a protocol like "s3:/" to the root_dir used for writing and reading model data.
.. code-block:: python
# `default_root_dir` is the default path used for logs and checkpoints
trainer = Trainer(default_root_dir="s3://my_bucket/data/")
trainer.fit(model)
----
Resume training from a cloud checkpoint
=======================================
To resume training from a cloud checkpoint use a cloud url.
.. code-block:: python
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
trainer.fit(model, ckpt_path="s3://my_bucket/ckpts/classifier.ckpt")
PyTorch Lightning uses `fsspec `_ internally to handle all filesystem operations.
----
***************************
Modularize your checkpoints
***************************
Checkpoints can also save the state of :doc:`datamodules <../extensions/datamodules_state>` and :doc:`callbacks <../extensions/callbacks_state>`.
----
****************************
Modify a checkpoint anywhere
****************************
When you need to change the components of a checkpoint before saving or loading, use the :meth:`~lightning.pytorch.core.hooks.CheckpointHooks.on_save_checkpoint` and :meth:`~lightning.pytorch.core.hooks.CheckpointHooks.on_load_checkpoint` of your ``LightningModule``.
.. code-block:: python
class LitModel(L.LightningModule):
def on_save_checkpoint(self, checkpoint):
checkpoint["something_cool_i_want_to_save"] = my_cool_pickable_object
def on_load_checkpoint(self, checkpoint):
my_cool_pickable_object = checkpoint["something_cool_i_want_to_save"]
Use the above approach when you need to couple this behavior to your LightningModule for reproducibility reasons. Otherwise, Callbacks also have the :meth:`~lightning.pytorch.callbacks.callback.Callback.on_save_checkpoint` and :meth:`~lightning.pytorch.callbacks.callback.Callback.on_load_checkpoint` which you should use instead:
.. code-block:: python
import lightning as L
class LitCallback(L.Callback):
def on_save_checkpoint(self, checkpoint):
checkpoint["something_cool_i_want_to_save"] = my_cool_pickable_object
def on_load_checkpoint(self, checkpoint):
my_cool_pickable_object = checkpoint["something_cool_i_want_to_save"]
----
********************************
Resume from a partial checkpoint
********************************
Loading a checkpoint is normally "strict", meaning parameter names in the checkpoint must match the parameter names in the model or otherwise PyTorch will raise an error.
In use cases where you want to load only a partial checkpoint, you can disable strict loading by setting ``self.strict_loading = False`` in the LightningModule to avoid errors.
A common use case is when you have a pretrained feature extractor or encoder that you don't update during training, and you don't want it included in the checkpoint:
.. code-block:: python
import lightning as L
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
# This model only trains the decoder, we don't save the encoder
self.encoder = from_pretrained(...).requires_grad_(False)
self.decoder = Decoder()
# Set to False because we only care about the decoder
self.strict_loading = False
def state_dict(self):
# Don't save the encoder, it is not being trained
return {k: v for k, v in super().state_dict().items() if "encoder" not in k}
Since ``strict_loading`` is set to ``False``, you won't get any key errors when resuming the checkpoint with the Trainer:
.. code-block:: python
trainer = Trainer()
model = LitModel()
# Will load weights with `.load_state_dict(strict=model.strict_loading)`
trainer.fit(model, ckpt_path="path/to/checkpoint")