Shortcuts

Save DataModule stateΒΆ

When a checkpoint is created, it asks every DataModule for their state. If your DataModule defines the state_dict and load_state_dict methods, the checkpoint will automatically track and restore your DataModules.

class LitDataModule(pl.DataModuler):
    def state_dict(self):
        # track whatever you want here
        state = {"current_train_batch_index": self.current_train_batch_index}
        return state

    def load_state_dict(self, state_dict):
        # restore the state based on what you tracked in (def state_dict)
        self.current_train_batch_index = state_dict["current_train_batch_index"]

© Copyright Copyright (c) 2018-2022, Lightning AI et al... Revision dbb5ca8d.

Built with Sphinx using a theme provided by Read the Docs.