Shortcuts

Model Hooks

There are cases when you might want to do something different at different parts of the training/validation loop. To enable a hook, simply override the method in your LightningModule and the trainer will call it at the correct time.

Contributing If there’s a hook you’d like to add, simply:

  1. Fork PyTorchLightning.

  2. Add the hook to pytorch_lightning.core.hooks.ModelHooks.

  3. Add it in the correct place in pytorch_lightning.trainer where it should be called.


Hooks lifecycle

Training set-up

  • prepare_data()

  • setup()

  • init_optimizers()

  • configure_apex()

  • train_dataloader()

  • test_dataloader()

  • val_dataloader()

  • summarize()

  • restore_weights()

Warning

prepare_data is only called from global_rank=0. Don’t assign state (self.something), use setup for that


Training loop

  • on_epoch_start()

  • on_batch_start()

  • tbptt_split_batch()

  • training_step()

  • training_step_end() (optional)

  • on_before_zero_grad()

  • backward()

  • on_after_backward()

  • optimizer.step()

  • on_batch_end()

  • training_epoch_end()

  • on_epoch_end()


Validation loop

  • model.zero_grad()

  • model.eval()

  • torch.set_grad_enabled(False)

  • validation_step()

  • validation_step_end()

  • validation_epoch_end()

  • model.train()

  • torch.set_grad_enabled(True)

  • on_post_performance_check()


Test loop

  • model.zero_grad()

  • model.eval()

  • torch.set_grad_enabled(False)

  • test_step()

  • test_step_end()

  • test_epoch_end()

  • model.train()

  • torch.set_grad_enabled(True)

  • on_post_performance_check()


General hooks

class pytorch_lightning.core.hooks.ModelHooks[source]

Bases: object

backward(trainer, loss, optimizer, optimizer_idx)[source]

Override backward with your own implementation if you need to.

Parameters
  • trainer – Pointer to the trainer

  • loss (Tensor) – Loss is already scaled by accumulated grads

  • optimizer (Optimizer) – Current optimizer being used

  • optimizer_idx (int) – Index of the current optimizer being used

Called to perform backward step. Feel free to override as needed.

The loss passed in has already been scaled for accumulated gradients if requested.

Example:

def backward(self, trainer, loss, optimizer, optimizer_idx):
    loss.backward()
Return type

None

on_after_backward()[source]

Called in the training loop after loss.backward() and before optimizers do anything. This is the ideal place to inspect or log gradient information.

Example:

def on_after_backward(self):
    # example to inspect gradient information in tensorboard
    if self.trainer.global_step % 25 == 0:  # don't make the tf file huge
        params = self.state_dict()
        for k, v in params.items():
            grads = v
            name = k
            self.logger.experiment.add_histogram(tag=name, values=grads,
                                                 global_step=self.trainer.global_step)
Return type

None

on_batch_end()[source]

Called in the training loop after the batch.

Warning

Deprecated in 0.9.0 will remove 1.0.0 (use on_train_batch_end instead)

Return type

None

on_batch_start(batch)[source]

Called in the training loop before anything happens for that batch.

If you return -1 here, you will skip training for the rest of the current epoch.

Parameters

batch (Any) – The batched data as it is returned by the training DataLoader.

Warning

Deprecated in 0.9.0 will remove 1.0.0 (use on_train_batch_start instead)

Return type

None

on_before_zero_grad(optimizer)[source]

Called after optimizer.step() and before optimizer.zero_grad().

Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.

This is where it is called:

for optimizer in optimizers:
    optimizer.step()
    model.on_before_zero_grad(optimizer) # < ---- called here
    optimizer.zero_grad()
Parameters

optimizer (Optimizer) – The optimizer for which grads should be zeroed.

Return type

None

on_epoch_end()[source]

Called in the training loop at the very end of the epoch.

Return type

None

on_epoch_start()[source]

Called in the training loop at the very beginning of the epoch.

Return type

None

on_fit_end()[source]

Called at the very end of fit. If on DDP it is called on every process

on_fit_start()[source]

Called at the very beginning of fit. If on DDP it is called on every process

on_pretrain_routine_end()[source]

Called at the end of the pretrain routine (between fit and train start).

  • fit

  • pretrain_routine start

  • pretrain_routine end

  • training_start

Return type

None

on_pretrain_routine_start()[source]

Called at the beginning of the pretrain routine (between fit and train start).

  • fit

  • pretrain_routine start

  • pretrain_routine end

  • training_start

Return type

None

on_test_batch_end(batch, batch_idx, dataloader_idx)[source]

Called in the test loop after the batch.

Parameters
  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_test_batch_start(batch, batch_idx, dataloader_idx)[source]

Called in the test loop before anything happens for that batch.

Parameters
  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_test_epoch_end()[source]

Called in the test loop at the very end of the epoch.

Return type

None

on_test_epoch_start()[source]

Called in the test loop at the very beginning of the epoch.

Return type

None

on_test_model_eval()[source]

Sets the model to eval during the test loop

Return type

None

on_test_model_train()[source]

Sets the model to train during the test loop

Return type

None

on_train_batch_end(batch, batch_idx, dataloader_idx)[source]

Called in the training loop after the batch.

Parameters
  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_train_batch_start(batch, batch_idx, dataloader_idx)[source]

Called in the training loop before anything happens for that batch.

If you return -1 here, you will skip training for the rest of the current epoch.

Parameters
  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_train_end()[source]

Called at the end of training before logger experiment is closed.

Return type

None

on_train_epoch_end()[source]

Called in the training loop at the very end of the epoch.

Return type

None

on_train_epoch_start()[source]

Called in the training loop at the very beginning of the epoch.

Return type

None

on_train_start()[source]

Called at the beginning of training before sanity check.

Return type

None

on_validation_batch_end(batch, batch_idx, dataloader_idx)[source]

Called in the validation loop after the batch.

Parameters
  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_validation_batch_start(batch, batch_idx, dataloader_idx)[source]

Called in the validation loop before anything happens for that batch.

Parameters
  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_validation_epoch_end()[source]

Called in the validation loop at the very end of the epoch.

Return type

None

on_validation_epoch_start()[source]

Called in the validation loop at the very beginning of the epoch.

Return type

None

on_validation_model_eval()[source]

Sets the model to eval during the val loop

Return type

None

on_validation_model_train()[source]

Sets the model to train during the val loop

Return type

None

setup(stage)[source]

Called at the beginning of fit and test. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters

stage (str) – either ‘fit’ or ‘test’

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(stage):
        data = Load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
teardown(stage)[source]

Called at the end of fit and test.

Parameters

stage (str) – either ‘fit’ or ‘test’

class pytorch_lightning.core.hooks.DataHooks[source]

Bases: object

prepare_data()[source]

Use this to download and prepare data.

Warning

DO NOT set state to the model (use setup instead) since this is NOT called on every GPU in DDP/TPU

Example:

def prepare_data(self):
    # good
    download_data()
    tokenize()
    etc()

    # bad
    self.split = data_split
    self.some_state = some_other_state()

In DDP prepare_data can be called in two ways (using Trainer(prepare_data_per_node)):

  1. Once per node. This is the default and is only called on LOCAL_RANK=0.

  2. Once in total. Only called on GLOBAL_RANK=0.

Example:

# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
Trainer(prepare_data_per_node=True)

# call on GLOBAL_RANK=0 (great for shared file systems)
Trainer(prepare_data_per_node=False)

This is called before requesting the dataloaders:

model.prepare_data()
    if ddp/tpu: init()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
Return type

None

test_dataloader()[source]

Implement one or multiple PyTorch DataLoaders for testing.

The dataloader you return will not be called every epoch unless you set reload_dataloaders_every_epoch to True.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • fit()

  • prepare_data()

  • setup()

  • train_dataloader()

  • val_dataloader()

  • test_dataloader()

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Return type

Union[DataLoader, List[DataLoader]]

Returns

Single or multiple PyTorch DataLoaders.

Example

def test_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def test_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

Note

In the case where you return multiple test dataloaders, the test_step() will have an argument dataloader_idx which matches the order here.

train_dataloader()[source]

Implement a PyTorch DataLoader for training.

Return type

DataLoader

Returns

Single PyTorch DataLoader.

The dataloader you return will not be called every epoch unless you set reload_dataloaders_every_epoch to True.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • fit()

  • prepare_data()

  • setup()

  • train_dataloader()

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Example

def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=True
    )
    return loader
transfer_batch_to_device(batch, device)[source]

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

The data types listed below (and any arbitrary nesting of them) are supported out of the box:

For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, …).

Example:

def transfer_batch_to_device(self, batch, device)
    if isinstance(batch, CustomBatch):
        # move all tensors in your custom data structure to the device
        batch.samples = batch.samples.to(device)
        batch.targets = batch.targets.to(device)
    else:
        batch = super().transfer_batch_to_device(data, device)
    return batch
Parameters
  • batch (Any) – A batch of data that needs to be transferred to a new device.

  • device (device) – The target device as defined in PyTorch.

Return type

Any

Returns

A reference to the data on the new device.

Note

This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing).

Note

This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support for your custom batch objects, you need to define your custom DistributedDataParallel or LightningDistributedDataParallel and override configure_ddp().

See also

  • move_data_to_device()

  • apply_to_collection()

val_dataloader()[source]

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be called every epoch unless you set reload_dataloaders_every_epoch to True.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • prepare_data()

  • train_dataloader()

  • val_dataloader()

  • test_dataloader()

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Return type

Union[DataLoader, List[DataLoader]]

Returns

Single or multiple PyTorch DataLoaders.

Examples

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Note

In the case where you return multiple validation dataloaders, the validation_step() will have an argument dataloader_idx which matches the order here.