Skip to content

Lightning Module interface

[Github Code]

A lightning module is a strict superclass of nn.Module, it provides a standard interface for the trainer to interact with the model.

The easiest thing to do is copy the minimal example below and modify accordingly.

Otherwise, to Define a Lightning Module, implement the following methods:

Required:

Optional:


Minimal example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import pytorch_lightning as pl

class CoolModel(pl.LightningModule):

    def __init__(self):
        super(CoolModel, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        return {'loss': F.cross_entropy(y_hat, y)}

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'avg_val_loss': avg_loss}

    def test_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        return {'avg_test_loss': avg_loss}

    def configure_optimizers(self):
        # REQUIRED
        return torch.optim.Adam(self.parameters(), lr=0.02)

    @pl.data_loader
    def train_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    @pl.data_loader
    def val_dataloader(self):
        # OPTIONAL
        # can also return a list of val dataloaders
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    @pl.data_loader
    def test_dataloader(self):
        # OPTIONAL
        # can also return a list of test dataloaders
        return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)

How do these methods fit into the broader training?

The LightningModule interface is on the right. Each method corresponds to a part of a research project. Lightning automates everything not in blue.

Required Methods

training_step

1
def training_step(self, batch, batch_nb)

In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something specific to your model.

Params

Param description
batch The output of your dataloader. A tensor, tuple or list
batch_nb Integer displaying which batch this is

Return

Dictionary or OrderedDict

key value is required
loss tensor scalar Y
progress_bar Dict for progress bar display. Must have only tensors N
log Dict of metrics to add to logger. Must have only tensors (no images, etc) N

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def training_step(self, batch, batch_nb):
    x, y, z = batch

    # implement your own
    out = self.forward(x)
    loss = self.loss(out, x)

    logger_logs = {'training_loss': loss} # optional (MUST ALL BE TENSORS)

    # if using TestTubeLogger or TensorboardLogger you can nest scalars
    logger_logs = {'losses': logger_logs} # optional (MUST ALL BE TENSORS)

    output = {
        'loss': loss, # required
        'progress_bar': {'training_loss': loss}, # optional (MUST ALL BE TENSORS)
        'log': logger_logs
    }

    # return a dict
    return output

If you define multiple optimizers, this step will also be called with an additional optimizer_idx param.

1
2
3
4
5
6
# Multiple optimizers (ie: GANs)     
def training_step(self, batch, batch_nb, optimizer_idx):
    if optimizer_idx == 0:
        # do training_step with encoder
    if optimizer_idx == 1:
        # do training_step with decoder    

If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.

1
2
3
# Truncated back-propagation through time   
def training_step(self, batch, batch_nb, hiddens):
    # hiddens are the hiddens from the previous truncated backprop step

You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to break out of the current training epoch early.


training_end

1
def training_end(self, train_step_outputs)

In certain cases (dp, ddp2), you might want to use all outputs of every process to do something. For instance, if using negative samples, you could run a batch via dp and use ALL the outputs for a single softmax across the full batch (ie: the denominator would use the full batch).

In this case you should define training_end to perform those calculations.

Params

Param description
outputs What you return in training_step.

Return

Dictionary or OrderedDict

key value is required
loss tensor scalar Y
progress_bar Dict for progress bar display. Must have only tensors N
log Dict of metrics to add to logger. Must have only tensors (no images, etc) N

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# WITHOUT training_end
# if used in DP or DDP2, this batch is 1/nb_gpus large
def training_step(self, batch, batch_nb):
    # batch is 1/nb_gpus big
    x, y = batch

    out = self.forward(x)
    loss = self.softmax(out)
    loss = nce_loss(loss)
    return {'loss': loss}

# --------------
# with training_end to do softmax over the full batch
def training_step(self, batch, batch_nb):
    # batch is 1/nb_gpus big
    x, y = batch

    out = self.forward(x)
    return {'out': out}

def training_end(self, outputs):
    # this out is now the full size of the batch
    out = outputs['out']

    # this softmax now uses the full batch size
    loss = self.softmax(out)
    loss = nce_loss(loss)
    return {'loss': loss}

If you define multiple optimizers, this step will also be called with an additional optimizer_idx param.

1
2
3
4
5
6
# Multiple optimizers (ie: GANs)     
def training_step(self, batch, batch_nb, optimizer_idx):
    if optimizer_idx == 0:
        # do training_step with encoder
    if optimizer_idx == 1:
        # do training_step with decoder    

If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.

1
2
3
# Truncated back-propagation through time   
def training_step(self, batch, batch_nb, hiddens):
    # hiddens are the hiddens from the previous truncated backprop step

You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to break out of the current training epoch early.


train_dataloader

1
2
@pl.data_loader
def train_dataloader(self)

Called by lightning during training loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
If you want to change the data during every epoch DON'T use the data_loader decorator.

Return

PyTorch DataLoader

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
@pl.data_loader
def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.hparams.batch_size,
        shuffle=True
    )
    return loader

configure_optimizers

1
def configure_optimizers(self)

Set up as many optimizers and (optionally) learning rate schedulers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple. Lightning will call .backward() and .step() on each one in every epoch. If you use 16 bit precision it will also handle that.

Note: If you use multiple optimizers, training_step will have an additional optimizer_idx parameter.
Note 2: If you use LBFGS lightning handles the closure function automatically for you.

Return

Return any of these 3 options:
Single optimizer
List or Tuple - List of optimizers
Two lists - The first list has multiple optimizers, the second a list of learning-rate schedulers

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
# most cases
def configure_optimizers(self):
    opt = Adam(self.parameters(), lr=0.01)
    return opt

# multiple optimizer case (eg: GAN)
def configure_optimizers(self):
    generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
    disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
    return generator_opt, disriminator_opt

# example with learning_rate schedulers  
def configure_optimizers(self):
    generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
    disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
    discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
    return [generator_opt, disriminator_opt], [discriminator_sched]

If you need to control how often those optimizers step or override the default .step() schedule, override the optimizer_step hook.

Optional Methods

validation_step

1
2
3
4
5
# if you have one val dataloader:
def validation_step(self, batch, batch_nb)   

# if you have multiple val dataloaders:  
def validation_step(self, batch, batch_nb, dataloader_idxdx)

OPTIONAL
If you don't need to validate you don't need to implement this method. In this step you'd normally generate examples or calculate anything of interest such as accuracy.

When the validation_step is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, model goes back to training mode and gradients are enabled.

The dict you return here will be available in the validation_end method.

Params

Param description
batch The output of your dataloader. A tensor, tuple or list
batch_nb Integer displaying which batch this is
dataloader_idx Integer displaying which dataloader this is (only if multiple val datasets used)

Return

Return description optional
dict Dict or OrderedDict - passed to the validation_end step N

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# CASE 1: A single validation dataset
def validation_step(self, batch, batch_nb):
    x, y = batch

    # implement your own
    out = self.forward(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0) 

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # all optional...
    # return whatever you need for the collation function validation_end
    output = OrderedDict({
        'val_loss': loss_val,
        'val_acc': torch.tensor(val_acc), # everything must be a tensor
    })

    # return an optional dict
    return output

If you pass in multiple validation datasets, validation_step will have an additional argument.

1
2
3
# CASE 2: multiple validation datasets
def validation_step(self, batch, batch_nb, dataset_idx):
    # dataset_idx tells you which dataset this is.   

The dataset_idx corresponds to the order of datasets returned in val_dataloader.


validation_end

1
def validation_end(self, outputs)

If you didn't define a validation_step, this won't be called. Called at the end of the validation loop with the outputs of validation_step.

The outputs here are strictly for the progress bar. If you don't need to display anything, don't return anything.
Any keys present in 'log', 'progress_bar' or the rest of the dictionary are available for callbacks to access. Params

Param description
outputs List of outputs you defined in validation_step, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader

Return

Dictionary or OrderedDict

key value is required
progress_bar Dict for progress bar display. Must have only tensors N
log Dict of metrics to add to logger. Must have only tensors (no images, etc) N

Example

With a single dataloader

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def validation_end(self, outputs):
    """
    Called at the end of validation to aggregate outputs
    :param outputs: list of individual outputs of each validation step
    :return:
    """
    val_loss_mean = 0
    val_acc_mean = 0
    for output in outputs:
        val_loss_mean += output['val_loss']
        val_acc_mean += output['val_acc']

    val_loss_mean /= len(outputs)
    val_acc_mean /= len(outputs)
    tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}

    # show val_loss and val_acc in progress bar but only log val_loss
    results = {
        'progress_bar': tqdm_dict,
        'log': {'val_loss': val_loss_mean.item()}
    }
    return results

With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each validation step for that dataloader.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def validation_end(self, outputs):
    """
    Called at the end of validation to aggregate outputs
    :param outputs: list of list of individual outputs of each validation step
    :return:
    """
    val_loss_mean = 0
    val_acc_mean = 0
    i = 0
    for dataloader_outputs in outputs:
        for output in dataloader_outputs:
            val_loss_mean += output['val_loss']
            val_acc_mean += output['val_acc']
            i += 1

    val_loss_mean /= i
    val_acc_mean /= i
    tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}

    # show val_loss and val_acc in progress bar but only log val_loss
    results = {
        'progress_bar': tqdm_dict,
        'log': {'val_loss': val_loss_mean.item()}
    }
    return results

test_step

1
2
3
4
5
# if you have one test dataloader:
def test_step(self, batch, batch_nb)   

# if you have multiple test dataloaders:  
def test_step(self, batch, batch_nb, dataloader_idxdx)

OPTIONAL
If you don't need to test you don't need to implement this method. In this step you'd normally generate examples or calculate anything of interest such as accuracy.

When the validation_step is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, model goes back to training mode and gradients are enabled.

The dict you return here will be available in the test_end method.

This function is used when you execute trainer.test().

Params

Param description
batch The output of your dataloader. A tensor, tuple or list
batch_nb Integer displaying which batch this is
dataloader_idx Integer displaying which dataloader this is (only if multiple test datasets used)

Return

Return description optional
dict Dict or OrderedDict with metrics to display in progress bar. All keys must be tensors. Y

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
# CASE 1: A single test dataset
def test_step(self, batch, batch_nb):
    x, y = batch

    # implement your own
    out = self.forward(x)
    loss = self.loss(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # all optional...
    # return whatever you need for the collation function test_end
    output = OrderedDict({
        'test_loss': loss_test,
        'test_acc': torch.tensor(test_acc), # everything must be a tensor
    })

    # return an optional dict
    return output

If you pass in multiple test datasets, test_step will have an additional argument.

1
2
3
# CASE 2: multiple test datasets
def test_step(self, batch, batch_nb, dataset_idx):
    # dataset_idx tells you which dataset this is.   

The dataset_idx corresponds to the order of datasets returned in test_dataloader.


test_end

1
def test_end(self, outputs)

If you didn't define a test_step, this won't be called.

Called at the end of the test step with the output of each test_step.

The outputs here are strictly for the progress bar. If you don't need to display anything, don't return anything.

Params

Param description
outputs List of outputs you defined in test_step, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader

Return

Return description optional
dict Dict of OrderedDict with metrics to display in progress bar Y

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def test_end(self, outputs):
    """
    Called at the end of test to aggregate outputs
    :param outputs: list of individual outputs of each test step
    :return:
    """
    test_loss_mean = 0
    test_acc_mean = 0
    for output in outputs:
        test_loss_mean += output['test_loss']
        test_acc_mean += output['test_acc']

    test_loss_mean /= len(outputs)
    test_acc_mean /= len(outputs)
    tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}

    # show test_loss and test_acc in progress bar but only log test_loss
    results = {
        'progress_bar': tqdm_dict,
        'log': {'test_loss': val_loss_mean.item()}
    }
    return results

With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each validation step for that dataloader.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def test_end(self, outputs):
    """
    Called at the end of test to aggregate outputs
    :param outputs: list of individual outputs of each test step
    :return:
    """
    test_loss_mean = 0
    test_acc_mean = 0
    i = 0
    for dataloader_outputs in outputs:
        for output in dataloader_outputs:
            test_loss_mean += output['test_loss']
            test_acc_mean += output['test_acc']
            i += 1

    test_loss_mean /= i 
    test_acc_mean /= i
    tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}

    # show test_loss and test_acc in progress bar but only log test_loss
    results = {
        'progress_bar': tqdm_dict,
        'log': {'test_loss': val_loss_mean.item()}
    }
    return results

on_save_checkpoint

1
def on_save_checkpoint(self, checkpoint)

Called by lightning to checkpoint your model. Lightning saves the training state (current epoch, global_step, etc) and also saves the model state_dict. If you want to save anything else, use this method to add your own key-value pair.

Return

Nothing

Example

1
2
3
def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method 
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

on_load_checkpoint

1
def on_load_checkpoint(self, checkpoint)

Called by lightning to restore your model. Lighting auto-restores global step, epoch, etc... It also restores the model state_dict. If you saved something with on_save_checkpoint this is your chance to restore this.

Return

Nothing

Example

1
2
3
def on_load_checkpoint(self, checkpoint):
    # 99% of the time you don't need to implement this method
    self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']

val_dataloader

1
2
@pl.data_loader
def val_dataloader(self)

OPTIONAL
If you don't need a validation dataset and a validation_step, you don't need to implement this method.

Called by lightning during validation loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
If you want to change the data during every epoch DON'T use the data_loader decorator.

Return

PyTorch DataLoader or list of PyTorch Dataloaders.

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
@pl.data_loader
def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.hparams.batch_size,
        shuffle=True
    )

    return loader

# can also return multiple dataloaders   
@pl.data_loader
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]   

In the case where you return multiple val_dataloaders, the validation_step will have an arguement dataset_idx which matches the order here.


test_dataloader

1
2
@pl.data_loader
def test_dataloader(self)

OPTIONAL
If you don't need a test dataset and a test_step, you don't need to implement this method.

Called by lightning during test loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed. If you want to change the data during every epoch DON'T use the data_loader decorator.

Return

PyTorch DataLoader

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
@pl.data_loader
def test_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.hparams.batch_size,
        shuffle=True
    )

    return loader

add_model_specific_args

1
2
@staticmethod
def add_model_specific_args(parent_parser, root_dir)

Lightning has a list of default argparse commands. This method is your chance to add or modify commands specific to your model. The hyperparameter argument parser is available anywhere in your model by calling self.hparams.

Return

An argument parser

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
@staticmethod
def add_model_specific_args(parent_parser, root_dir):
    parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])

    # param overwrites
    # parser.set_defaults(gradient_clip_val=5.0)

    # network params
    parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
    parser.add_argument('--in_features', default=28*28)
    parser.add_argument('--out_features', default=10)
    parser.add_argument('--hidden_dim', default=50000) # use 500 for CPU, 50000 for GPU to see speed difference

    # data
    parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str)

    # training params (opt)
    parser.opt_list('--learning_rate', default=0.001, type=float, options=[0.0001, 0.0005, 0.001, 0.005],
                    tunable=False)
    parser.opt_list('--batch_size', default=256, type=int, options=[32, 64, 128, 256], tunable=False)
    parser.opt_list('--optimizer_name', default='adam', type=str, options=['adam'], tunable=False)
    return parser