Custom Checkpointing IO¶
The Checkpoint IO API is experimental and subject to change.
Lightning supports modifying the checkpointing save/load functionality through the
CheckpointIO. This encapsulates the save/load logic
that is managed by the
CheckpointIO can be extended to include your custom save/load functionality to and from a path. The
CheckpointIO object can be passed to either a
Trainer object or a
TrainingTypePlugin as shown below.
from pathlib import Path from typing import Any, Dict, Optional, Union from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.plugins import CheckpointIO, SingleDevicePlugin class CustomCheckpointIO(CheckpointIO): def save_checkpoint( self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None ) -> None: ... def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]: ... custom_checkpoint_io = CustomCheckpointIO() # Pass into the Trainer object model = MyModel() trainer = Trainer( plugins=[custom_checkpoint_io], callbacks=ModelCheckpoint(save_last=True), ) trainer.fit(model) # pass into TrainingTypePlugin model = MyModel() device = torch.device("cpu") trainer = Trainer( plugins=SingleDevicePlugin(device, checkpoint_io=custom_checkpoint_io), callbacks=ModelCheckpoint(save_last=True), ) trainer.fit(model)
TrainingTypePlugins do not support custom
CheckpointIO as as checkpointing logic is not modifiable.