Shortcuts

Welcome to ⚡ PyTorch Lightning

Animation showing how to convert a standard training loop to a Lightning loop

PyTorch Lightning is the deep learning framework for professional AI researchers and machine learning engineers who need maximal flexibility without sacrificing performance at scale. Lightning evolves with you as your projects go from idea to paper/production.


Install Lightning

Pip users

pip install pytorch-lightning

Conda users

conda install pytorch-lightning -c conda-forge

Or read the advanced install guide

We are fully compatible with any stable PyTorch version v1.10 and above.


Get Started


Current Lightning Users

Lightning in 15 minutes

Required background: None

Goal: In this guide, we’ll walk you through the 7 key steps of a typical Lightning workflow.

PyTorch Lightning is the deep learning framework with “batteries included” for professional AI researchers and machine learning engineers who need maximal flexibility while super-charging performance at scale.

Lightning organizes PyTorch code to remove boilerplate and unlock scalability.


By organizing PyTorch code, lightning enables:

Full flexibility

Try any ideas using raw PyTorch without the boilerplate.

Reproducible + Readable

Decoupled research and engineering code enable reproducibility and better readability.

Simple multi-GPU training

Use multiple GPUs/TPUs/HPUs etc... without code changes.

Built-in testing

We've done all the testing so you don't have to.


1: Install PyTorch Lightning

For pip users

pip install pytorch-lightning

For conda users

conda install pytorch-lightning -c conda-forge

Or read the advanced install guide


2: Define a LightningModule

A LightningModule enables your PyTorch nn.Module to play together in complex ways inside the training_step (there is also an optional validation_step and test_step).

import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

3: Define a dataset

Lightning supports ANY iterable (DataLoader, numpy, etc…) for the train/val/test/predict splits.

# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

4: Train the model

The Lightning Trainer “mixes” any LightningModule with any dataset and abstracts away all the engineering complexity needed for scale.

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

The Lightning Trainer automates 40+ tricks including:


5: Use the model

Once you’ve trained the model you can export to onnx, torchscript and put it into production or simply load the weights and run predictions.

# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = Tensor(4, 28 * 28)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

6: Visualize training

If you have tensorboard installed, you can use it for visualizing experiments.

Run this on your commandline and open your browser to http://localhost:6006/

tensorboard --logdir .

7: Supercharge training

Enable advanced training features using Trainer arguments. These are state-of-the-art techniques that are automatically integrated into your training loop without changes to your code.

# train on 4 GPUs
trainer = Trainer(
    devices=4,
    accelerator="gpu",
 )

# train 1TB+ parameter models with Deepspeed/fsdp
trainer = Trainer(
    devices=4,
    accelerator="gpu",
    strategy="deepspeed_stage_2",
    precision=16
 )

# 20+ helpful flags for rapid idea iteration
trainer = Trainer(
    max_epochs=10,
    min_epochs=5,
    overfit_batches=1
 )

# access the latest state of the art techniques
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])

Maximize flexibility

Lightning’s core guiding principle is to always provide maximal flexibility without ever hiding any of the PyTorch.

Lightning offers 5 added degrees of flexibility depending on your project’s complexity.


Customize training loop
Injecting custom code in a training loop

Inject custom code anywhere in the Training loop using any of the 20+ methods (Hooks) available in the LightningModule.

class LitAutoEncoder(pl.LightningModule):
    def backward(self, loss, optimizer, optimizer_idx):
        loss.backward()

Extend the Trainer

If you have multiple lines of code with similar functionalities, you can use callbacks to easily group them together and toggle all of those lines on or off at the same time.

trainer = Trainer(callbacks=[AWSCheckpoints()])

Use a raw PyTorch loop

For certain types of work at the bleeding-edge of research, Lightning offers experts full control of their training loops in various ways.


Next steps

Depending on your use case, you might want to check one of these out next.

Installation

Warning

pip install pytorch-lightning has been deprecated and will stop being updated June 2023. Use pip install lightning instead.


Apple Silicon (M1/M2/M3) Macs

While ML related python packages are updated to work with Apple Silicon, you’ll need to set 2 environment variables on install.

# needed for M1/M2/M3
export GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1
export GRPC_PYTHON_BUILD_SYSTEM_ZLIB=1

python -m pip install -U lightning

Install with pip

Install lightning inside a virtual env or conda environment with pip

python -m pip install lightning

Install with Conda

If you don’t have conda installed, follow the Conda Installation Guide. Lightning can be installed with conda using the following command:

conda install pytorch-lightning -c conda-forge

You can also use Conda Environments:

conda activate my_env
conda install pytorch-lightning -c conda-forge

Build from Source

Install nightly from the source. Note that it contains all the bug fixes and newly released features that are not published yet. This is the bleeding edge, so use it at your own discretion.

pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U

Install future patch releases from the source. Note that the patch release contains only the bug fixes for the recent major release.

pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/release/stable.zip -U

Optimized for model development

If you are deploying models built with Lightning in production and require few dependencies, try using the optimized lightning[pytorch] package:

pip install pytorch-lightning
Custom PyTorch Version

To use any PyTorch version visit the PyTorch Installation Page.


Optimized for ML workflows (lightning Apps)

If you are deploying workflows built with Lightning in production and require fewer dependencies, try using the optimized lightning[apps] package:

pip install lightning-app

Basic skills

Learn the basics of model development with Lightning. Researchers and machine learning engineers should start here.


Intermediate skills

Learn to scale up your models and enable collaborative model development at academic or industry research labs.


Advanced skills

Configure all aspects of Lightning for advanced usecases.


Expert skills

Customize and extend Lightning for things like custom hardware or distributed strategies.


LightningModule

A LightningModule organizes your PyTorch code into 6 sections:

  • Computations (init).

  • Train Loop (training_step)

  • Validation Loop (validation_step)

  • Test Loop (test_step)

  • Prediction Loop (predict_step)

  • Optimizers and LR Schedulers (configure_optimizers)



Notice a few things.

  1. It is the SAME code.

  2. The PyTorch code IS NOT abstracted - just organized.

  3. All the other code that’s not in the LightningModule has been automated for you by the Trainer.


net = Net()
trainer = Trainer()
trainer.fit(net)
  1. There are no .cuda() or .to(device) calls required. Lightning does these for you.


# don't do in Lightning
x = torch.Tensor(2, 3)
x = x.cuda()
x = x.to(device)

# do this instead
x = x  # leave it alone!

# or to init a new tensor
new_x = torch.Tensor(2, 3)
new_x = new_x.to(x)
  1. When running under a distributed strategy, Lightning handles the distributed sampler for you by default.


# Don't do in Lightning...
data = MNIST(...)
sampler = DistributedSampler(data)
DataLoader(data, sampler=sampler)

# do this instead
data = MNIST(...)
DataLoader(data)
  1. A LightningModule is a torch.nn.Module but with added functionality. Use it as such!


net = Net.load_from_checkpoint(PATH)
net.freeze()
out = net(x)

Thus, to use Lightning, you just need to organize your code which takes about 30 minutes, (and let’s be real, you probably should do anyway).


Starter Example

Here are the only required methods.

import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

Which you can train by doing:

train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
trainer = pl.Trainer(max_epochs=1)
model = LitModel()

trainer.fit(model, train_dataloaders=train_loader)

The LightningModule has many convenience methods, but the core ones you need to know about are:

Name

Description

init

Define computations here

forward

Use for inference only (separate from training_step)

training_step

the complete training loop

validation_step

the complete validation loop

test_step

the complete test loop

predict_step

the complete prediction loop

configure_optimizers

define optimizers and LR schedulers


Training

Training Loop

To activate the training loop, override the training_step() method.

class LitClassifier(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

Under the hood, Lightning does the following (pseudocode):

# put model in train mode and enable gradient calculation
model.train()
torch.set_grad_enabled(True)

outs = []
for batch_idx, batch in enumerate(train_dataloader):
    loss = training_step(batch, batch_idx)
    outs.append(loss.detach())

    # clear gradients
    optimizer.zero_grad()

    # backward
    loss.backward()

    # update parameters
    optimizer.step()
Train Epoch-level Metrics

If you want to calculate epoch-level metrics and log them, use log().

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)

    # logs metrics for each training_step,
    # and the average across the epoch, to the progress bar and logger
    self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return loss

The log() object automatically reduces the requested metrics across a complete epoch and devices. Here’s the pseudocode of what it does under the hood:

outs = []
for batch_idx, batch in enumerate(train_dataloader):
    # forward
    loss = training_step(batch, batch_idx)
    outs.append(loss)

    # clear gradients
    optimizer.zero_grad()

    # backward
    loss.backward()

    # update parameters
    optimizer.step()

epoch_metric = torch.mean(torch.stack([x for x in outs]))
Train Epoch-level Operations

If you need to do something with all the outputs of each training_step(), override the training_epoch_end() method.

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    preds = ...
    return {"loss": loss, "other_stuff": preds}


def training_epoch_end(self, training_step_outputs):
    all_preds = torch.stack(training_step_outputs)
    ...

The matching pseudocode is:

outs = []
for batch_idx, batch in enumerate(train_dataloader):
    # forward
    loss = training_step(batch, batch_idx)
    outs.append(loss)

    # clear gradients
    optimizer.zero_grad()

    # backward
    loss.backward()

    # update parameters
    optimizer.step()

training_epoch_end(outs)
Training with DataParallel

When training using a strategy that splits data from each batch across GPUs, sometimes you might need to aggregate them on the main GPU for processing (DP).

In this case, implement the training_step_end() method which will have outputs from all the devices and you can accumulate to get the effective results.

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred = ...
    return {"loss": loss, "pred": pred}


def training_step_end(self, batch_parts):
    # predictions from each GPU
    predictions = batch_parts["pred"]
    # losses from each GPU
    losses = batch_parts["loss"]

    gpu_0_prediction = predictions[0]
    gpu_1_prediction = predictions[1]

    # do something with both outputs
    return (losses[0] + losses[1]) / 2


def training_epoch_end(self, training_step_outputs):
    for out in training_step_outputs:
        ...

Here is the Lightning training pseudo-code for DP:

outs = []
for batch_idx, train_batch in enumerate(train_dataloader):
    batches = split_batch(train_batch)
    dp_outs = []
    for sub_batch in batches:
        # 1
        dp_out = training_step(sub_batch, batch_idx)
        dp_outs.append(dp_out)

    # 2
    out = training_step_end(dp_outs)
    outs.append(out)

# do something with the outputs for all batches
# 3
training_epoch_end(outs)

Validation

Validation Loop

To activate the validation loop while training, override the validation_step() method.

class LitModel(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss)

Under the hood, Lightning does the following (pseudocode):

# ...
for batch_idx, batch in enumerate(train_dataloader):
    loss = model.training_step(batch, batch_idx)
    loss.backward()
    # ...

    if validate_at_some_point:
        # disable grads + batchnorm + dropout
        torch.set_grad_enabled(False)
        model.eval()

        # ----------------- VAL LOOP ---------------
        for val_batch_idx, val_batch in enumerate(val_dataloader):
            val_out = model.validation_step(val_batch, val_batch_idx)
        # ----------------- VAL LOOP ---------------

        # enable grads + batchnorm + dropout
        torch.set_grad_enabled(True)
        model.train()

You can also run just the validation loop on your validation dataloaders by overriding validation_step() and calling validate().

model = Model()
trainer = Trainer()
trainer.validate(model)

Note

It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once. This is helpful to make sure benchmarking for research papers is done the right way. Otherwise, in a multi-device setting, samples could occur duplicated when DistributedSampler is used, for eg. with strategy="ddp". It replicates some samples on some devices to make sure all devices have same batch size in case of uneven inputs.

Validation Epoch-level Metrics

If you need to do something with all the outputs of each validation_step(), override the validation_epoch_end() method. Note that this method is called before training_epoch_end().

def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred = ...
    return pred


def validation_epoch_end(self, validation_step_outputs):
    all_preds = torch.stack(validation_step_outputs)
    ...
Validating with DataParallel

When validating using a strategy that splits data from each batch across GPUs, sometimes you might need to aggregate them on the main GPU for processing (DP).

In this case, implement the validation_step_end() method which will have outputs from all the devices and you can accumulate to get the effective results.

def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred = ...
    return {"loss": loss, "pred": pred}


def validation_step_end(self, batch_parts):
    # predictions from each GPU
    predictions = batch_parts["pred"]
    # losses from each GPU
    losses = batch_parts["loss"]

    gpu_0_prediction = predictions[0]
    gpu_1_prediction = predictions[1]

    # do something with both outputs
    return (losses[0] + losses[1]) / 2


def validation_epoch_end(self, validation_step_outputs):
    for out in validation_step_outputs:
        ...

Here is the Lightning validation pseudo-code for DP:

outs = []
for batch in dataloader:
    batches = split_batch(batch)
    dp_outs = []
    for sub_batch in batches:
        # 1
        dp_out = validation_step(sub_batch)
        dp_outs.append(dp_out)

    # 2
    out = validation_step_end(dp_outs)
    outs.append(out)

# do something with the outputs for all batches
# 3
validation_epoch_end(outs)

Testing

Test Loop

The process for enabling a test loop is the same as the process for enabling a validation loop. Please refer to the section above for details. For this you need to override the test_step() method.

The only difference is that the test loop is only called when test() is used.

model = Model()
trainer = Trainer()
trainer.fit(model)

# automatically loads the best weights for you
trainer.test(model)

There are two ways to call test():

# call after training
trainer = Trainer()
trainer.fit(model)

# automatically auto-loads the best weights from the previous run
trainer.test(dataloaders=test_dataloader)

# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
trainer.test(model, dataloaders=test_dataloader)

Note

It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once. This is helpful to make sure benchmarking for research papers is done the right way. Otherwise, in a multi-device setting, samples could occur duplicated when DistributedSampler is used, for eg. with strategy="ddp". It replicates some samples on some devices to make sure all devices have same batch size in case of uneven inputs.


Inference

Prediction Loop

By default, the predict_step() method runs the forward() method. In order to customize this behaviour, simply override the predict_step() method.

For the example let’s override predict_step and try out Monte Carlo Dropout:

class LitMCdropoutModel(pl.LightningModule):
    def __init__(self, model, mc_iteration):
        super().__init__()
        self.model = model
        self.dropout = nn.Dropout()
        self.mc_iteration = mc_iteration

    def predict_step(self, batch, batch_idx):
        # enable Monte Carlo Dropout
        self.dropout.train()

        # take average of `self.mc_iteration` iterations
        pred = torch.vstack([self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]).mean(dim=0)
        return pred

Under the hood, Lightning does the following (pseudocode):

# disable grads + batchnorm + dropout
torch.set_grad_enabled(False)
model.eval()
all_preds = []

for batch_idx, batch in enumerate(predict_dataloader):
    pred = model.predict_step(batch, batch_idx)
    all_preds.append(pred)

There are two ways to call predict():

# call after training
trainer = Trainer()
trainer.fit(model)

# automatically auto-loads the best weights from the previous run
predictions = trainer.predict(dataloaders=predict_dataloader)

# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
predictions = trainer.predict(model, dataloaders=test_dataloader)
Inference in Research

If you want to perform inference with the system, you can add a forward method to the LightningModule.

Note

When using forward, you are responsible to call eval() and use the no_grad() context manager.

class Autoencoder(pl.LightningModule):
    def forward(self, x):
        return self.decoder(x)


model = Autoencoder()
model.eval()
with torch.no_grad():
    reconstruction = model(embedding)

The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure, such as text generation:

class Seq2Seq(pl.LightningModule):
    def forward(self, x):
        embeddings = self(x)
        hidden_states = self.encoder(embeddings)
        for h in hidden_states:
            # decode
            ...
        return decoded

In the case where you want to scale your inference, you should be using predict_step().

class Autoencoder(pl.LightningModule):
    def forward(self, x):
        return self.decoder(x)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        # this calls forward
        return self(batch)


data_module = ...
model = Autoencoder()
trainer = Trainer(accelerator="gpu", devices=2)
trainer.predict(model, data_module)
Inference in Production

For cases like production, you might want to iterate different models inside a LightningModule.

from torchmetrics.functional import accuracy


class ClassificationTask(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"val_acc": acc, "val_loss": loss}
        self.log_dict(metrics)
        return metrics

    def test_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"test_acc": acc, "test_loss": loss}
        self.log_dict(metrics)
        return metrics

    def _shared_eval_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)
        return loss, acc

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, y = batch
        y_hat = self.model(x)
        return y_hat

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.02)

Then pass in any arbitrary model to be fit with this task

for model in [resnet50(), vgg16(), BidirectionalRNN()]:
    task = ClassificationTask(model)

    trainer = Trainer(accelerator="gpu", devices=2)
    trainer.fit(task, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

Tasks can be arbitrarily complex such as implementing GAN training, self-supervised or even RL.

class GANTask(pl.LightningModule):
    def __init__(self, generator, discriminator):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator

    ...

When used like this, the model can be separated from the Task and thus used in production without needing to keep it in a LightningModule.

The following example shows how you can run inference in the Python runtime:

task = ClassificationTask(model)
trainer = Trainer(accelerator="gpu", devices=2)
trainer.fit(task, train_dataloader, val_dataloader)
trainer.save_checkpoint("best_model.ckpt")

# use model after training or load weights and drop into the production system
model = ClassificationTask.load_from_checkpoint("best_model.ckpt")
x = ...
model.eval()
with torch.no_grad():
    y_hat = model(x)

Check out Inference in Production guide to learn about the possible ways to perform inference in production.


Save Hyperparameters

Often times we train many versions of a model. You might share that model or come back to it a few months later at which point it is very useful to know how that model was trained (i.e.: what learning rate, neural network, etc…).

Lightning has a standardized way of saving the information for you in checkpoints and YAML files. The goal here is to improve readability and reproducibility.

save_hyperparameters

Use save_hyperparameters() within your LightningModule’s __init__ method. It will enable Lightning to store all the provided arguments under the self.hparams attribute. These hyperparameters will also be stored within the model checkpoint, which simplifies model re-instantiation after training.

class LitMNIST(LightningModule):
    def __init__(self, layer_1_dim=128, learning_rate=1e-2):
        super().__init__()
        # call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint
        self.save_hyperparameters()

        # equivalent
        self.save_hyperparameters("layer_1_dim", "learning_rate")

        # Now possible to access layer_1_dim from hparams
        self.hparams.layer_1_dim

In addition, loggers that support it will automatically log the contents of self.hparams.

Excluding hyperparameters

By default, every parameter of the __init__ method will be considered a hyperparameter to the LightningModule. However, sometimes some parameters need to be excluded from saving, for example when they are not serializable. Those parameters should be provided back when reloading the LightningModule. In this case, exclude them explicitly:

class LitMNIST(LightningModule):
    def __init__(self, loss_fx, generator_network, layer_1_dim=128):
        super().__init__()
        self.layer_1_dim = layer_1_dim
        self.loss_fx = loss_fx

        # call this to save only (layer_1_dim=128) to the checkpoint
        self.save_hyperparameters("layer_1_dim")

        # equivalent
        self.save_hyperparameters(ignore=["loss_fx", "generator_network"])
load_from_checkpoint

LightningModules that have hyperparameters automatically saved with save_hyperparameters() can conveniently be loaded and instantiated directly from a checkpoint with load_from_checkpoint():

# to load specify the other args
model = LitMNIST.load_from_checkpoint(PATH, loss_fx=torch.nn.SomeOtherLoss, generator_network=MyGenerator())

If parameters were excluded, they need to be provided at the time of loading:

# the excluded parameters were `loss_fx` and `generator_network`
model = LitMNIST.load_from_checkpoint(PATH, loss_fx=torch.nn.SomeOtherLoss, generator_network=MyGenerator())

Child Modules

Research projects tend to test different approaches to the same dataset. This is very easy to do in Lightning with inheritance.

For example, imagine we now want to train an AutoEncoder to use as a feature extractor for images. The only things that change in the LitAutoEncoder model are the init, forward, training, validation and test step.

class Encoder(torch.nn.Module):
    ...


class Decoder(torch.nn.Module):
    ...


class AutoEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        return self.decoder(self.encoder(x))


class LitAutoEncoder(LightningModule):
    def __init__(self, auto_encoder):
        super().__init__()
        self.auto_encoder = auto_encoder
        self.metric = torch.nn.MSELoss()

    def forward(self, x):
        return self.auto_encoder.encoder(x)

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x_hat = self.auto_encoder(x)
        loss = self.metric(x, x_hat)
        return loss

    def validation_step(self, batch, batch_idx):
        self._shared_eval(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        self._shared_eval(batch, batch_idx, "test")

    def _shared_eval(self, batch, batch_idx, prefix):
        x, _ = batch
        x_hat = self.auto_encoder(x)
        loss = self.metric(x, x_hat)
        self.log(f"{prefix}_loss", loss)

and we can train this using the Trainer:

auto_encoder = AutoEncoder()
lightning_module = LitAutoEncoder(auto_encoder)
trainer = Trainer()
trainer.fit(lightning_module, train_dataloader, val_dataloader)

And remember that the forward method should define the practical use of a LightningModule. In this case, we want to use the LitAutoEncoder to extract image representations:

some_images = torch.Tensor(32, 1, 28, 28)
representations = lightning_module(some_images)

LightningModule API

Methods
all_gather
LightningModule.all_gather(data, group=None, sync_grads=False)[source]

Allows users to call self.all_gather() from the LightningModule, thus making the all_gather operation accelerator agnostic. all_gather is a function provided by accelerators to gather a tensor from several distributed processes.

Parameters
  • data (Union[Tensor, Dict, List, Tuple]) – int, float, tensor of shape (batch, …), or a (possibly nested) collection thereof.

  • group (Optional[Any]) – the process group to gather results from. Defaults to all processes (world)

  • sync_grads (bool) – flag that allows users to synchronize gradients for the all_gather operation

Return type

Union[Tensor, Dict, List, Tuple]

Returns

A tensor of shape (world_size, batch, …), or if the input was a collection the output will also be a collection with tensors of this shape.

configure_callbacks
LightningModule.configure_callbacks()[source]

Configure model-specific callbacks. When the model gets attached, e.g., when .fit() or .test() gets called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer’s callbacks argument. If a callback returned here has the same type as one or several callbacks already present in the Trainer’s callbacks list, it will take priority and replace them. In addition, Lightning will make sure ModelCheckpoint callbacks run last.

Return type

Union[Sequence[Callback], Callback]

Returns

A callback or a list of callbacks which will extend the list of callbacks in the Trainer.

Example:

def configure_callbacks(self):
    early_stop = EarlyStopping(monitor="val_acc", mode="max")
    checkpoint = ModelCheckpoint(monitor="val_loss")
    return [early_stop, checkpoint]
configure_optimizers
LightningModule.configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.

Return type

Any

Returns

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • Tuple of dictionaries as described above, with an optional "frequency" key.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated"
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

The frequency value specified in a dict along with the optimizer key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:

  • In the former case, all optimizers will operate on the given batch in each optimization step.

  • In the latter, only one optimizer will operate on the given batch at every step.

This is different from the frequency value specified in the lr_scheduler_config mentioned above.

def configure_optimizers(self):
    optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01)
    optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01)
    return [
        {"optimizer": optimizer_one, "frequency": 5},
        {"optimizer": optimizer_two, "frequency": 10},
    ]

In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the lr_scheduler key in the above dict, the scheduler will only be updated when its optimizer is being used.

Examples:

# most cases. no learning rate scheduler
def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)

# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    return gen_opt, dis_opt

# example with learning rate schedulers
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    dis_sch = CosineAnnealing(dis_opt, T_max=10)
    return [gen_opt, dis_opt], [dis_sch]

# example with step-based learning rate schedulers
# each optimizer has its own scheduler
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    gen_sch = {
        'scheduler': ExponentialLR(gen_opt, 0.99),
        'interval': 'step'  # called after each training step
    }
    dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
    return [gen_opt, dis_opt], [gen_sch, dis_sch]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
# https://arxiv.org/abs/1704.00028
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    n_critic = 5
    return (
        {'optimizer': dis_opt, 'frequency': n_critic},
        {'optimizer': gen_opt, 'frequency': 1}
    )

Note

Some things to know:

  • Lightning calls .backward() and .step() on each optimizer as needed.

  • If learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizers.

  • If you use multiple optimizers, training_step() will have an additional optimizer_idx parameter.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.

  • If you need to control how often those optimizers step or override the default .step() schedule, override the optimizer_step() hook.

forward
LightningModule.forward(*args, **kwargs)[source]

Same as torch.nn.Module.forward().

Parameters
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Return type

Any

Returns

Your model’s output

freeze
LightningModule.freeze()[source]

Freeze all params for inference.

Example:

model = MyLightningModule(...)
model.freeze()
Return type

None

log
LightningModule.log(name, value, prog_bar=False, logger=None, on_step=None, on_epoch=None, reduce_fx='mean', enable_graph=False, sync_dist=False, sync_dist_group=None, add_dataloader_idx=True, batch_size=None, metric_attribute=None, rank_zero_only=False)[source]

Log a key, value pair.

Example:

self.log('train_loss', loss)

The default behavior per hook is documented here: Automatic Logging.

Parameters
  • name (str) – key to log.

  • value (Union[Metric, Tensor, int, float, Mapping[str, Union[Metric, Tensor, int, float]]]) – value to log. Can be a float, Tensor, Metric, or a dictionary of the former.

  • prog_bar (bool) – if True logs to the progress bar.

  • logger (Optional[bool]) – if True logs to the logger.

  • on_step (Optional[bool]) – if True logs at this step. The default value is determined by the hook. See Automatic Logging for details.

  • on_epoch (Optional[bool]) – if True logs epoch accumulated metrics. The default value is determined by the hook. See Automatic Logging for details.

  • reduce_fx (Union[str, Callable]) – reduction function over step values for end of epoch. torch.mean() by default.

  • enable_graph (bool) – if True, will not auto detach the graph.

  • sync_dist (bool) – if True, reduces the metric across devices. Use with care as this may lead to a significant communication overhead.

  • sync_dist_group (Optional[Any]) – the DDP group to sync across.

  • add_dataloader_idx (bool) – if True, appends the index of the current dataloader to the name (when using multiple dataloaders). If False, user needs to give unique names for each dataloader to not mix the values.

  • batch_size (Optional[int]) – Current batch_size. This will be directly inferred from the loaded batch, but for some data structures you might need to explicitly provide it.

  • metric_attribute (Optional[str]) – To restore the metric state, Lightning requires the reference of the torchmetrics.Metric in your model. This is found automatically if it is a model attribute.

  • rank_zero_only (bool) – Whether the value will be logged only on rank 0. This will prevent synchronization which would produce a deadlock as not all processes would perform this log call.

Return type

None

log_dict
LightningModule.log_dict(dictionary, prog_bar=False, logger=None, on_step=None, on_epoch=None, reduce_fx='mean', enable_graph=False, sync_dist=False, sync_dist_group=None, add_dataloader_idx=True, batch_size=None, rank_zero_only=False)[source]

Log a dictionary of values at once.

Example:

values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
self.log_dict(values)
Parameters
  • dictionary (Mapping[str, Union[Metric, Tensor, int, float, Mapping[str, Union[Metric, Tensor, int, float]]]]) – key value pairs. The values can be a float, Tensor, Metric, a dictionary of the former or a MetricCollection.

  • prog_bar (bool) – if True logs to the progress base.

  • logger (Optional[bool]) – if True logs to the logger.

  • on_step (Optional[bool]) – if True logs at this step. None auto-logs for training_step but not validation/test_step. The default value is determined by the hook. See Automatic Logging for details.

  • on_epoch (Optional[bool]) – if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step. The default value is determined by the hook. See Automatic Logging for details.

  • reduce_fx (Union[str, Callable]) – reduction function over step values for end of epoch. torch.mean() by default.

  • enable_graph (bool) – if True, will not auto-detach the graph

  • sync_dist (bool) – if True, reduces the metric across GPUs/TPUs. Use with care as this may lead to a significant communication overhead.

  • sync_dist_group (Optional[Any]) – the ddp group to sync across.

  • add_dataloader_idx (bool) – if True, appends the index of the current dataloader to the name (when using multiple). If False, user needs to give unique names for each dataloader to not mix values.

  • batch_size (Optional[int]) – Current batch size. This will be directly inferred from the loaded batch, but some data structures might need to explicitly provide it.

  • rank_zero_only (bool) – Whether the value will be logged only on rank 0. This will prevent synchronization which would produce a deadlock as not all processes would perform this log call.

Return type

None

lr_schedulers
LightningModule.lr_schedulers()[source]

Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization.

Return type

Union[None, List[Union[LRScheduler, ReduceLROnPlateau]], LRScheduler, ReduceLROnPlateau]

Returns

A single scheduler, or a list of schedulers in case multiple ones are present, or None if no schedulers were returned in configure_optimizers().

manual_backward
LightningModule.manual_backward(loss, *args, **kwargs)[source]

Call this directly from your training_step() when doing optimizations manually. By using this, Lightning can ensure that all the proper scaling gets applied when using mixed precision.

See manual optimization for more examples.

Example:

def training_step(...):
    opt = self.optimizers()
    loss = ...
    opt.zero_grad()
    # automatically applies scaling, etc...
    self.manual_backward(loss)
    opt.step()
Parameters
  • loss (Tensor) – The tensor on which to compute gradients. Must have a graph attached.

  • *args – Additional positional arguments to be forwarded to backward()

  • **kwargs – Additional keyword arguments to be forwarded to backward()

Return type

None

optimizers
LightningModule.optimizers(use_pl_optimizer: Literal[True] = True) Union[pytorch_lightning.core.optimizer.LightningOptimizer, List[pytorch_lightning.core.optimizer.LightningOptimizer]][source]
LightningModule.optimizers(use_pl_optimizer: Literal[False]) Union[torch.optim.optimizer.Optimizer, List[torch.optim.optimizer.Optimizer]]
LightningModule.optimizers(use_pl_optimizer: bool) Union[torch.optim.optimizer.Optimizer, pytorch_lightning.core.optimizer.LightningOptimizer, lightning_fabric.wrappers._FabricOptimizer, List[torch.optim.optimizer.Optimizer], List[pytorch_lightning.core.optimizer.LightningOptimizer], List[lightning_fabric.wrappers._FabricOptimizer]]

Returns the optimizer(s) that are being used during training. Useful for manual optimization.

Parameters

use_pl_optimizer (bool) – If True, will wrap the optimizer(s) in a LightningOptimizer for automatic handling of precision and profiling.

Return type

Union[Optimizer, LightningOptimizer, _FabricOptimizer, List[Optimizer], List[LightningOptimizer], List[_FabricOptimizer]]

Returns

A single optimizer, or a list of optimizers in case multiple ones are present.

print
LightningModule.print(*args, **kwargs)[source]

Prints only from process 0. Use this in any distributed mode to log only once.

Parameters
  • *args – The thing to print. The same as for Python’s built-in print function.

  • **kwargs – The same as for Python’s built-in print function.

Example:

def forward(self, x):
    self.print(x, 'in forward')
Return type

None

predict_step
LightningModule.predict_step(batch, batch_idx, dataloader_idx=0)[source]

Step function called during predict(). By default, it calls forward(). Override to add any processing logic.

The predict_step() is used to scale inference on multi-devices.

To prevent an OOM error, it is possible to use BasePredictionWriter callback to write the predictions to disk or database after each batch or on epoch end.

The BasePredictionWriter should be used while using a spawn based accelerator. This happens for Trainer(strategy="ddp_spawn") or training on 8 TPU cores with Trainer(accelerator="tpu", devices=8) as predictions won’t be returned.

Example

class MyModel(LightningModule):

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

dm = ...
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=2)
predictions = trainer.predict(model, dm)
Parameters
  • batch (Any) – Current batch.

  • batch_idx (int) – Index of current batch.

  • dataloader_idx (int) – Index of the current dataloader.

Return type

Any

Returns

Predicted output

save_hyperparameters
LightningModule.save_hyperparameters(*args, ignore=None, frame=None, logger=True)

Save arguments to hparams attribute.

Parameters
  • args (Any) – single object of dict, NameSpace or OmegaConf or string names or arguments from class __init__

  • ignore (Union[Sequence[str], str, None]) – an argument name or a list of argument names from class __init__ to be ignored

  • frame (Optional[frame]) – a frame object. Default is None

  • logger (bool) – Whether to send the hyperparameters to the logger. Default: True

Example::
>>> from pytorch_lightning.core.mixins import HyperparametersMixin
>>> class ManuallyArgsModel(HyperparametersMixin):
...     def __init__(self, arg1, arg2, arg3):
...         super().__init__()
...         # manually assign arguments
...         self.save_hyperparameters('arg1', 'arg3')
...     def forward(self, *args, **kwargs):
...         ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14
>>> from pytorch_lightning.core.mixins import HyperparametersMixin
>>> class AutomaticArgsModel(HyperparametersMixin):
...     def __init__(self, arg1, arg2, arg3):
...         super().__init__()
...         # equivalent automatic
...         self.save_hyperparameters()
...     def forward(self, *args, **kwargs):
...         ...
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg2": abc
"arg3": 3.14
>>> from pytorch_lightning.core.mixins import HyperparametersMixin
>>> class SingleArgModel(HyperparametersMixin):
...     def __init__(self, params):
...         super().__init__()
...         # manually assign single argument
...         self.save_hyperparameters(params)
...     def forward(self, *args, **kwargs):
...         ...
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
>>> model.hparams
"p1": 1
"p2": abc
"p3": 3.14
>>> from pytorch_lightning.core.mixins import HyperparametersMixin
>>> class ManuallyArgsModel(HyperparametersMixin):
...     def __init__(self, arg1, arg2, arg3):
...         super().__init__()
...         # pass argument(s) to ignore as a string or in a list
...         self.save_hyperparameters(ignore='arg2')
...     def forward(self, *args, **kwargs):
...         ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14
Return type

None

toggle_optimizer
LightningModule.toggle_optimizer(optimizer, optimizer_idx)[source]

Makes sure only the gradients of the current optimizer’s parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.

This is only called automatically when automatic optimization is enabled and multiple optimizers are used. It works with untoggle_optimizer() to make sure param_requires_grad_state is properly reset.

Parameters
Return type

None

test_step
LightningModule.test_step(*args, **kwargs)[source]

Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.

# the pseudocode for these calls
test_outs = []
for test_batch in test_data:
    out = test_step(test_batch)
    test_outs.append(out)
test_epoch_end(test_outs)
Parameters
  • batch – The output of your DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_id – The index of the dataloader that produced this batch. (only if multiple test dataloaders used).

Return type

Union[Tensor, Dict[str, Any], None]

Returns

Any of.

  • Any object or value

  • None - Testing will skip to the next batch

# if you have one test dataloader:
def test_step(self, batch, batch_idx):
    ...


# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx=0):
    ...

Examples:

# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple test dataloaders, test_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...

Note

If you don’t need to test you don’t need to implement this method.

Note

When the test_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

test_step_end
LightningModule.test_step_end(*args, **kwargs)[source]

Use this when testing with DP because test_step() will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.

Note

If you later switch to ddp or some other mode, this will still be called so that you don’t have to change your code.

# pseudocode
sub_batches = split_batches_for_dp(batch)
step_output = [test_step(sub_batch) for sub_batch in sub_batches]
test_step_end(step_output)
Parameters

step_output – What you return in test_step() for each batch part.

Return type

Union[Tensor, Dict[str, Any], None]

Returns

None or anything

# WITHOUT test_step_end
# if used in DP, this batch is 1/num_gpus large
def test_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self(x)
    loss = self.softmax(out)
    self.log("test_loss", loss)


# --------------
# with test_step_end to do softmax over the full batch
def test_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self.encoder(x)
    return out


def test_step_end(self, output_results):
    # this out is now the full size of the batch
    all_test_step_outs = output_results.out
    loss = nce_loss(all_test_step_outs)
    self.log("test_loss", loss)

See also

See the Multi GPU Training guide for more details.

test_epoch_end
LightningModule.test_epoch_end(outputs)[source]

Called at the end of a test epoch with the output of all test steps.

# the pseudocode for these calls
test_outs = []
for test_batch in test_data:
    out = test_step(test_batch)
    test_outs.append(out)
test_epoch_end(test_outs)
Parameters

outputs (Union[List[Union[Tensor, Dict[str, Any]]], List[List[Union[Tensor, Dict[str, Any]]]]]) – List of outputs you defined in test_step_end(), or if there are multiple dataloaders, a list containing a list of outputs for each dataloader

Return type

None

Returns

None

Note

If you didn’t define a test_step(), this won’t be called.

Examples

With a single dataloader:

def test_epoch_end(self, outputs):
    # do something with the outputs of all test batches
    all_test_preds = test_step_outputs.predictions

    some_result = calc_all_results(all_test_preds)
    self.log(some_result)

With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each test step for that dataloader.

def test_epoch_end(self, outputs):
    final_value = 0
    for dataloader_outputs in outputs:
        for test_step_out in dataloader_outputs:
            # do something
            final_value += test_step_out

    self.log("final_metric", final_value)
to_onnx
LightningModule.to_onnx(file_path, input_sample=None, **kwargs)[source]

Saves the model in ONNX format.

Parameters
  • file_path (Union[str, Path]) – The path of the file the onnx model should be saved to.

  • input_sample (Optional[Any]) – An input for tracing. Default: None (Use self.example_input_array)

  • **kwargs – Will be passed to torch.onnx.export function.

Example

>>> class SimpleModel(LightningModule):
...     def __init__(self):
...         super().__init__()
...         self.l1 = torch.nn.Linear(in_features=64, out_features=4)
...
...     def forward(self, x):
...         return torch.relu(self.l1(x.view(x.size(0), -1)))
>>> import os, tempfile
>>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
...     model = SimpleModel()
...     input_sample = torch.randn((1, 64))
...     model.to_onnx(tmpfile.name, input_sample, export_params=True)
...     os.path.isfile(tmpfile.name)
True
Return type

None

to_torchscript
LightningModule.to_torchscript(file_path=None, method='script', example_inputs=None, **kwargs)[source]

By default compiles the whole model to a ScriptModule. If you want to use tracing, please provided the argument method='trace' and make sure that either the example_inputs argument is provided, or the model has example_input_array set. If you would like to customize the modules that are scripted you should override this method. In case you want to return multiple modules, we recommend using a dictionary.

Parameters
  • file_path (Union[str, Path, None]) – Path where to save the torchscript. Default: None (no file saved).

  • method (Optional[str]) – Whether to use TorchScript’s script or trace method. Default: ‘script’

  • example_inputs (Optional[Any]) – An input to be used to do tracing when method is set to ‘trace’. Default: None (uses example_input_array)

  • **kwargs – Additional arguments that will be passed to the torch.jit.script() or torch.jit.trace() function.

Note

  • Requires the implementation of the forward() method.

  • The exported script will be set to evaluation mode.

  • It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. See also the torch.jit documentation for supported features.

Example

>>> class SimpleModel(LightningModule):
...     def __init__(self):
...         super().__init__()
...         self.l1 = torch.nn.Linear(in_features=64, out_features=4)
...
...     def forward(self, x):
...         return torch.relu(self.l1(x.view(x.size(0), -1)))
...
>>> import os
>>> model = SimpleModel()
>>> model.to_torchscript(file_path="model.pt")  
>>> os.path.isfile("model.pt")  
>>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', 
...                                     example_inputs=torch.randn(1, 64)))  
>>> os.path.isfile("model_trace.pt")  
True
Return type

Union[ScriptModule, Dict[str, ScriptModule]]

Returns

This LightningModule as a torchscript, regardless of whether file_path is defined or not.

training_step
LightningModule.training_step(*args, **kwargs)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters
  • batch (Tensor | (Tensor, …) | [Tensor, …]) – The output of your DataLoader. A tensor, tuple or list.

  • batch_idx (int) – Integer displaying index of this batch

  • optimizer_idx (int) – When using multiple optimizers, this argument will also be present.

  • hiddens (Any) – Passed in if truncated_bptt_steps > 0.

Return type

Union[Tensor, Dict[str, Any]]

Returns

Any of.

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'

  • None - Training will skip to the next batch. This is only for automatic optimization.

    This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

If you define multiple optimizers, this step will be called with an additional optimizer_idx parameter.

# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
    if optimizer_idx == 0:
        # do training_step with encoder
        ...
    if optimizer_idx == 1:
        # do training_step with decoder
        ...

If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.

# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
    # hiddens are the hidden states from the previous truncated backprop step
    out, hiddens = self.lstm(data, hiddens)
    loss = ...
    return {"loss": loss, "hiddens": hiddens}

Note

The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

training_step_end
LightningModule.training_step_end(step_output)[source]

Use this when training with dp because training_step() will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.

Note

If you later switch to ddp or some other mode, this will still be called so that you don’t have to change your code

# pseudocode
sub_batches = split_batches_for_dp(batch)
step_output = [training_step(sub_batch) for sub_batch in sub_batches]
training_step_end(step_output)
Parameters

step_output (Union[Tensor, Dict[str, Any]]) – What you return in training_step for each batch part.

Return type

Union[Tensor, Dict[str, Any]]

Returns

Anything

When using the DP strategy, only a portion of the batch is inside the training_step:

def training_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self(x)

    # softmax uses only a portion of the batch in the denominator
    loss = self.softmax(out)
    loss = nce_loss(loss)
    return loss

If you wish to do something with all the parts of the batch, then use this method to do it:

def training_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self.encoder(x)
    return {"pred": out}


def training_step_end(self, training_step_outputs):
    gpu_0_pred = training_step_outputs[0]["pred"]
    gpu_1_pred = training_step_outputs[1]["pred"]
    gpu_n_pred = training_step_outputs[n]["pred"]

    # this softmax now uses the full batch
    loss = nce_loss([gpu_0_pred, gpu_1_pred, gpu_n_pred])
    return loss

See also

See the Multi GPU Training guide for more details.

training_epoch_end
LightningModule.training_epoch_end(outputs)[source]

Called at the end of the training epoch with the outputs of all training steps. Use this in case you need to do something with all the outputs returned by training_step().

# the pseudocode for these calls
train_outs = []
for train_batch in train_data:
    out = training_step(train_batch)
    train_outs.append(out)
training_epoch_end(train_outs)
Parameters

outputs (List[Union[Tensor, Dict[str, Any]]]) – List of outputs you defined in training_step(). If there are multiple optimizers or when using truncated_bptt_steps > 0, the lists have the dimensions (n_batches, tbptt_steps, n_optimizers). Dimensions of length 1 are squeezed.

Return type

None

Returns

None

Note

If this method is not overridden, this won’t be called.

def training_epoch_end(self, training_step_outputs):
    # do something with all training_step outputs
    for out in training_step_outputs:
        ...
unfreeze
LightningModule.unfreeze()[source]

Unfreeze all parameters for training.

model = MyLightningModule(...)
model.unfreeze()
Return type

None

untoggle_optimizer
LightningModule.untoggle_optimizer(optimizer_idx)[source]

Resets the state of required gradients that were toggled with toggle_optimizer().

This is only called automatically when automatic optimization is enabled and multiple optimizers are used.

Parameters

optimizer_idx (int) – The index of the optimizer to untoggle.

Return type

None

validation_step
LightningModule.validation_step(*args, **kwargs)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    val_outs.append(out)
validation_epoch_end(val_outs)
Parameters
  • batch – The output of your DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple val dataloaders used)

Return type

Union[Tensor, Dict[str, Any], None]

Returns

  • Any object or value

  • None - Validation will skip to the next batch

# pseudocode of order
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    if defined("validation_step_end"):
        out = validation_step_end(out)
    val_outs.append(out)
val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader:
def validation_step(self, batch, batch_idx):
    ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

validation_step_end
LightningModule.validation_step_end(*args, **kwargs)[source]

Use this when validating with dp because validation_step() will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.

Note

If you later switch to ddp or some other mode, this will still be called so that you don’t have to change your code.

# pseudocode
sub_batches = split_batches_for_dp(batch)
step_output = [validation_step(sub_batch) for sub_batch in sub_batches]
validation_step_end(step_output)
Parameters

step_output – What you return in validation_step() for each batch part.

Return type

Union[Tensor, Dict[str, Any], None]

Returns

None or anything

# WITHOUT validation_step_end
# if used in DP, this batch is 1/num_gpus large
def validation_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self.encoder(x)
    loss = self.softmax(out)
    loss = nce_loss(loss)
    self.log("val_loss", loss)


# --------------
# with validation_step_end to do softmax over the full batch
def validation_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self(x)
    return out


def validation_step_end(self, val_step_outputs):
    for out in val_step_outputs:
        ...

See also

See the Multi GPU Training guide for more details.

validation_epoch_end
LightningModule.validation_epoch_end(outputs)[source]

Called at the end of the validation epoch with the outputs of all validation steps.

# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    val_outs.append(out)
validation_epoch_end(val_outs)
Parameters

outputs (Union[List[Union[Tensor, Dict[str, Any]]], List[List[Union[Tensor, Dict[str, Any]]]]]) – List of outputs you defined in validation_step(), or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.

Return type

None

Returns

None

Note

If you didn’t define a validation_step(), this won’t be called.

Examples

With a single dataloader:

def validation_epoch_end(self, val_step_outputs):
    for out in val_step_outputs:
        ...

With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each validation step for that dataloader.

def validation_epoch_end(self, outputs):
    for dataloader_output_result in outputs:
        dataloader_outs = dataloader_output_result.dataloader_i_outputs

    self.log("final_metric", final_value)

Properties

These are properties available in a LightningModule.

current_epoch

The number of epochs run.

def training_step(self, batch, batch_idx):
    if self.current_epoch == 0:
        ...
device

The device the module is on. Use it to keep your code device agnostic.

def training_step(self, batch, batch_idx):
    z = torch.rand(2, 3, device=self.device)
global_rank

The global_rank is the index of the current process across all nodes and devices. Lightning will perform some operations such as logging, weight checkpointing only when global_rank=0. You usually do not need to use this property, but it is useful to know how to access it if needed.

def training_step(self, batch, batch_idx):
    if self.global_rank == 0:
        # do something only once across all the nodes
        ...
global_step

The number of optimizer steps taken (does not reset each epoch). This includes multiple optimizers and TBPTT steps (if enabled).

def training_step(self, batch, batch_idx):
    self.logger.experiment.log_image(..., step=self.global_step)
hparams

The arguments passed through LightningModule.__init__() and saved by calling save_hyperparameters() could be accessed by the hparams attribute.

def __init__(self, learning_rate):
    self.save_hyperparameters()


def configure_optimizers(self):
    return Adam(self.parameters(), lr=self.hparams.learning_rate)
logger

The current logger being used (tensorboard or other supported logger)

def training_step(self, batch, batch_idx):
    # the generic logger (same no matter if tensorboard or other supported logger)
    self.logger

    # the particular logger
    tensorboard_logger = self.logger.experiment
loggers

The list of loggers currently being used by the Trainer.

def training_step(self, batch, batch_idx):
    # List of Logger objects
    loggers = self.loggers
    for logger in loggers:
        logger.log_metrics({"foo": 1.0})
local_rank

The local_rank is the index of the current process across all the devices for the current node. You usually do not need to use this property, but it is useful to know how to access it if needed. For example, if using 10 machines (or nodes), the GPU at index 0 on each machine has local_rank = 0.

def training_step(self, batch, batch_idx):
    if self.local_rank == 0:
        # do something only once across each node
        ...
precision

The type of precision used:

def training_step(self, batch, batch_idx):
    if self.precision == 16:
        ...
trainer

Pointer to the trainer

def training_step(self, batch, batch_idx):
    max_steps = self.trainer.max_steps
    any_flag = self.trainer.any_flag
prepare_data_per_node

If set to True will call prepare_data() on LOCAL_RANK=0 for every node. If set to False will only call from NODE_RANK=0, LOCAL_RANK=0.

class LitModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = True
automatic_optimization

When set to False, Lightning does not automate the optimization process. This means you are responsible for handling your optimizers. However, we do take care of precision and any accelerators used.

See manual optimization for details.

def __init__(self):
    self.automatic_optimization = False


def training_step(self, batch, batch_idx):
    opt = self.optimizers(use_pl_optimizer=True)

    loss = ...
    opt.zero_grad()
    self.manual_backward(loss)
    opt.step()

This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note that automatic optimization can still be used with multiple optimizers by relying on the optimizer_idx parameter. Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research.

def __init__(self):
    self.automatic_optimization = False


def training_step(self, batch, batch_idx):
    # access your optimizers with use_pl_optimizer=False. Default is True
    opt_a, opt_b = self.optimizers(use_pl_optimizer=True)

    gen_loss = ...
    opt_a.zero_grad()
    self.manual_backward(gen_loss)
    opt_a.step()

    disc_loss = ...
    opt_b.zero_grad()
    self.manual_backward(disc_loss)
    opt_b.step()
example_input_array

Set and access example_input_array, which basically represents a single batch.

def __init__(self):
    self.example_input_array = ...
    self.generator = ...


def on_train_epoch_end(self):
    # generate some images using the example_input_array
    gen_images = self.generator(self.example_input_array)
truncated_bptt_steps

Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of a much longer sequence. This is made possible by passing training batches split along the time-dimensions into splits of size k to the training_step. In order to keep the same forward propagation behavior, all hidden states should be kept in-between each time-dimension split.

If this is enabled, your batches will automatically get truncated and the Trainer will apply Truncated Backprop to it.

(Williams et al. “An efficient gradient-based algorithm for on-line training of recurrent network trajectories.”)

Tutorial

from pytorch_lightning import LightningModule


class MyModel(LightningModule):
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        # batch_first has to be set to True
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
        )

        ...

        # Important: This property activates truncated backpropagation through time
        # Setting this value to 2 splits the batch into sequences of size 2
        self.truncated_bptt_steps = 2

    # Truncated back-propagation through time
    def training_step(self, batch, batch_idx, hiddens):
        x, y = batch

        # the training step must be updated to accept a ``hiddens`` argument
        # hiddens are the hiddens from the previous truncated backprop step
        out, hiddens = self.lstm(x, hiddens)

        ...

        return {"loss": ..., "hiddens": hiddens}

Lightning takes care of splitting your batch along the time-dimension. It is assumed to be the second dimension of your batches. Therefore, in the example above, we have set batch_first=True.

# we use the second as the time dimension
# (batch, time, ...)
sub_batch = batch[0, 0:t, ...]

To modify how the batch is split, override the pytorch_lightning.core.module.LightningModule.tbptt_split_batch() method:

class LitMNIST(LightningModule):
    def tbptt_split_batch(self, batch, split_size):
        # do your own splitting on the batch
        return splits

Hooks

This is the pseudocode to describe the structure of fit(). The inputs and outputs of each function are not represented for simplicity. Please check each function’s API reference for more information.

def fit(self):
    if global_rank == 0:
        # prepare data is called on GLOBAL_ZERO only
        prepare_data()

    configure_callbacks()

    with parallel(devices):
        # devices can be GPUs, TPUs, ...
        train_on_device(model)


def train_on_device(model):
    # called PER DEVICE
    setup("fit")
    configure_optimizers()
    on_fit_start()

    # the sanity check runs here

    on_train_start()
    for epoch in epochs:
        fit_loop()
    on_train_end()

    on_fit_end()
    teardown("fit")


def fit_loop():
    on_train_epoch_start()

    for batch in train_dataloader():
        on_train_batch_start()

        on_before_batch_transfer()
        transfer_batch_to_device()
        on_after_batch_transfer()

        training_step()

        on_before_zero_grad()
        optimizer_zero_grad()

        on_before_backward()
        backward()
        on_after_backward()

        on_before_optimizer_step()
        configure_gradient_clipping()
        optimizer_step()

        on_train_batch_end()

        if should_check_val:
            val_loop()
    # end training epoch
    training_epoch_end()

    on_train_epoch_end()


def val_loop():
    on_validation_model_eval()  # calls `model.eval()`
    torch.set_grad_enabled(False)

    on_validation_start()
    on_validation_epoch_start()

    val_outs = []
    for batch_idx, batch in enumerate(val_dataloader()):
        on_validation_batch_start(batch, batch_idx)

        batch = on_before_batch_transfer(batch)
        batch = transfer_batch_to_device(batch)
        batch = on_after_batch_transfer(batch)

        out = validation_step(batch, batch_idx)

        on_validation_batch_end(batch, batch_idx)
        val_outs.append(out)

    validation_epoch_end(val_outs)

    on_validation_epoch_end()
    on_validation_end()

    # set up for train
    on_validation_model_train()  # calls `model.train()`
    torch.set_grad_enabled(True)
backward
LightningModule.backward(loss, optimizer, optimizer_idx, *args, **kwargs)[source]

Called to perform backward on the loss returned in training_step(). Override this hook with your own implementation if you need to.

Parameters
  • loss (Tensor) – The loss tensor returned by training_step(). If gradient accumulation is used, the loss here holds the normalized value (scaled by 1 / accumulation steps).

  • optimizer (Optional[Steppable]) – Current optimizer being used. None if using manual optimization.

  • optimizer_idx (Optional[int]) – Index of the current optimizer being used. None if using manual optimization.

Example:

def backward(self, loss, optimizer, optimizer_idx):
    loss.backward()
Return type

None

on_before_backward
LightningModule.on_before_backward(loss)

Called before loss.backward().

Parameters

loss (Tensor) – Loss divided by number of batches for gradient accumulation and scaled if using native AMP.

Return type

None

on_after_backward
LightningModule.on_after_backward()

Called after loss.backward() and before optimizers are stepped.

Note

If using native AMP, the gradients will not be unscaled at this point. Use the on_before_optimizer_step if you need the unscaled gradients.

Return type

None

on_before_zero_grad
LightningModule.on_before_zero_grad(optimizer)

Called after training_step() and before optimizer.zero_grad().

Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.

This is where it is called:

for optimizer in optimizers:
    out = training_step(...)

    model.on_before_zero_grad(optimizer) # < ---- called here
    optimizer.zero_grad()

    backward()
Parameters

optimizer (Optimizer) – The optimizer for which grads should be zeroed.

Return type

None

on_fit_start
LightningModule.on_fit_start()

Called at the very beginning of fit.

If on DDP it is called on every process

Return type

None

on_fit_end
LightningModule.on_fit_end()

Called at the very end of fit.

If on DDP it is called on every process

Return type

None

on_load_checkpoint
LightningModule.on_load_checkpoint(checkpoint)

Called by Lightning to restore your model. If you saved something with on_save_checkpoint() this is your chance to restore this.

Parameters

checkpoint (Dict[str, Any]) – Loaded checkpoint

Example:

def on_load_checkpoint(self, checkpoint):
    # 99% of the time you don't need to implement this method
    self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']

Note

Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.

Return type

None

on_save_checkpoint
LightningModule.on_save_checkpoint(checkpoint)

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters

checkpoint (Dict[str, Any]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.

Example:

def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

Note

Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.

Return type

None

load_from_checkpoint
classmethod LightningModule.load_from_checkpoint(checkpoint_path, map_location=None, hparams_file=None, strict=True, **kwargs)

Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under "hyper_parameters".

Any arguments specified through **kwargs will override args stored in "hyper_parameters".

Parameters
  • checkpoint_path (Union[str, Path, IO]) – Path to checkpoint. This can also be a URL, or file-like object

  • map_location (Union[device, str, int, Callable[[Union[device, str, int]], Union[device, str, int]], Dict[Union[device, str, int], Union[device, str, int]], None]) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in torch.load().

  • hparams_file (Union[str, Path, None]) –

    Optional path to a .yaml or .csv file with hierarchical structure as in this example:

    drop_prob: 0.2
    dataloader:
        batch_size: 32
    

    You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you’d like to use. These will be converted into a dict and passed into your LightningModule for use.

    If your model’s hparams argument is Namespace and .yaml file has hierarchical structure, you need to refactor your model to treat hparams as dict.

  • strict (bool) – Whether to strictly enforce that the keys in checkpoint_path match the keys returned by this module’s state dict.

  • **kwargs – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.

Return type

Self

Returns

LightningModule instance with loaded weights and hyperparameters (if available).

Note

load_from_checkpoint is a class method. You should use your LightningModule class to call it instead of the LightningModule instance.

Example:

# load weights without mapping ...
model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')

# or load weights mapping all weights from GPU 1 to GPU 0 ...
map_location = {'cuda:1':'cuda:0'}
model = MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    map_location=map_location
)

# or load weights and hyperparameters from separate files.
model = MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    hparams_file='/path/to/hparams_file.yaml'
)

# override some of the params with new values
model = MyLightningModule.load_from_checkpoint(
    PATH,
    num_layers=128,
    pretrained_ckpt_path=NEW_PATH,
)

# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
on_train_start
LightningModule.on_train_start()

Called at the beginning of training after sanity check.

Return type

None

on_train_end
LightningModule.on_train_end()

Called at the end of training before logger experiment is closed.

Return type

None

on_validation_start
LightningModule.on_validation_start()

Called at the beginning of validation.

Return type

None

on_validation_end
LightningModule.on_validation_end()

Called at the end of validation.

Return type

None

on_test_batch_start
LightningModule.on_test_batch_start(batch, batch_idx, dataloader_idx)

Called in the test loop before anything happens for that batch.

Parameters
  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_test_batch_end
LightningModule.on_test_batch_end(outputs, batch, batch_idx, dataloader_idx)

Called in the test loop after the batch.

Parameters
  • outputs (Union[Tensor, Dict[str, Any], None]) – The outputs of test_step_end(test_step(x))

  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_test_epoch_start
LightningModule.on_test_epoch_start()

Called in the test loop at the very beginning of the epoch.

Return type

None

on_test_epoch_end
LightningModule.on_test_epoch_end()

Called in the test loop at the very end of the epoch.

Return type

None

on_test_start
LightningModule.on_test_start()

Called at the beginning of testing.

Return type

None

on_test_end
LightningModule.on_test_end()

Called at the end of testing.

Return type

None

on_predict_batch_start
LightningModule.on_predict_batch_start(batch, batch_idx, dataloader_idx)

Called in the predict loop before anything happens for that batch.

Parameters
  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_predict_batch_end
LightningModule.on_predict_batch_end(outputs, batch, batch_idx, dataloader_idx)

Called in the predict loop after the batch.

Parameters
  • outputs (Optional[Any]) – The outputs of predict_step_end(test_step(x))

  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_predict_epoch_start
LightningModule.on_predict_epoch_start()

Called at the beginning of predicting.

Return type

None

on_predict_epoch_end
LightningModule.on_predict_epoch_end(results)

Called at the end of predicting.

Return type

None

on_predict_start
LightningModule.on_predict_start()

Called at the beginning of predicting.

Return type

None

on_predict_end
LightningModule.on_predict_end()

Called at the end of predicting.

Return type

None

on_train_batch_start
LightningModule.on_train_batch_start(batch, batch_idx)

Called in the training loop before anything happens for that batch.

If you return -1 here, you will skip training for the rest of the current epoch.

Parameters
  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

Return type

Optional[int]

on_train_batch_end
LightningModule.on_train_batch_end(outputs, batch, batch_idx)

Called in the training loop after the batch.

Parameters
  • outputs (Union[Tensor, Dict[str, Any]]) – The outputs of training_step_end(training_step(x))

  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

Return type

None

on_train_epoch_start
LightningModule.on_train_epoch_start()

Called in the training loop at the very beginning of the epoch.

Return type

None

on_train_epoch_end
LightningModule.on_train_epoch_end()

Called in the training loop at the very end of the epoch.

To access all batch outputs at the end of the epoch, either:

  1. Implement training_epoch_end in the LightningModule OR

  2. Cache data across steps on the attribute(s) of the LightningModule and access them in this hook

Return type

None

on_validation_batch_start
LightningModule.on_validation_batch_start(batch, batch_idx, dataloader_idx)

Called in the validation loop before anything happens for that batch.

Parameters
  • batch (Any) – The batched data as it is returned by the validation DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_validation_batch_end
LightningModule.on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)

Called in the validation loop after the batch.

Parameters
  • outputs (Union[Tensor, Dict[str, Any], None]) – The outputs of validation_step_end(validation_step(x))

  • batch (Any) – The batched data as it is returned by the validation DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_validation_epoch_start
LightningModule.on_validation_epoch_start()

Called in the validation loop at the very beginning of the epoch.

Return type

None

on_validation_epoch_end
LightningModule.on_validation_epoch_end()

Called in the validation loop at the very end of the epoch.

Return type

None

configure_sharded_model
LightningModule.configure_sharded_model()

Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, where we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.

This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent.

Return type

None

on_validation_model_eval
LightningModule.on_validation_model_eval()

Sets the model to eval during the val loop.

Return type

None

on_validation_model_train
LightningModule.on_validation_model_train()

Sets the model to train during the val loop.

Return type

None

on_test_model_eval
LightningModule.on_test_model_eval()

Sets the model to eval during the test loop.

Return type

None

on_test_model_train
LightningModule.on_test_model_train()

Sets the model to train during the test loop.

Return type

None

on_before_optimizer_step
LightningModule.on_before_optimizer_step(optimizer, optimizer_idx)

Called before optimizer.step().

If using gradient accumulation, the hook is called once the gradients have been accumulated. See: accumulate_grad_batches.

If using native AMP, the loss will be unscaled before calling this hook. See these docs for more information on the scaling of gradients.

If clipping gradients, the gradients will not have been clipped yet.

Parameters
  • optimizer (Optimizer) – Current optimizer being used.

  • optimizer_idx (int) – Index of the current optimizer being used.

Example:

def on_before_optimizer_step(self, optimizer, optimizer_idx):
    # example to inspect gradient information in tensorboard
    if self.trainer.global_step % 25 == 0:  # don't make the tf file huge
        for k, v in self.named_parameters():
            self.logger.experiment.add_histogram(
                tag=k, values=v.grad, global_step=self.trainer.global_step
            )
Return type

None

configure_gradient_clipping
LightningModule.configure_gradient_clipping(optimizer, optimizer_idx, gradient_clip_val=None, gradient_clip_algorithm=None)[source]

Perform gradient clipping for the optimizer parameters. Called before optimizer_step().

Parameters
  • optimizer (Optimizer) – Current optimizer being used.

  • optimizer_idx (int) – Index of the current optimizer being used.

  • gradient_clip_val (Union[int, float, None]) – The value at which to clip gradients. By default value passed in Trainer will be available here.

  • gradient_clip_algorithm (Optional[str]) – The gradient clipping algorithm to use. By default value passed in Trainer will be available here.

Example:

# Perform gradient clipping on gradients associated with discriminator (optimizer_idx=1) in GAN
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
    if optimizer_idx == 1:
        # Lightning will handle the gradient clipping
        self.clip_gradients(
            optimizer,
            gradient_clip_val=gradient_clip_val,
            gradient_clip_algorithm=gradient_clip_algorithm
        )
    else:
        # implement your own custom logic to clip gradients for generator (optimizer_idx=0)
Return type

None

optimizer_step
LightningModule.optimizer_step(epoch, batch_idx, optimizer, optimizer_idx=0, optimizer_closure=None, on_tpu=False, using_lbfgs=False)[source]

Override this method to adjust the default way the Trainer calls each optimizer.

By default, Lightning calls step() and zero_grad() as shown in the example once per optimizer. This method (and zero_grad()) won’t be called during the accumulation phase when Trainer(accumulate_grad_batches != 1). Overriding this hook has no benefit with manual optimization.

Parameters
  • epoch (int) – Current epoch

  • batch_idx (int) – Index of current batch

  • optimizer (Union[Optimizer, LightningOptimizer]) – A PyTorch optimizer

  • optimizer_idx (int) – If you used multiple optimizers, this indexes into that list.

  • optimizer_closure (Optional[Callable[[], Any]]) – The optimizer closure. This closure must be executed as it includes the calls to training_step(), optimizer.zero_grad(), and backward().

  • on_tpu (bool) – True if TPU backward is required

  • using_lbfgs (bool) – True if the matching optimizer is torch.optim.LBFGS

Examples:

# DEFAULT
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                   optimizer_closure, on_tpu, using_lbfgs):
    optimizer.step(closure=optimizer_closure)

# Alternating schedule for optimizer steps (i.e.: GANs)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                   optimizer_closure, on_tpu, using_lbfgs):
    # update generator opt every step
    if optimizer_idx == 0:
        optimizer.step(closure=optimizer_closure)

    # update discriminator opt every 2 steps
    if optimizer_idx == 1:
        if (batch_idx + 1) % 2 == 0 :
            optimizer.step(closure=optimizer_closure)
        else:
            # call the closure by itself to run `training_step` + `backward` without an optimizer step
            optimizer_closure()

    # ...
    # add as many optimizers as you want

Here’s another example showing how to use this for more advanced things such as learning rate warm-up:

# learning rate warm-up
def optimizer_step(
    self,
    epoch,
    batch_idx,
    optimizer,
    optimizer_idx,
    optimizer_closure,
    on_tpu,
    using_lbfgs,
):
    # update params
    optimizer.step(closure=optimizer_closure)

    # manually warm up lr without a scheduler
    if self.trainer.global_step < 500:
        lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0)
        for pg in optimizer.param_groups:
            pg["lr"] = lr_scale * self.learning_rate
Return type

None

optimizer_zero_grad
LightningModule.optimizer_zero_grad(epoch, batch_idx, optimizer, optimizer_idx)[source]

Override this method to change the default behaviour of optimizer.zero_grad().

Parameters
  • epoch (int) – Current epoch

  • batch_idx (int) – Index of current batch

  • optimizer (Optimizer) – A PyTorch optimizer

  • optimizer_idx (int) – If you used multiple optimizers this indexes into that list.

Examples:

# DEFAULT
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
    optimizer.zero_grad()

# Set gradients to `None` instead of zero to improve performance (not required on `torch>=2.0.0`).
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
    optimizer.zero_grad(set_to_none=True)

See torch.optim.Optimizer.zero_grad() for the explanation of the above example.

Return type

None

prepare_data
LightningModule.prepare_data()

Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.

Warning

DO NOT set state to the model (use setup instead) since this is NOT called on every device

Example:

def prepare_data(self):
    # good
    download_data()
    tokenize()
    etc()

    # bad
    self.split = data_split
    self.some_state = some_other_state()

In a distributed environment, prepare_data can be called in two ways (using prepare_data_per_node)

  1. Once per node. This is the default and is only called on LOCAL_RANK=0.

  2. Once in total. Only called on GLOBAL_RANK=0.

Example:

# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = True


# call on GLOBAL_RANK=0 (great for shared file systems)
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = False

This is called before requesting the dataloaders:

model.prepare_data()
initialize_distributed()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
model.predict_dataloader()
Return type

None

setup
LightningModule.setup(stage)

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters

stage (str) – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
Return type

None

tbptt_split_batch
LightningModule.tbptt_split_batch(batch, split_size)[source]

When using truncated backpropagation through time, each batch must be split along the time dimension. Lightning handles this by default, but for custom behavior override this function.

Parameters
  • batch (Any) – Current batch

  • split_size (int) – The size of the split

Return type

List[Any]

Returns

List of batch splits. Each split will be passed to training_step() to enable truncated back propagation through time. The default implementation splits root level Tensors and Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.

Examples:

def tbptt_split_batch(self, batch, split_size):
    splits = []
    for t in range(0, time_dims[0], split_size):
        batch_split = []
        for i, x in enumerate(batch):
            if isinstance(x, torch.Tensor):
                split_x = x[:, t:t + split_size]
            elif isinstance(x, collections.abc.Sequence):
                split_x = [None] * len(x)
                for batch_idx in range(len(x)):
                  split_x[batch_idx] = x[batch_idx][t:t + split_size]
            batch_split.append(split_x)
        splits.append(batch_split)
    return splits

Note

Called in the training loop after on_train_batch_start() if truncated_bptt_steps > 0. Each returned batch split is passed separately to training_step().

teardown
LightningModule.teardown(stage)

Called at the end of fit (train + validate), validate, test, or predict.

Parameters

stage (str) – either 'fit', 'validate', 'test', or 'predict'

Return type

None

train_dataloader
LightningModule.train_dataloader()

Implement one or more PyTorch DataLoaders for training.

Return type

Union[DataLoader, Sequence[DataLoader], Sequence[Sequence[DataLoader]], Sequence[Dict[str, DataLoader]], Dict[str, DataLoader], Dict[str, Dict[str, DataLoader]], Dict[str, Sequence[DataLoader]]]

Returns

A collection of torch.utils.data.DataLoader specifying training samples. In the case of multiple dataloaders, please see this section.

The dataloader you return will not be reloaded unless you set reload_dataloaders_every_n_epochs to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • fit()

  • prepare_data()

  • setup()

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Example:

# single dataloader
def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=True
    )
    return loader

# multiple dataloaders, return as list
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a list of tensors: [batch_mnist, batch_cifar]
    return [mnist_loader, cifar_loader]

# multiple dataloader, return as dict
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
    return {'mnist': mnist_loader, 'cifar': cifar_loader}
val_dataloader
LightningModule.val_dataloader()

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be reloaded unless you set reload_dataloaders_every_n_epochs to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Return type

Union[DataLoader, Sequence[DataLoader]]

Returns

A torch.utils.data.DataLoader or a sequence of them specifying validation samples.

Examples:

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Note

In the case where you return multiple validation dataloaders, the validation_step() will have an argument dataloader_idx which matches the order here.

test_dataloader
LightningModule.test_dataloader()

Implement one or multiple PyTorch DataLoaders for testing.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • test()

  • prepare_data()

  • setup()

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Return type

Union[DataLoader, Sequence[DataLoader]]

Returns

A torch.utils.data.DataLoader or a sequence of them specifying testing samples.

Example:

def test_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def test_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

Note

In the case where you return multiple test dataloaders, the test_step() will have an argument dataloader_idx which matches the order here.

predict_dataloader
LightningModule.predict_dataloader()

Implement one or multiple PyTorch DataLoaders for prediction.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Return type

Union[DataLoader, Sequence[DataLoader]]

Returns

A torch.utils.data.DataLoader or a sequence of them specifying prediction samples.

Note

In the case where you return multiple prediction dataloaders, the predict_step() will have an argument dataloader_idx which matches the order here.

transfer_batch_to_device
LightningModule.transfer_batch_to_device(batch, device, dataloader_idx)

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

The data types listed below (and any arbitrary nesting of them) are supported out of the box:

For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, …).

Note

This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Note

This hook only runs on single GPU training and DDP (no data-parallel). Data-Parallel support will come in near future.

Parameters
  • batch (Any) – A batch of data that needs to be transferred to a new device.

  • device (device) – The target device as defined in PyTorch.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Return type

Any

Returns

A reference to the data on the new device.

Example:

def transfer_batch_to_device(self, batch, device, dataloader_idx):
    if isinstance(batch, CustomBatch):
        # move all tensors in your custom data structure to the device
        batch.samples = batch.samples.to(device)
        batch.targets = batch.targets.to(device)
    elif dataloader_idx == 0:
        # skip device transfer for the first dataloader or anything you wish
        pass
    else:
        batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
    return batch
Raises
  • MisconfigurationException – If using data-parallel, Trainer(strategy='dp').

  • MisconfigurationException – If using IPUs, Trainer(accelerator='ipu').

See also

  • move_data_to_device()

  • apply_to_collection()

on_before_batch_transfer
LightningModule.on_before_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

Note

To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Note

This hook only runs on single GPU training and DDP (no data-parallel). Data-Parallel support will come in near future.

Parameters
  • batch (Any) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Return type

Any

Returns

A batch of data

Example:

def on_before_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = transforms(batch['x'])
    return batch

See also

  • on_after_batch_transfer()

  • transfer_batch_to_device()

on_after_batch_transfer
LightningModule.on_after_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

Note

To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Note

This hook only runs on single GPU training and DDP (no data-parallel). Data-Parallel support will come in near future.

Parameters
  • batch (Any) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Return type

Any

Returns

A batch of data

Example:

def on_after_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = gpu_transforms(batch['x'])
    return batch
Raises
  • MisconfigurationException – If using data-parallel, Trainer(strategy='dp').

  • MisconfigurationException – If using IPUs, Trainer(accelerator='ipu').

See also

  • on_before_batch_transfer()

  • transfer_batch_to_device()

Trainer

Once you’ve organized your PyTorch code into a LightningModule, the Trainer automates everything else.


This abstraction achieves the following:

  1. You maintain control over all aspects via PyTorch code without an added abstraction.

  2. The trainer uses best practices embedded by contributors and users from top AI labs such as Facebook AI Research, NYU, MIT, Stanford, etc…

  3. The trainer allows overriding any key part that you don’t want automated.



Basic use

This is the basic use of the trainer:

model = MyLightningModule()

trainer = Trainer()
trainer.fit(model, train_dataloader, val_dataloader)

Under the hood

Under the hood, the Lightning Trainer handles the training loop details for you, some examples include:

  • Automatically enabling/disabling grads

  • Running the training, validation and test dataloaders

  • Calling the Callbacks at the appropriate times

  • Putting batches and computations on the correct devices

Here’s the pseudocode for what the trainer does under the hood (showing the train loop only)

# put model in train mode
model.train()
torch.set_grad_enabled(True)

losses = []
for batch in train_dataloader:
    # calls hooks like this one
    on_train_batch_start()

    # train step
    loss = training_step(batch)

    # clear gradients
    optimizer.zero_grad()

    # backward
    loss.backward()

    # update parameters
    optimizer.step()

    losses.append(loss)

Trainer in Python scripts

In Python scripts, it’s recommended you use a main function to call the Trainer.

from argparse import ArgumentParser


def main(hparams):
    model = LightningModule()
    trainer = Trainer(accelerator=hparams.accelerator, devices=hparams.devices)
    trainer.fit(model)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--accelerator", default=None)
    parser.add_argument("--devices", default=None)
    args = parser.parse_args()

    main(args)

So you can run it like so:

python main.py --accelerator 'gpu' --devices 2

Note

Pro-tip: You don’t need to define all flags manually. Lightning can add them automatically

from argparse import ArgumentParser


def main(args):
    model = LightningModule()
    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    main(args)

So you can run it like so:

python main.py --accelerator 'gpu' --devices 2 --max_steps 10 --limit_train_batches 10 --any_trainer_arg x

Note

If you want to stop a training run early, you can press “Ctrl + C” on your keyboard. The trainer will catch the KeyboardInterrupt and attempt a graceful shutdown, including running accelerator callback on_train_end to clean up memory. The trainer object will also set an attribute interrupted to True in such cases. If you have a callback which shuts down compute resources, for example, you can conditionally run the shutdown logic for only uninterrupted runs.


Validation

You can perform an evaluation epoch over the validation set, outside of the training loop, using validate(). This might be useful if you want to collect new metrics from a model right at its initialization or after it has already been trained.

trainer.validate(model=model, dataloaders=val_dataloaders)

Testing

Once you’re done training, feel free to run the test set! (Only right before publishing your paper or pushing to production)

trainer.test(dataloaders=test_dataloaders)

Reproducibility

To ensure full reproducibility from run to run you need to set seeds for pseudo-random generators, and set deterministic flag in Trainer.

Example:

from pytorch_lightning import Trainer, seed_everything

seed_everything(42, workers=True)
# sets seeds for numpy, torch and python.random.
model = Model()
trainer = Trainer(deterministic=True)

By setting workers=True in seed_everything(), Lightning derives unique seeds across all dataloader workers and processes for torch, numpy and stdlib random number generators. When turned on, it ensures that e.g. data augmentations are not repeated across workers.


Trainer flags

accelerator

Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "auto") as well as custom accelerator instances.

# CPU accelerator
trainer = Trainer(accelerator="cpu")

# Training with GPU Accelerator using 2 GPUs
trainer = Trainer(devices=2, accelerator="gpu")

# Training with TPU Accelerator using 8 tpu cores
trainer = Trainer(devices=8, accelerator="tpu")

# Training with GPU Accelerator using the DistributedDataParallel strategy
trainer = Trainer(devices=4, accelerator="gpu", strategy="ddp")

Note

The "auto" option recognizes the machine you are on, and selects the respective Accelerator.

# If your machine has GPUs, it will use the GPU Accelerator for training
trainer = Trainer(devices=2, accelerator="auto")

You can also modify hardware behavior by subclassing an existing accelerator to adjust for your needs.

Example:

class MyOwnAcc(CPUAccelerator):
    ...

Trainer(accelerator=MyOwnAcc())

Note

If the devices flag is not defined, it will assume devices to be "auto" and fetch the auto_device_count from the accelerator.

# This is part of the built-in `CUDAAccelerator`
class CUDAAccelerator(Accelerator):
    """Accelerator for GPU devices."""

    @staticmethod
    def auto_device_count() -> int:
        """Get the devices when set to auto."""
        return torch.cuda.device_count()


# Training with GPU Accelerator using total number of gpus available on the system
Trainer(accelerator="gpu")
accumulate_grad_batches

Accumulates grads every k batches or as set up in the dict. Trainer also calls optimizer.step() for the last indivisible step number.

# default used by the Trainer (no accumulation)
trainer = Trainer(accumulate_grad_batches=1)

Example:

# accumulate every 4 batches (effective batch size is batch*4)
trainer = Trainer(accumulate_grad_batches=4)

# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})
auto_scale_batch_size

Automatically tries to find the largest batch size that fits into memory, before any training.

# default used by the Trainer (no scaling of batch size)
trainer = Trainer(auto_scale_batch_size=None)

# run batch size scaling, result overrides hparams.batch_size
trainer = Trainer(auto_scale_batch_size="binsearch")

# call tune to find the batch size
trainer.tune(model)
auto_lr_find

Runs a learning rate finder algorithm (see this paper) when calling trainer.tune(), to find optimal initial learning rate.

# default used by the Trainer (no learning rate finder)
trainer = Trainer(auto_lr_find=False)

Example:

# run learning rate finder, results override hparams.learning_rate
trainer = Trainer(auto_lr_find=True)

# call tune to find the lr
trainer.tune(model)

Example:

# run learning rate finder, results override hparams.my_lr_arg
trainer = Trainer(auto_lr_find='my_lr_arg')

# call tune to find the lr
trainer.tune(model)
benchmark

The value (True or False) to set torch.backends.cudnn.benchmark to. The value for torch.backends.cudnn.benchmark set in the current session will be used (False if not manually set). If deterministic is set to True, this will default to False. You can read more about the interaction of torch.backends.cudnn.benchmark and torch.backends.cudnn.deterministic here

Setting this flag to True can increase the speed of your system if your input sizes don’t change. However, if they do, then it might make your system slower. The CUDNN auto-tuner will try to find the best algorithm for the hardware when a new input size is encountered. This might also increase the memory usage. Read more about it here.

Example:

# Will use whatever the current value for torch.backends.cudnn.benchmark, normally False
trainer = Trainer(benchmark=None)  # default

# you can overwrite the value
trainer = Trainer(benchmark=True)
deterministic

This flag sets the torch.backends.cudnn.deterministic flag. Might make your system slower, but ensures reproducibility.

For more info check PyTorch docs.

Example:

# default used by the Trainer
trainer = Trainer(deterministic=False)
callbacks

Add a list of Callback. Callbacks run sequentially in the order defined here with the exception of ModelCheckpoint callbacks which run after all others to ensure all states are saved to the checkpoints.

# a list of callbacks
callbacks = [PrintCallback()]
trainer = Trainer(callbacks=callbacks)

Example:

from pytorch_lightning.callbacks import Callback

class PrintCallback(Callback):
    def on_train_start(self, trainer, pl_module):
        print("Training is started!")
    def on_train_end(self, trainer, pl_module):
        print("Training is done.")

Model-specific callbacks can also be added inside the LightningModule through configure_callbacks(). Callbacks returned in this hook will extend the list initially given to the Trainer argument, and replace the trainer callbacks should there be two or more of the same type. ModelCheckpoint callbacks always run last.

check_val_every_n_epoch

Check val every n train epochs.

Example:

# default used by the Trainer
trainer = Trainer(check_val_every_n_epoch=1)

# run val loop every 10 training epochs
trainer = Trainer(check_val_every_n_epoch=10)
default_root_dir

Default path for logs and weights when no logger or pytorch_lightning.callbacks.ModelCheckpoint callback passed. On certain clusters you might want to separate where logs and checkpoints are stored. If you don’t then use this argument for convenience. Paths can be local paths or remote paths such as s3://bucket/path or ‘hdfs://path/’. Credentials will need to be set up to use remote filepaths.

# default used by the Trainer
trainer = Trainer(default_root_dir=os.getcwd())
devices

Number of devices to train on (int), which devices to train on (list or str), or "auto". It will be mapped to either gpus, tpu_cores, num_processes or ipus, based on the accelerator type ("cpu", "gpu", "tpu", "ipu", "auto").

# Training with CPU Accelerator using 2 processes
trainer = Trainer(devices=2, accelerator="cpu")

# Training with GPU Accelerator using GPUs 1 and 3
trainer = Trainer(devices=[1, 3], accelerator="gpu")

# Training with TPU Accelerator using 8 tpu cores
trainer = Trainer(devices=8, accelerator="tpu")

Tip

The "auto" option recognizes the devices to train on, depending on the Accelerator being used.

# If your machine has GPUs, it will use all the available GPUs for training
trainer = Trainer(devices="auto", accelerator="auto")

# Training with CPU Accelerator using 1 process
trainer = Trainer(devices="auto", accelerator="cpu")

# Training with TPU Accelerator using 8 tpu cores
trainer = Trainer(devices="auto", accelerator="tpu")

# Training with IPU Accelerator using 4 ipus
trainer = Trainer(devices="auto", accelerator="ipu")

Note

If the devices flag is not defined, it will assume devices to be "auto" and fetch the auto_device_count from the accelerator.

# This is part of the built-in `CUDAAccelerator`
class CUDAAccelerator(Accelerator):
    """Accelerator for GPU devices."""

    @staticmethod
    def auto_device_count() -> int:
        """Get the devices when set to auto."""
        return torch.cuda.device_count()


# Training with GPU Accelerator using total number of gpus available on the system
Trainer(accelerator="gpu")
enable_checkpointing

By default Lightning saves a checkpoint for you in your current working directory, with the state of your last training epoch, Checkpoints capture the exact value of all parameters used by a model. To disable automatic checkpointing, set this to False.

# default used by Trainer, saves the most recent model to a single checkpoint after each epoch
trainer = Trainer(enable_checkpointing=True)

# turn off automatic checkpointing
trainer = Trainer(enable_checkpointing=False)

You can override the default behavior by initializing the ModelCheckpoint callback, and adding it to the callbacks list. See Saving and Loading Checkpoints for how to customize checkpointing.

from pytorch_lightning.callbacks import ModelCheckpoint

# Init ModelCheckpoint callback, monitoring 'val_loss'
checkpoint_callback = ModelCheckpoint(monitor="val_loss")

# Add your callback to the callbacks list
trainer = Trainer(callbacks=[checkpoint_callback])
fast_dev_run

Runs n if set to n (int) else 1 if set to True batch(es) to ensure your code will execute without errors. This applies to fitting, validating, testing, and predicting. This flag is only recommended for debugging purposes and should not be used to limit the number of batches to run.

# default used by the Trainer
trainer = Trainer(fast_dev_run=False)

# runs only 1 training and 1 validation batch and the program ends
trainer = Trainer(fast_dev_run=True)
trainer.fit(...)

# runs 7 predict batches and program ends
trainer = Trainer(fast_dev_run=7)
trainer.predict(...)

This argument is different from limit_{train,val,test,predict}_batches because side effects are avoided to reduce the impact to subsequent runs. These are the changes enabled:

  • Sets Trainer(max_epochs=1).

  • Sets Trainer(max_steps=...) to 1 or the number passed.

  • Sets Trainer(num_sanity_val_steps=0).

  • Sets Trainer(val_check_interval=1.0).

  • Sets Trainer(check_every_n_epoch=1).

  • Disables all loggers.

  • Disables passing logged metrics to loggers.

  • The ModelCheckpoint callbacks will not trigger.

  • The EarlyStopping callbacks will not trigger.

  • Sets limit_{train,val,test,predict}_batches to 1 or the number passed.

  • Disables the Tuner.

  • If using the CLI, the configuration file is not saved.

gpus

Warning

gpus=x has been deprecated in v1.7 and will be removed in v2.0. Please use accelerator='gpu' and devices=x instead.


  • Number of GPUs to train on (int)

  • or which GPUs to train on (list)

  • can handle strings

# default used by the Trainer (ie: train on CPU)
trainer = Trainer(gpus=None)

# equivalent
trainer = Trainer(gpus=0)

Example:

# int: train on 2 gpus
trainer = Trainer(gpus=2)

# list: train on GPUs 1, 4 (by bus ordering)
trainer = Trainer(gpus=[1, 4])
trainer = Trainer(gpus='1, 4') # equivalent

# -1: train on all gpus
trainer = Trainer(gpus=-1)
trainer = Trainer(gpus='-1') # equivalent

# combine with num_nodes to train on multiple GPUs across nodes
# uses 8 gpus in total
trainer = Trainer(gpus=2, num_nodes=4)

# train only on GPUs 1 and 4 across nodes
trainer = Trainer(gpus=[1, 4], num_nodes=4)
See Also:
gradient_clip_val

Gradient clipping value

  • 0 means don’t clip.

# default used by the Trainer
trainer = Trainer(gradient_clip_val=0.0)
limit_train_batches

How much of training dataset to check. Useful when debugging or testing something that happens at the end of an epoch.

# default used by the Trainer
trainer = Trainer(limit_train_batches=1.0)

Example:

# default used by the Trainer
trainer = Trainer(limit_train_batches=1.0)

# run through only 25% of the training set each epoch
trainer = Trainer(limit_train_batches=0.25)

# run through only 10 batches of the training set each epoch
trainer = Trainer(limit_train_batches=10)
limit_test_batches

How much of test dataset to check.

# default used by the Trainer
trainer = Trainer(limit_test_batches=1.0)

# run through only 25% of the test set each epoch
trainer = Trainer(limit_test_batches=0.25)

# run for only 10 batches
trainer = Trainer(limit_test_batches=10)

In the case of multiple test dataloaders, the limit applies to each dataloader individually.

limit_val_batches

How much of validation dataset to check. Useful when debugging or testing something that happens at the end of an epoch.

# default used by the Trainer
trainer = Trainer(limit_val_batches=1.0)

# run through only 25% of the validation set each epoch
trainer = Trainer(limit_val_batches=0.25)

# run for only 10 batches
trainer = Trainer(limit_val_batches=10)

# disable validation
trainer = Trainer(limit_val_batches=0)

In the case of multiple validation dataloaders, the limit applies to each dataloader individually.

log_every_n_steps

How often to add logging rows (does not write to disk)

# default used by the Trainer
trainer = Trainer(log_every_n_steps=50)
See Also:
logger

Logger (or iterable collection of loggers) for experiment tracking. A True value uses the default TensorBoardLogger shown below. False will disable logging.

from pytorch_lightning.loggers import TensorBoardLogger

# default logger used by trainer (if tensorboard is installed)
logger = TensorBoardLogger(save_dir=os.getcwd(), version=1, name="lightning_logs")
Trainer(logger=logger)
max_epochs

Stop training once this number of epochs is reached

# default used by the Trainer
trainer = Trainer(max_epochs=1000)

If both max_epochs and max_steps aren’t specified, max_epochs will default to 1000. To enable infinite training, set max_epochs = -1.

min_epochs

Force training for at least these many epochs

# default used by the Trainer
trainer = Trainer(min_epochs=1)
max_steps

Stop training after this number of global steps. Training will stop if max_steps or max_epochs have reached (earliest).

# Default (disabled)
trainer = Trainer(max_steps=-1)

# Stop after 100 steps
trainer = Trainer(max_steps=100)

If max_steps is not specified, max_epochs will be used instead (and max_epochs defaults to 1000 if max_epochs is not specified). To disable this default, set max_steps = -1.

min_steps

Force training for at least this number of global steps. Trainer will train model for at least min_steps or min_epochs (latest).

# Default (disabled)
trainer = Trainer(min_steps=None)

# Run at least for 100 steps (disable min_epochs)
trainer = Trainer(min_steps=100, min_epochs=0)
max_time

Set the maximum amount of time for training. Training will get interrupted mid-epoch. For customizable options use the Timer callback.

# Default (disabled)
trainer = Trainer(max_time=None)

# Stop after 12 hours of training or when reaching 10 epochs (string)
trainer = Trainer(max_time="00:12:00:00", max_epochs=10)

# Stop after 1 day and 5 hours (dict)
trainer = Trainer(max_time={"days": 1, "hours": 5})

In case max_time is used together with min_steps or min_epochs, the min_* requirement always has precedence.

num_nodes

Number of GPU nodes for distributed training.

# default used by the Trainer
trainer = Trainer(num_nodes=1)

# to train on 8 nodes
trainer = Trainer(num_nodes=8)
num_processes

Warning

num_processes=x has been deprecated in v1.7 and will be removed in v2.0. Please use accelerator='cpu' and devices=x instead.


Number of processes to train with. Automatically set to the number of GPUs when using strategy="ddp". Set to a number greater than 1 when using accelerator="cpu" and strategy="ddp" to mimic distributed training on a machine without GPUs. This is useful for debugging, but will not provide any speedup, since single-process Torch already makes efficient use of multiple CPUs. While it would typically spawns subprocesses for training, setting num_nodes > 1 and keeping num_processes = 1 runs training in the main process.

# Simulate DDP for debugging on your GPU-less laptop
trainer = Trainer(accelerator="cpu", strategy="ddp", num_processes=2)
num_sanity_val_steps

Sanity check runs n batches of val before starting the training routine. This catches any bugs in your validation without having to wait for the first validation check. The Trainer uses 2 steps by default. Turn it off or modify it here.

# default used by the Trainer
trainer = Trainer(num_sanity_val_steps=2)

# turn it off
trainer = Trainer(num_sanity_val_steps=0)

# check all validation data
trainer = Trainer(num_sanity_val_steps=-1)

This option will reset the validation dataloader unless num_sanity_val_steps=0.

overfit_batches

Uses this much data of the training & validation set. If the training & validation dataloaders have shuffle=True, Lightning will automatically disable it.

Useful for quickly debugging or trying to overfit on purpose.

# default used by the Trainer
trainer = Trainer(overfit_batches=0.0)

# use only 1% of the train & val set
trainer = Trainer(overfit_batches=0.01)

# overfit on 10 of the same batches
trainer = Trainer(overfit_batches=10)
plugins

Plugins allow you to connect arbitrary backends, precision libraries, clusters etc. For example:

To define your own behavior, subclass the relevant class and pass it in. Here’s an example linking up your own ClusterEnvironment.

from pytorch_lightning.plugins.environments import ClusterEnvironment


class MyCluster(ClusterEnvironment):
    def main_address(self):
        return your_main_address

    def main_port(self):
        return your_main_port

    def world_size(self):
        return the_world_size


trainer = Trainer(plugins=[MyCluster()], ...)
precision

Lightning supports either double (64), float (32), bfloat16 (bf16), or half (16) precision training.

Half precision, or mixed precision, is the combined use of 32 and 16 bit floating points to reduce memory footprint during model training. This can result in improved performance, achieving +3X speedups on modern GPUs.

# default used by the Trainer
trainer = Trainer(precision=32)

# 16-bit precision
trainer = Trainer(precision=16, accelerator="gpu", devices=1)  # works only on CUDA

# bfloat16 precision
trainer = Trainer(precision="bf16")

# 64-bit precision
trainer = Trainer(precision=64)

Note

When running on TPUs, torch.bfloat16 will be used but tensor printing will still show torch.float32.

profiler

To profile individual steps during training and assist in identifying bottlenecks.

See the profiler documentation. for more details.

from pytorch_lightning.profilers import SimpleProfiler, AdvancedProfiler

# default used by the Trainer
trainer = Trainer(profiler=None)

# to profile standard training events, equivalent to `profiler=SimpleProfiler()`
trainer = Trainer(profiler="simple")

# advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler()`
trainer = Trainer(profiler="advanced")
enable_progress_bar

Whether to enable or disable the progress bar. Defaults to True.

# default used by the Trainer
trainer = Trainer(enable_progress_bar=True)

# disable progress bar
trainer = Trainer(enable_progress_bar=False)
reload_dataloaders_every_n_epochs

Set to a positive integer to reload dataloaders every n epochs from your currently used data source. DataSource can be a LightningModule or a LightningDataModule.

# if 0 (default)
train_loader = model.train_dataloader()
# or if using data module: datamodule.train_dataloader()
for epoch in epochs:
    for batch in train_loader:
        ...

# if a positive integer
for epoch in epochs:
    if not epoch % reload_dataloaders_every_n_epochs:
        train_loader = model.train_dataloader()
        # or if using data module: datamodule.train_dataloader()
    for batch in train_loader:
        ...

The pseudocode applies also to the val_dataloader.

replace_sampler_ddp

Enables auto adding of DistributedSampler. In PyTorch, you must use it in distributed settings such as TPUs or multi-node. The sampler makes sure each GPU sees the appropriate part of your data. By default it will add shuffle=True for train sampler and shuffle=False for val/test sampler. If you already use a custom sampler, Lightning will wrap it in a way that it samples from your sampler in a distributed manner. If you want to customize it, you can set replace_sampler_ddp=False and add your own distributed sampler. If replace_sampler_ddp=True and a distributed sampler was already added, Lightning will not replace the existing one.

# default used by the Trainer
trainer = Trainer(replace_sampler_ddp=True)

By setting to False, you have to add your own distributed sampler:

# in your LightningModule or LightningDataModule
def train_dataloader(self):
    # default used by the Trainer
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True)
    dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
    return dataloader

Note

For iterable datasets, we don’t do this automatically.

resume_from_checkpoint

Warning

resume_from_checkpoint is deprecated in v1.5 and will be removed in v2.0. Please pass trainer.fit(ckpt_path="some/path/to/my_checkpoint.ckpt") instead.


To resume training from a specific checkpoint pass in the path here. If resuming from a mid-epoch checkpoint, training will start from the beginning of the next epoch.

# default used by the Trainer
trainer = Trainer(resume_from_checkpoint=None)

# resume from a specific checkpoint
trainer = Trainer(resume_from_checkpoint="some/path/to/my_checkpoint.ckpt")
strategy

Supports passing different training strategies with aliases (ddp, ddp_spawn, etc) as well as custom strategies.

# Training with the DistributedDataParallel strategy on 4 GPUs
trainer = Trainer(strategy="ddp", accelerator="gpu", devices=4)

# Training with the DDP Spawn strategy using 4 cpu processes
trainer = Trainer(strategy="ddp_spawn", accelerator="cpu", devices=4)

Note

Additionally, you can pass your custom strategy to the strategy argument.

from pytorch_lightning.strategies import DDPStrategy


class CustomDDPStrategy(DDPStrategy):
    def configure_ddp(self):
        self._model = MyCustomDistributedDataParallel(
            self.model,
            device_ids=...,
        )


trainer = Trainer(strategy=CustomDDPStrategy(), accelerator="gpu", devices=2)
See Also:
sync_batchnorm

Enable synchronization between batchnorm layers across all GPUs.

trainer = Trainer(sync_batchnorm=True)
track_grad_norm

  • no tracking (-1)

  • Otherwise tracks that norm (2 for 2-norm)

# default used by the Trainer
trainer = Trainer(track_grad_norm=-1)

# track the 2-norm
trainer = Trainer(track_grad_norm=2)
tpu_cores

Warning

tpu_cores=x has been deprecated in v1.7 and will be removed in v2.0. Please use accelerator='tpu' and devices=x instead.


  • How many TPU cores to train on (1 or 8).

  • Which TPU core to train on [1-8]

A single TPU v2 or v3 has 8 cores. A TPU pod has up to 2048 cores. A slice of a POD means you get as many cores as you request.

Your effective batch size is batch_size * total tpu cores.

This parameter can be either 1 or 8.

Example:

# your_trainer_file.py

# default used by the Trainer (ie: train on CPU)
trainer = Trainer(tpu_cores=None)

# int: train on a single core
trainer = Trainer(tpu_cores=1)

# list: train on a single selected core
trainer = Trainer(tpu_cores=[2])

# int: train on all cores few cores
trainer = Trainer(tpu_cores=8)

# for 8+ cores must submit via xla script with
# a max of 8 cores specified. The XLA script
# will duplicate script onto each TPU in the POD
trainer = Trainer(tpu_cores=8)

To train on more than 8 cores (ie: a POD), submit this script using the xla_dist script.

Example:

python -m torch_xla.distributed.xla_dist
--tpu=$TPU_POD_NAME
--conda-env=torch-xla-nightly
--env=XLA_USE_BF16=1
-- python your_trainer_file.py
val_check_interval

How often within one training epoch to check the validation set. Can specify as float or int.

  • pass a float in the range [0.0, 1.0] to check after a fraction of the training epoch.

  • pass an int to check after a fixed number of training batches. An int value can only be higher than the number of training batches when check_val_every_n_epoch=None, which validates after every N training batches across epochs or iteration-based training.

# default used by the Trainer
trainer = Trainer(val_check_interval=1.0)

# check validation set 4 times during a training epoch
trainer = Trainer(val_check_interval=0.25)

# check validation set every 1000 training batches in the current epoch
trainer = Trainer(val_check_interval=1000)

# check validation set every 1000 training batches across complete epochs or during iteration-based training
# use this when using iterableDataset and your dataset has no length
# (ie: production cases with streaming data)
trainer = Trainer(val_check_interval=1000, check_val_every_n_epoch=None)
# Here is the computation to estimate the total number of batches seen within an epoch.

# Find the total number of train batches
total_train_batches = total_train_samples // (train_batch_size * world_size)

# Compute how many times we will call validation during the training loop
val_check_batch = max(1, int(total_train_batches * val_check_interval))
val_checks_per_epoch = total_train_batches / val_check_batch

# Find the total number of validation batches
total_val_batches = total_val_samples // (val_batch_size * world_size)

# Total number of batches run
total_fit_batches = total_train_batches + total_val_batches
enable_model_summary

Whether to enable or disable the model summarization. Defaults to True.

# default used by the Trainer
trainer = Trainer(enable_model_summary=True)

# disable summarization
trainer = Trainer(enable_model_summary=False)

# enable custom summarization
from pytorch_lightning.callbacks import ModelSummary

trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])
inference_mode

Whether to use torch.inference_mode() or torch.no_grad() mode during evaluation (validate/test/predict)

# default used by the Trainer
trainer = Trainer(inference_mode=True)

# Use `torch.no_grad` instead
trainer = Trainer(inference_mode=False)

With torch.inference_mode() disabled, you can enable the grad of your model layers if required.

class LitModel(LightningModule):
    def validation_step(self, batch, batch_idx):
        preds = self.layer1(batch)
        with torch.enable_grad():
            grad_preds = preds.requires_grad_()
            preds2 = self.layer2(grad_preds)


model = LitModel()
trainer = Trainer(inference_mode=False)
trainer.validate(model)

Trainer class API

Methods
init
Trainer.__init__(logger=True, enable_checkpointing=True, callbacks=None, default_root_dir=None, gradient_clip_val=None, gradient_clip_algorithm=None, num_nodes=1, num_processes=None, devices=None, gpus=None, auto_select_gpus=None, tpu_cores=None, ipus=None, enable_progress_bar=True, overfit_batches=0.0, track_grad_norm=- 1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=None, max_epochs=None, min_epochs=None, max_steps=- 1, min_steps=None, max_time=None, limit_train_batches=None, limit_val_batches=None, limit_test_batches=None, limit_predict_batches=None, val_check_interval=None, log_every_n_steps=50, accelerator=None, strategy=None, sync_batchnorm=False, precision=32, enable_model_summary=True, num_sanity_val_steps=2, resume_from_checkpoint=None, profiler=None, benchmark=None, deterministic=None, reload_dataloaders_every_n_epochs=0, auto_lr_find=False, replace_sampler_ddp=True, detect_anomaly=False, auto_scale_batch_size=False, plugins=None, amp_backend=None, amp_level=None, move_metrics_to_cpu=False, multiple_trainloader_mode='max_size_cycle', inference_mode=True)[source]

Customize every aspect of training via flags.

Parameters
  • accelerator (Union[str, Accelerator, None]) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps”, “auto”) as well as custom accelerator instances.

  • accumulate_grad_batches (Union[int, Dict[int, int], None]) – Accumulates grads every k batches or as set up in the dict. Default: None.

  • amp_backend (Optional[str]) –

    The mixed precision backend to use (“native” or “apex”). Default: 'native''.

    Deprecated since version v1.9: Setting amp_backend inside the Trainer is deprecated in v1.8.0 and will be removed in v2.0.0. This argument was only relevant for apex which is being removed.

  • amp_level (Optional[str]) –

    The optimization level to use (O1, O2, etc…). By default it will be set to “O2” if amp_backend is set to “apex”.

    Deprecated since version v1.8: Setting amp_level inside the Trainer is deprecated in v1.8.0 and will be removed in v2.0.0.

  • auto_lr_find (Union[bool, str]) – If set to True, will make trainer.tune() run a learning rate finder, trying to optimize initial learning for faster convergence. trainer.tune() method will set the suggested learning rate in self.lr or self.learning_rate in the LightningModule. To use a different key set a string instead of True with the key name. Default: False.

  • auto_scale_batch_size (Union[str, bool]) – If set to True, will initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule or LightningDataModule depending on your setup. Additionally, can be set to either power that estimates the batch size through a power search or binsearch that estimates the batch size through a binary search. Default: False.

  • auto_select_gpus (Optional[bool]) –

    If enabled and gpus or devices is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in “exclusive mode”, such that only one process at a time can access them. Default: False.

    Deprecated since version v1.9: auto_select_gpus has been deprecated in v1.9.0 and will be removed in v2.0.0. Please use the function find_usable_cuda_devices() instead.

  • benchmark (Optional[bool]) – The value (True or False) to set torch.backends.cudnn.benchmark to. The value for torch.backends.cudnn.benchmark set in the current session will be used (False if not manually set). If deterministic is set to True, this will default to False. Override to manually set a different value. Default: None.

  • callbacks (Union[List[Callback], Callback, None]) – Add a callback or list of callbacks. Default: None.

  • enable_checkpointing (bool) – If True, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in callbacks. Default: True.

  • check_val_every_n_epoch (Optional[int]) – Perform a validation loop every after every N training epochs. If None, validation will be done solely based on the number of training batches, requiring val_check_interval to be an integer value. Default: 1.

  • default_root_dir (Union[str, Path, None]) – Default path for logs and weights when no logger/ckpt_callback passed. Default: os.getcwd(). Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’

  • detect_anomaly (bool) – Enable anomaly detection for the autograd engine. Default: False.

  • deterministic (Union[bool, Literal[‘warn’], None]) – If True, sets whether PyTorch operations must use deterministic algorithms. Set to "warn" to use deterministic algorithms whenever possible, throwing warnings on operations that don’t support deterministic mode (requires PyTorch 1.11+). If not set, defaults to False. Default: None.

  • devices (Union[List[int], str, int, None]) – Will be mapped to either gpus, tpu_cores, num_processes or ipus, based on the accelerator type.

  • fast_dev_run (Union[int, bool]) – Runs n if set to n (int) else 1 if set to True batch(es) of train, val and test to find any bugs (ie: a sort of unit test). Default: False.

  • gpus (Union[List[int], str, int, None]) –

    Number of GPUs to train on (int) or which GPUs to train on (list or str) applied per node Default: None.

    Deprecated since version v1.7: gpus has been deprecated in v1.7 and will be removed in v2.0. Please use accelerator='gpu' and devices=x instead.

  • gradient_clip_val (Union[int, float, None]) – The value at which to clip gradients. Passing gradient_clip_val=None disables gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before. Default: None.

  • gradient_clip_algorithm (Optional[str]) – The gradient clipping algorithm to use. Pass gradient_clip_algorithm="value" to clip by value, and gradient_clip_algorithm="norm" to clip by norm. By default it will be set to "norm".

  • limit_train_batches (Union[int, float, None]) – How much of training dataset to check (float = fraction, int = num_batches). Default: 1.0.

  • limit_val_batches (Union[int, float, None]) – How much of validation dataset to check (float = fraction, int = num_batches). Default: 1.0.

  • limit_test_batches (Union[int, float, None]) – How much of test dataset to check (float = fraction, int = num_batches). Default: 1.0.

  • limit_predict_batches (Union[int, float, None]) – How much of prediction dataset to check (float = fraction, int = num_batches). Default: 1.0.

  • logger (Union[Logger, Iterable[Logger], bool]) – Logger (or iterable collection of loggers) for experiment tracking. A True value uses the default TensorBoardLogger if it is installed, otherwise CSVLogger. False will disable logging. If multiple loggers are provided, local files (checkpoints, profiler traces, etc.) are saved in the log_dir of he first logger. Default: True.

  • log_every_n_steps (int) – How often to log within steps. Default: 50.

  • enable_progress_bar (bool) – Whether to enable to progress bar by default. Default: True.

  • profiler (Union[Profiler, str, None]) – To profile individual steps during training and assist in identifying bottlenecks. Default: None.

  • overfit_batches (Union[int, float]) – Overfit a fraction of training/validation data (float) or a set number of batches (int). Default: 0.0.

  • plugins (Union[PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync, str, List[Union[PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync, str]], None]) – Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. Default: None.

  • precision (Union[Literal[64, 32, 16], Literal[‘64’, ‘32’, ‘16’, ‘bf16’]]) – Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16). Can be used on CPU, GPU, TPUs, HPUs or IPUs. Default: 32.

  • max_epochs (Optional[int]) – Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to max_epochs = 1000. To enable infinite training, set max_epochs = -1.

  • min_epochs (Optional[int]) – Force training for at least these many epochs. Disabled by default (None).

  • max_steps (int) – Stop training after this number of steps. Disabled by default (-1). If max_steps = -1 and max_epochs = None, will default to max_epochs = 1000. To enable infinite training, set max_epochs to -1.

  • min_steps (Optional[int]) – Force training for at least these number of steps. Disabled by default (None).

  • max_time (Union[str, timedelta, Dict[str, int], None]) – Stop training after this amount of time has passed. Disabled by default (None). The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a datetime.timedelta, or a dictionary with keys that will be passed to datetime.timedelta.

  • num_nodes (int) – Number of GPU nodes for distributed training. Default: 1.

  • num_processes (Optional[int]) –

    Number of processes for distributed training with accelerator="cpu". Default: 1.

    Deprecated since version v1.7: num_processes has been deprecated in v1.7 and will be removed in v2.0. Please use accelerator='cpu' and devices=x instead.

  • num_sanity_val_steps (int) – Sanity check runs n validation batches before starting the training routine. Set it to -1 to run all batches in all validation dataloaders. Default: 2.

  • reload_dataloaders_every_n_epochs (int) – Set to a non-negative integer to reload dataloaders every n epochs. Default: 0.

  • replace_sampler_ddp (bool) – Explicitly enables or disables sampler replacement. If not specified this will toggled automatically when DDP is used. By default it will add shuffle=True for train sampler and shuffle=False for val/test sampler. If you want to customize it, you can set replace_sampler_ddp=False and add your own distributed sampler.

  • resume_from_checkpoint (Union[str, Path, None]) –

    Path/URL of the checkpoint from which training is resumed. If there is no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch.

    Deprecated since version v1.5: resume_from_checkpoint is deprecated in v1.5 and will be removed in v2.0. Please pass the path to Trainer.fit(..., ckpt_path=...) instead.

  • strategy (Union[str, Strategy, None]) – Supports different training strategies with aliases as well custom strategies. Default: None.

  • sync_batchnorm (bool) – Synchronize batch norm layers between process groups/whole world. Default: False.

  • tpu_cores (Union[List[int], str, int, None]) –

    How many TPU cores to train on (1 or 8) / Single TPU to train on (1) Default: None.

    Deprecated since version v1.7: tpu_cores has been deprecated in v1.7 and will be removed in v2.0. Please use accelerator='tpu' and devices=x instead.

  • ipus (Optional[int]) –

    How many IPUs to train on. Default: None.

    Deprecated since version v1.7: ipus has been deprecated in v1.7 and will be removed in v2.0. Please use accelerator='ipu' and devices=x instead.

  • track_grad_norm (Union[int, float, str]) – -1 no tracking. Otherwise tracks that p-norm. May be set to ‘inf’ infinity-norm. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before logging them. Default: -1.

  • val_check_interval (Union[int, float, None]) – How often to check the validation set. Pass a float in the range [0.0, 1.0] to check after a fraction of the training epoch. Pass an int to check after a fixed number of training batches. An int value can only be higher than the number of training batches when check_val_every_n_epoch=None, which validates after every N training batches across epochs or during iteration-based training. Default: 1.0.

  • enable_model_summary (bool) – Whether to enable model summarization by default. Default: True.

  • move_metrics_to_cpu (bool) – Whether to force internal logged metrics to be moved to cpu. This can save some gpu memory, but can make training slower. Use with attention. Default: False.

  • multiple_trainloader_mode (str) – How to loop over the datasets when there are multiple train loaders. In ‘max_size_cycle’ mode, the trainer ends one epoch when the largest dataset is traversed, and smaller datasets reload when running out of their data. In ‘min_size’ mode, all the datasets reload when reaching the minimum length of datasets. Default: "max_size_cycle".

  • inference_mode (bool) – Whether to use torch.inference_mode() or torch.no_grad() during evaluation (validate/test/predict).

fit
Trainer.fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None, ckpt_path=None)[source]

Runs the full optimization routine.

Parameters
Return type

None

validate
Trainer.validate(model=None, dataloaders=None, ckpt_path=None, verbose=True, datamodule=None)[source]

Perform one evaluation epoch over the validation set.

Parameters
Return type

List[Dict[str, float]]

Returns

List of dictionaries with metrics logged during the validation phase, e.g., in model- or callback hooks like validation_step(), validation_epoch_end(), etc. The length of the list corresponds to the number of validation dataloaders used.

test
Trainer.test(model=None, dataloaders=None, ckpt_path=None, verbose=True, datamodule=None)[source]

Perform one evaluation epoch over the test set. It’s separated from fit to make sure you never run on your test set until you want to.

Parameters
Return type

List[Dict[str, float]]

Returns

List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks like test_step(), test_epoch_end(), etc. The length of the list corresponds to the number of test dataloaders used.

predict
Trainer.predict(model=None, dataloaders=None, datamodule=None, return_predictions=None, ckpt_path=None)[source]

Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the predict hooks.

Parameters
Return type

Union[List[Any], List[List[Any]], None]

Returns

Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.

See Lightning inference section for more.

tune
Trainer.tune(model, train_dataloaders=None, val_dataloaders=None, dataloaders=None, datamodule=None, scale_batch_size_kwargs=None, lr_find_kwargs=None, method='fit')[source]

Runs routines to tune hyperparameters before training.

Parameters
Return type

_TunerResult

Properties
callback_metrics

The metrics available to callbacks. These are automatically set when you log via self.log

def training_step(self, batch, batch_idx):
    self.log("a_val", 2)


callback_metrics = trainer.callback_metrics
assert callback_metrics["a_val"] == 2
current_epoch

The number of epochs run.

if trainer.current_epoch >= 10:
    ...
datamodule

The current datamodule, which is used by the trainer.

used_datamodule = trainer.datamodule
is_last_batch

Whether trainer is executing last batch in the current epoch.

if trainer.is_last_batch:
    ...
global_step

The number of optimizer steps taken (does not reset each epoch). This includes multiple optimizers and TBPTT steps (if enabled).

if trainer.global_step >= 100:
    ...
logger

The current logger being used. Here’s an example using tensorboard

logger = trainer.logger
tensorboard = logger.experiment
loggers

The list of loggers currently being used by the Trainer.

# List of Logger objects
loggers = trainer.loggers
for logger in loggers:
    logger.log_metrics({"foo": 1.0})
logged_metrics

The metrics sent to the logger (visualizer).

def training_step(self, batch, batch_idx):
    self.log("a_val", 2, logger=True)


logged_metrics = trainer.logged_metrics
assert logged_metrics["a_val"] == 2
log_dir

The directory for the current experiment. Use this to save images to, etc…

def training_step(self, batch, batch_idx):
    img = ...
    save_img(img, self.trainer.log_dir)
is_global_zero

Whether this process is the global zero in multi-node training

def training_step(self, batch, batch_idx):
    if self.trainer.is_global_zero:
        print("in node 0, accelerator 0")
progress_bar_metrics

The metrics sent to the progress bar.

def training_step(self, batch, batch_idx):
    self.log("a_val", 2, prog_bar=True)


progress_bar_metrics = trainer.progress_bar_metrics
assert progress_bar_metrics["a_val"] == 2
predict_dataloaders

The current predict dataloaders of the trainer. Note that property returns a list of predict dataloaders.

used_predict_dataloaders = trainer.predict_dataloaders
estimated_stepping_batches

Check out estimated_stepping_batches().

state

The current state of the Trainer, including the current function that is running, the stage of execution within that function, and the status of the Trainer.

# fn in ("fit", "validate", "test", "predict", "tune")
trainer.state.fn
# status in ("initializing", "running", "finished", "interrupted")
trainer.state.status
# stage in ("train", "sanity_check", "validate", "test", "predict", "tune")
trainer.state.stage
should_stop

If you want to terminate the training during .fit, you can set trainer.should_stop=True to terminate the training as soon as possible. Note that, it will respect the arguments min_steps and min_epochs to check whether to stop. If these arguments are set and the current_epoch or global_step don’t meet these minimum conditions, training will continue until both conditions are met. If any of these arguments is not set, it won’t be considered for the final decision.

# setting `trainer.should_stop` at any point of training will terminate it
class LitModel(LightningModule):
    def training_step(self, *args, **kwargs):
        self.trainer.should_stop = True


trainer = Trainer()
model = LitModel()
trainer.fit(model)
# setting `trainer.should_stop` will stop training only after at least 5 epochs have run
class LitModel(LightningModule):
    def training_step(self, *args, **kwargs):
        if self.current_epoch == 2:
            self.trainer.should_stop = True


trainer = Trainer(min_epochs=5, max_epochs=100)
model = LitModel()
trainer.fit(model)
# setting `trainer.should_stop` will stop training only after at least 5 steps have run
class LitModel(LightningModule):
    def training_step(self, *args, **kwargs):
        if self.global_step == 2:
            self.trainer.should_stop = True


trainer = Trainer(min_steps=5, max_epochs=100)
model = LitModel()
trainer.fit(model)
# setting `trainer.should_stop` at any until both min_steps and min_epochs are satisfied
class LitModel(LightningModule):
    def training_step(self, *args, **kwargs):
        if self.global_step == 7:
            self.trainer.should_stop = True


trainer = Trainer(min_steps=5, min_epochs=5, max_epochs=100)
model = LitModel()
trainer.fit(model)
train_dataloader

The current train dataloader of the trainer.

used_train_dataloader = trainer.train_dataloader
test_dataloaders

The current test dataloaders of the trainer. Note that property returns a list of test dataloaders.

used_test_dataloaders = trainer.test_dataloaders
val_dataloaders

The current val dataloaders of the trainer. Note that property returns a list of val dataloaders.

used_val_dataloaders = trainer.val_dataloaders

Fabric (Beta)

Fabric is the fast and lightweight way to scale PyTorch models without boilerplate code.

  • Easily switch from running on CPU to GPU (Apple Silicon, CUDA, …), TPU, multi-GPU or even multi-node training

  • State-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed) and mixed precision out of the box

  • Handles all the boilerplate device logic for you

  • Brings useful tools to help you build a trainer (callbacks, logging, checkpoints, …)

  • Designed with multi-billion parameter models in mind


  import torch
  import torch.nn as nn
  from torch.utils.data import DataLoader, Dataset

+ from lightning.fabric import Fabric

  class PyTorchModel(nn.Module):
      ...

  class PyTorchDataset(Dataset):
      ...

+ fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp")
+ fabric.launch()

- device = "cuda" if torch.cuda.is_available() else "cpu
  model = PyTorchModel(...)
  optimizer = torch.optim.SGD(model.parameters())
+ model, optimizer = fabric.setup(model, optimizer)
  dataloader = DataLoader(PyTorchDataset(...), ...)
+ dataloader = fabric.setup_dataloaders(dataloader)
  model.train()

  for epoch in range(num_epochs):
      for batch in dataloader:
          input, target = batch
-         input, target = input.to(device), target.to(device)
          optimizer.zero_grad()
          output = model(input)
          loss = loss_fn(output, target)
-         loss.backward()
+         fabric.backward(loss)
          optimizer.step()
          lr_scheduler.step()

Note

Fabric is currently in Beta. Its API is subject to change based on feedback.


Why Fabric?

Fabric differentiates itself from a fully-fledged trainer like Lightning Trainer in these key aspects:

Fast to implement There is no need to restructure your code: Just change a few lines in the PyTorch script and you’ll be able to leverage Fabric features.

Maximum Flexibility Write your own training and/or inference logic down to the individual optimizer calls. You aren’t forced to conform to a standardized epoch-based training loop like the one in Lightning Trainer. You can do flexible iteration based training, meta-learning, cross-validation and other types of optimization algorithms without digging into framework internals. This also makes it super easy to adopt Fabric in existing PyTorch projects to speed-up and scale your models without the compromise on large refactors. Just remember: With great power comes a great responsibility.

Maximum Control The Lightning Trainer has many built in features to make research simpler with less boilerplate, but debugging it requires some familiarity with the framework internals. In Fabric, everything is opt-in. Think of it as a toolbox: You take out the tools (Fabric functions) you need and leave the other ones behind. This makes it easier to develop and debug your PyTorch code as you gradually add more features to it. Fabric provides important tools to remove undesired boilerplate code (distributed, hardware, checkpoints, logging, …), but leaves the design and orchestration fully up to you.


Fundamentals


Build Your Own Trainer


Advanced Topics


Examples


API

accelerators

Accelerator

The Accelerator base class for Lightning PyTorch.

CPUAccelerator

Accelerator for CPU devices.

CUDAAccelerator

Accelerator for NVIDIA CUDA devices.

HPUAccelerator

Accelerator for HPU devices.

IPUAccelerator

Accelerator for IPUs.

TPUAccelerator

Accelerator for TPU devices.

callbacks

BackboneFinetuning

Finetune a backbone model based on a learning rate user-defined scheduling.

BaseFinetuning

This class implements the base logic for writing your own Finetuning Callback.

BasePredictionWriter

Base class to implement how the predictions should be stored.

BatchSizeFinder

The BatchSizeFinder callback tries to find the largest batch size for a given model that does not give an out of memory (OOM) error.

Callback

Abstract base class used to build new callbacks.

DeviceStatsMonitor

Automatically monitors and logs device stats during training, validation and testing stage.

EarlyStopping

Monitor a metric and stop training when it stops improving.

GradientAccumulationScheduler

Change gradient accumulation factor according to scheduling.

LambdaCallback

Create a simple callback on the fly using lambda functions.

LearningRateFinder

The LearningRateFinder callback enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.

LearningRateMonitor

Automatically monitor and logs learning rate for learning rate schedulers during training.

ModelCheckpoint

Save the model periodically by monitoring a quantity.

ModelPruning

Model pruning Callback, using PyTorch's prune utilities.

ModelSummary

Generates a summary of all layers in a LightningModule.

ProgressBarBase

The base class for progress bars in Lightning.

QuantizationAwareTraining

Quantization allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision.

RichModelSummary

Generates a summary of all layers in a LightningModule with rich text formatting.

RichProgressBar

Create a progress bar with rich text formatting.

StochasticWeightAveraging

Implements the Stochastic Weight Averaging (SWA) Callback to average a model.

Timer

The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the Trainer if the given time limit for the training loop is reached.

TQDMProgressBar

This is the default progress bar used by Lightning.

cli

LightningCLI

Implementation of a configurable command line tool for pytorch-lightning.

LightningArgumentParser

Extension of jsonargparse's ArgumentParser for pytorch-lightning.

SaveConfigCallback

Saves a LightningCLI config to the log_dir when training starts.

core

CheckpointHooks

Hooks to be used with Checkpointing.

DataHooks

Hooks to be used for data related stuff.

ModelHooks

Hooks to be used in LightningModule.

LightningDataModule

A DataModule standardizes the training, val, test splits, data preparation and transforms.

LightningModule

HyperparametersMixin

LightningOptimizer

This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches.

ModelIO

loggers

logger

Abstract base class used to build new loggers.

comet

Comet Logger

csv_logs

CSV logger

mlflow

MLflow Logger

neptune

Neptune Logger

tensorboard

TensorBoard Logger

wandb

Weights and Biases Logger

loops

Base Classes

DataLoaderLoop

Base class to loop over all dataloaders.

Loop

Basic Loops interface.

Training

TrainingBatchLoop

Runs over a single batch of data.

TrainingEpochLoop

Runs over all batches in a dataloader (one epoch).

FitLoop

This Loop iterates over the epochs to run the training.

ManualOptimization

A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens entirely in the training_step() and therefore the user is responsible for back-propagating gradients and making calls to the optimizers.

OptimizerLoop

Runs over a sequence of optimizers.

Validation and Testing

EvaluationEpochLoop

This is the loop performing the evaluation.

EvaluationLoop

Loops over all dataloaders for evaluation.

Prediction

PredictionEpochLoop

Loop performing prediction on arbitrary sequentially used dataloaders.

PredictionLoop

Loop to run over dataloaders for prediction.

plugins

precision

ColossalAIPrecisionPlugin

Precision plugin for ColossalAI integration.

DeepSpeedPrecisionPlugin

Precision plugin for DeepSpeed integration.

DoublePrecisionPlugin

Plugin for training with double (torch.float64) precision.

FullyShardedNativeMixedPrecisionPlugin

Native AMP for Fully Sharded Training.

FullyShardedNativeNativeMixedPrecisionPlugin

Native AMP for Fully Sharded Native Training.

HPUPrecisionPlugin

Plugin that enables bfloat/half support on HPUs.

IPUPrecisionPlugin

Precision plugin for IPU integration.

MixedPrecisionPlugin

Plugin for Automatic Mixed Precision (AMP) training with torch.autocast.

PrecisionPlugin

Base class for all plugins handling the precision-specific parts of the training.

ShardedNativeMixedPrecisionPlugin

Native AMP for Sharded Training.

TPUBf16PrecisionPlugin

Plugin that enables bfloats on TPUs.

TPUPrecisionPlugin

Precision plugin for TPU integration.

environments

ClusterEnvironment

Specification of a cluster environment.

KubeflowEnvironment

Environment for distributed training using the PyTorchJob operator from Kubeflow

LightningEnvironment

The default environment used by Lightning for a single node or free cluster (not managed).

LSFEnvironment

An environment for running on clusters managed by the LSF resource manager.

SLURMEnvironment

Cluster environment for training on a cluster managed by SLURM.

TorchElasticEnvironment

Environment for fault-tolerant and elastic training with torchelastic

XLAEnvironment

Cluster environment for training on a TPU Pod with the PyTorch/XLA library.

io

AsyncCheckpointIO

AsyncCheckpointIO enables saving the checkpoints asynchronously in a thread.

CheckpointIO

Interface to save/load checkpoints as they are saved through the Strategy.

HPUCheckpointIO

CheckpointIO to save checkpoints for HPU training strategies.

TorchCheckpointIO

CheckpointIO that utilizes torch.save() and torch.load() to save and load checkpoints respectively, common for most use cases.

XLACheckpointIO

CheckpointIO that utilizes xm.save() to save checkpoints for TPU training strategies.

others

LayerSync

Abstract base class for creating plugins that wrap layers of a model with synchronization logic for multiprocessing.

NativeSyncBatchNorm

A plugin that wraps all batch normalization layers of a model with synchronization logic for multiprocessing.

profiler

AdvancedProfiler

This profiler uses Python's cProfiler to record more detailed information about time spent in each function call recorded during a given action.

PassThroughProfiler

This class should be used when you don't want the (small) overhead of profiling.

Profiler

If you wish to write a custom profiler, you should inherit from this class.

PyTorchProfiler

This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.

SimpleProfiler

This profiler simply records the duration of actions (in seconds) and reports the mean duration of each action and the total time spent over the entire training run.

XLAProfiler

XLA Profiler will help you debug and optimize training workload performance for your models using Cloud TPU performance tools.

trainer

Trainer

Customize every aspect of training via flags.

strategies

BaguaStrategy

Strategy for training using the Bagua library, with advanced distributed training algorithms and system optimizations.

ColossalAIStrategy

ColossalAI strategy.

DDPFullyShardedNativeStrategy

Strategy for Fully Sharded Data Parallel provided by torch.distributed.

DDPFullyShardedStrategy

Plugin for Fully Sharded Data Parallel provided by FairScale.

DDPShardedStrategy

Optimizer and gradient sharded training provided by FairScale.

DDPSpawnShardedStrategy

Optimizer sharded training provided by FairScale.

DDPSpawnStrategy

Spawns processes using the torch.multiprocessing.spawn() method and joins processes after training finishes.

DDPStrategy

Strategy for multi-process single-device training on one or multiple nodes.

DataParallelStrategy

Implements data-parallel training in a single process, i.e., the model gets replicated to each device and each gets a split of the data.

DeepSpeedStrategy

Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models.

HivemindStrategy

Provides capabilities to train using the Hivemind Library, training collaboratively across the internet with unreliable machines.

HPUParallelStrategy

Strategy for distributed training on multiple HPU devices.

IPUStrategy

Plugin for training on IPU devices.

ParallelStrategy

Plugin for training with multiple processes in parallel.

SingleDeviceStrategy

Strategy that handles communication on a single device.

SingleHPUStrategy

Strategy for training on single HPU device.

SingleTPUStrategy

Strategy for training on a single TPU device.

Strategy

Base class for all strategies that change the behaviour of the training, validation and test- loop.

TPUSpawnStrategy

Strategy for training multiple TPU devices using the torch_xla.distributed.xla_multiprocessing.spawn() method.

tuner

Tuner

Tuner class to tune your model.

utilities

apply_func

Utilities used for collections.

argparse

Utilities for Argument Parsing within Lightning Components.

cloud_io

Utilities related to data saving/loading.

deepspeed

Utilities that can be used with Deepspeed.

distributed

Utilities that can be used with distributed training.

finite_checks

Helper functions to detect NaN/Inf values.

memory

Utilities related to memory.

model_summary

optimizer

parsing

Utilities used for parameter parsing.

rank_zero

Utilities that can be used for calling functions on a particular rank.

seed

Utilities to help with reproducibility of models.

warnings

Warning-related utilities.

Add validation and test datasets

Build a Model

Configure hyperparameters from the CLI

Why use a CLI

When running deep learning experiments, there are a couple of good practices that are recommended to follow:

  • Separate configuration from source code

  • Guarantee reproducibility of experiments

Implementing a command line interface (CLI) makes it possible to execute an experiment from a shell terminal. By having a CLI, there is a clear separation between the Python source code and what hyperparameters are used for a particular experiment. If the CLI corresponds to a stable version of the code, reproducing an experiment can be achieved by installing the same version of the code plus dependencies and running with the same configuration (CLI arguments).


Basic use


Advanced use


Miscellaneous

Customize the progress bar

Lightning supports two different types of progress bars (tqdm and rich). TQDMProgressBar is used by default, but you can override it by passing a custom TQDMProgressBar or RichProgressBar to the callbacks argument of the Trainer.

You could also use the ProgressBarBase class to implement your own progress bar.


TQDMProgressBar

The TQDMProgressBar uses the tqdm library internally and is the default progress bar used by Lightning. It prints to stdout and shows up to four different bars:

  • sanity check progress: the progress during the sanity check run

  • main progress: shows training + validation progress combined. It also accounts for multiple validation runs during training when val_check_interval is used.

  • validation progress: only visible during validation; shows total progress over all validation datasets.

  • test progress: only active when testing; shows total progress over all test datasets.

For infinite datasets, the progress bar never ends.

You can update refresh_rate (rate (number of batches) at which the progress bar get updated) for TQDMProgressBar by:

from pytorch_lightning.callbacks import TQDMProgressBar

trainer = Trainer(callbacks=[TQDMProgressBar(refresh_rate=10)])

If you want to customize the default TQDMProgressBar used by Lightning, you can override specific methods of the callback class and pass your custom implementation to the Trainer.

class LitProgressBar(TQDMProgressBar):
    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        bar.set_description("running validation...")
        return bar


trainer = Trainer(callbacks=[LitProgressBar()])

See also


RichProgressBar

Rich is a Python library for rich text and beautiful formatting in the terminal. To use the RichProgressBar as your progress bar, first install the package:

pip install rich

Then configure the callback and pass it to the Trainer:

from pytorch_lightning.callbacks import RichProgressBar

trainer = Trainer(callbacks=[RichProgressBar()])

Customize the theme for your RichProgressBar like this:

from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme

# create your own theme!
progress_bar = RichProgressBar(
    theme=RichProgressBarTheme(
        description="green_yellow",
        progress_bar="green1",
        progress_bar_finished="green1",
        progress_bar_pulse="#6206E0",
        batch_progress="green_yellow",
        time="grey82",
        processing_speed="grey82",
        metrics="grey82",
    )
)

trainer = Trainer(callbacks=progress_bar)

You can customize the components used within RichProgressBar with ease by overriding the configure_columns() method.

from rich.progress import TextColumn

custom_column = TextColumn("[progress.description]Custom Rich Progress Bar!")


class CustomRichProgressBar(RichProgressBar):
    def configure_columns(self, trainer):
        return [custom_column]


progress_bar = CustomRichProgressBar()

If you wish for a new progress bar to be displayed at the end of every epoch, you should enable RichProgressBar.leave by passing True

from pytorch_lightning.callbacks import RichProgressBar

trainer = Trainer(callbacks=[RichProgressBar(leave=True)])

See also

Note

Progress bar is automatically enabled with the Trainer, and to disable it, one should do this:

trainer = Trainer(enable_progress_bar=False)

Deploy models into production

Basics


Advanced

Effective Training Techniques

Lightning implements various techniques to help during training that can help make the training smoother.


Accumulate Gradients

Accumulated gradients run K small batches of size N before doing a backward pass. The effect is a large effective batch size of size KxN, where N is the batch size. Internally it doesn’t stack up the batches and do a forward pass rather it accumulates the gradients for K batches and then do an optimizer.step to make sure the effective batch size is increased but there is no memory overhead.

Warning

When using distributed training for eg. DDP, with let’s say with P devices, each device accumulates independently i.e. it stores the gradients after each loss.backward() and doesn’t sync the gradients across the devices until we call optimizer.step(). So for each accumulation step, the effective batch size on each device will remain N*K but right before the optimizer.step(), the gradient sync will make the effective batch size as P*N*K. For DP, since the batch is split across devices, the final effective batch size will be N*K.

See also

Trainer

# DEFAULT (ie: no accumulated grads)
trainer = Trainer(accumulate_grad_batches=1)

# Accumulate gradients for 7 batches
trainer = Trainer(accumulate_grad_batches=7)

You can set different values for it at different epochs by passing a dictionary, where the key represents the epoch at which the value for gradient accumulation should be updated.

# till 5th epoch, it will accumulate every 8 batches. From 5th epoch
# till 9th epoch it will accumulate every 4 batches and after that no accumulation
# will happen. Note that you need to use zero-indexed epoch keys here
trainer = Trainer(accumulate_grad_batches={0: 8, 4: 4, 8: 1})

Or, you can create custom GradientAccumulationScheduler

from pytorch_lightning.callbacks import GradientAccumulationScheduler


# till 5th epoch, it will accumulate every 8 batches. From 5th epoch
# till 9th epoch it will accumulate every 4 batches and after that no accumulation
# will happen. Note that you need to use zero-indexed epoch keys here
accumulator = GradientAccumulationScheduler(scheduling={0: 8, 4: 4, 8: 1})
trainer = Trainer(callbacks=accumulator)

Gradient Clipping

Gradient clipping can be enabled to avoid exploding gradients. By default, this will clip the gradient norm by calling torch.nn.utils.clip_grad_norm_() computed over all model parameters together. If the Trainer’s gradient_clip_algorithm is set to 'value' ('norm' by default), this will use instead torch.nn.utils.clip_grad_value_() for each parameter instead.

Note

If using mixed precision, the gradient_clip_val does not need to be changed as the gradients are unscaled before applying the clipping function.

See also

Trainer

# DEFAULT (ie: don't clip)
trainer = Trainer(gradient_clip_val=0)

# clip gradients' global norm to <=0.5 using gradient_clip_algorithm='norm' by default
trainer = Trainer(gradient_clip_val=0.5)

# clip gradients' maximum magnitude to <=0.5
trainer = Trainer(gradient_clip_val=0.5, gradient_clip_algorithm="value")

Read more about Configuring Gradient Clipping for advanced use-cases.


Stochastic Weight Averaging

Stochastic Weight Averaging (SWA) can make your models generalize better at virtually no additional cost. This can be used with both non-trained and trained models. The SWA procedure smooths the loss landscape thus making it harder to end up in a local minimum during optimization.

For a more detailed explanation of SWA and how it works, read this post by the PyTorch team.

See also

The StochasticWeightAveraging callback

# Enable Stochastic Weight Averaging using the callback
trainer = Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)])

Batch Size Finder

Auto-scaling of batch size can be enabled to find the largest batch size that fits into memory. Large batch size often yields a better estimation of the gradients, but may also result in longer training time. Inspired by https://github.com/BlackHC/toma.

See also

Trainer

# DEFAULT (ie: don't scale batch size automatically)
trainer = Trainer(auto_scale_batch_size=None)

# Autoscale batch size
trainer = Trainer(auto_scale_batch_size=None | "power" | "binsearch")

# Find the batch size
trainer.tune(model)

Currently, this feature supports two modes 'power' scaling and 'binsearch' scaling. In 'power' scaling, starting from a batch size of 1 keeps doubling the batch size until an out-of-memory (OOM) error is encountered. Setting the argument to 'binsearch' will initially also try doubling the batch size until it encounters an OOM, after which it will do a binary search that will finetune the batch size. Additionally, it should be noted that the batch size scaler cannot search for batch sizes larger than the size of the training dataset.

Note

This feature expects that a batch_size field is either located as a model attribute i.e. model.batch_size or as a field in your hparams i.e. model.hparams.batch_size. Similarly it can work with datamodules too. The field should exist and will be updated by the results of this algorithm. Additionally, your train_dataloader() method should depend on this field for this feature to work i.e.

# using LightningModule
class LitModel(LightningModule):
    def __init__(self, batch_size):
        super().__init__()
        self.save_hyperparameters()
        # or
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)


trainer = Trainer(...)
model = LitModel(batch_size=32)
trainer.tune(model)


# using LightningDataModule
class LitDataModule(LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.save_hyperparameters()
        # or
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)


trainer = Trainer(...)
model = MyModel()
datamodule = LitDataModule(batch_size=32)
trainer.tune(model, datamodule=datamodule)

Note that the train_dataloader can be either part of the LightningModule or LightningDataModule as shown above. If both the LightningModule and the LightningDataModule contain a train_dataloader, the LightningDataModule takes precedence.

Warning

Due to the constraints listed above, this features does NOT work when passing dataloaders directly to .fit().

The scaling algorithm has a number of parameters that the user can control by invoking the scale_batch_size() method:

# Use default in trainer construction
trainer = Trainer()
tuner = Tuner(trainer)

# Invoke method
new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)

# Override old batch size (this is done automatically)
model.hparams.batch_size = new_batch_size

# Fit as normal
trainer.fit(model)
The algorithm in short works by:
  1. Dumping the current state of the model and trainer

  2. Iteratively until convergence or maximum number of tries max_trials (default 25) has been reached:
    • Call fit() method of trainer. This evaluates steps_per_trial (default 3) number of optimization steps. Each training step can trigger an OOM error if the tensors (training batch, weights, gradients, etc.) allocated during the steps have a too large memory footprint.

    • If an OOM error is encountered, decrease batch size else increase it. How much the batch size is increased/decreased is determined by the chosen strategy.

  3. The found batch size is saved to either model.batch_size or model.hparams.batch_size

  4. Restore the initial state of model and trainer

Warning

Batch size finder is not yet supported for DDP or any of its variations, it is coming soon.

Customizing Batch Size Finder
  1. You can also customize the BatchSizeFinder callback to run at different epochs. This feature is useful while fine-tuning models since you can’t always use the same batch size after unfreezing the backbone.

from pytorch_lightning.callbacks import BatchSizeFinder


class FineTuneBatchSizeFinder(BatchSizeFinder):
    def __init__(self, milestones, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.milestones = milestones

    def on_fit_start(self, *args, **kwargs):
        return

    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
            self.scale_batch_size(trainer, pl_module)


trainer = Trainer(callbacks=[FineTuneBatchSizeFinder(milestones=(5, 10))])
trainer.fit(...)
  1. Run batch size finder for validate/test/predict.

from pytorch_lightning.callbacks import BatchSizeFinder


class EvalBatchSizeFinder(BatchSizeFinder):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def on_fit_start(self, *args, **kwargs):
        return

    def on_test_start(self, trainer, pl_module):
        self.scale_batch_size(trainer, pl_module)


trainer = Trainer(callbacks=[EvalBatchSizeFinder()])
trainer.test(...)

Learning Rate Finder


For training deep neural networks, selecting a good learning rate is essential for both better performance and faster convergence. Even optimizers such as Adam that are self-adjusting the learning rate can benefit from more optimal choices.

To reduce the amount of guesswork concerning choosing a good initial learning rate, a learning rate finder can be used. As described in this paper a learning rate finder does a small run where the learning rate is increased after each processed batch and the corresponding loss is logged. The result of this is a lr vs. loss plot that can be used as guidance for choosing an optimal initial learning rate.

Warning

For the moment, this feature only works with models having a single optimizer.

Note

With DDP: Since all the processes run in isolation, only process with global_rank=0 will make the decision to stop the learning rate finder and broadcast its results to all other ranks. That means, at the end of LR finder, each process will be running with the learning rate found on global_rank=0.

Using Lightning’s built-in LR finder

To enable the learning rate finder, your lightning module needs to have a learning_rate or lr attribute (or as a field in your hparams i.e. hparams.learning_rate or hparams.lr). Then, set Trainer(auto_lr_find=True) during trainer construction, and then call trainer.tune(model) to run the LR finder. The suggested learning_rate will be written to the console and will be automatically set to your lightning module, which can be accessed via self.learning_rate or self.lr.

See also

trainer.tune.

class LitModel(LightningModule):
    def __init__(self, learning_rate):
        super().__init__()
        self.learning_rate = learning_rate
        self.model = Model(...)

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=(self.lr or self.learning_rate))


model = LitModel()

# finds learning rate automatically
# sets hparams.lr or hparams.learning_rate to that learning rate
trainer = Trainer(auto_lr_find=True)

trainer.tune(model)

If your model is using an arbitrary value instead of self.lr or self.learning_rate, set that value as auto_lr_find:

model = LitModel()

# to set to your own hparams.my_value
trainer = Trainer(auto_lr_find="my_value")

trainer.tune(model)

You can also inspect the results of the learning rate finder or just play around with the parameters of the algorithm. This can be done by invoking the lr_find() method. A typical example of this would look like:

model = MyModelClass(hparams)
trainer = Trainer()

# Run learning rate finder
lr_finder = trainer.tuner.lr_find(model)

# Results can be found in
print(lr_finder.results)

# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()

# Pick point based on plot, or get suggestion
new_lr = lr_finder.suggestion()

# update hparams of the model
model.hparams.lr = new_lr

# Fit model
trainer.fit(model)

The figure produced by lr_finder.plot() should look something like the figure below. It is recommended to not pick the learning rate that achieves the lowest loss, but instead something in the middle of the sharpest downward slope (red point). This is the point returned py lr_finder.suggestion().

Customizing Learning Rate Finder

You can also customize the LearningRateFinder callback to run at different epochs. This feature is useful while fine-tuning models.

from pytorch_lightning.callbacks import LearningRateFinder


class FineTuneLearningRateFinder(LearningRateFinder):
    def __init__(self, milestones, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.milestones = milestones

    def on_fit_start(self, *args, **kwargs):
        return

    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
            self.lr_find(trainer, pl_module)


trainer = Trainer(callbacks=[FineTuneLearningRateFinder(milestones=(5, 10))])
trainer.fit(...)
_images/lr_finder.png

Advanced GPU Optimizations

When training on single or multiple GPU machines, Lightning offers a host of advanced optimizations to improve throughput, memory efficiency, and model scaling. Refer to Advanced GPU Optimized Training for more details.


Sharing Datasets Across Process Boundaries

The LightningDataModule class provides an organized way to decouple data loading from training logic, with prepare_data() being used for downloading and pre-processing the dataset on a single process, and setup() loading the pre-processed data for each process individually:

class MNISTDataModule(pl.LightningDataModule):
    def prepare_data(self):
        MNIST(self.data_dir, download=True)

    def setup(self, stage: str):
        self.mnist = MNIST(self.data_dir)

    def train_loader(self):
        return DataLoader(self.mnist, batch_size=128)

However, for in-memory datasets, that means that each process will hold a (redundant) replica of the dataset in memory, which may be impractical when using many processes while utilizing datasets that nearly fit into CPU memory, as the memory consumption will scale up linearly with the number of processes. For example, when training Graph Neural Networks, a common strategy is to load the entire graph into CPU memory for fast access to the entire graph structure and its features, and to then perform neighbor sampling to obtain mini-batches that fit onto the GPU.

A simple way to prevent redundant dataset replicas is to rely on torch.multiprocessing to share the data automatically between spawned processes via shared memory. For this, all data pre-loading should be done on the main process inside DataModule.__init__(). As a result, all tensor-data will get automatically shared when using the DDPSpawnStrategy strategy.

Warning

torch.multiprocessing will send a handle of each individual tensor to other processes. In order to prevent any errors due to too many open file handles, try to reduce the number of tensors to share, e.g., by stacking your data into a single tensor.

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str):
        self.mnist = MNIST(data_dir, download=True, transform=T.ToTensor())

    def train_loader(self):
        return DataLoader(self.mnist, batch_size=128)


model = Model(...)
datamodule = MNISTDataModule("data/MNIST")

trainer = Trainer(accelerator="gpu", devices=2, strategy="ddp_spawn")
trainer.fit(model, datamodule)

See the graph-level and node-level prediction examples in PyTorch Geometric for practical use-cases.

Find bottlenecks in your code

Track and Visualize Experiments (intermediate)

Audience: Users who want to track more complex outputs and use third-party experiment managers.


Track audio and other artifacts

To track other artifacts, such as histograms or model topology graphs first select one of the many loggers supported by Lightning

from pytorch_lightning import loggers as pl_loggers

tensorboard = pl_loggers.TensorBoardLogger(save_dir="")
trainer = Trainer(logger=tensorboard)

then access the logger’s API directly

def training_step(self):
    tensorboard = self.logger.experiment
    tensorboard.add_image()
    tensorboard.add_histogram(...)
    tensorboard.add_figure(...)

Comet.ml

To use Comet.ml first install the comet package:

pip install comet-ml

Configure the logger and pass it to the Trainer:

from pytorch_lightning.loggers import CometLogger

comet_logger = CometLogger(api_key="YOUR_COMET_API_KEY")
trainer = Trainer(logger=comet_logger)

Access the comet logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class LitModel(LightningModule):
    def any_lightning_module_function_or_hook(self):
        comet = self.logger.experiment
        fake_images = torch.Tensor(32, 3, 28, 28)
        comet.add_image("generated_images", fake_images, 0)

Here’s the full documentation for the CometLogger.


MLflow

To use MLflow first install the MLflow package:

pip install mlflow

Configure the logger and pass it to the Trainer:

from pytorch_lightning.loggers import MLFlowLogger

mlf_logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs")
trainer = Trainer(logger=mlf_logger)

Access the mlflow logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class LitModel(LightningModule):
    def any_lightning_module_function_or_hook(self):
        mlf_logger = self.logger.experiment
        fake_images = torch.Tensor(32, 3, 28, 28)
        mlf_logger.add_image("generated_images", fake_images, 0)

Here’s the full documentation for the MLFlowLogger.


Neptune.ai

To use Neptune.ai first install the neptune package:

pip install neptune-client

or with conda:

conda install -c conda-forge neptune-client

Configure the logger and pass it to the Trainer:

from pytorch_lightning.loggers import NeptuneLogger

neptune_logger = NeptuneLogger(
    api_key="ANONYMOUS",  # replace with your own
    project="common/pytorch-lightning-integration",  # format "<WORKSPACE/PROJECT>"
)
trainer = Trainer(logger=neptune_logger)

Access the neptune logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class LitModel(LightningModule):
    def any_lightning_module_function_or_hook(self):
        neptune_logger = self.logger.experiment["your/metadata/structure"]
        neptune_logger.log(metadata)

Here’s the full documentation for the NeptuneLogger.


Tensorboard

TensorBoard can be installed with:

pip install tensorboard

Configure the logger and pass it to the Trainer:

from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger()
trainer = Trainer(logger=logger)

Access the tensorboard logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class LitModel(LightningModule):
    def any_lightning_module_function_or_hook(self):
        tensorboard_logger = self.logger.experiment
        fake_images = torch.Tensor(32, 3, 28, 28)
        tensorboard_logger.add_image("generated_images", fake_images, 0)

Here’s the full documentation for the TensorBoardLogger.


Weights and Biases

To use Weights and Biases (wandb) first install the wandb package:

pip install wandb

Configure the logger and pass it to the Trainer:

from pytorch_lightning.loggers import WandbLogger

wandb_logger = WandbLogger(project="MNIST", log_model="all")
trainer = Trainer(logger=wandb_logger)

# log gradients and model topology
wandb_logger.watch(model)

Access the wandb logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class MyModule(LightningModule):
    def any_lightning_module_function_or_hook(self):
        wandb_logger = self.logger.experiment
        fake_images = torch.Tensor(32, 3, 28, 28)

        # Option 1
        wandb_logger.log({"generated_images": [wandb.Image(fake_images, caption="...")]})

        # Option 2 for specifically logging images
        wandb_logger.log_image(key="generated_images", images=[fake_images])

Here’s the full documentation for the WandbLogger. Demo in Google Colab with hyperparameter search and model logging.


Use multiple exp managers

To use multiple experiment managers at the same time, pass a list to the logger Trainer argument.

from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

logger1 = TensorBoardLogger()
logger2 = WandbLogger()
trainer = Trainer(logger=[logger1, logger2])

Access all loggers from any function (except the LightningModule init) to use their APIs for tracking advanced artifacts

class MyModule(LightningModule):
    def any_lightning_module_function_or_hook(self):
        tensorboard_logger = self.loggers.experiment[0]
        wandb_logger = self.loggers.experiment[1]

        fake_images = torch.Tensor(32, 3, 28, 28)

        tensorboard_logger.add_image("generated_images", fake_images, 0)
        wandb_logger.add_image("generated_images", fake_images, 0)

Track multiple metrics in the same chart

If your logger supports plotting multiple metrics on the same chart, pass in a dictionary to self.log.

self.log("performance", {"acc": acc, "recall": recall})

Track hyperparameters

To track hyperparameters, first call save_hyperparameters from the LightningModule init:

class MyLightningModule(LightningModule):
    def __init__(self, learning_rate, another_parameter, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

If your logger supports tracked hyperparameters, the hyperparameters will automatically show up on the logger dashboard.

TODO: show tracked hyperparameters.


Track model topology

Multiple loggers support visualizing the model topology. Here’s an example that tracks the model topology using Tensorboard.

def any_lightning_module_function_or_hook(self):
    tensorboard_logger = self.logger.experiment

    prototype_array = torch.Tensor(32, 1, 28, 27)
    tensorboard_logger.log_graph(model=self, input_array=prototype_array)

TODO: show tensorboard topology.

How to Organize PyTorch Into Lightning

To enable your code to work with Lightning, perform the following to organize PyTorch into Lightning.


1. Keep Your Computational Code

Keep your regular nn.Module architecture

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F


class LitModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        return x

2. Configure Training Logic

In the training_step of the LightningModule configure how your training routine behaves with a batch of training data:

class LitModel(pl.LightningModule):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.encoder(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

Note

If you need to fully own the training loop for complicated legacy projects, check out Own your loop.


3. Move Optimizer(s) and LR Scheduler(s)

Move your optimizers to the configure_optimizers() hook.

class LitModel(pl.LightningModule):
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.encoder.parameters(), lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

4. Organize Validation Logic (optional)

If you need a validation loop, configure how your validation routine behaves with a batch of validation data:

class LitModel(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.encoder(x)
        val_loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", val_loss)

Tip

trainer.validate() loads the best checkpoint automatically by default if checkpointing was enabled during fitting.


5. Organize Testing Logic (optional)

If you need a test loop, configure how your testing routine behaves with a batch of test data:

class LitModel(pl.LightningModule):
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.encoder(x)
        test_loss = F.cross_entropy(y_hat, y)
        self.log("test_loss", test_loss)

6. Configure Prediction Logic (optional)

If you need a prediction loop, configure how your prediction routine behaves with a batch of test data:

class LitModel(LightningModule):
    def predict_step(self, batch, batch_idx):
        x, y = batch
        pred = self.encoder(x)
        return pred

7. Remove any .cuda() or .to(device) Calls

Your LightningModule can automatically run on any hardware!

If you have any explicit calls to .cuda() or .to(device), you can remove them since Lightning makes sure that the data coming from DataLoader and all the Module instances initialized inside LightningModule.__init__ are moved to the respective devices automatically. If you still need to access the current device, you can use self.device anywhere in your LightningModule except in the __init__ and setup methods.

class LitModel(LightningModule):
    def training_step(self, batch, batch_idx):
        z = torch.randn(4, 5, device=self.device)
        ...

Hint: If you are initializing a Tensor within the LightningModule.__init__ method and want it to be moved to the device automatically you should call register_buffer() to register it as a parameter.

class LitModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.register_buffer("running_mean", torch.zeros(num_features))

8. Use your own data

Regular PyTorch DataLoaders work with Lightning. For more modular and scalable datasets, check out LightningDataModule.


Good to know

Additionally, you can run only the validation loop using validate() method.

model = LitModel()
trainer.validate(model)

Note

model.eval() and torch.no_grad() are called automatically for validation.

The test loop isn’t used within fit(), therefore, you would need to explicitly call test().

model = LitModel()
trainer.test(model)

Note

model.eval() and torch.no_grad() are called automatically for testing.

Tip

trainer.test() loads the best checkpoint automatically by default if checkpointing is enabled.

The predict loop will not be used until you call predict().

model = LitModel()
trainer.predict(model)

Note

model.eval() and torch.no_grad() are called automatically for testing.

Tip

trainer.predict() loads the best checkpoint automatically by default if checkpointing is enabled.

Run on an on-prem cluster

Checkpointing


N-Bit Precision

Training on unreliable mixed GPUs across the internet

Audience: Users who do not have access to top-tier multi-gpu/multi-node servers and want to scale training across different GPU types, or across the internet.


Train 1 trillion+ parameter models

When training large models, fitting larger batch sizes, or trying to increase throughput using multi-GPU compute, Lightning provides advanced optimized distributed training strategies to support these cases and offer substantial improvements in memory usage.

Note that some of the extreme memory saving configurations will affect the speed of training. This Speed/Memory trade-off in most cases can be adjusted.

Some of these memory-efficient strategies rely on offloading onto other forms of memory, such as CPU RAM or NVMe. This means you can even see memory benefits on a single GPU, using a strategy such as DeepSpeed ZeRO Stage 3 Offload.

Check out this amazing video explaining model parallelism and how it works behind the scenes:

Choosing an Advanced Distributed GPU Strategy

If you would like to stick with PyTorch DDP, see DDP Optimizations.

Unlike DistributedDataParallel (DDP) where the maximum trainable model size and batch size do not change with respect to the number of GPUs, memory-optimized strategies can accommodate bigger models and larger batches as more GPUs are used. This means as you scale up the number of GPUs, you can reach the number of model parameters you’d like to train.

There are many considerations when choosing a strategy as described below. In addition, check out the visualization of various strategy benchmarks using minGPT here.

Pre-training vs Fine-tuning

When fine-tuning, we often use a magnitude less data compared to pre-training a model. This is important when choosing a distributed strategy as usually for pre-training, we are compute-bound. This means we cannot sacrifice throughput as much as if we were fine-tuning, because in fine-tuning the data requirement is smaller.

Overall:

For example when using 128 GPUs, you can pre-train large 10 to 20 Billion parameter models using DeepSpeed ZeRO Stage 2 without having to take a performance hit with more advanced optimized multi-gpu strategy.

But for fine-tuning a model, you can reach 10 to 20 Billion parameter models using DeepSpeed ZeRO Stage 3 Offload on a single GPU. This does come with a significant throughput hit, which needs to be weighed accordingly.

When Shouldn’t I use an Optimized Distributed Strategy?

Sharding techniques help when model sizes are fairly large; roughly 500M+ parameters is where we’ve seen benefits. However, in the following cases, we recommend sticking to ordinary distributed strategies

  • When your model is small (ResNet50 of around 80M Parameters), unless you are using unusually large batch sizes or inputs.

  • Due to high distributed communication between devices, if running on a slow network/interconnect, the training might be much slower than expected and then it’s up to you to determince the tradeoff here.


Colossal-AI

ColossalAIStrategy implements ZeRO-DP with chunk-based memory management. With this chunk mechanism, really large models can be trained with a small number of GPUs. It supports larger trainable model size and batch size than usual heterogeneous training by reducing CUDA memory fragments and CPU memory consumption. Also, it speeds up this kind of heterogeneous training by fully utilizing all kinds of resources.

When enabling chunk mechanism, a set of consecutive parameters are stored in a chunk, and then the chunk is sharded across different processes. This can reduce communication and data transmission frequency and fully utilize communication and PCI-E bandwidth, which makes training faster.

Unlike traditional implementations, which adopt static memory partition, we implemented a dynamic heterogeneous memory management system named Gemini. During the first training step, the warmup phase will sample the maximum non-model data memory (memory usage expect parameters, gradients, and optimizer states). In later training, it will use the collected memory usage information to evict chunks dynamically. Gemini allows you to fit much larger models with limited GPU memory.

According to our benchmark results, we can train models with up to 24 billion parameters in 1 GPU. You can install colossalai by consulting how to download colossalai. Then, run this benchmark in Colossalai-PL/gpt.

Here is an example showing how to use ColossalAI:

from colossalai.nn.optimizer import HybridAdam


class MyBert(LightningModule):
    ...

    def configure_sharded_model(self) -> None:
        # create your model here
        self.model = BertForSequenceClassification.from_pretrained("bert-base-uncased")

    def configure_optimizers(self):
        # use the specified optimizer
        optimizer = HybridAdam(self.model.parameters(), self.lr)

    ...


model = MyBert()
trainer = Trainer(accelerator="gpu", devices=1, precision=16, strategy="colossalai")
trainer.fit(model)

You can find more examples in the Colossalai-PL repository.

Note

  • The only accelerator which ColossalAI supports is "gpu". But CPU resources will be used when the placement policy is set to “auto” or “cpu”.

  • The only precision which ColossalAI allows is 16 (FP16).

  • It only supports a single optimizer, which must be colossalai.nn.optimizer.CPUAdam or colossalai.nn.optimizer. HybridAdam now. You can set adamw_mode to False to use normal Adam. Noticing that HybridAdam is highly optimized, it uses fused CUDA kernel and parallel CPU kernel. It is recomended to use HybridAdam, since it updates parameters in GPU and CPU both.

  • Your model must be created using the configure_sharded_model() method.

  • ColossalaiStrategy doesn’t support gradient accumulation as of now.

Placement Policy

Placement policies can help users fully exploit their GPU-CPU heterogeneous memory space for better training efficiency. There are three options for the placement policy. They are “cpu”, “cuda” and “auto” respectively.

When the placement policy is set to “cpu”, all participated parameters will be offloaded into CPU memory immediately at the end of every auto-grad operation. In this way, “cpu” placement policy uses the least CUDA memory. It is the best choice for users who want to exceptionally enlarge their model size or training batch size.

When using “cuda” option, all parameters are placed in the CUDA memory, no CPU resources will be used during the training. It is for users who get plenty of CUDA memory.

The third option, “auto”, enables Gemini. It monitors the consumption of CUDA memory during the warmup phase and collects CUDA memory usage of all auto-grad operations. In later training steps, Gemini automatically manages the data transmission between GPU and CPU according to collected CUDA memory usage information. It is the fastest option when CUDA memory is enough.

Here’s an example of changing the placement policy to “cpu”.

from pytorch_lightning.strategies import ColossalAIStrategy

model = MyModel()
my_strategy = ColossalAIStrategy(placement_policy="cpu")
trainer = Trainer(accelerator="gpu", devices=4, precision=16, strategy=my_strategy)
trainer.fit(model)

Sharded Training

The technique can be found within DeepSpeed ZeRO and ZeRO-2, however the implementation is built from the ground up to be PyTorch compatible and standalone. Sharded Training allows you to maintain GPU scaling efficiency, whilst reducing memory overhead drastically. In short, expect near-normal linear scaling (if your network allows), and significantly reduced memory usage when training large models.

Sharded Training still utilizes Data Parallel Training under the hood, except optimizer states and gradients are sharded across GPUs. This means the memory overhead per GPU is lower, as each GPU only has to maintain a partition of your optimizer state and gradients.

The benefits vary by model and parameter sizes, but we’ve recorded up to a 63% memory reduction per GPU allowing us to double our model sizes. Because of efficient communication, these benefits in multi-GPU setups are almost free and throughput scales well with multi-node setups.

It is highly recommended to use Sharded Training in multi-GPU environments where memory is limited, or where training larger models are beneficial (500M+ parameter models). A technical note: as batch size scales, storing activations for the backwards pass becomes the bottleneck in training. As a result, sharding optimizer state and gradients becomes less impactful.

# train using Sharded DDP
trainer = Trainer(strategy="ddp_sharded")

Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required.


Fully Sharded Training

PyTorch has it’s own version of FSDP which is upstreamed from their fairscale project. It was introduced in their v1.11.0 release but it is recommended to use it with PyTorch v1.12 or more and that’s what Lightning supports.

Auto Wrapping

Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don’t have to wrap layers manually as in the case of manual wrapping.

Note

While initializing the optimizers inside configure_optimizers hook, make sure to use self.trainer.model.parameters(), else PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your lightning_module.parameters() will return a generator with no params. This inconvenience will be addressed in the future.

model = BoringModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp_native", precision=16)
trainer.fit(model)

Read more here.

Manual Wrapping

Manual wrapping can be useful to explore complex sharding strategies by applying wrap selectively to some parts of the model. To activate parameter sharding with manual wrapping, you can wrap your model using the wrap function. Internally in Lightning, we enable a context manager around the configure_sharded_model function to make sure the wrap parameters are passed correctly.

When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies.

wrap simply wraps the module with a Fully Sharded Parallel class with the correct parameters from the Lightning context manager.

Here’s an example using that uses wrap to create your model:

import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.distributed.fsdp.wrap import wrap


class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.linear_layer = nn.Linear(32, 32)
        self.block = nn.Sequential(nn.Linear(32, 32), nn.Linear(32, 32))

    def configure_sharded_model(self):
        # modules are sharded across processes
        # as soon as they are wrapped with `wrap`.
        # During the forward/backward passes, weights get synced across processes
        # and de-allocated once computation is complete, saving memory.

        # Wraps the layer in a Fully Sharded Wrapper automatically
        linear_layer = wrap(self.linear_layer)

        for i, layer in enumerate(self.block):
            self.block[i] = wrap(layer)

        self.model = nn.Sequential(linear_layer, nn.ReLU(), self.block)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters())


model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp_native", precision=16)
trainer.fit(model)

You can customize the strategy configuration by adjusting the arguments of DDPFullyShardedNativeStrategy and pass that to the strategy argument inside the Trainer.

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy


native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=True)
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", devices=4)

Check out this tutorial to learn more about the native support.


Activation Checkpointing

Activation checkpointing reduces GPU memory usage by avoiding the storage of intermediate activation tensors in selected layers. The tradeoff is that computation cost for the backpropagation increases, as the dropped activations need to be recomputed.

Enable checkpointing on large layers (like Transformers) by providing the layer class/type to the strategy:

from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy

fsdp = DDPFullyShardedNativeStrategy(
    activation_checkpointing=MyTransformerBlock,  # or pass a list with multiple types
)
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)

DeepSpeed

Note

The DeepSpeed strategy is in beta and the API is subject to change. Please create an issue if you run into any issues.

DeepSpeed is a deep learning training optimization library, providing the means to train massive billion parameter models at scale. Using the DeepSpeed strategy, we were able to train model sizes of 10 Billion parameters and above, with a lot of useful information in this benchmark and the DeepSpeed docs. DeepSpeed also offers lower level training optimizations, and efficient optimizers such as 1-bit Adam. We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models).

Below is a summary of all the configurations of DeepSpeed.

  • DeepSpeed ZeRO Stage 1 - Shard optimizer states, remains at speed parity with DDP whilst providing memory improvement

  • DeepSpeed ZeRO Stage 2 - Shard optimizer states and gradients, remains at speed parity with DDP whilst providing even more memory improvement

  • DeepSpeed ZeRO Stage 2 Offload - Offload optimizer states and gradients to CPU. Increases distributed communication volume and GPU-CPU device transfer, but provides significant memory improvement

  • DeepSpeed ZeRO Stage 3 - Shard optimizer states, gradients, parameters and optionally activations. Increases distributed communication volume, but provides even more memory improvement

  • DeepSpeed ZeRO Stage 3 Offload - Offload optimizer states, gradients, parameters and optionally activations to CPU. Increases distributed communication volume and GPU-CPU device transfer, but even more significant memory improvement.

  • DeepSpeed Activation Checkpointing - Free activations after forward pass. Increases computation, but provides memory improvement for all stages.

To use DeepSpeed, you first need to install DeepSpeed using the commands below.

pip install deepspeed

If you run into an issue with the install or later in training, ensure that the CUDA version of the PyTorch you’ve installed matches your locally installed CUDA (you can see which one has been recognized by running nvcc --version).

Note

DeepSpeed currently only supports single optimizer, single scheduler within the training loop.

When saving a checkpoint we rely on DeepSpeed which saves a directory containing the model and various components.

DeepSpeed ZeRO Stage 1

DeepSpeed ZeRO Stage 1 partitions your optimizer states (Stage 1) across your GPUs to reduce memory.

It is recommended to skip Stage 1 and use Stage 2, which comes with larger memory improvements and still remains efficient. Stage 1 is useful to pair with certain optimizations such as Torch ORT.

from pytorch_lightning import Trainer

model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="deepspeed_stage_1", precision=16)
trainer.fit(model)
DeepSpeed ZeRO Stage 2

DeepSpeed ZeRO Stage 2 partitions your optimizer states (Stage 1) and your gradients (Stage 2) across your GPUs to reduce memory. In most cases, this is more efficient or at parity with DDP, primarily due to the optimized custom communications written by the DeepSpeed team. As a result, benefits can also be seen on a single GPU. Do note that the default bucket sizes allocate around 3.6GB of VRAM to use during distributed communications, which can be tweaked when instantiating the strategy described in a few sections below.

from pytorch_lightning import Trainer

model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="deepspeed_stage_2", precision=16)
trainer.fit(model)
python train.py --strategy deepspeed_stage_2 --precision 16 --accelerator 'gpu' --devices 4
DeepSpeed ZeRO Stage 2 Offload

Below we show an example of running ZeRO-Offload. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption.

from pytorch_lightning import Trainer

model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="deepspeed_stage_2_offload", precision=16)
trainer.fit(model)

This can also be done via the command line using a PyTorch Lightning script:

python train.py --strategy deepspeed_stage_2_offload --precision 16 --accelerator 'gpu' --devices 4

You can also modify the ZeRO-Offload parameters via the strategy as below.

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DeepSpeedStrategy

model = MyModel()
trainer = Trainer(
    accelerator="gpu",
    devices=4,
    strategy=DeepSpeedStrategy(offload_optimizer=True, allgather_bucket_size=5e8, reduce_bucket_size=5e8),
    precision=16,
)
trainer.fit(model)

Note

We suggest tuning the allgather_bucket_size parameter and reduce_bucket_size parameter to find optimum parameters based on your model size. These control how large a buffer we limit the model to using when reducing gradients/gathering updated parameters. Smaller values will result in less memory, but tradeoff with speed.

DeepSpeed allocates a reduce buffer size multiplied by 1.5x so take that into consideration when tweaking the parameters.

The strategy sets a reasonable default of 2e8, which should work for most low VRAM GPUs (less than 7GB), allocating roughly 3.6GB of VRAM as buffer. Higher VRAM GPUs should aim for values around 5e8.

For even more speed benefit, DeepSpeed offers an optimized CPU version of ADAM called DeepSpeedCPUAdam to run the offloaded computation, which is faster than the standard PyTorch implementation.

import pytorch_lightning
from pytorch_lightning import Trainer
from deepspeed.ops.adam import DeepSpeedCPUAdam


class MyModel(pl.LightningModule):
    ...

    def configure_optimizers(self):
        # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w)
        return DeepSpeedCPUAdam(self.parameters())


model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="deepspeed_stage_2_offload", precision=16)
trainer.fit(model)
DeepSpeed ZeRO Stage 3

DeepSpeed ZeRO Stage 3 shards the optimizer states, gradients and the model parameters (also optionally activations). Sharding model parameters and activations comes with an increase in distributed communication, however allows you to scale your models massively from one GPU to multiple GPUs. The DeepSpeed team report the ability to fine-tune models with over 40B parameters on a single GPU and over 2 Trillion parameters on 512 GPUs. For more information we suggest checking the DeepSpeed ZeRO-3 Offload documentation.

We’ve ran benchmarks for all these features and given a simple example of how all these features work in Lightning, which you can see at minGPT.

To reach the highest memory efficiency or model size, you must:

  1. Use the DeepSpeed strategy with the stage 3 parameter

  2. Use CPU Offloading to offload weights to CPU, plus have a reasonable amount of CPU RAM to offload onto

  3. Use DeepSpeed Activation Checkpointing to shard activations

Below we describe how to enable all of these to see benefit. With all these improvements we reached 45 Billion parameters training a GPT model on 8 GPUs with ~1TB of CPU RAM available.

Also please have a look at our DeepSpeed ZeRO Stage 3 Tips which contains a lot of helpful information when configuring your own models.

Note

When saving a model using DeepSpeed and Stage 3, model states and optimizer states will be saved in separate sharded states (based on the world size). See Collating Single File Checkpoint for DeepSpeed ZeRO Stage 3 to obtain a single checkpoint file.

from pytorch_lightning import Trainer
from deepspeed.ops.adam import FusedAdam


class MyModel(pl.LightningModule):
    ...

    def configure_optimizers(self):
        return FusedAdam(self.parameters())


model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="deepspeed_stage_3", precision=16)
trainer.fit(model)

trainer.test()
trainer.predict()

You can also use the Lightning Trainer to run predict or evaluate with DeepSpeed once the model has been trained.

from pytorch_lightning import Trainer


class MyModel(pl.LightningModule):
    ...


model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="deepspeed_stage_3", precision=16)
trainer.test(ckpt_path="my_saved_deepspeed_checkpoint.ckpt")
Shard Model Instantly to Reduce Initialization Time/Memory

When instantiating really large models, it is sometimes necessary to shard the model layers instantly.

This is the case if layers may not fit on one single machines CPU or GPU memory, but would fit once sharded across multiple machines. We expose a hook that layers initialized within the hook will be sharded instantly on a per layer basis, allowing you to instantly shard models.

This reduces the time taken to initialize very large models, as well as ensure we do not run out of memory when instantiating larger models. For more information you can refer to the DeepSpeed docs for Constructing Massive Models.

import torch.nn as nn
from pytorch_lightning import Trainer
from deepspeed.ops.adam import FusedAdam


class MyModel(pl.LightningModule):
    ...

    def configure_sharded_model(self):
        # Created within sharded model context, modules are instantly sharded across processes
        # as soon as they are made.
        self.block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())

    def configure_optimizers(self):
        return FusedAdam(self.parameters())


model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="deepspeed_stage_3", precision=16)
trainer.fit(model)

trainer.test()
trainer.predict()
DeepSpeed ZeRO Stage 3 Offload

DeepSpeed ZeRO Stage 3 Offloads optimizer state, gradients to the host CPU to reduce memory usage as ZeRO Stage 2 does, however additionally allows you to offload the parameters as well for even more memory saving.

Note

When saving a model using DeepSpeed and Stage 3, model states and optimizer states will be saved in separate sharded states (based on the world size). See Collating Single File Checkpoint for DeepSpeed ZeRO Stage 3 to obtain a single checkpoint file.

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DeepSpeedStrategy

# Enable CPU Offloading
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="deepspeed_stage_3_offload", precision=16)
trainer.fit(model)

# Enable CPU Offloading, and offload parameters to CPU
model = MyModel()
trainer = Trainer(
    accelerator="gpu",
    devices=4,
    strategy=DeepSpeedStrategy(
        stage=3,
        offload_optimizer=True,
        offload_parameters=True,
    ),
    precision=16,
)
trainer.fit(model)
DeepSpeed Infinity (NVMe Offloading)

Additionally, DeepSpeed supports offloading to NVMe drives for even larger models, utilizing the large memory space found in NVMes. DeepSpeed reports the ability to fine-tune 1 Trillion+ parameters using NVMe Offloading on one 8 GPU machine. Below shows how to enable this, assuming the NVMe drive is mounted in a directory called /local_nvme.

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DeepSpeedStrategy

# Enable CPU Offloading
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="deepspeed_stage_3_offload", precision=16)
trainer.fit(model)

# Enable CPU Offloading, and offload parameters to CPU
model = MyModel()
trainer = Trainer(
    accelerator="gpu",
    devices=4,
    strategy=DeepSpeedStrategy(
        stage=3,
        offload_optimizer=True,
        offload_parameters=True,
        remote_device="nvme",
        offload_params_device="nvme",
        offload_optimizer_device="nvme",
        nvme_path="/local_nvme",
    ),
    precision=16,
)
trainer.fit(model)

When offloading to NVMe you may notice that the speed is slow. There are parameters that need to be tuned based on the drives that you are using. Running the aio_bench_perf_sweep.py script can help you to find optimum parameters. See the issue for more information on how to parse the information.

DeepSpeed Activation Checkpointing

Activation checkpointing frees activations from memory as soon as they are not needed during the forward pass. They are then re-computed for the backwards pass as needed.

Activation checkpointing is very useful when you have intermediate layers that produce large activations.

This saves memory when training larger models, however requires using a checkpoint function to run modules as shown below.

Warning

Ensure to not wrap the entire model with activation checkpointing. This is not the intended usage of activation checkpointing, and will lead to failures as seen in this discussion.

from pytorch_lightning import Trainer
import deepspeed


class MyModel(LightningModule):
    ...

    def __init__(self):
        super().__init__()
        self.block_1 = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
        self.block_2 = torch.nn.Linear(32, 2)

    def forward(self, x):
        # Use the DeepSpeed checkpointing function instead of calling the module directly
        # checkpointing self.block_1 means the activations are deleted after use,
        # and re-calculated during the backward passes
        x = deepspeed.checkpointing.checkpoint(self.block_1, x)
        return self.block_2(x)
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed


class MyModel(pl.LightningModule):
    ...

    def configure_sharded_model(self):
        self.block_1 = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
        self.block_2 = torch.nn.Linear(32, 2)

    def forward(self, x):
        # Use the DeepSpeed checkpointing function instead of calling the module directly
        x = deepspeed.checkpointing.checkpoint(self.block_1, x)
        return self.block_2(x)


model = MyModel()

trainer = Trainer(accelerator="gpu", devices=4, strategy="deepspeed_stage_3_offload", precision=16)

# Enable CPU Activation Checkpointing
trainer = Trainer(
    accelerator="gpu",
    devices=4,
    strategy=DeepSpeedStrategy(
        stage=3,
        offload_optimizer=True,  # Enable CPU Offloading
        cpu_checkpointing=True,  # (Optional) offload activations to CPU
    ),
    precision=16,
)
trainer.fit(model)
DeepSpeed ZeRO Stage 3 Tips

Here is some helpful information when setting up DeepSpeed ZeRO Stage 3 with Lightning.

  • If you’re using Adam or AdamW, ensure to use FusedAdam or DeepSpeedCPUAdam (for CPU Offloading) rather than the default torch optimizers as they come with large speed benefits

  • Treat your GPU/CPU memory as one large pool. In some cases, you may not want to offload certain things (like activations) to provide even more space to offload model parameters

  • When offloading to the CPU, make sure to bump up the batch size as GPU memory will be freed

  • We also support sharded checkpointing. By passing save_full_weights=False to the DeepSpeedStrategy, we’ll save shards of the model which allows you to save extremely large models. However to load the model and run test/validation/predict you must use the Trainer object.

Collating Single File Checkpoint for DeepSpeed ZeRO Stage 3

After training using ZeRO Stage 3, you’ll notice that your checkpoints are a directory of sharded model and optimizer states. If you’d like to collate a single file from the checkpoint directory please use the below command, which handles all the Lightning states additionally when collating the file.

from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict

# lightning deepspeed has saved a directory instead of a file
save_path = "lightning_logs/version_0/checkpoints/epoch=0-step=0.ckpt/"
output_path = "lightning_model.pt"
convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path)

Warning

This single file checkpoint does not include the optimizer/lr-scheduler states. This means we cannot restore training via the trainer.fit(ckpt_path=) call. Ensure to keep the sharded checkpoint directory if this is required.

Custom DeepSpeed Config

In some cases you may want to define your own DeepSpeed Config, to access all parameters defined. We’ve exposed most of the important parameters, however, there may be debugging parameters to enable. Also, DeepSpeed allows the use of custom DeepSpeed optimizers and schedulers defined within a config file that is supported.

Note

All strategy default parameters will be ignored when a config object is passed. All compatible arguments can be seen in the DeepSpeed docs.

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DeepSpeedStrategy

deepspeed_config = {
    "zero_allow_untested_optimizer": True,
    "optimizer": {
        "type": "OneBitAdam",
        "params": {
            "lr": 3e-5,
            "betas": [0.998, 0.999],
            "eps": 1e-5,
            "weight_decay": 1e-9,
            "cuda_aware": True,
        },
    },
    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "last_batch_iteration": -1,
            "warmup_min_lr": 0,
            "warmup_max_lr": 3e-5,
            "warmup_num_steps": 100,
        },
    },
    "zero_optimization": {
        "stage": 2,  # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning)
        "offload_optimizer": True,  # Enable Offloading optimizer state/calculation to the host CPU
        "contiguous_gradients": True,  # Reduce gradient fragmentation.
        "overlap_comm": True,  # Overlap reduce/backward operation of gradients for speed.
        "allgather_bucket_size": 2e8,  # Number of elements to all gather at once.
        "reduce_bucket_size": 2e8,  # Number of elements we reduce/allreduce at once.
    },
}

model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy=DeepSpeedStrategy(config=deepspeed_config), precision=16)
trainer.fit(model)

We support taking the config as a json formatted file:

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DeepSpeedStrategy

model = MyModel()
trainer = Trainer(
    accelerator="gpu", devices=4, strategy=DeepSpeedStrategy(config="/path/to/deepspeed_config.json"), precision=16
)
trainer.fit(model)

You can use also use an environment variable via your PyTorch Lightning script:

PL_DEEPSPEED_CONFIG_PATH=/path/to/deepspeed_config.json python train.py --strategy deepspeed

DDP Optimizations

When Using DDP Strategies, Set find_unused_parameters=False

By default, we have set find_unused_parameters=True for compatibility reasons that have been observed in the past (refer to the discussion for more details). When enabled, it can result in a performance hit and can be disabled in most cases. Read more about it here.

Tip

It applies to all DDP strategies that support find_unused_parameters as input.

from pytorch_lightning.strategies import DDPStrategy

trainer = pl.Trainer(
    accelerator="gpu",
    devices=2,
    strategy=DDPStrategy(find_unused_parameters=False),
)
from pytorch_lightning.strategies import DDPSpawnStrategy

trainer = pl.Trainer(
    accelerator="gpu",
    devices=2,
    strategy=DDPSpawnStrategy(find_unused_parameters=False),
)
DDP Static Graph

DDP static graph assumes that your model employs the same set of used/unused parameters in every iteration, so that it can deterministically know the flow of training and apply special optimizations during runtime.

Note

DDP static graph support requires PyTorch>=1.11.0

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy

trainer = Trainer(devices=4, strategy=DDPStrategy(static_graph=True))
When Using DDP on a Multi-node Cluster, Set NCCL Parameters

NCCL is the NVIDIA Collective Communications Library that is used by PyTorch to handle communication across nodes and GPUs. There are reported benefits in terms of speedups when adjusting NCCL parameters as seen in this issue. In the issue, we see a 30% speed improvement when training the Transformer XLM-RoBERTa and a 15% improvement in training with Detectron2.

NCCL parameters can be adjusted via environment variables.

Note

AWS and GCP already set default values for these on their clusters. This is typically useful for custom cluster setups.

export NCCL_NSOCKS_PERTHREAD=4
export NCCL_SOCKET_NTHREADS=2
Gradients as Bucket View

Enabling gradient_as_bucket_view=True in the DDPStrategy will make gradients views point to different offsets of the allreduce communication buckets. See DistributedDataParallel for more information.

This can reduce peak memory usage and throughput as saved memory will be equal to the total gradient memory + removes the need to copy gradients to the allreduce communication buckets.

Note

When gradient_as_bucket_view=True you cannot call detach_() on gradients. If hitting such errors, please fix it by referring to the zero_grad() function in torch/optim/optimizer.py as a solution (source).

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy

model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPStrategy(gradient_as_bucket_view=True))
trainer.fit(model)
DDP Communication Hooks

DDP Communication hooks is an interface to control how gradients are communicated across workers, overriding the standard allreduce in DistributedDataParallel. This allows you to enable performance improving communication hooks when using multiple nodes.

Enable FP16 Compress Hook for multi-node throughput improvement:

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default

model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPStrategy(ddp_comm_hook=default.fp16_compress_hook))
trainer.fit(model)

Enable PowerSGD for multi-node throughput improvement:

Note

PowerSGD typically requires extra memory of the same size as the model’s gradients to enable error feedback, which can compensate for biased compressed communication and improve accuracy (source).

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD

model = MyModel()
trainer = Trainer(
    accelerator="gpu",
    devices=4,
    strategy=DDPStrategy(
        ddp_comm_state=powerSGD.PowerSGDState(
            process_group=None,
            matrix_approximation_rank=1,
            start_powerSGD_iter=5000,
        ),
        ddp_comm_hook=powerSGD.powerSGD_hook,
    ),
)
trainer.fit(model)

Combine hooks for accumulated benefit:

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
from torch.distributed.algorithms.ddp_comm_hooks import (
    default_hooks as default,
    powerSGD_hook as powerSGD,
)

model = MyModel()
trainer = Trainer(
    accelerator="gpu",
    devices=4,
    strategy=DDPStrategy(
        ddp_comm_state=powerSGD.PowerSGDState(
            process_group=None,
            matrix_approximation_rank=1,
            start_powerSGD_iter=5000,
        ),
        ddp_comm_hook=powerSGD.powerSGD_hook,
        ddp_comm_wrapper=default.fp16_compress_wrapper,
    ),
)
trainer.fit(model)

When using Post-localSGD, you must also pass model_averaging_period to allow for model parameter averaging:

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
from torch.distributed.algorithms.ddp_comm_hooks import post_localSGD_hook as post_localSGD

model = MyModel()
trainer = Trainer(
    accelerator="gpu",
    devices=4,
    strategy=DDPStrategy(
        ddp_comm_state=post_localSGD.PostLocalSGDState(
            process_group=None,
            subgroup=None,
            start_localSGD_iter=8,
        ),
        ddp_comm_hook=post_localSGD.post_localSGD_hook,
        model_averaging_period=4,
    ),
)
trainer.fit(model)

Accelerator: GPU training

Accelerator: HPU training

Accelerator: IPU training

Accelerator: TPU training

Accelerator: Apple Silicon training

Transfer Learning

Audience: Users looking to use pretrained models with Lightning.


Use any PyTorch nn.Module

Any model that is a PyTorch nn.Module can be used with Lightning (because LightningModules are nn.Modules also).


Use a pretrained LightningModule

Let’s use the AutoEncoder as a feature extractor in a separate model.

class Encoder(torch.nn.Module):
    ...


class AutoEncoder(LightningModule):
    def __init__(self):
        self.encoder = Encoder()
        self.decoder = Decoder()


class CIFAR10Classifier(LightningModule):
    def __init__(self):
        # init the pretrained LightningModule
        self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
        self.feature_extractor.freeze()

        # the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
        self.classifier = nn.Linear(100, 10)

    def forward(self, x):
        representations = self.feature_extractor(x)
        x = self.classifier(representations)
        ...

We used our pretrained Autoencoder (a LightningModule) for transfer learning!


Example: Imagenet (Computer Vision)

import torchvision.models as models


class ImagenetTransferLearning(LightningModule):
    def __init__(self):
        super().__init__()

        # init a pretrained resnet
        backbone = models.resnet50(weights="DEFAULT")
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # use the pretrained model to classify cifar-10 (10 image classes)
        num_target_classes = 10
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        ...

Finetune

model = ImagenetTransferLearning()
trainer = Trainer()
trainer.fit(model)

And use it to predict your data of interest

model = ImagenetTransferLearning.load_from_checkpoint(PATH)
model.freeze()

x = some_images_from_cifar10()
predictions = model(x)

We used a pretrained model on imagenet, finetuned on CIFAR-10 to predict on CIFAR-10. In the non-academic world we would finetune on a tiny dataset you have and predict on your dataset.


Example: BERT (NLP)

Lightning is completely agnostic to what’s used for transfer learning so long as it is a torch.nn.Module subclass.

Here’s a model that uses Huggingface transformers.

class BertMNLIFinetuner(LightningModule):
    def __init__(self):
        super().__init__()

        self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
        self.W = nn.Linear(bert.config.hidden_size, 3)
        self.num_classes = 3

    def forward(self, input_ids, attention_mask, token_type_ids):
        h, _, attn = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        h_cls = h[:, 0]
        logits = self.W(h_cls)
        return logits, attn

Injecting 3rd Party Data Iterables

When training a model on a specific task, data loading and preprocessing might become a bottleneck. Lightning does not enforce a specific data loading approach nor does it try to control it. The only assumption Lightning makes is that the data is returned as an iterable of batches.

For PyTorch-based programs, these iterables are typically instances of DataLoader.

However, Lightning also supports other data types such as plain list of batches, generators or other custom iterables.

# random list of batches
data = [(torch.rand(32, 3, 32, 32), torch.randint(0, 10, (32,))) for _ in range(100)]
model = LitClassifier()
trainer = Trainer()
trainer.fit(model, data)

Examples for custom iterables include NVIDIA DALI or FFCV for computer vision. Both libraries offer support for custom data loading and preprocessing (also hardware accelerated) and can be used with Lightning.

For example, taking the example from FFCV’s readme, we can use it with Lightning by just removing the hardcoded ToDevice(0) as Lightning takes care of GPU placement. In case you want to use some data transformations on GPUs, change the ToDevice(0) to ToDevice(self.trainer.local_rank) to correctly map to the desired GPU in your pipeline.

from ffcv.loader import Loader, OrderOption
from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, Cutout
from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder


class CustomClassifier(LitClassifier):
    def train_dataloader(self):
        # Random resized crop
        decoder = RandomResizedCropRGBImageDecoder((224, 224))

        # Data decoding and augmentation
        image_pipeline = [decoder, Cutout(), ToTensor(), ToTorchImage()]
        label_pipeline = [IntDecoder(), ToTensor()]

        # Pipeline for each data field
        pipelines = {"image": image_pipeline, "label": label_pipeline}

        # Replaces PyTorch data loader (`torch.utils.data.Dataloader`)
        loader = Loader(
            write_path, batch_size=bs, num_workers=num_workers, order=OrderOption.RANDOM, pipelines=pipelines
        )

        return loader

When moving data to a specific device, you can always refer to self.trainer.local_rank to get the accelerator used by the current process.

By just changing device_id=0 to device_id=self.trainer.local_rank we can also leverage DALI’s GPU decoding:

from nvidia.dali.pipeline import pipeline_def
import nvidia.dali.types as types
import nvidia.dali.fn as fn
from nvidia.dali.plugin.pytorch import DALIGenericIterator
import os


class CustomLitClassifier(LitClassifier):
    def train_dataloader(self):
        # To run with different data, see documentation of nvidia.dali.fn.readers.file
        # points to https://github.com/NVIDIA/DALI_extra
        data_root_dir = os.environ["DALI_EXTRA_PATH"]
        images_dir = os.path.join(data_root_dir, "db", "single", "jpeg")

        @pipeline_def(num_threads=4, device_id=self.trainer.local_rank)
        def get_dali_pipeline():
            images, labels = fn.readers.file(file_root=images_dir, random_shuffle=True, name="Reader")
            # decode data on the GPU
            images = fn.decoders.image_random_crop(images, device="mixed", output_type=types.RGB)
            # the rest of processing happens on the GPU as well
            images = fn.resize(images, resize_x=256, resize_y=256)
            images = fn.crop_mirror_normalize(
                images,
                crop_h=224,
                crop_w=224,
                mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
                mirror=fn.random.coin_flip(),
            )
            return images, labels

        train_data = DALIGenericIterator(
            [get_dali_pipeline(batch_size=16)],
            ["data", "label"],
            reader_name="Reader",
        )

        return train_data

Limitations

Lightning works with all kinds of custom data iterables as shown above. There are, however, a few features that cannot be supported this way. These restrictions come from the fact that for their support, Lightning needs to know a lot on the internals of these iterables.

  • In a distributed multi-GPU setting (ddp), Lightning automatically replaces the DataLoader’s sampler with its distributed counterpart. This makes sure that each GPU sees a different part of the dataset. As sampling can be implemented in arbitrary ways with custom iterables, there is no way for Lightning to know, how to replace the sampler.

  • When training fails for some reason, Lightning is able to extract all of the relevant data from the model, optimizers, trainer and dataloader to resume it at the exact same batch it crashed. This feature is called fault-tolerance and is limited to PyTorch DataLoaders. Lighning needs to know a lot about sampling, fast forwarding and random number handling to enable fault tolerance, meaning that it cannot be supported for arbitrary iterables.

Use a pure PyTorch training loop

Accelerator

The Accelerator connects a Lightning Trainer to arbitrary hardware (CPUs, GPUs, TPUs, IPUs, MPS, …). Currently there are accelerators for:

The Accelerator is part of the Strategy which manages communication across multiple devices (distributed communication). Whenever the Trainer, the loops or any other component in Lightning needs to talk to hardware, it calls into the Strategy and the Strategy calls into the Accelerator.

Illustration of the Strategy as a composition of the Accelerator and several plugins

We expose Accelerators and Strategies mainly for expert users who want to extend Lightning to work with new hardware and distributed training or clusters.


Create a Custom Accelerator

Here is how you create a new Accelerator. Let’s pretend we want to integrate the fictional XPU accelerator and we have access to its hardware through a library xpulib.

import xpulib


class XPUAccelerator(Accelerator):
    """Experimental support for XPU, optimized for large-scale machine learning."""

    @staticmethod
    def parse_devices(devices: Any) -> Any:
        # Put parsing logic here how devices can be passed into the Trainer
        # via the `devices` argument
        return devices

    @staticmethod
    def get_parallel_devices(devices: Any) -> Any:
        # Here, convert the device indices to actual device objects
        return [torch.device("xpu", idx) for idx in devices]

    @staticmethod
    def auto_device_count() -> int:
        # Return a value for auto-device selection when `Trainer(devices="auto")`
        return xpulib.available_devices()

    @staticmethod
    def is_available() -> bool:
        return xpulib.is_available()

    def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
        # Return optional device statistics for loggers
        return {}

Finally, add the XPUAccelerator to the Trainer:

from pytorch_lightning import Trainer

accelerator = XPUAccelerator()
trainer = Trainer(accelerator=accelerator, devices=2)

Learn more about Strategies and how they interact with the Accelerator.


Registering Accelerators

If you wish to switch to a custom accelerator from the CLI without code changes, you can implement the register_accelerators() class method to register your new accelerator under a shorthand name like so:

class XPUAccelerator(Accelerator):
    ...

    @classmethod
    def register_accelerators(cls, accelerator_registry):
        accelerator_registry.register(
            "xpu",
            cls,
            description=f"XPU Accelerator - optimized for large-scale machine learning.",
        )

Now, this is possible:

trainer = Trainer(accelerator="xpu")

Or if you are using the Lightning CLI, for example:

python train.py fit --trainer.accelerator=xpu --trainer.devices=2

Accelerator API

Accelerator

The Accelerator base class for Lightning PyTorch.

CPUAccelerator

Accelerator for CPU devices.

CUDAAccelerator

Accelerator for NVIDIA CUDA devices.

HPUAccelerator

Accelerator for HPU devices.

IPUAccelerator

Accelerator for IPUs.

MPSAccelerator

Accelerator for Metal Apple Silicon GPU devices.

TPUAccelerator

Accelerator for TPU devices.

Callback


A callback is a self-contained program that can be reused across projects.

Lightning has a callback system to execute them when needed. Callbacks should capture NON-ESSENTIAL logic that is NOT required for your lightning module to run.

Here’s the flow of how the callback hooks are executed:

An overall Lightning system should have:

  1. Trainer for all engineering

  2. LightningModule for all research code.

  3. Callbacks for non-essential code.


Example:

from pytorch_lightning.callbacks import Callback


class MyPrintingCallback(Callback):
    def on_train_start(self, trainer, pl_module):
        print("Training is starting")

    def on_train_end(self, trainer, pl_module):
        print("Training is ending")


trainer = Trainer(callbacks=[MyPrintingCallback()])

We successfully extended functionality without polluting our super clean lightning module research code.


Examples

You can do pretty much anything with callbacks.


Built-in Callbacks

Lightning has a few built-in callbacks.

Note

For a richer collection of callbacks, check out our bolts library.

BackboneFinetuning

Finetune a backbone model based on a learning rate user-defined scheduling.

BaseFinetuning

This class implements the base logic for writing your own Finetuning Callback.

BasePredictionWriter

Base class to implement how the predictions should be stored.

BatchSizeFinder

The BatchSizeFinder callback tries to find the largest batch size for a given model that does not give an out of memory (OOM) error.

Callback

Abstract base class used to build new callbacks.

DeviceStatsMonitor

Automatically monitors and logs device stats during training, validation and testing stage.

EarlyStopping

Monitor a metric and stop training when it stops improving.

GradientAccumulationScheduler

Change gradient accumulation factor according to scheduling.

LambdaCallback

Create a simple callback on the fly using lambda functions.

LearningRateFinder

The LearningRateFinder callback enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.

LearningRateMonitor

Automatically monitor and logs learning rate for learning rate schedulers during training.

ModelCheckpoint

Save the model periodically by monitoring a quantity.

ModelPruning

Model pruning Callback, using PyTorch's prune utilities.

ModelSummary

Generates a summary of all layers in a LightningModule.

ProgressBarBase

The base class for progress bars in Lightning.

QuantizationAwareTraining

Quantization allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision.

RichModelSummary

Generates a summary of all layers in a LightningModule with rich text formatting.

RichProgressBar

Create a progress bar with rich text formatting.

StochasticWeightAveraging

Implements the Stochastic Weight Averaging (SWA) Callback to average a model.

Timer

The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the Trainer if the given time limit for the training loop is reached.

TQDMProgressBar

This is the default progress bar used by Lightning.


Save Callback state

Some callbacks require internal state in order to function properly. You can optionally choose to persist your callback’s state as part of model checkpoint files using state_dict() and load_state_dict(). Note that the returned state must be able to be pickled.

When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then the callback must define a state_key property in order for Lightning to be able to distinguish the different states when loading the callback state. This concept is best illustrated by the following example.

class Counter(Callback):
    def __init__(self, what="epochs", verbose=True):
        self.what = what
        self.verbose = verbose
        self.state = {"epochs": 0, "batches": 0}

    @property
    def state_key(self):
        # note: we do not include `verbose` here on purpose
        return self._generate_state_key(what=self.what)

    def on_train_epoch_end(self, *args, **kwargs):
        if self.what == "epochs":
            self.state["epochs"] += 1

    def on_train_batch_end(self, *args, **kwargs):
        if self.what == "batches":
            self.state["batches"] += 1

    def load_state_dict(self, state_dict):
        self.state.update(state_dict)

    def state_dict(self):
        return self.state.copy()


# two callbacks of the same type are being used
trainer = Trainer(callbacks=[Counter(what="epochs"), Counter(what="batches")])

A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information:

{
    "state_dict": ...,
    "callbacks": {
        "Counter{'what': 'batches'}": {"batches": 32, "epochs": 0},
        "Counter{'what': 'epochs'}": {"batches": 0, "epochs": 2},
        ...
    }
}

The implementation of a state_key is essential here. If it were missing, Lightning would not be able to disambiguate the state for these two callbacks, and state_key by default only defines the class name as the key, e.g., here Counter.


Best Practices

The following are best practices when using/designing callbacks.

  1. Callbacks should be isolated in their functionality.

  2. Your callback should not rely on the behavior of other callbacks in order to work properly.

  3. Do not manually call methods from the callback.

  4. Directly calling methods (eg. on_validation_end) is strongly discouraged.

  5. Whenever possible, your callbacks should not depend on the order in which they are executed.


Entry Points

Lightning supports registering Trainer callbacks directly through Entry Points. Entry points allow an arbitrary package to include callbacks that the Lightning Trainer can automatically use, without you having to add them to the Trainer manually. This is useful in production environments where it is common to provide specialized monitoring and logging callbacks globally for every application.

Here is a callback factory function that returns two special callbacks:

factories.py
def my_custom_callbacks_factory():
    return [MyCallback1(), MyCallback2()]

If we make this factories.py file into an installable package, we can define an entry point for this factory function. Here is a minimal example of the setup.py file for the package my-package:

setup.py
from setuptools import setup

setup(
    name="my-package",
    version="0.0.1",
    install_requires=["pytorch-lightning"],
    entry_points={
        "pytorch_lightning.callbacks_factory": [
            # The format here must be [any name]=[module path]:[function name]
            "monitor_callbacks=factories:my_custom_callbacks_factory"
        ]
    },
)

The group name for the entry points is pytorch_lightning.callbacks_factory and it contains a list of strings that specify where to find the function within the package.

Now, if you pip install -e . this package, it will register the my_custom_callbacks_factory function and Lightning will automatically call it to collect the callbacks whenever you run the Trainer!

To unregister the factory, simply uninstall the package with pip uninstall “my-package”.


Callback API

Here is the full API of methods available in the Callback base class.

The Callback class is the base for all the callbacks in Lightning just like the LightningModule is the base for all models. It defines a public interface that each callback implementation must follow, the key ones are:

Properties
state_key
Callback.state_key

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.

Return type

str

Hooks
setup
Callback.setup(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune begins.

Return type

None

teardown
Callback.teardown(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune ends.

Return type

None

on_fit_start
Callback.on_fit_start(trainer, pl_module)[source]

Called when fit begins.

Return type

None

on_fit_end
Callback.on_fit_end(trainer, pl_module)[source]

Called when fit ends.

Return type

None

on_sanity_check_start
Callback.on_sanity_check_start(trainer, pl_module)[source]

Called when the validation sanity check starts.

Return type

None

on_sanity_check_end
Callback.on_sanity_check_end(trainer, pl_module)[source]

Called when the validation sanity check ends.

Return type

None

on_train_batch_start
Callback.on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]

Called when the train batch begins.

Return type

None

on_train_batch_end
Callback.on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

Return type

None

on_train_epoch_start
Callback.on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

Return type

None

on_train_epoch_end
Callback.on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, either:

  1. Implement training_epoch_end in the LightningModule and access outputs via the module OR

  2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.

Return type

None

on_validation_epoch_start
Callback.on_validation_epoch_start(trainer, pl_module)[source]

Called when the val epoch begins.

Return type

None

on_validation_epoch_end
Callback.on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

Return type

None

on_test_epoch_start
Callback.on_test_epoch_start(trainer, pl_module)[source]

Called when the test epoch begins.

Return type

None

on_test_epoch_end
Callback.on_test_epoch_end(trainer, pl_module)[source]

Called when the test epoch ends.

Return type

None

on_predict_epoch_start
Callback.on_predict_epoch_start(trainer, pl_module)[source]

Called when the predict epoch begins.

Return type

None

on_predict_epoch_end
Callback.on_predict_epoch_end(trainer, pl_module, outputs)[source]

Called when the predict epoch ends.

Return type

None

on_validation_batch_start
Callback.on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]

Called when the validation batch begins.

Return type

None

on_validation_batch_end
Callback.on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the validation batch ends.

Return type

None

on_test_batch_start
Callback.on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]

Called when the test batch begins.

Return type

None

on_test_batch_end
Callback.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the test batch ends.

Return type

None

on_predict_batch_start
Callback.on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]

Called when the predict batch begins.

Return type

None

on_predict_batch_end
Callback.on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the predict batch ends.

Return type

None

on_train_start
Callback.on_train_start(trainer, pl_module)[source]

Called when the train begins.

Return type

None

on_train_end
Callback.on_train_end(trainer, pl_module)[source]

Called when the train ends.

Return type

None

on_validation_start
Callback.on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

Return type

None

on_validation_end
Callback.on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

Return type

None

on_test_start
Callback.on_test_start(trainer, pl_module)[source]

Called when the test begins.

Return type

None

on_test_end
Callback.on_test_end(trainer, pl_module)[source]

Called when the test ends.

Return type

None

on_predict_start
Callback.on_predict_start(trainer, pl_module)[source]

Called when the predict begins.

Return type

None

on_predict_end
Callback.on_predict_end(trainer, pl_module)[source]

Called when predict ends.

Return type

None

on_exception
Callback.on_exception(trainer, pl_module, exception)[source]

Called when any trainer execution is interrupted by an exception.

Return type

None

state_dict
Callback.state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type

Dict[str, Any]

Returns

A dictionary containing callback state.

on_save_checkpoint
Callback.on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters
Return type

None

load_state_dict
Callback.load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters

state_dict (Dict[str, Any]) – the callback state returned by state_dict.

Return type

None

on_load_checkpoint
Callback.on_load_checkpoint(trainer, pl_module, checkpoint)[source]

Called when loading a model checkpoint, use to reload state.

Parameters
Return type

None

on_before_backward
Callback.on_before_backward(trainer, pl_module, loss)[source]

Called before loss.backward().

Return type

None

on_after_backward
Callback.on_after_backward(trainer, pl_module)[source]

Called after loss.backward() and before optimizers are stepped.

Return type

None

on_before_optimizer_step
Callback.on_before_optimizer_step(trainer, pl_module, optimizer, opt_idx)[source]

Called before optimizer.step().

Return type

None

on_before_zero_grad
Callback.on_before_zero_grad(trainer, pl_module, optimizer)[source]

Called before optimizer.zero_grad().

Return type

None

Cloud-based checkpoints (advanced)

Cloud checkpoints

Lightning is integrated with the major remote file systems including local filesystems and several cloud storage providers such as S3 on AWS, GCS on Google Cloud, or ADL on Azure.

PyTorch Lightning uses fsspec internally to handle all filesystem operations.


Save a cloud checkpoint

To save to a remote filesystem, prepend a protocol like “s3:/” to the root_dir used for writing and reading model data.

# `default_root_dir` is the default path used for logs and checkpoints
trainer = Trainer(default_root_dir="s3://my_bucket/data/")
trainer.fit(model)

Resume training from a cloud checkpoint

To resume training from a cloud checkpoint use a cloud url.

trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
trainer.fit(model, ckpt_path="s3://my_bucket/ckpts/classifier.ckpt")

PyTorch Lightning uses fsspec internally to handle all filesystem operations.


Modularize your checkpoints

Checkpoints can also save the state of datamodules and callbacks.


Modify a checkpoint anywhere

When you need to change the components of a checkpoint before saving or loading, use the on_save_checkpoint() and on_load_checkpoint() of your LightningModule.

class LitModel(pl.LightningModule):
    def on_save_checkpoint(self, checkpoint):
        checkpoint["something_cool_i_want_to_save"] = my_cool_pickable_object

    def on_load_checkpoint(self, checkpoint):
        my_cool_pickable_object = checkpoint["something_cool_i_want_to_save"]

Use the above approach when you need to couple this behavior to your LightningModule for reproducibility reasons. Otherwise, Callbacks also have the on_save_checkpoint() and on_load_checkpoint() which you should use instead:

class LitCallback(pl.Callback):
    def on_save_checkpoint(self, checkpoint):
        checkpoint["something_cool_i_want_to_save"] = my_cool_pickable_object

    def on_load_checkpoint(self, checkpoint):
        my_cool_pickable_object = checkpoint["something_cool_i_want_to_save"]

Console logging

Audience: Engineers looking to capture more visible logs.


Enable console logs

Lightning logs useful information about the training process and user warnings to the console. You can retrieve the Lightning console logger and change it to your liking. For example, adjust the logging level or redirect output for certain modules to log files:

import logging

# configure logging at the root level of Lightning
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)

# configure logging on module level, redirect to file
logger = logging.getLogger("pytorch_lightning.core")
logger.addHandler(logging.FileHandler("core.log"))

Read more about custom Python logging here.

Debug your model

Early Stopping

Stopping an Epoch Early

You can stop and skip the rest of the current epoch early by overriding on_train_batch_start() to return -1 when some condition is met.

If you do this repeatedly, for every epoch you had originally requested, then this will stop your entire training.

EarlyStopping Callback

The EarlyStopping callback can be used to monitor a metric and stop the training when no improvement is observed.

To enable it:

  • Import EarlyStopping callback.

  • Log the metric you want to monitor using log() method.

  • Init the callback, and set monitor to the logged metric of your choice.

  • Set the mode based on the metric needs to be monitored.

  • Pass the EarlyStopping callback to the Trainer callbacks flag.

from pytorch_lightning.callbacks.early_stopping import EarlyStopping


class LitModel(LightningModule):
    def validation_step(self, batch, batch_idx):
        loss = ...
        self.log("val_loss", loss)


model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)

You can customize the callbacks behaviour by changing its parameters.

early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])

Additional parameters that stop training at extreme points:

  • stopping_threshold: Stops training immediately once the monitored quantity reaches this threshold. It is useful when we know that going beyond a certain optimal value does not further benefit us.

  • divergence_threshold: Stops training as soon as the monitored quantity becomes worse than this threshold. When reaching a value this bad, we believes the model cannot recover anymore and it is better to stop early and run with different initial conditions.

  • check_finite: When turned on, it stops training if the monitored metric becomes NaN or infinite.

  • check_on_train_epoch_end: When turned on, it checks the metric at the end of a training epoch. Use this only when you are monitoring any metric logged within training-specific hooks on epoch-level.

In case you need early stopping in a different part of training, subclass EarlyStopping and change where it is called:

class MyEarlyStopping(EarlyStopping):
    def on_validation_end(self, trainer, pl_module):
        # override this to disable early stopping at the end of val loop
        pass

    def on_train_end(self, trainer, pl_module):
        # instead, do it at the end of training loop
        self._run_early_stopping_check(trainer)

Note

The EarlyStopping callback runs at the end of every validation epoch by default. However, the frequency of validation can be modified by setting various parameters in the Trainer, for example check_val_every_n_epoch and val_check_interval. It must be noted that the patience parameter counts the number of validation checks with no improvement, and not the number of training epochs. Therefore, with parameters check_val_every_n_epoch=10 and patience=3, the trainer will perform at least 40 training epochs before being stopped.

Manage Experiments

To track other artifacts, such as histograms or model topology graphs first select one of the many experiment managers (loggers) supported by Lightning

from pytorch_lightning import loggers as pl_loggers

tensorboard = pl_loggers.TensorBoardLogger()
trainer = Trainer(logger=tensorboard)

then access the logger’s API directly

def training_step(self):
    tensorboard = self.logger.experiment
    tensorboard.add_image()
    tensorboard.add_histogram(...)
    tensorboard.add_figure(...)

Comet.ml

To use Comet.ml first install the comet package:

pip install comet-ml

Configure the logger and pass it to the Trainer:

from pytorch_lightning.loggers import CometLogger

comet_logger = CometLogger(api_key="YOUR_COMET_API_KEY")
trainer = Trainer(logger=comet_logger)

Access the comet logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class LitModel(LightningModule):
    def any_lightning_module_function_or_hook(self):
        comet = self.logger.experiment
        fake_images = torch.Tensor(32, 3, 28, 28)
        comet.add_image("generated_images", fake_images, 0)

Here’s the full documentation for the CometLogger.


MLflow

To use MLflow first install the MLflow package:

pip install mlflow

Configure the logger and pass it to the Trainer:

from pytorch_lightning.loggers import MLFlowLogger

mlf_logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs")
trainer = Trainer(logger=mlf_logger)

Access the mlflow logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class LitModel(LightningModule):
    def any_lightning_module_function_or_hook(self):
        mlf_logger = self.logger.experiment
        fake_images = torch.Tensor(32, 3, 28, 28)
        mlf_logger.add_image("generated_images", fake_images, 0)

Here’s the full documentation for the MLFlowLogger.


Neptune.ai

To use Neptune.ai first install the neptune package:

pip install neptune-client

or with conda:

conda install -c conda-forge neptune-client

Configure the logger and pass it to the Trainer:

from pytorch_lightning.loggers import NeptuneLogger

neptune_logger = NeptuneLogger(
    api_key="ANONYMOUS",  # replace with your own
    project="common/pytorch-lightning-integration",  # format "<WORKSPACE/PROJECT>"
)
trainer = Trainer(logger=neptune_logger)

Access the neptune logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class LitModel(LightningModule):
    def any_lightning_module_function_or_hook(self):
        neptune_logger = self.logger.experiment["your/metadata/structure"]
        neptune_logger.log(metadata)

Here’s the full documentation for the NeptuneLogger.


Tensorboard

TensorBoard can be installed with:

pip install tensorboard

Configure the logger and pass it to the Trainer:

from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger()
trainer = Trainer(logger=logger)

Access the tensorboard logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class LitModel(LightningModule):
    def any_lightning_module_function_or_hook(self):
        tensorboard_logger = self.logger.experiment
        fake_images = torch.Tensor(32, 3, 28, 28)
        tensorboard_logger.add_image("generated_images", fake_images, 0)

Here’s the full documentation for the TensorBoardLogger.


Weights and Biases

To use Weights and Biases (wandb) first install the wandb package:

pip install wandb

Configure the logger and pass it to the Trainer:

from pytorch_lightning.loggers import WandbLogger

wandb_logger = WandbLogger(project="MNIST", log_model="all")
trainer = Trainer(logger=wandb_logger)

# log gradients and model topology
wandb_logger.watch(model)

Access the wandb logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class MyModule(LightningModule):
    def any_lightning_module_function_or_hook(self):
        wandb_logger = self.logger.experiment
        fake_images = torch.Tensor(32, 3, 28, 28)

        # Option 1
        wandb_logger.log({"generated_images": [wandb.Image(fake_images, caption="...")]})

        # Option 2 for specifically logging images
        wandb_logger.log_image(key="generated_images", images=[fake_images])

Here’s the full documentation for the WandbLogger. Demo in Google Colab with hyperparameter search and model logging.


Use multiple exp managers

To use multiple experiment managers at the same time, pass a list to the logger Trainer argument.

from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

logger1 = TensorBoardLogger()
logger2 = WandbLogger()
trainer = Trainer(logger=[logger1, logger2])

Access all loggers from any function (except the LightningModule init) to use their APIs for tracking advanced artifacts

class MyModule(LightningModule):
    def any_lightning_module_function_or_hook(self):
        tensorboard_logger = self.loggers.experiment[0]
        wandb_logger = self.loggers.experiment[1]

        fake_images = torch.Tensor(32, 3, 28, 28)

        tensorboard_logger.add_image("generated_images", fake_images, 0)
        wandb_logger.add_image("generated_images", fake_images, 0)

Fault-tolerant Training

Transfer Learning

Audience: Users looking to use pretrained models with Lightning.


Use any PyTorch nn.Module

Any model that is a PyTorch nn.Module can be used with Lightning (because LightningModules are nn.Modules also).


Use a pretrained LightningModule

Let’s use the AutoEncoder as a feature extractor in a separate model.

class Encoder(torch.nn.Module):
    ...


class AutoEncoder(LightningModule):
    def __init__(self):
        self.encoder = Encoder()
        self.decoder = Decoder()


class CIFAR10Classifier(LightningModule):
    def __init__(self):
        # init the pretrained LightningModule
        self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
        self.feature_extractor.freeze()

        # the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
        self.classifier = nn.Linear(100, 10)

    def forward(self, x):
        representations = self.feature_extractor(x)
        x = self.classifier(representations)
        ...

We used our pretrained Autoencoder (a LightningModule) for transfer learning!


Example: Imagenet (Computer Vision)

import torchvision.models as models


class ImagenetTransferLearning(LightningModule):
    def __init__(self):
        super().__init__()

        # init a pretrained resnet
        backbone = models.resnet50(weights="DEFAULT")
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # use the pretrained model to classify cifar-10 (10 image classes)
        num_target_classes = 10
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        ...

Finetune

model = ImagenetTransferLearning()
trainer = Trainer()
trainer.fit(model)

And use it to predict your data of interest

model = ImagenetTransferLearning.load_from_checkpoint(PATH)
model.freeze()

x = some_images_from_cifar10()
predictions = model(x)

We used a pretrained model on imagenet, finetuned on CIFAR-10 to predict on CIFAR-10. In the non-academic world we would finetune on a tiny dataset you have and predict on your dataset.


Example: BERT (NLP)

Lightning is completely agnostic to what’s used for transfer learning so long as it is a torch.nn.Module subclass.

Here’s a model that uses Huggingface transformers.

class BertMNLIFinetuner(LightningModule):
    def __init__(self):
        super().__init__()

        self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
        self.W = nn.Linear(bert.config.hidden_size, 3)
        self.num_classes = 3

    def forward(self, input_ids, attention_mask, token_type_ids):
        h, _, attn = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        h_cls = h[:, 0]
        logits = self.W(h_cls)
        return logits, attn

Deploy models into production (intermediate)

Audience: Researchers and MLEs looking to use their models for predictions without Lightning dependencies.


Use PyTorch as normal

If you prefer to use PyTorch directly, feel free to use any Lightning checkpoint without Lightning.

import torch


class MyModel(nn.Module):
    ...


model = MyModel()
checkpoint = torch.load("path/to/lightning/checkpoint.ckpt")
model.load_state_dict(checkpoint["state_dict"])
model.eval()

Extract nn.Module from Lightning checkpoints

You can also load the saved checkpoint and use it as a regular torch.nn.Module. You can extract all your torch.nn.Module and load the weights using the checkpoint saved using LightningModule after training. For this, we recommend copying the exact implementation from your LightningModule init and forward method.

class Encoder(nn.Module):
    ...


class Decoder(nn.Module):
    ...


class AutoEncoderProd(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        return self.encoder(x)


class AutoEncoderSystem(LightningModule):
    def __init__(self):
        super().__init__()
        self.auto_encoder = AutoEncoderProd()

    def forward(self, x):
        return self.auto_encoder.encoder(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.auto_encoder.encoder(x)
        y_hat = self.auto_encoder.decoder(y_hat)
        loss = ...
        return loss


# train it
trainer = Trainer(devices=2, accelerator="gpu", strategy="ddp")
model = AutoEncoderSystem()
trainer.fit(model, train_dataloader, val_dataloader)
trainer.save_checkpoint("best_model.ckpt")


# create the PyTorch model and load the checkpoint weights
model = AutoEncoderProd()
checkpoint = torch.load("best_model.ckpt")
hyper_parameters = checkpoint["hyper_parameters"]

# if you want to restore any hyperparameters, you can pass them too
model = AutoEncoderProd(**hyper_parameters)

model_weights = checkpoint["state_dict"]

# update keys by dropping `auto_encoder.`
for key in list(model_weights):
    model_weights[key.replace("auto_encoder.", "")] = model_weights.pop(key)

model.load_state_dict(model_weights)
model.eval()
x = torch.randn(1, 64)

with torch.no_grad():
    y_hat = model(x)

LightningDataModule

A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data:


A datamodule encapsulates the five steps involved in data processing in PyTorch:

  1. Download / tokenize / process.

  2. Clean and (maybe) save to disk.

  3. Load inside Dataset.

  4. Apply transforms (rotate, tokenize, etc…).

  5. Wrap inside a DataLoader.


This class can then be shared and used anywhere:

from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule

model = LitClassifier()
trainer = Trainer()

imagenet = ImagenetDataModule()
trainer.fit(model, datamodule=imagenet)

cifar10 = CIFAR10DataModule()
trainer.fit(model, datamodule=cifar10)

Why do I need a DataModule?

In normal PyTorch code, the data cleaning/preparation is usually scattered across many files. This makes sharing and reusing the exact splits and transforms across projects impossible.

Datamodules are for you if you ever asked the questions:

  • what splits did you use?

  • what transforms did you use?

  • what normalization did you use?

  • how did you prepare/tokenize the data?


What is a DataModule?

A DataModule is simply a collection of a train_dataloader(s), val_dataloader(s), test_dataloader(s) and predict_dataloader(s) along with the matching transforms and data processing/downloads steps required.

Here’s a simple PyTorch example:

# regular PyTorch
test_data = MNIST(my_path, train=False, download=True)
predict_data = MNIST(my_path, train=False, download=True)
train_data = MNIST(my_path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])

train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
predict_loader = DataLoader(predict_data, batch_size=32)

The equivalent DataModule just organizes the same exact code, but makes it reusable across projects.

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: str):
        self.mnist_test = MNIST(self.data_dir, train=False)
        self.mnist_predict = MNIST(self.data_dir, train=False)
        mnist_full = MNIST(self.data_dir, train=True)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)

    def teardown(self, stage: str):
        # Used to clean-up when the run is finished
        ...

But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can let Lightning handle those details for you while making this dataset reusable so you can share with colleagues or use in different projects.

mnist = MNISTDataModule(my_path)
model = LitClassifier()

trainer = Trainer()
trainer.fit(model, mnist)

Here’s a more realistic, complex DataModule that shows how much more reusable the datamodule is.

import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)

LightningDataModule API

To define a DataModule the following methods are used to create train/val/test/predict dataloaders:

prepare_data

Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures the prepare_data() is called only within a single process on CPU, so you can safely add your downloading logic within. In case of multi-node training, the execution of this hook depends upon prepare_data_per_node. setup() is called after prepare_data and there is a barrier in between which ensures that all the processes proceed to setup once the data is prepared and available for use.

  • download, i.e. download data only once on the disk from a single process

  • tokenize. Since it’s a one time process, it is not recommended to do it on all processes

  • etc…

class MNISTDataModule(pl.LightningDataModule):
    def prepare_data(self):
        # download
        MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

Warning

prepare_data is called from the main process. It is not recommended to assign state here (e.g. self.x = y) since it is called on a single process and if you assign states here then they won’t be available for other processes.

setup

There are also data operations you might want to perform on every GPU. Use setup() to do things like:

  • count number of classes

  • build vocabulary

  • perform train/val/test splits

  • create datasets

  • apply transforms (defined explicitly in your datamodule)

  • etc…

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def setup(self, stage: str):
        # Assign Train/val split(s) for use in Dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign Test split(s) for use in Dataloaders
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform)

For eg., if you are working with NLP task where you need to tokenize the text and use it, then you can do something like as follows:

class LitDataModule(LightningDataModule):
    def prepare_data(self):
        dataset = load_Dataset(...)
        train_dataset = ...
        val_dataset = ...
        # tokenize
        # save it to disk

    def setup(self, stage):
        # load it back here
        dataset = load_dataset_from_disk(...)

This method expects a stage argument. It is used to separate setup logic for trainer.{fit,validate,test,predict}.

Note

setup is called from every process across all the nodes. Setting state here is recommended.

Note

teardown can be used to clean up the state. It is also called from every process across all the nodes.

train_dataloader

Use the train_dataloader() method to generate the training dataloader(s). Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer fit() method uses.

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=64)
val_dataloader

Use the val_dataloader() method to generate the validation dataloader(s). Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer fit() and validate() methods uses.

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=64)
test_dataloader

Use the test_dataloader() method to generate the test dataloader(s). Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer test() method uses.

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=64)
predict_dataloader

Use the predict_dataloader() method to generate the prediction dataloader(s). Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer predict() method uses.

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=64)
transfer_batch_to_device
LightningDataModule.transfer_batch_to_device(batch, device, dataloader_idx)

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

The data types listed below (and any arbitrary nesting of them) are supported out of the box:

For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, …).

Note

This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Note

This hook only runs on single GPU training and DDP (no data-parallel). Data-Parallel support will come in near future.

Parameters
  • batch (Any) – A batch of data that needs to be transferred to a new device.

  • device (device) – The target device as defined in PyTorch.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Return type

Any

Returns

A reference to the data on the new device.

Example:

def transfer_batch_to_device(self, batch, device, dataloader_idx):
    if isinstance(batch, CustomBatch):
        # move all tensors in your custom data structure to the device
        batch.samples = batch.samples.to(device)
        batch.targets = batch.targets.to(device)
    elif dataloader_idx == 0:
        # skip device transfer for the first dataloader or anything you wish
        pass
    else:
        batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
    return batch
Raises
  • MisconfigurationException – If using data-parallel, Trainer(strategy='dp').

  • MisconfigurationException – If using IPUs, Trainer(accelerator='ipu').

See also

  • move_data_to_device()

  • apply_to_collection()

on_before_batch_transfer
LightningDataModule.on_before_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

Note

To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Note

This hook only runs on single GPU training and DDP (no data-parallel). Data-Parallel support will come in near future.

Parameters
  • batch (Any) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Return type

Any

Returns

A batch of data

Example:

def on_before_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = transforms(batch['x'])
    return batch

See also

  • on_after_batch_transfer()

  • transfer_batch_to_device()

on_after_batch_transfer
LightningDataModule.on_after_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

Note

To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Note

This hook only runs on single GPU training and DDP (no data-parallel). Data-Parallel support will come in near future.

Parameters
  • batch (Any) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Return type

Any

Returns

A batch of data

Example:

def on_after_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = gpu_transforms(batch['x'])
    return batch
Raises
  • MisconfigurationException – If using data-parallel, Trainer(strategy='dp').

  • MisconfigurationException – If using IPUs, Trainer(accelerator='ipu').

See also

  • on_before_batch_transfer()

  • transfer_batch_to_device()

load_state_dict
LightningDataModule.load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

Parameters

state_dict (Dict[str, Any]) – the datamodule state returned by state_dict.

Return type

None

state_dict
LightningDataModule.state_dict()[source]

Called when saving a checkpoint, implement to generate and save datamodule state.

Return type

Dict[str, Any]

Returns

A dictionary containing datamodule state.

teardown
LightningDataModule.teardown(stage)

Called at the end of fit (train + validate), validate, test, or predict.

Parameters

stage (str) – either 'fit', 'validate', 'test', or 'predict'

Return type

None

prepare_data_per_node

If set to True will call prepare_data() on LOCAL_RANK=0 for every node. If set to False will only call from NODE_RANK=0, LOCAL_RANK=0.

class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = True

Using a DataModule

The recommended way to use a DataModule is simply:

dm = MNISTDataModule()
model = Model()
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
trainer.validate(datamodule=dm)
trainer.predict(datamodule=dm)

If you need information from the dataset to build your model, then run prepare_data and setup manually (Lightning ensures the method runs on the correct devices).

dm = MNISTDataModule()
dm.prepare_data()
dm.setup(stage="fit")

model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)

dm.setup(stage="test")
trainer.test(datamodule=dm)

You can access the current used datamodule of a trainer via trainer.datamodule and the current used dataloaders via trainer.train_dataloader, trainer.val_dataloaders and trainer.test_dataloaders.


DataModules without Lightning

You can of course use DataModules in plain PyTorch code as well.

# download, etc...
dm = MNISTDataModule()
dm.prepare_data()

# splits/transforms
dm.setup(stage="fit")

# use data
for batch in dm.train_dataloader():
    ...

for batch in dm.val_dataloader():
    ...

dm.teardown(stage="fit")

# lazy load test data
dm.setup(stage="test")
for batch in dm.test_dataloader():
    ...

dm.teardown(stage="test")

But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified structure.


Hyperparameters in DataModules

Like LightningModules, DataModules support hyperparameters with the same API.

import pytorch_lightning as pl


class CustomDataModule(pl.LightningDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

    def configure_optimizers(self):
        # access the saved hyperparameters
        opt = optim.Adam(self.parameters(), lr=self.hparams.lr)

Refer to save_hyperparameters in lightning module for more details.


Save DataModule state

When a checkpoint is created, it asks every DataModule for their state. If your DataModule defines the state_dict and load_state_dict methods, the checkpoint will automatically track and restore your DataModules.

class LitDataModule(pl.DataModuler):
    def state_dict(self):
        # track whatever you want here
        state = {"current_train_batch_index": self.current_train_batch_index}
        return state

    def load_state_dict(self, state_dict):
        # restore the state based on what you tracked in (def state_dict)
        self.current_train_batch_index = state_dict["current_train_batch_index"]

Track and Visualize Experiments

Loops

Loops let advanced users swap out the default gradient descent optimization loop at the core of Lightning with a different optimization paradigm.

The Lightning Trainer is built on top of the standard gradient descent optimization loop which works for 90%+ of machine learning use cases:

for i, batch in enumerate(dataloader):
    x, y = batch
    y_hat = model(x)
    loss = loss_function(y_hat, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

However, some new research use cases such as meta-learning, active learning, recommendation systems, etc., require a different loop structure. For example here is a simple loop that guides the weight updates with a loss from a special validation split:

for i, batch in enumerate(train_dataloader):
    x, y = batch
    y_hat = model(x)
    loss = loss_function(y_hat, y)
    optimizer.zero_grad()
    loss.backward()

    val_loss = 0
    for i, val_batch in enumerate(val_dataloader):
        x, y = val_batch
        y_hat = model(x)
        val_loss += loss_function(y_hat, y)

    scale_gradients(model, 1 / val_loss)
    optimizer.step()

With Lightning Loops, you can customize to non-standard gradient descent optimizations to get the same loop above:

trainer = Trainer()
trainer.fit_loop.epoch_loop = MyGradientDescentLoop()

Think of this as swapping out the engine in a car!


Understanding the Default Trainer Loop

The Lightning Trainer automates the standard optimization loop which every PyTorch user is familiar with:

for i, batch in enumerate(dataloader):
    x, y = batch
    y_hat = model(x)
    loss = loss_function(y_hat, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

The core research logic is simply shifted to the LightningModule:

for i, batch in enumerate(dataloader):
    # x, y = batch                      moved to training_step
    # y_hat = model(x)                  moved to training_step
    # loss = loss_function(y_hat, y)    moved to training_step
    loss = lightning_module.training_step(batch, i)

    # Lightning handles automatically:
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Under the hood, the above loop is implemented using the Loop API like so:

class DefaultLoop(Loop):
    def advance(self, batch, i):
        loss = lightning_module.training_step(batch, i)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    def run(self, dataloader):
        for i, batch in enumerate(dataloader):
            self.advance(batch, i)

Defining a loop within a class interface instead of hard-coding a raw Python for/while loop has several benefits:

  1. You can have full control over the data flow through loops.

  2. You can add new loops and nest as many of them as you want.

  3. If needed, the state of a loop can be saved and resumed.

  4. New hooks can be injected at any point.

Animation showing how to convert a standard training loop to a Lightning loop

Overriding the default Loops

The fastest way to get started with loops, is to override functionality of an existing loop. Lightning has 4 main loops which relies on : FitLoop for fitting (training and validating), EvaluationLoop for validating or testing, PredictionLoop for predicting.

For simple changes that don’t require a custom loop, you can modify each of these loops.

Each loop has a series of methods that can be modified. For example with the FitLoop:

from pytorch_lightning.loops import FitLoop


class MyLoop(FitLoop):
    def advance(self):
        """Advance from one iteration to the next."""

    def on_advance_end(self):
        """Do something at the end of an iteration."""

    def on_run_end(self):
        """Do something when the loop ends."""

A full list with all built-in loops and subloops can be found here.

To add your own modifications to a loop, simply subclass an existing loop class and override what you need. Here is a simple example how to add a new hook:

from pytorch_lightning.loops import FitLoop


class CustomFitLoop(FitLoop):
    def advance(self):
        """Put your custom logic here."""

Now simply attach the correct loop in the trainer directly:

trainer = Trainer(...)
trainer.fit_loop = CustomFitLoop()

# fit() now uses the new FitLoop!
trainer.fit(...)

# the equivalent for validate()
val_loop = CustomValLoop()
trainer = Trainer()
trainer.validate_loop = val_loop
trainer.validate(...)

Now your code is FULLY flexible and you can still leverage ALL the best parts of Lightning!

Animation showing how to replace a loop on the Trainer

Creating a New Loop From Scratch

You can also go wild and implement a full loop from scratch by sub-classing the Loop base class. You will need to override a minimum of two things:

from pytorch_lightning.loop import Loop


class MyFancyLoop(Loop):
    @property
    def done(self):
        """Provide a condition to stop the loop."""

    def advance(self):
        """
        Access your dataloader/s in whatever way you want.
        Do your fancy optimization things.
        Call the LightningModule methods at your leisure.
        """

Finally, attach it into the Trainer:

trainer = Trainer(...)
trainer.fit_loop = MyFancyLoop()

# fit() now uses your fancy loop!
trainer.fit(...)

But beware: Loop customization gives you more power and full control over the Trainer and with great power comes great responsibility. We recommend that you familiarize yourself with overriding the default loops first before you start building a new loop from the ground up.


Loop API

Here is the full API of methods available in the Loop base class.

The Loop class is the base of all loops in the same way as the LightningModule is the base of all models. It defines a public interface that each loop implementation must follow, the key ones are:

Properties
done
Loop.done

Property indicating when the loop is finished.

Example:

@property
def done(self):
    return self.trainer.global_step >= self.trainer.max_steps
Return type

bool

skip (optional)
Loop.skip

Determine whether to return immediately from the call to run().

Example:

@property
def skip(self):
    return len(self.trainer.train_dataloader) == 0
Return type

bool

Methods
reset (optional)
abstract Loop.reset()[source]

Resets the internal state of the loop at the beginning of each call to run.

Example:

def reset(self):
    # reset your internal state or add custom logic
    # if you expect run() to be called multiple times
    self.current_iteration = 0
    self.outputs = []
Return type

None

advance
abstract Loop.advance(*args, **kwargs)[source]

Performs a single step.

Accepts all arguments passed to run.

Example:

def advance(self, iterator):
    batch = next(iterator)
    loss = self.trainer.lightning_module.training_step(batch, batch_idx)
    ...
Return type

None

run (optional)
Loop.run(*args, **kwargs)[source]

The main entry point to the loop.

Will frequently check the done condition and calls advance until done evaluates to True.

Override this if you wish to change the default behavior. The default implementation is:

Example:

def run(self, *args, **kwargs):
    if self.skip:
        return self.on_skip()

    self.reset()
    self.on_run_start(*args, **kwargs)

    while not self.done:
        self.advance(*args, **kwargs)

    output = self.on_run_end()
    return output
Return type

TypeVar(T)

Returns

The output of on_run_end (often outputs collected from each step of the loop)


Subloops

When you want to customize nested loops within loops, use the replace() method:

# This takes care of properly instantiating the new Loop and setting all references
trainer.fit_loop.replace(epoch_loop=MyEpochLoop)
# Trainer runs the fit loop with your new epoch loop!
trainer.fit(model)

Alternatively, for more fine-grained control, use the connect() method:

# Optional: stitch back the trainer arguments
epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps)
# Optional: connect children loops as they might have existing state
epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop)
# Instantiate and connect the loop.
trainer.fit_loop.connect(epoch_loop=epoch_loop)
trainer.fit(model)

More about the built-in loops and how they are composed is explained in the next section.

Animation showing how to connect a custom subloop

Built-in Loops

The training loop in Lightning is called fit loop and is actually a combination of several loops. Here is what the structure would look like in plain Python:

# FitLoop
for epoch in range(max_epochs):
    # TrainingEpochLoop
    for batch_idx, batch in enumerate(train_dataloader):
        # TrainingBatchLoop
        for split_batch in tbptt_split(batch):
            # OptimizerLoop
            for optimizer_idx, opt in enumerate(optimizers):
                loss = lightning_module.training_step(batch, batch_idx, optimizer_idx)
                ...

        # ValidationEpochLoop
        for batch_idx, batch in enumerate(val_dataloader):
            lightning_module.validation_step(batch, batch_idx, optimizer_idx)
            ...

Each of these for-loops represents a class implementing the Loop interface.

Trainer entry points and associated loops

Built-in loop

Description

FitLoop

The FitLoop is the top-level loop where training starts. It simply counts the epochs and iterates from one to the next by calling TrainingEpochLoop.run() in its advance() method.

TrainingEpochLoop

The TrainingEpochLoop is the one that iterates over the dataloader that the user returns in their train_dataloader() method. Its main responsibilities are calling the *_epoch_start and *_epoch_end hooks, accumulating outputs if the user request them in one of these hooks, and running validation at the requested interval. The validation is carried out by yet another loop, ValidationEpochLoop.

In the run() method, the training epoch loop could in theory simply call the LightningModule.training_step already and perform the optimization. However, Lightning has built-in support for automatic optimization with multiple optimizers and on top of that also supports TBPTT. For this reason there are actually two more loops nested under TrainingEpochLoop.

TrainingBatchLoop

The responsibility of the TrainingBatchLoop is to split a batch given by the TrainingEpochLoop along the time-dimension and iterate over the list of splits. It also keeps track of the hidden state hiddens returned by the training step. By default, when truncated back-propagation through time (TBPTT) is turned off, this loop does not do anything except redirect the call to the OptimizerLoop. Read more about TBPTT.

OptimizerLoop

The OptimizerLoop iterates over one or multiple optimizers and for each one it calls the training_step() method with the batch, the current batch index and the optimizer index if multiple optimizers are requested. It is the leaf node in the tree of loops and performs the actual optimization (forward, zero grad, backward, optimizer step).

ManualOptimization

Substitutes the OptimizerLoop in case of manual optimization and implements the manual optimization step.

EvaluationLoop

The EvaluationLoop is the top-level loop where validation/testing starts. It simply iterates over each evaluation dataloader from one to the next by calling EvaluationEpochLoop.run() in its advance() method.

PredictionLoop

The PredictionLoop is the top-level loop where prediction starts. It simply iterates over each prediction dataloader from one to the next by calling PredictionEpochLoop.run() in its advance() method.


Available Loops in Lightning Flash

Active Learning is a machine learning practice in which the user interacts with the learner in order to provide new labels when required.

You can find a real use case in Lightning Flash.

Flash implements the ActiveLearningLoop that you can use together with the ActiveLearningDataModule to label new data on the fly. To run the following demo, install Flash and BaaL first:

pip install lightning-flash[image] baal
import torch

import flash
from flash.core.classification import Probabilities
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")

# Implement the research use-case where we mask labels from labelled dataset.
datamodule = ActiveLearningDataModule(
    ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2),
    initial_num_labels=5,
    val_split=0.1,
)

# 2. Build the task
head = torch.nn.Sequential(
    torch.nn.Dropout(p=0.1),
    torch.nn.Linear(512, datamodule.num_classes),
)
model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, output=Probabilities())


# 3.1 Create the trainer
trainer = flash.Trainer(max_epochs=3)

# 3.2 Create the active learning loop and connect it to the trainer
active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1)
active_learning_loop.connect(trainer.fit_loop)
trainer.fit_loop = active_learning_loop

# 3.3 Finetune
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict what's on a few images! ants or bees?
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Here is the Active Learning Loop example and the code for the active learning loop.


Advanced Examples

Ready-to-run loop examples and tutorials

Link to Example

Description

K-fold Cross Validation

KFold / Cross Validation is a machine learning practice in which the training dataset is being partitioned into num_folds complementary subsets. One cross validation round will perform fitting where one fold is left out for validation and the other folds are used for training. To reduce variability, once all rounds are performed using the different folds, the trained models are ensembled and their predictions are averaged when estimating the model’s predictive performance on the test dataset.

Yielding Training Step

This loop enables you to write the training_step() hook as a Python Generator for automatic optimization with multiple optimizers, i.e., you can yield loss values from it instead of returning them. This can enable more elegant and expressive implementations, as shown shown with a GAN in this example.


Advanced Features

Next: Advanced loop features

Plugins

Plugins allow custom integrations to the internals of the Trainer such as custom precision, checkpointing or cluster environment implementation.

Under the hood, the Lightning Trainer is using plugins in the training routine, added automatically depending on the provided Trainer arguments.

There are three types of Plugins in Lightning with different responsibilities:

  • Precision Plugins

  • CheckpointIO Plugins

  • Cluster Environments

You can make the Trainer use one or multiple plugins by adding it to the plugins argument like so:

trainer = Trainer(plugins=[plugin1, plugin2, ...])

By default, the plugins get selected based on the rest of the Trainer settings such as the strategy.


Precision Plugins

We provide precision plugins for you to benefit from numerical representations with lower precision than 32-bit floating-point or higher precision, such as 64-bit floating-point.

# Training with 16-bit precision
trainer = Trainer(precision=16)

The full list of built-in precision plugins is listed below.

ColossalAIPrecisionPlugin

Precision plugin for ColossalAI integration.

DeepSpeedPrecisionPlugin

Precision plugin for DeepSpeed integration.

DoublePrecisionPlugin

Plugin for training with double (torch.float64) precision.

FullyShardedNativeMixedPrecisionPlugin

Native AMP for Fully Sharded Training.

FullyShardedNativeNativeMixedPrecisionPlugin

Native AMP for Fully Sharded Native Training.

HPUPrecisionPlugin

Plugin that enables bfloat/half support on HPUs.

IPUPrecisionPlugin

Precision plugin for IPU integration.

MixedPrecisionPlugin

Plugin for Automatic Mixed Precision (AMP) training with torch.autocast.

PrecisionPlugin

Base class for all plugins handling the precision-specific parts of the training.

ShardedNativeMixedPrecisionPlugin

Native AMP for Sharded Training.

TPUBf16PrecisionPlugin

Plugin that enables bfloats on TPUs.

TPUPrecisionPlugin

Precision plugin for TPU integration.

More information regarding precision with Lightning can be found here


CheckpointIO Plugins

As part of our commitment to extensibility, we have abstracted Lightning’s checkpointing logic into the CheckpointIO plugin. With this, you have the ability to customize the checkpointing logic to match the needs of your infrastructure.

Below is a list of built-in plugins for checkpointing.

AsyncCheckpointIO

AsyncCheckpointIO enables saving the checkpoints asynchronously in a thread.

CheckpointIO

Interface to save/load checkpoints as they are saved through the Strategy.

HPUCheckpointIO

CheckpointIO to save checkpoints for HPU training strategies.

TorchCheckpointIO

CheckpointIO that utilizes torch.save() and torch.load() to save and load checkpoints respectively, common for most use cases.

XLACheckpointIO

CheckpointIO that utilizes xm.save() to save checkpoints for TPU training strategies.

Learn more about custom checkpointing with Lightning here.


Cluster Environments

You can define the interface of your own cluster environment based on the requirements of your infrastructure.

ClusterEnvironment

Specification of a cluster environment.

KubeflowEnvironment

Environment for distributed training using the PyTorchJob operator from Kubeflow

LightningEnvironment

The default environment used by Lightning for a single node or free cluster (not managed).

LSFEnvironment

An environment for running on clusters managed by the LSF resource manager.

SLURMEnvironment

Cluster environment for training on a cluster managed by SLURM.

TorchElasticEnvironment

Environment for fault-tolerant and elastic training with torchelastic

XLAEnvironment

Cluster environment for training on a TPU Pod with the PyTorch/XLA library.

Deploy models into production (advanced)

Audience: Machine learning engineers optimizing models for enterprise-scale production environments.


Compile your model to ONNX

ONNX is a package developed by Microsoft to optimize inference. ONNX allows the model to be independent of PyTorch and run on any ONNX Runtime.

To export your model to ONNX format call the to_onnx() function on your LightningModule with the filepath and input_sample.

class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(in_features=64, out_features=4)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))


# create the model
model = SimpleModel()
filepath = "model.onnx"
input_sample = torch.randn((1, 64))
model.to_onnx(filepath, input_sample, export_params=True)

You can also skip passing the input sample if the example_input_array property is specified in your LightningModule.

class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(in_features=64, out_features=4)
        self.example_input_array = torch.randn(7, 64)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))


# create the model
model = SimpleModel()
filepath = "model.onnx"
model.to_onnx(filepath, export_params=True)

Once you have the exported model, you can run it on your ONNX runtime in the following way:

import onnxruntime

ort_session = onnxruntime.InferenceSession(filepath)
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.random.randn(1, 64)}
ort_outs = ort_session.run(None, ort_inputs)

Validate a Model Is Servable

Production ML Engineers would argue that a model shouldn’t be trained if it can’t be deployed reliably and in a fully automated manner.

In order to ease transition from training to production, PyTorch Lightning provides a way for you to validate a model can be served even before starting training.

In order to do so, your LightningModule needs to subclass the ServableModule, implements its hooks and pass a ServableModuleValidator callback to the Trainer.

Below you can find an example of how the serving of a resnet18 can be validated.

import base64
from dataclasses import dataclass
from io import BytesIO
from os import path
from typing import Dict, Optional

import numpy as np
import torch
import torchvision
import torchvision.transforms as T
from PIL import Image as PILImage

from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.serve import ServableModule, ServableModuleValidator
from pytorch_lightning.utilities.model_helpers import get_torchvision_model

DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")


class LitModule(LightningModule):
    def __init__(self, name: str = "resnet18"):
        super().__init__()
        self.model = get_torchvision_model(name, weights="DEFAULT")
        self.model.fc = torch.nn.Linear(self.model.fc.in_features, 10)
        self.criterion = torch.nn.CrossEntropyLoss()

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self.model(inputs)
        loss = self.criterion(outputs, labels)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self.model(inputs)
        loss = self.criterion(outputs, labels)
        self.log("val_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)


class CIFAR10DataModule(LightningDataModule):
    transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])

    def train_dataloader(self, *args, **kwargs):
        trainset = torchvision.datasets.CIFAR10(root=DATASETS_PATH, train=True, download=True, transform=self.transform)
        return torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=True, num_workers=0)

    def val_dataloader(self, *args, **kwargs):
        valset = torchvision.datasets.CIFAR10(root=DATASETS_PATH, train=False, download=True, transform=self.transform)
        return torch.utils.data.DataLoader(valset, batch_size=2, shuffle=True, num_workers=0)


@dataclass(unsafe_hash=True)
class Image:
    height: Optional[int] = None
    width: Optional[int] = None
    extension: str = "JPEG"
    mode: str = "RGB"
    channel_first: bool = False

    def deserialize(self, data: str) -> torch.Tensor:
        encoded_with_padding = (data + "===").encode("UTF-8")
        img = base64.b64decode(encoded_with_padding)
        buffer = BytesIO(img)
        img = PILImage.open(buffer, mode="r")
        if self.height and self.width:
            img = img.resize((self.width, self.height))
        arr = np.array(img)
        return T.ToTensor()(arr).unsqueeze(0)


class Top1:
    def serialize(self, tensor: torch.Tensor) -> int:
        return torch.nn.functional.softmax(tensor).argmax().item()


class ProductionReadyModel(LitModule, ServableModule):
    def configure_payload(self):
        # 1: Access the train dataloader and load a single sample.
        image, _ = self.trainer.train_dataloader.loaders.dataset[0]

        # 2: Convert the image into a PIL Image to bytes and encode it with base64
        pil_image = T.ToPILImage()(image)
        buffered = BytesIO()
        pil_image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode("UTF-8")

        payload = {"body": {"x": img_str}}
        return payload

    def configure_serialization(self):
        return {"x": Image(224, 224).deserialize}, {"output": Top1().serialize}

    def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        return {"output": self.model(x)}

    def configure_response(self):
        return {"output": 7}


def cli_main():
    cli = LightningCLI(
        ProductionReadyModel,
        CIFAR10DataModule,
        seed_everything_default=42,
        save_config_kwargs={"overwrite": True},
        run=False,
        trainer_defaults={
            "callbacks": [ServableModuleValidator()],
            "max_epochs": 1,
            "limit_train_batches": 5,
            "limit_val_batches": 5,
        },
    )
    cli.trainer.fit(cli.model, cli.datamodule)


if __name__ == "__main__":
    cli_lightning_logo()
    cli_main()

Deploy models into production (basic)

Audience: All users.


Load a checkpoint and predict

The easiest way to use a model for predictions is to load the weights using load_from_checkpoint found in the LightningModule.

model = LitModel.load_from_checkpoint("best_model.ckpt")
model.eval()
x = torch.randn(1, 64)

with torch.no_grad():
    y_hat = model(x)

Predict step with your LightningModule

Loading a checkpoint and predicting still leaves you with a lot of boilerplate around the predict epoch. The predict step in the LightningModule removes this boilerplate.

class MyModel(LightningModule):
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

And pass in any dataloader to the Lightning Trainer:

data_loader = DataLoader(...)
model = MyModel()
trainer = Trainer()
predictions = trainer.predict(model, data_loader)

Enable complicated predict logic

When you need to add complicated pre-processing or post-processing logic to your data use the predict step. For example here we do Monte Carlo Dropout for predictions:

class LitMCdropoutModel(pl.LightningModule):
    def __init__(self, model, mc_iteration):
        super().__init__()
        self.model = model
        self.dropout = nn.Dropout()
        self.mc_iteration = mc_iteration

    def predict_step(self, batch, batch_idx):
        # enable Monte Carlo Dropout
        self.dropout.train()

        # take average of `self.mc_iteration` iterations
        pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
        pred = torch.vstack(pred).mean(dim=0)
        return pred

Enable distributed inference

By using the predict step in Lightning you get free distributed inference using BasePredictionWriter.

import torch
from pytorch_lightning.callbacks import BasePredictionWriter


class CustomWriter(BasePredictionWriter):
    def __init__(self, output_dir, write_interval):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
        torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))

        # optionally, you can also save `batch_indices` to get the information about the data index
        # from your prediction data
        torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))


# or you can set `writer_interval="batch"` and override `write_on_batch_end` to save
# predictions at batch level
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])
model = BoringModel()
trainer.predict(model, return_predictions=False)

Pruning and Quantization

Pruning and Quantization are techniques to compress model size for deployment, allowing inference speed up and energy saving without significant accuracy losses.

Pruning

Warning

Pruning is in beta and subject to change.

Pruning is a technique which focuses on eliminating some of the model weights to reduce the model size and decrease inference requirements.

Pruning has been shown to achieve significant efficiency improvements while minimizing the drop in model performance (prediction quality). Model pruning is recommended for cloud endpoints, deploying models on edge devices, or mobile inference (among others).

To enable pruning during training in Lightning, simply pass in the ModelPruning callback to the Lightning Trainer. PyTorch’s native pruning implementation is used under the hood.

This callback supports multiple pruning functions: pass any torch.nn.utils.prune function as a string to select which weights to prune (random_unstructured, RandomStructured, etc) or implement your own by subclassing BasePruningMethod.

from pytorch_lightning.callbacks import ModelPruning

# set the amount to be the fraction of parameters to prune
trainer = Trainer(callbacks=[ModelPruning("l1_unstructured", amount=0.5)])

You can also perform iterative pruning, apply the lottery ticket hypothesis, and more!

def compute_amount(epoch):
    # the sum of all returned values need to be smaller than 1
    if epoch == 10:
        return 0.5

    elif epoch == 50:
        return 0.25

    elif 75 < epoch < 99:
        return 0.01


# the amount can be also be a callable
trainer = Trainer(callbacks=[ModelPruning("l1_unstructured", amount=compute_amount)])

Quantization

Warning

Quantization is in beta and subject to change.

Model quantization is another performance optimization technique that allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating-point precision. This is particularly beneficial during model deployment.

There are two model quantization methods, Quantization Aware Training (QAT) and Post-training Quantization (PTQ). QAT mimics the effects of quantization during training: The computations are carried-out in floating-point precision but the subsequent quantization effect is taken into account. The weights and activations are quantized into lower precision only for inference, when training is completed. PTQ focuses on quantize the fine-tuned model without retraining. The weights and activations of ops are converted into lower precision for saving the memory and computation losses.

Quantization is useful when it is required to serve large models on machines with limited memory, or when there’s a need to switch between models and reducing the I/O time is important. For example, switching between monolingual speech recognition models across multiple languages.

Quantization Aware Training

Lightning includes QuantizationAwareTraining callback (using PyTorch’s native quantization, read more here), which allows creating fully quantized models (compatible with torchscript).

from pytorch_lightning.callbacks import QuantizationAwareTraining


class RegressionModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_0 = nn.Linear(16, 64)
        self.layer_0a = torch.nn.ReLU()
        self.layer_1 = nn.Linear(64, 64)
        self.layer_1a = torch.nn.ReLU()
        self.layer_end = nn.Linear(64, 1)

    def forward(self, x):
        x = self.layer_0(x)
        x = self.layer_0a(x)
        x = self.layer_1(x)
        x = self.layer_1a(x)
        x = self.layer_end(x)
        return x


trainer = Trainer(callbacks=[QuantizationAwareTraining()])
qmodel = RegressionModel()
trainer.fit(qmodel, ...)

batch = iter(my_dataloader()).next()
qmodel(qmodel.quant(batch[0]))

tsmodel = qmodel.to_torchscript()
tsmodel(tsmodel.quant(batch[0]))

You can further customize the callback:

qcb = QuantizationAwareTraining(
    # specification of quant estimation quality
    observer_type="histogram",
    # specify which layers shall be merged together to increase efficiency
    modules_to_fuse=[(f"layer_{i}", f"layer_{i}a") for i in range(2)],
    # make your model compatible with all original input/outputs, in such case the model is wrapped in a shell with entry/exit layers.
    input_compatible=True,
)

batch = iter(my_dataloader()).next()
qmodel(batch[0])
Post-training Quantization

If you want to quantize a fine-tuned model with PTQ, it is recommended to adopt a third party API names Intel® Neural Compressor, read more here, which provides a convenient tool for accelerating the model inference speed on Intel CPUs and GPUs.

Remote Filesystems

PyTorch Lightning enables working with data from a variety of filesystems, including local filesystems and several cloud storage providers such as S3 on AWS, GCS on Google Cloud, or ADL on Azure.

This applies to saving and writing checkpoints, as well as for logging. Working with different filesystems can be accomplished by appending a protocol like “s3:/” to file paths for writing and reading data.

# `default_root_dir` is the default path used for logs and checkpoints
trainer = Trainer(default_root_dir="s3://my_bucket/data/")
trainer.fit(model)

You could pass custom paths to loggers for logging data.

from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger(save_dir="s3://my_bucket/logs/")

trainer = Trainer(logger=logger)
trainer.fit(model)

Additionally, you could also resume training with a checkpoint stored at a remote filesystem.

trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
trainer.fit(model, ckpt_path="s3://my_bucket/ckpts/classifier.ckpt")

PyTorch Lightning uses fsspec internally to handle all filesystem operations.

The most common filesystems supported by Lightning are:

  • Local filesystem: file:// - It’s the default and doesn’t need any protocol to be used. It’s installed by default in Lightning.

  • Amazon S3: s3:// - Amazon S3 remote binary store, using the library s3fs. Run pip install fsspec[s3] to install it.

  • Google Cloud Storage: gcs:// or gs:// - Google Cloud Storage, using gcsfs. Run pip install fsspec[gcs] to install it.

  • Microsoft Azure Storage: adl://, abfs:// or az:// - Microsoft Azure Storage, using adlfs. Run pip install fsspec[adl] to install it.

  • Hadoop File System: hdfs:// - Hadoop Distributed File System. This uses PyArrow as the backend. Run pip install fsspec[hdfs] to install it.

You could learn more about the available filesystems with:

from fsspec.registry import known_implementations

print(known_implementations)

You could also look into CheckpointIO Plugin for more details on how to customize saving and loading checkpoints.

What is a Strategy?

Strategy controls the model distribution across training, evaluation, and prediction to be used by the Trainer. It can be controlled by passing different strategy with aliases ("ddp", "ddp_spawn", "deepspeed" and so on) as well as a custom strategy to the strategy parameter for Trainer.

The Strategy in PyTorch Lightning handles the following responsibilities:

  • Launch and teardown of training processes (if applicable).

  • Setup communication between processes (NCCL, GLOO, MPI, and so on).

  • Provide a unified communication interface for reduction, broadcast, and so on.

  • Owns the LightningModule

  • Handles/owns optimizers and schedulers.

Strategy is a composition of one Accelerator, one Precision Plugin, a CheckpointIO plugin and other optional plugins such as the ClusterEnvironment.

Illustration of the Strategy as a composition of the Accelerator and several plugins

We expose Strategies mainly for expert users that want to extend Lightning for new hardware support or new distributed backends (e.g. a backend not yet supported by PyTorch itself).


Selecting a Built-in Strategy

Built-in strategies can be selected in two ways.

  1. Pass the shorthand name to the strategy Trainer argument

  2. Import a Strategy from pytorch_lightning.strategies, instantiate it and pass it to the strategy Trainer argument

The latter allows you to configure further options on the specifc strategy. Here are some examples:

# Training with the DistributedDataParallel strategy on 4 GPUs
trainer = Trainer(strategy="ddp", accelerator="gpu", devices=4)

# Training with the DistributedDataParallel strategy on 4 GPUs, with options configured
trainer = Trainer(strategy=DDPStrategy(find_unused_parameters=False), accelerator="gpu", devices=4)

# Training with the DDP Spawn strategy using auto accelerator selection
trainer = Trainer(strategy="ddp_spawn", accelerator="auto", devices=4)

# Training with the DeepSpeed strategy on available GPUs
trainer = Trainer(strategy="deepspeed", accelerator="gpu", devices="auto")

# Training with the DDP strategy using 3 CPU processes
trainer = Trainer(strategy="ddp", accelerator="cpu", devices=3)

# Training with the DDP Spawn strategy on 8 TPU cores
trainer = Trainer(strategy="ddp_spawn", accelerator="tpu", devices=8)

# Training with the default IPU strategy on 8 IPUs
trainer = Trainer(accelerator="ipu", devices=8)

The below table lists all relevant strategies available in Lightning with their corresponding short-hand name:

Strategy Classes and Nicknames

Name

Class

Description

bagua

BaguaStrategy

Strategy for training using the Bagua library, with advanced distributed training algorithms and system optimizations. Learn more.

collaborative

HivemindStrategy

Strategy for training collaboratively on local machines or unreliable GPUs across the internet. Learn more.

colossalai

ColossalAIStrategy

Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. Learn more.

fsdp_native

DDPFullyShardedNativeStrategy

Strategy for Fully Sharded Data Parallel. Learn more.

ddp_spawn

DDPSpawnStrategy

Spawns processes using the torch.multiprocessing.spawn() method and joins processes after training finishes. Learn more.

ddp

DDPStrategy

Strategy for multi-process single-device training on one or multiple nodes. Learn more.

dp

DataParallelStrategy

Implements data-parallel training in a single process, i.e., the model gets replicated to each device and each gets a split of the data. Learn more.

deepspeed

DeepSpeedStrategy

Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. Learn more.

hpu_parallel

HPUParallelStrategy

Strategy for distributed training on multiple HPU devices. Learn more.

hpu_single

SingleHPUStrategy

Strategy for training on a single HPU device. Learn more.

ipu_strategy

IPUStrategy

Plugin for training on IPU devices. Learn more.

tpu_spawn

TPUSpawnStrategy

Strategy for training on multiple TPU devices using the torch_xla.distributed.xla_multiprocessing.spawn() method. Learn more.

single_tpu

SingleTPUStrategy

Strategy for training on a single TPU device. Learn more.


Create a Custom Strategy

Every strategy in Lightning is a subclass of one of the main base classes: Strategy, SingleDeviceStrategy or ParallelStrategy.

Strategy base classes

As an expert user, you may choose to extend either an existing built-in Strategy or create a completely new one by subclassing the base classes.

from pytorch_lightning.strategies import DDPStrategy


class CustomDDPStrategy(DDPStrategy):
    def configure_ddp(self):
        self.model = MyCustomDistributedDataParallel(
            self.model,
            device_ids=...,
        )

    def setup(self, trainer):
        # you can access the accelerator and plugins directly
        self.accelerator.setup()
        self.precision_plugin.connect(...)

The custom strategy can then be passed into the Trainer directly via the strategy parameter.

# custom strategy
trainer = Trainer(strategy=CustomDDPStrategy())

Since the strategy also hosts the Accelerator and various plugins, you can customize all of them to work together as you like:

# custom strategy, with new accelerator and plugins
accelerator = MyAccelerator()
precision_plugin = MyPrecisionPlugin()
strategy = CustomDDPStrategy(accelerator=accelerator, precision_plugin=precision_plugin)
trainer = Trainer(strategy=strategy)

Strategy Registry

Warning

The Strategy Registry is experimental and subject to change.

Lightning includes a registry that holds information about Training strategies and allows for the registration of new custom strategies.

The Strategies are assigned strings that identify them, such as “ddp”, “deepspeed_stage_2_offload”, and so on. It also returns the optional description and parameters for initialising the Strategy that were defined during registration.

# Training with the DDP Strategy with `find_unused_parameters` as False
trainer = Trainer(strategy="ddp_find_unused_parameters_false", accelerator="gpu", devices=4)

# Training with DeepSpeed ZeRO Stage 3 and CPU Offload
trainer = Trainer(strategy="deepspeed_stage_3_offload", accelerator="gpu", devices=3)

# Training with the TPU Spawn Strategy with `debug` as True
trainer = Trainer(strategy="tpu_spawn_debug", accelerator="tpu", devices=8)

Additionally, you can pass your custom registered training strategies to the strategy argument.

from pytorch_lightning.strategies import DDPStrategy, StrategyRegistry, CheckpointIO


class CustomCheckpointIO(CheckpointIO):
    def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> None:
        ...

    def load_checkpoint(self, path: Union[str, Path]) -> Dict[str, Any]:
        ...


custom_checkpoint_io = CustomCheckpointIO()

# Register the DDP Strategy with your custom CheckpointIO plugin
StrategyRegistry.register(
    "ddp_custom_checkpoint_io",
    DDPStrategy,
    description="DDP Strategy with custom checkpoint io plugin",
    checkpoint_io=custom_checkpoint_io,
)

trainer = Trainer(strategy="ddp_custom_checkpoint_io", accelerator="gpu", devices=2)

Style Guide

The main goal of PyTorch Lightning is to improve readability and reproducibility. Imagine looking into any GitHub repo or a research project, finding a LightningModule, and knowing exactly where to look to find the things you care about.

The goal of this style guide is to encourage Lightning code to be structured similarly.


LightningModule

These are best practices for structuring your LightningModule class:

Systems vs Models
https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/model_system.png

The main principle behind a LightningModule is that a full system should be self-contained. In Lightning, we differentiate between a system and a model.

A model is something like a resnet18, RNN, and so on.

A system defines how a collection of models interact with each other with user-defined training/evaluation logic. Examples of this are:

  • GANs

  • Seq2Seq

  • BERT

  • etc.

A LightningModule can define both a system and a model:

Here’s a LightningModule that defines a system. This structure is what we recommend as a best practice. Keeping the model separate from the system improves modularity, which eventually helps in better testing, reduces dependencies on the system and makes it easier to refactor.

class Encoder(nn.Module):
    ...


class Decoder(nn.Module):
    ...


class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        return self.encoder(x)


class AutoEncoderSystem(LightningModule):
    def __init__(self):
        super().__init__()
        self.auto_encoder = AutoEncoder()

For fast prototyping, it’s often useful to define all the computations in a LightningModule. For reusability and scalability, it might be better to pass in the relevant backbones.

Here’s a LightningModule that defines a model. Although, we do not recommend to define a model like in the example.

class LitModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear()
        self.layer_2 = nn.Linear()
        self.layer_3 = nn.Linear()
Self-contained

A Lightning module should be self-contained. To see how self-contained your model is, a good test is to ask yourself this question:

“Can someone drop this file into a Trainer without knowing anything about the internals?”

For example, we couple the optimizer with a model because the majority of models require a specific optimizer with a specific learning rate scheduler to work well.

Init

The first place where LightningModules tend to stop being self-contained is in the init. Try to define all the relevant sensible defaults in the init so that the user doesn’t have to guess.

Here’s an example where a user will have to go hunt through files to figure out how to init this LightningModule.

class LitModel(LightningModule):
    def __init__(self, params):
        self.lr = params.lr
        self.coef_x = params.coef_x

Models defined as such leave you with many questions, such as what is coef_x? Is it a string? A float? What is the range? Instead, be explicit in your init

class LitModel(LightningModule):
    def __init__(self, encoder: nn.Module, coef_x: float = 0.2, lr: float = 1e-3):
        ...

Now the user doesn’t have to guess. Instead, they know the value type, and the model has a sensible default where the user can see the value immediately.

Method Order

The only required methods in the LightningModule are:

  • init

  • training_step

  • configure_optimizers

However, if you decide to implement the rest of the optional methods, the recommended order is:

  • model/system definition (init)

  • if doing inference, define forward

  • training hooks

  • validation hooks

  • test hooks

  • predict hooks

  • configure_optimizers

  • any other hooks

In practice, the code looks like this:

class LitModel(pl.LightningModule):

    def __init__(...):

    def forward(...):

    def training_step(...):

    def training_step_end(...):

    def training_epoch_end(...):

    def validation_step(...):

    def validation_step_end(...):

    def validation_epoch_end(...):

    def test_step(...):

    def test_step_end(...):

    def test_epoch_end(...):

    def configure_optimizers(...):

    def any_extra_hook(...):
Forward vs training_step

We recommend using forward() for inference/predictions and keeping training_step() independent.

def forward(self, x):
    embeddings = self.encoder(x)
    return embeddings


def training_step(self, batch, batch_idx):
    x, _ = batch
    z = self.encoder(x)
    pred = self.decoder(z)
    ...

Data

These are best practices for handling data.

DataLoaders

Lightning uses DataLoader to handle all the data flow through the system. Whenever you structure dataloaders, make sure to tune the number of workers for maximum efficiency.

Warning

Make sure not to use Trainer(strategy="ddp_spawn") with num_workers>0 in the DataLoader or you will bottleneck you code.

DataModules

The LightningDataModule is designed as a way of decoupling data-related hooks from the LightningModule so you can develop dataset agnostic models. It makes it easy to hot swap different datasets with your model, so you can test it and benchmark it across domains. It also makes sharing and reusing the exact data splits and transforms across projects possible.

Check out Managing Data document to understand data management within Lightning and its best practices.

  • What dataset splits were used?

  • How many samples does this dataset have overall and within each split?

  • Which transforms were used?

It’s for this reason that we recommend you use datamodules. This is especially important when collaborating because it will save your team a lot of time as well.

All they need to do is drop a datamodule into the Trainer and not worry about what was done to the data.

This is true for both academic and corporate settings where data cleaning and ad-hoc instructions slow down the progress of iterating through ideas.

Run on an on-prem cluster (advanced)


Run on a SLURM managed cluster

Lightning automates the details behind training on a SLURM-powered cluster. In contrast to the general purpose cluster above, the user does not start the jobs manually on each node and instead submits it to SLURM which schedules the resources and time for which the job is allowed to run.


Design your training script

To train a model using multiple nodes, do the following:

  1. Design your LightningModule (no need to add anything specific here).

  2. Enable DDP in the trainer

    # train on 32 GPUs across 4 nodes
    trainer = Trainer(accelerator="gpu", devices=8, num_nodes=4, strategy="ddp")
    
  3. It’s a good idea to structure your training script like this:

    # train.py
    def main(args):
        model = YourLightningModule(args)
    
        trainer = Trainer(accelerator="gpu", devices=8, num_nodes=4, strategy="ddp")
    
        trainer.fit(model)
    
    
    if __name__ == "__main__":
        args = ...  # you can use your CLI parser of choice, or the `LightningCLI`
        # TRAIN
        main(args)
    
  4. Create the appropriate SLURM job:

    # (submit.sh)
    #!/bin/bash -l
    
    # SLURM SUBMIT SCRIPT
    #SBATCH --nodes=4             # This needs to match Trainer(num_nodes=...)
    #SBATCH --gres=gpu:8
    #SBATCH --ntasks-per-node=8   # This needs to match Trainer(devices=...)
    #SBATCH --mem=0
    #SBATCH --time=0-02:00:00
    
    # activate conda env
    source activate $1
    
    # debugging flags (optional)
    export NCCL_DEBUG=INFO
    export PYTHONFAULTHANDLER=1
    
    # on your cluster you might need these:
    # set the network interface
    # export NCCL_SOCKET_IFNAME=^docker0,lo
    
    # might need the latest CUDA
    # module load NCCL/2.4.7-1-cuda.10.0
    
    # run script from above
    srun python3 train.py
    
  5. If you want auto-resubmit (read below), add this line to the submit.sh script

    #SBATCH --signal=SIGUSR1@90
    
  6. Submit the SLURM job

    sbatch submit.sh
    

Enable auto wall-time resubmitions

When you use Lightning in a SLURM cluster, it automatically detects when it is about to run into the wall time and does the following:

  1. Saves a temporary checkpoint.

  2. Requeues the job.

  3. When the job starts, it loads the temporary checkpoint.

To get this behavior make sure to add the correct signal to your SLURM script

# 90 seconds before training ends
SBATCH --signal=SIGUSR1@90

You can change this signal if your environment requires the use of a different one, for example

#SBATCH --signal=SIGHUP@90

Then, when you make your trainer, pass the requeue_signal option to the SLURMEnvironment plugin:

trainer = Trainer(plugins=[SLURMEnvironment(requeue_signal=signal.SIGHUP)])

If auto-resubmit is not desired, it can be turned off in the SLURMEnvironment plugin:

from pytorch_lightning.plugins.environments import SLURMEnvironment

trainer = Trainer(plugins=[SLURMEnvironment(auto_requeue=False)])

Build your SLURM script

Instead of manually building SLURM scripts, you can use the SlurmCluster object to do this for you. The SlurmCluster can also run a grid search if you pass in a HyperOptArgumentParser.

Here is an example where you run a grid search of 9 combinations of hyperparameters. See also the multi-node examples here.

# grid search 3 values of learning rate and 3 values of number of layers for your net
# this generates 9 experiments (lr=1e-3, layers=16), (lr=1e-3, layers=32),
# (lr=1e-3, layers=64), ... (lr=1e-1, layers=64)
parser = HyperOptArgumentParser(strategy="grid_search", add_help=False)
parser.opt_list("--learning_rate", default=0.001, type=float, options=[1e-3, 1e-2, 1e-1], tunable=True)
parser.opt_list("--layers", default=1, type=float, options=[16, 32, 64], tunable=True)
hyperparams = parser.parse_args()

# Slurm cluster submits 9 jobs, each with a set of hyperparams
cluster = SlurmCluster(
    hyperparam_optimizer=hyperparams,
    log_path="/some/path/to/save",
)

# OPTIONAL FLAGS WHICH MAY BE CLUSTER DEPENDENT
# which interface your nodes use for communication
cluster.add_command("export NCCL_SOCKET_IFNAME=^docker0,lo")

# see the output of the NCCL connection process
# NCCL is how the nodes talk to each other
cluster.add_command("export NCCL_DEBUG=INFO")

# setting a main port here is a good idea.
cluster.add_command("export MASTER_PORT=%r" % PORT)

# ************** DON'T FORGET THIS ***************
# MUST load the latest NCCL version
cluster.load_modules(["NCCL/2.4.7-1-cuda.10.0"])

# configure cluster
cluster.per_experiment_nb_nodes = 12
cluster.per_experiment_nb_gpus = 8

cluster.add_slurm_cmd(cmd="ntasks-per-node", value=8, comment="1 task per gpu")

# submit a script with 9 combinations of hyper params
# (lr=1e-3, layers=16), (lr=1e-3, layers=32), (lr=1e-3, layers=64), ... (lr=1e-1, layers=64)
cluster.optimize_parallel_cluster_gpu(
    main, nb_trials=9, job_name="name_for_squeue"  # how many permutations of the grid search to run
)

Troubleshooting

The Trainer is stuck initializing at startup, what is causing this?

You are seeing a message like this in the logs but nothing happens:

Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4

The most likely reasons and how to fix it:

  • You forgot to run the python train.py command with srun: Please have a look at the SLURM template script above which includes the srun at the botton of the script.

  • The number of nodes or number of devices per node is configured incorrectly: There are two parametres in the SLURM submission script that determine how many processes will run your training, the #SBATCH --nodes=X setting and #SBATCH --ntasks-per-node=Y settings. The numbers there need to match what is configured in your Trainer in the code: Trainer(num_nodes=X, devices=Y). If you change the numbers, update them in BOTH places.

Transfer Learning

Audience: Users looking to use pretrained models with Lightning.


Use any PyTorch nn.Module

Any model that is a PyTorch nn.Module can be used with Lightning (because LightningModules are nn.Modules also).


Use a pretrained LightningModule

Let’s use the AutoEncoder as a feature extractor in a separate model.

class Encoder(torch.nn.Module):
    ...


class AutoEncoder(LightningModule):
    def __init__(self):
        self.encoder = Encoder()
        self.decoder = Decoder()


class CIFAR10Classifier(LightningModule):
    def __init__(self):
        # init the pretrained LightningModule
        self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
        self.feature_extractor.freeze()

        # the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
        self.classifier = nn.Linear(100, 10)

    def forward(self, x):
        representations = self.feature_extractor(x)
        x = self.classifier(representations)
        ...

We used our pretrained Autoencoder (a LightningModule) for transfer learning!


Example: Imagenet (Computer Vision)

import torchvision.models as models


class ImagenetTransferLearning(LightningModule):
    def __init__(self):
        super().__init__()

        # init a pretrained resnet
        backbone = models.resnet50(weights="DEFAULT")
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # use the pretrained model to classify cifar-10 (10 image classes)
        num_target_classes = 10
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        ...

Finetune

model = ImagenetTransferLearning()
trainer = Trainer()
trainer.fit(model)

And use it to predict your data of interest

model = ImagenetTransferLearning.load_from_checkpoint(PATH)
model.freeze()

x = some_images_from_cifar10()
predictions = model(x)

We used a pretrained model on imagenet, finetuned on CIFAR-10 to predict on CIFAR-10. In the non-academic world we would finetune on a tiny dataset you have and predict on your dataset.


Example: BERT (NLP)

Lightning is completely agnostic to what’s used for transfer learning so long as it is a torch.nn.Module subclass.

Here’s a model that uses Huggingface transformers.

class BertMNLIFinetuner(LightningModule):
    def __init__(self):
        super().__init__()

        self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
        self.W = nn.Linear(bert.config.hidden_size, 3)
        self.num_classes = 3

    def forward(self, input_ids, attention_mask, token_type_ids):
        h, _, attn = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        h_cls = h[:, 0]
        logits = self.W(h_cls)
        return logits, attn

Run on an on-prem cluster (intermediate)

Run with TorchDistributed

Torch Distributed Run provides helper functions to setup distributed environment variables from the PyTorch distributed communication package that need to be defined on each node.

Once the script is setup like described in :ref:` Training Script Setup<training_script_setup>`, you can run the below command across your nodes to start multi-node training.

Like a custom cluster, you have to ensure that there is network connectivity between the nodes with firewall rules that allow traffic flow on a specified MASTER_PORT.

Finally, you’ll need to decide which node you’d like to be the main node (MASTER_ADDR), and the ranks of each node (NODE_RANK).

For example:

  • MASTER_ADDR 10.10.10.16

  • MASTER_PORT 29500

  • NODE_RANK 0 for the first node, 1 for the second node

Run the below command with the appropriate variables set on each node.

python -m torch.distributed.run
    --nnodes=2 # number of nodes you'd like to run with
    --master_addr <MASTER_ADDR>
    --master_port <MASTER_PORT>
    --node_rank <NODE_RANK>
    train.py (--arg1 ... train script args...)

Note

torch.distributed.run assumes that you’d like to spawn a process per GPU if GPU devices are found on the node. This can be adjusted with -nproc_per_node.

Tutorial 1: Introduction to PyTorch

  • Author: Phillip Lippe

  • License: CC BY-SA

  • Generated: 2022-05-12T13:44:14.531736

This tutorial will give a short introduction to PyTorch basics, and get you setup for writing your own neural networks. This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "torch>=1.8" "pytorch-lightning>=1.4" "ipython[notebook]" "setuptools==59.5.0" "torchmetrics>=0.7" "matplotlib"

Welcome to our PyTorch tutorial for the Deep Learning course 2020 at the University of Amsterdam! The following notebook is meant to give a short introduction to PyTorch basics, and get you setup for writing your own neural networks. PyTorch is an open source machine learning framework that allows you to write your own neural networks and optimize them efficiently. However, PyTorch is not the only framework of its kind. Alternatives to PyTorch include TensorFlow, JAX and Caffe. We choose to teach PyTorch at the University of Amsterdam because it is well established, has a huge developer community (originally developed by Facebook), is very flexible and especially used in research. Many current papers publish their code in PyTorch, and thus it is good to be familiar with PyTorch as well. Meanwhile, TensorFlow (developed by Google) is usually known for being a production-grade deep learning library. Still, if you know one machine learning framework in depth, it is very easy to learn another one because many of them use the same concepts and ideas. For instance, TensorFlow’s version 2 was heavily inspired by the most popular features of PyTorch, making the frameworks even more similar. If you are already familiar with PyTorch and have created your own neural network projects, feel free to just skim this notebook.

We are of course not the first ones to create a PyTorch tutorial. There are many great tutorials online, including the “60-min blitz” on the official PyTorch website. Yet, we choose to create our own tutorial which is designed to give you the basics particularly necessary for the practicals, but still understand how PyTorch works under the hood. Over the next few weeks, we will also keep exploring new PyTorch features in the series of Jupyter notebook tutorials about deep learning.

We will use a set of standard libraries that are often used in machine learning projects. If you are running this notebook on Google Colab, all libraries should be pre-installed. If you are running this notebook locally, make sure you have installed our dl2020 environment (link) and have activated it.

[2]:
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data

%matplotlib inline
from IPython.display import set_matplotlib_formats
from matplotlib.colors import to_rgba
from torch import Tensor
from tqdm.notebook import tqdm  # Progress bar

set_matplotlib_formats("svg", "pdf")
/tmp/ipykernel_2736/3457578344.py:15: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
  set_matplotlib_formats("svg", "pdf")

The Basics of PyTorch

We will start with reviewing the very basic concepts of PyTorch. As a prerequisite, we recommend to be familiar with the numpy package as most machine learning frameworks are based on very similar concepts. If you are not familiar with numpy yet, don’t worry: here is a tutorial to go through.

So, let’s start with importing PyTorch. The package is called torch, based on its original framework Torch. As a first step, we can check its version:

[3]:
print("Using torch", torch.__version__)
Using torch 1.9.1+cu111

At the time of writing this tutorial (mid of August 2021), the current stable version is 1.9. You should therefore see the output Using torch 1.9.0, eventually with some extension for the CUDA version on Colab. In case you use the dl2020 environment, you should see Using torch 1.6.0 since the environment was provided in October 2020. It is recommended to update the PyTorch version to the newest one. If you see a lower version number than 1.6, make sure you have installed the correct the environment, or ask one of your TAs. In case PyTorch 1.10 or newer will be published during the time of the course, don’t worry. The interface between PyTorch versions doesn’t change too much, and hence all code should also be runnable with newer versions.

As in every machine learning framework, PyTorch provides functions that are stochastic like generating random numbers. However, a very good practice is to setup your code to be reproducible with the exact same random numbers. This is why we set a seed below.

[4]:
torch.manual_seed(42)  # Setting the seed
[4]:
<torch._C.Generator at 0x7fb4ad211bb0>
Tensors

Tensors are the PyTorch equivalent to Numpy arrays, with the addition to also have support for GPU acceleration (more on that later). The name “tensor” is a generalization of concepts you already know. For instance, a vector is a 1-D tensor, and a matrix a 2-D tensor. When working with neural networks, we will use tensors of various shapes and number of dimensions.

Most common functions you know from numpy can be used on tensors as well. Actually, since numpy arrays are so similar to tensors, we can convert most tensors to numpy arrays (and back) but we don’t need it too often.

Initialization

Let’s first start by looking at different ways of creating a tensor. There are many possible options, the most simple one is to call Tensor passing the desired shape as input argument:

[5]:
x = Tensor(2, 3, 4)
print(x)
tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]])

The function torch.Tensor allocates memory for the desired tensor, but reuses any values that have already been in the memory. To directly assign values to the tensor during initialization, there are many alternatives including:

  • torch.zeros: Creates a tensor filled with zeros

  • torch.ones: Creates a tensor filled with ones

  • torch.rand: Creates a tensor with random values uniformly sampled between 0 and 1

  • torch.randn: Creates a tensor with random values sampled from a normal distribution with mean 0 and variance 1

  • torch.arange: Creates a tensor containing the values N,N+1,N+2,...,M

  • torch.Tensor (input list): Creates a tensor from the list elements you provide

[6]:
# Create a tensor from a (nested) list
x = Tensor([[1, 2], [3, 4]])
print(x)
tensor([[1., 2.],
        [3., 4.]])
[7]:
# Create a tensor with random values between 0 and 1 with the shape [2, 3, 4]
x = torch.rand(2, 3, 4)
print(x)
tensor([[[0.8823, 0.9150, 0.3829, 0.9593],
         [0.3904, 0.6009, 0.2566, 0.7936],
         [0.9408, 0.1332, 0.9346, 0.5936]],

        [[0.8694, 0.5677, 0.7411, 0.4294],
         [0.8854, 0.5739, 0.2666, 0.6274],
         [0.2696, 0.4414, 0.2969, 0.8317]]])

You can obtain the shape of a tensor in the same way as in numpy (x.shape), or using the .size method:

[8]:
shape = x.shape
print("Shape:", x.shape)

size = x.size()
print("Size:", size)

dim1, dim2, dim3 = x.size()
print("Size:", dim1, dim2, dim3)
Shape: torch.Size([2, 3, 4])
Size: torch.Size([2, 3, 4])
Size: 2 3 4
Tensor to Numpy, and Numpy to Tensor

Tensors can be converted to numpy arrays, and numpy arrays back to tensors. To transform a numpy array into a tensor, we can use the function torch.from_numpy:

[9]:
np_arr = np.array([[1, 2], [3, 4]])
tensor = torch.from_numpy(np_arr)

print("Numpy array:", np_arr)
print("PyTorch tensor:", tensor)
Numpy array: [[1 2]
 [3 4]]
PyTorch tensor: tensor([[1, 2],
        [3, 4]])

To transform a PyTorch tensor back to a numpy array, we can use the function .numpy() on tensors:

[10]:
tensor = torch.arange(4)
np_arr = tensor.numpy()

print("PyTorch tensor:", tensor)
print("Numpy array:", np_arr)
PyTorch tensor: tensor([0, 1, 2, 3])
Numpy array: [0 1 2 3]

The conversion of tensors to numpy require the tensor to be on the CPU, and not the GPU (more on GPU support in a later section). In case you have a tensor on GPU, you need to call .cpu() on the tensor beforehand. Hence, you get a line like np_arr = tensor.cpu().numpy().

Operations

Most operations that exist in numpy, also exist in PyTorch. A full list of operations can be found in the PyTorch documentation, but we will review the most important ones here.

The simplest operation is to add two tensors:

[11]:
x1 = torch.rand(2, 3)
x2 = torch.rand(2, 3)
y = x1 + x2

print("X1", x1)
print("X2", x2)
print("Y", y)
X1 tensor([[0.1053, 0.2695, 0.3588],
        [0.1994, 0.5472, 0.0062]])
X2 tensor([[0.9516, 0.0753, 0.8860],
        [0.5832, 0.3376, 0.8090]])
Y tensor([[1.0569, 0.3448, 1.2448],
        [0.7826, 0.8848, 0.8151]])

Calling x1 + x2 creates a new tensor containing the sum of the two inputs. However, we can also use in-place operations that are applied directly on the memory of a tensor. We therefore change the values of x2 without the chance to re-accessing the values of x2 before the operation. An example is shown below:

[12]:
x1 = torch.rand(2, 3)
x2 = torch.rand(2, 3)
print("X1 (before)", x1)
print("X2 (before)", x2)

x2.add_(x1)
print("X1 (after)", x1)
print("X2 (after)", x2)
X1 (before) tensor([[0.5779, 0.9040, 0.5547],
        [0.3423, 0.6343, 0.3644]])
X2 (before) tensor([[0.7104, 0.9464, 0.7890],
        [0.2814, 0.7886, 0.5895]])
X1 (after) tensor([[0.5779, 0.9040, 0.5547],
        [0.3423, 0.6343, 0.3644]])
X2 (after) tensor([[1.2884, 1.8504, 1.3437],
        [0.6237, 1.4230, 0.9539]])

In-place operations are usually marked with a underscore postfix (e.g. “add_” instead of “add”).

Another common operation aims at changing the shape of a tensor. A tensor of size (2,3) can be re-organized to any other shape with the same number of elements (e.g. a tensor of size (6), or (3,2), …). In PyTorch, this operation is called view:

[13]:
x = torch.arange(6)
print("X", x)
X tensor([0, 1, 2, 3, 4, 5])
[14]:
x = x.view(2, 3)
print("X", x)
X tensor([[0, 1, 2],
        [3, 4, 5]])
[15]:
x = x.permute(1, 0)  # Swapping dimension 0 and 1
print("X", x)
X tensor([[0, 3],
        [1, 4],
        [2, 5]])

Other commonly used operations include matrix multiplications, which are essential for neural networks. Quite often, we have an input vector \mathbf{x}, which is transformed using a learned weight matrix \mathbf{W}. There are multiple ways and functions to perform matrix multiplication, some of which we list below:

  • torch.matmul: Performs the matrix product over two tensors, where the specific behavior depends on the dimensions. If both inputs are matrices (2-dimensional tensors), it performs the standard matrix product. For higher dimensional inputs, the function supports broadcasting (for details see the documentation). Can also be written as a @ b, similar to numpy.

  • torch.mm: Performs the matrix product over two matrices, but doesn’t support broadcasting (see documentation)

  • torch.bmm: Performs the matrix product with a support batch dimension. If the first tensor T is of shape (b\times n\times m), and the second tensor R (b\times m\times p), the output O is of shape (b\times n\times p), and has been calculated by performing b matrix multiplications of the submatrices of T and R: O_i = T_i @ R_i

  • torch.einsum: Performs matrix multiplications and more (i.e. sums of products) using the Einstein summation convention. Explanation of the Einstein sum can be found in assignment 1.

Usually, we use torch.matmul or torch.bmm. We can try a matrix multiplication with torch.matmul below.

[16]:
x = torch.arange(6)
x = x.view(2, 3)
print("X", x)
X tensor([[0, 1, 2],
        [3, 4, 5]])
[17]:
W = torch.arange(9).view(3, 3)  # We can also stack multiple operations in a single line
print("W", W)
W tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
[18]:
h = torch.matmul(x, W)  # Verify the result by calculating it by hand too!
print("h", h)
h tensor([[15, 18, 21],
        [42, 54, 66]])
Indexing

We often have the situation where we need to select a part of a tensor. Indexing works just like in numpy, so let’s try it:

[19]:
x = torch.arange(12).view(3, 4)
print("X", x)
X tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
[20]:
print(x[:, 1])  # Second column
tensor([1, 5, 9])
[21]:
print(x[0])  # First row
tensor([0, 1, 2, 3])
[22]:
print(x[:2, -1])  # First two rows, last column
tensor([3, 7])
[23]:
print(x[1:3, :])  # Middle two rows
tensor([[ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
Dynamic Computation Graph and Backpropagation

One of the main reasons for using PyTorch in Deep Learning projects is that we can automatically get gradients/derivatives of functions that we define. We will mainly use PyTorch for implementing neural networks, and they are just fancy functions. If we use weight matrices in our function that we want to learn, then those are called the parameters or simply the weights.

If our neural network would output a single scalar value, we would talk about taking the derivative, but you will see that quite often we will have multiple output variables (“values”); in that case we talk about gradients. It’s a more general term.

Given an input \mathbf{x}, we define our function by manipulating that input, usually by matrix-multiplications with weight matrices and additions with so-called bias vectors. As we manipulate our input, we are automatically creating a computational graph. This graph shows how to arrive at our output from our input. PyTorch is a define-by-run framework; this means that we can just do our manipulations, and PyTorch will keep track of that graph for us. Thus, we create a dynamic computation graph along the way.

So, to recap: the only thing we have to do is to compute the output, and then we can ask PyTorch to automatically get the gradients.

Note: Why do we want gradients? Consider that we have defined a function, a neural net, that is supposed to compute a certain output y for an input vector \mathbf{x}. We then define an error measure that tells us how wrong our network is; how bad it is in predicting output y from input \mathbf{x}. Based on this error measure, we can use the gradients to update the weights \mathbf{W} that were responsible for the output, so that the next time we present input \mathbf{x} to our network, the output will be closer to what we want.

The first thing we have to do is to specify which tensors require gradients. By default, when we create a tensor, it does not require gradients.

[24]:
x = torch.ones((3,))
print(x.requires_grad)
False

We can change this for an existing tensor using the function requires_grad_() (underscore indicating that this is a in-place operation). Alternatively, when creating a tensor, you can pass the argument requires_grad=True to most initializers we have seen above.

[25]:
x.requires_grad_(True)
print(x.requires_grad)
True

In order to get familiar with the concept of a computation graph, we will create one for the following function:

y = \frac{1}{|x|}\sum_i \left[(x_i + 2)^2 + 3\right]

You could imagine that x are our parameters, and we want to optimize (either maximize or minimize) the output y. For this, we want to obtain the gradients \partial y / \partial \mathbf{x}. For our example, we’ll use \mathbf{x}=[0,1,2] as our input.

[26]:
x = torch.arange(3, dtype=torch.float32, requires_grad=True)  # Only float tensors can have gradients
print("X", x)
X tensor([0., 1., 2.], requires_grad=True)

Now let’s build the computation graph step by step. You can combine multiple operations in a single line, but we will separate them here to get a better understanding of how each operation is added to the computation graph.

[27]:
a = x + 2
b = a**2
c = b + 3
y = c.mean()
print("Y", y)
Y tensor(12.6667, grad_fn=<MeanBackward0>)

Using the statements above, we have created a computation graph that looks similar to the figure below:

1a5405aff57e415781578e9832ae1226

We calculate a based on the inputs x and the constant 2, b is a squared, and so on. The visualization is an abstraction of the dependencies between inputs and outputs of the operations we have applied. Each node of the computation graph has automatically defined a function for calculating the gradients with respect to its inputs, grad_fn. You can see this when we printed the output tensor y. This is why the computation graph is usually visualized in the reverse direction (arrows point from the result to the inputs). We can perform backpropagation on the computation graph by calling the function backward() on the last output, which effectively calculates the gradients for each tensor that has the property requires_grad=True:

[28]:
y.backward()

x.grad will now contain the gradient \partial y/ \partial \mathcal{x}, and this gradient indicates how a change in \mathbf{x} will affect output y given the current input \mathbf{x}=[0,1,2]:

[29]:
print(x.grad)
tensor([1.3333, 2.0000, 2.6667])

We can also verify these gradients by hand. We will calculate the gradients using the chain rule, in the same way as PyTorch did it:

\frac{\partial y}{\partial x_i} = \frac{\partial y}{\partial c_i}\frac{\partial c_i}{\partial b_i}\frac{\partial b_i}{\partial a_i}\frac{\partial a_i}{\partial x_i}

Note that we have simplified this equation to index notation, and by using the fact that all operation besides the mean do not combine the elements in the tensor. The partial derivatives are:

\frac{\partial a_i}{\partial x_i} = 1,\hspace{1cm}
\frac{\partial b_i}{\partial a_i} = 2\cdot a_i\hspace{1cm}
\frac{\partial c_i}{\partial b_i} = 1\hspace{1cm}
\frac{\partial y}{\partial c_i} = \frac{1}{3}

Hence, with the input being \mathbf{x}=[0,1,2], our gradients are \partial y/\partial \mathbf{x}=[4/3,2,8/3]. The previous code cell should have printed the same result.

GPU support

A crucial feature of PyTorch is the support of GPUs, short for Graphics Processing Unit. A GPU can perform many thousands of small operations in parallel, making it very well suitable for performing large matrix operations in neural networks. When comparing GPUs to CPUs, we can list the following main differences (credit: Kevin Krewell, 2009)

bce21361fb1541b7a5b8a153bdd2f40a

CPUs and GPUs have both different advantages and disadvantages, which is why many computers contain both components and use them for different tasks. In case you are not familiar with GPUs, you can read up more details in this NVIDIA blog post or here.

GPUs can accelerate the training of your network up to a factor of 100 which is essential for large neural networks. PyTorch implements a lot of functionality for supporting GPUs (mostly those of NVIDIA due to the libraries CUDA and cuDNN). First, let’s check whether you have a GPU available:

[30]:
gpu_avail = torch.cuda.is_available()
print(f"Is the GPU available? {gpu_avail}")
Is the GPU available? True

If you have a GPU on your computer but the command above returns False, make sure you have the correct CUDA-version installed. The dl2020 environment comes with the CUDA-toolkit 10.1, which is selected for the Lisa supercomputer. Please change it if necessary (CUDA 10.2 is currently common). On Google Colab, make sure that you have selected a GPU in your runtime setup (in the menu, check under Runtime -> Change runtime type).

By default, all tensors you create are stored on the CPU. We can push a tensor to the GPU by using the function .to(...), or .cuda(). However, it is often a good practice to define a device object in your code which points to the GPU if you have one, and otherwise to the CPU. Then, you can write your code with respect to this device object, and it allows you to run the same code on both a CPU-only system, and one with a GPU. Let’s try it below. We can specify the device as follows:

[31]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device", device)
Device cuda

Now let’s create a tensor and push it to the device:

[32]:
x = torch.zeros(2, 3)
x = x.to(device)
print("X", x)
X tensor([[0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')

In case you have a GPU, you should now see the attribute device='cuda:0' being printed next to your tensor. The zero next to cuda indicates that this is the zero-th GPU device on your computer. PyTorch also supports multi-GPU systems, but this you will only need once you have very big networks to train (if interested, see the PyTorch documentation). We can also compare the runtime of a large matrix multiplication on the CPU with a operation on the GPU:

[33]:
x = torch.randn(5000, 5000)

# CPU version
start_time = time.time()
_ = torch.matmul(x, x)
end_time = time.time()
print(f"CPU time: {(end_time - start_time):6.5f}s")

# GPU version
if torch.cuda.is_available():
    x = x.to(device)
    # CUDA is asynchronous, so we need to use different timing functions
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    _ = torch.matmul(x, x)
    end.record()
    torch.cuda.synchronize()  # Waits for everything to finish running on the GPU
    print(f"GPU time: {0.001 * start.elapsed_time(end):6.5f}s")  # Milliseconds to seconds
CPU time: 0.27849s
GPU time: 0.02576s

Depending on the size of the operation and the CPU/GPU in your system, the speedup of this operation can be >50x. As matmul operations are very common in neural networks, we can already see the great benefit of training a NN on a GPU. The time estimate can be relatively noisy here because we haven’t run it for multiple times. Feel free to extend this, but it also takes longer to run.

When generating random numbers, the seed between CPU and GPU is not synchronized. Hence, we need to set the seed on the GPU separately to ensure a reproducible code. Note that due to different GPU architectures, running the same code on different GPUs does not guarantee the same random numbers. Still, we don’t want that our code gives us a different output every time we run it on the exact same hardware. Hence, we also set the seed on the GPU:

[34]:
# GPU operations have a separate seed we also want to set
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

# Additionally, some operations on a GPU are implemented stochastic for efficiency
# We want to ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

Learning by example: Continuous XOR

If we want to build a neural network in PyTorch, we could specify all our parameters (weight matrices, bias vectors) using Tensors (with requires_grad=True), ask PyTorch to calculate the gradients and then adjust the parameters. But things can quickly get cumbersome if we have a lot of parameters. In PyTorch, there is a package called torch.nn that makes building neural networks more convenient.

We will introduce the libraries and all additional parts you might need to train a neural network in PyTorch, using a simple example classifier on a simple yet well known example: XOR. Given two binary inputs x_1 and x_2, the label to predict is 1 if either x_1 or x_2 is 1 while the other is 0, or the label is 0 in all other cases. The example became famous by the fact that a single neuron, i.e. a linear classifier, cannot learn this simple function. Hence, we will learn how to build a small neural network that can learn this function. To make it a little bit more interesting, we move the XOR into continuous space and introduce some gaussian noise on the binary inputs. Our desired separation of an XOR dataset could look as follows:

88d3351e94e84343a83ff5c37e8eb702

The model

The package torch.nn defines a series of useful classes like linear networks layers, activation functions, loss functions etc. A full list can be found here. In case you need a certain network layer, check the documentation of the package first before writing the layer yourself as the package likely contains the code for it already. We import it below:

[ ]:

[ ]:

Additionally to torch.nn, there is also torch.nn.functional. It contains functions that are used in network layers. This is in contrast to torch.nn which defines them as nn.Modules (more on it below), and torch.nn actually uses a lot of functionalities from torch.nn.functional. Hence, the functional package is useful in many situations, and so we import it as well here.

nn.Module

In PyTorch, a neural network is built up out of modules. Modules can contain other modules, and a neural network is considered to be a module itself as well. The basic template of a module is as follows:

[35]:
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        # Some init for my module

    def forward(self, x):
        # Function for performing the calculation of the module.
        pass

The forward function is where the computation of the module is taken place, and is executed when you call the module (nn = MyModule(); nn(x)). In the init function, we usually create the parameters of the module, using nn.Parameter, or defining other modules that are used in the forward function. The backward calculation is done automatically, but could be overwritten as well if wanted.

Simple classifier

We can now make use of the pre-defined modules in the torch.nn package, and define our own small neural network. We will use a minimal network with a input layer, one hidden layer with tanh as activation function, and a output layer. In other words, our networks should look something like this:

92526bdd4a3f4a519135ae72e37525bf

The input neurons are shown in blue, which represent the coordinates x_1 and x_2 of a data point. The hidden neurons including a tanh activation are shown in white, and the output neuron in red. In PyTorch, we can define this as follows:

[36]:
class SimpleClassifier(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs):
        super().__init__()
        # Initialize the modules we need to build the network
        self.linear1 = nn.Linear(num_inputs, num_hidden)
        self.act_fn = nn.Tanh()
        self.linear2 = nn.Linear(num_hidden, num_outputs)

    def forward(self, x):
        # Perform the calculation of the model to determine the prediction
        x = self.linear1(x)
        x = self.act_fn(x)
        x = self.linear2(x)
        return x

For the examples in this notebook, we will use a tiny neural network with two input neurons and four hidden neurons. As we perform binary classification, we will use a single output neuron. Note that we do not apply a sigmoid on the output yet. This is because other functions, especially the loss, are more efficient and precise to calculate on the original outputs instead of the sigmoid output. We will discuss the detailed reason later.

[37]:
model = SimpleClassifier(num_inputs=2, num_hidden=4, num_outputs=1)
# Printing a module shows all its submodules
print(model)
SimpleClassifier(
  (linear1): Linear(in_features=2, out_features=4, bias=True)
  (act_fn): Tanh()
  (linear2): Linear(in_features=4, out_features=1, bias=True)
)

Printing the model lists all submodules it contains. The parameters of a module can be obtained by using its parameters() functions, or named_parameters() to get a name to each parameter object. For our small neural network, we have the following parameters:

[38]:
for name, param in model.named_parameters():
    print(f"Parameter {name}, shape {param.shape}")
Parameter linear1.weight, shape torch.Size([4, 2])
Parameter linear1.bias, shape torch.Size([4])
Parameter linear2.weight, shape torch.Size([1, 4])
Parameter linear2.bias, shape torch.Size([1])

Each linear layer has a weight matrix of the shape [output, input], and a bias of the shape [output]. The tanh activation function does not have any parameters. Note that parameters are only registered for nn.Module objects that are direct object attributes, i.e. self.a = .... If you define a list of modules, the parameters of those are not registered for the outer module and can cause some issues when you try to optimize your module. There are alternatives, like nn.ModuleList, nn.ModuleDict and nn.Sequential, that allow you to have different data structures of modules. We will use them in a few later tutorials and explain them there.

The data

PyTorch also provides a few functionalities to load the training and test data efficiently, summarized in the package torch.utils.data.

[ ]:

The data package defines two classes which are the standard interface for handling data in PyTorch: data.Dataset, and data.DataLoader. The dataset class provides an uniform interface to access the training/test data, while the data loader makes sure to efficiently load and stack the data points from the dataset into batches during training.

The dataset class

The dataset class summarizes the basic functionality of a dataset in a natural way. To define a dataset in PyTorch, we simply specify two functions: __getitem__, and __len__. The get-item function has to return the i-th data point in the dataset, while the len function returns the size of the dataset. For the XOR dataset, we can define the dataset class as follows:

[39]:


class XORDataset(data.Dataset): def __init__(self, size, std=0.1): """ Inputs: size - Number of data points we want to generate std - Standard deviation of the noise (see generate_continuous_xor function) """ super().__init__() self.size = size self.std = std self.generate_continuous_xor() def generate_continuous_xor(self): # Each data point in the XOR dataset has two variables, x and y, that can be either 0 or 1 # The label is their XOR combination, i.e. 1 if only x or only y is 1 while the other is 0. # If x=y, the label is 0. data = torch.randint(low=0, high=2, size=(self.size, 2), dtype=torch.float32) label = (data.sum(dim=1) == 1).to(torch.long) # To make it slightly more challenging, we add a bit of gaussian noise to the data points. data += self.std * torch.randn(data.shape) self.data = data self.label = label def __len__(self): # Number of data point we have. Alternatively self.data.shape[0], or self.label.shape[0] return self.size def __getitem__(self, idx): # Return the idx-th data point of the dataset # If we have multiple things to return (data point and label), we can return them as tuple data_point = self.data[idx] data_label = self.label[idx] return data_point, data_label

Let’s try to create such a dataset and inspect it:

[40]:
dataset = XORDataset(size=200)
print("Size of dataset:", len(dataset))
print("Data point 0:", dataset[0])
Size of dataset: 200
Data point 0: (tensor([0.9632, 0.1117]), tensor(1))

To better relate to the dataset, we visualize the samples below.

[41]:
def visualize_samples(data, label):
    if isinstance(data, Tensor):
        data = data.cpu().numpy()
    if isinstance(label, Tensor):
        label = label.cpu().numpy()
    data_0 = data[label == 0]
    data_1 = data[label == 1]

    plt.figure(figsize=(4, 4))
    plt.scatter(data_0[:, 0], data_0[:, 1], edgecolor="#333", label="Class 0")
    plt.scatter(data_1[:, 0], data_1[:, 1], edgecolor="#333", label="Class 1")
    plt.title("Dataset samples")
    plt.ylabel(r"$x_2$")
    plt.xlabel(r"$x_1$")
    plt.legend()
[42]:
visualize_samples(dataset.data, dataset.label)
plt.show()
_images/notebooks_course_UvA-DL_01-introduction-to-pytorch_85_0.svg
The data loader class

The class torch.utils.data.DataLoader represents a Python iterable over a dataset with support for automatic batching, multi-process data loading and many more features. The data loader communicates with the dataset using the function __getitem__, and stacks its outputs as tensors over the first dimension to form a batch. In contrast to the dataset class, we usually don’t have to define our own data loader class, but can just create an object of it with the dataset as input. Additionally, we can configure our data loader with the following input arguments (only a selection, see full list here):

  • batch_size: Number of samples to stack per batch

  • shuffle: If True, the data is returned in a random order. This is important during training for introducing stochasticity.

  • num_workers: Number of subprocesses to use for data loading. The default, 0, means that the data will be loaded in the main process which can slow down training for datasets where loading a data point takes a considerable amount of time (e.g. large images). More workers are recommended for those, but can cause issues on Windows computers. For tiny datasets as ours, 0 workers are usually faster.

  • pin_memory: If True, the data loader will copy Tensors into CUDA pinned memory before returning them. This can save some time for large data points on GPUs. Usually a good practice to use for a training set, but not necessarily for validation and test to save memory on the GPU.

  • drop_last: If True, the last batch is dropped in case it is smaller than the specified batch size. This occurs when the dataset size is not a multiple of the batch size. Only potentially helpful during training to keep a consistent batch size.

Let’s create a simple data loader below:

[43]:
data_loader = data.DataLoader(dataset, batch_size=8, shuffle=True)
[44]:
# next(iter(...)) catches the first batch of the data loader
# If shuffle is True, this will return a different batch every time we run this cell
# For iterating over the whole dataset, we can simple use "for batch in data_loader: ..."
data_inputs, data_labels = next(iter(data_loader))

# The shape of the outputs are [batch_size, d_1,...,d_N] where d_1,...,d_N are the
# dimensions of the data point returned from the dataset class
print("Data inputs", data_inputs.shape, "\n", data_inputs)
print("Data labels", data_labels.shape, "\n", data_labels)
Data inputs torch.Size([8, 2])
 tensor([[-0.0890,  0.8608],
        [ 1.0905, -0.0128],
        [ 0.7967,  0.2268],
        [-0.0688,  0.0371],
        [ 0.8732, -0.2240],
        [-0.0559, -0.0282],
        [ 0.9277,  0.0978],
        [ 1.0150,  0.9689]])
Data labels torch.Size([8])
 tensor([1, 1, 1, 0, 1, 0, 1, 0])
Optimization

After defining the model and the dataset, it is time to prepare the optimization of the model. During training, we will perform the following steps:

  1. Get a batch from the data loader

  2. Obtain the predictions from the model for the batch

  3. Calculate the loss based on the difference between predictions and labels

  4. Backpropagation: calculate the gradients for every parameter with respect to the loss

  5. Update the parameters of the model in the direction of the gradients

We have seen how we can do step 1, 2 and 4 in PyTorch. Now, we will look at step 3 and 5.

Loss modules

We can calculate the loss for a batch by simply performing a few tensor operations as those are automatically added to the computation graph. For instance, for binary classification, we can use Binary Cross Entropy (BCE) which is defined as follows:

\mathcal{L}_{BCE} = -\sum_i \left[ y_i \log x_i + (1 - y_i) \log (1 - x_i) \right]

where y are our labels, and x our predictions, both in the range of [0,1]. However, PyTorch already provides a list of predefined loss functions which we can use (see here for a full list). For instance, for BCE, PyTorch has two modules: nn.BCELoss(), nn.BCEWithLogitsLoss(). While nn.BCELoss expects the inputs x to be in the range [0,1], i.e. the output of a sigmoid, nn.BCEWithLogitsLoss combines a sigmoid layer and the BCE loss in a single class. This version is numerically more stable than using a plain Sigmoid followed by a BCE loss because of the logarithms applied in the loss function. Hence, it is adviced to use loss functions applied on “logits” where possible (remember to not apply a sigmoid on the output of the model in this case!). For our model defined above, we therefore use the module nn.BCEWithLogitsLoss.

[45]:
loss_module = nn.BCEWithLogitsLoss()
Stochastic Gradient Descent

For updating the parameters, PyTorch provides the package torch.optim that has most popular optimizers implemented. We will discuss the specific optimizers and their differences later in the course, but will for now use the simplest of them: torch.optim.SGD. Stochastic Gradient Descent updates parameters by multiplying the gradients with a small constant, called learning rate, and subtracting those from the parameters (hence minimizing the loss). Therefore, we slowly move towards the direction of minimizing the loss. A good default value of the learning rate for a small network as ours is 0.1.

[46]:
# Input to the optimizer are the parameters of the model: model.parameters()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

The optimizer provides two useful functions: optimizer.step(), and optimizer.zero_grad(). The step function updates the parameters based on the gradients as explained above. The function optimizer.zero_grad() sets the gradients of all parameters to zero. While this function seems less relevant at first, it is a crucial pre-step before performing backpropagation. If we would call the backward function on the loss while the parameter gradients are non-zero from the previous batch, the new gradients would actually be added to the previous ones instead of overwriting them. This is done because a parameter might occur multiple times in a computation graph, and we need to sum the gradients in this case instead of replacing them. Hence, remember to call optimizer.zero_grad() before calculating the gradients of a batch.

Training

Finally, we are ready to train our model. As a first step, we create a slightly larger dataset and specify a data loader with a larger batch size.

[47]:
train_dataset = XORDataset(size=1000)
train_data_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=True)

Now, we can write a small training function. Remember our five steps: load a batch, obtain the predictions, calculate the loss, backpropagate, and update. Additionally, we have to push all data and model parameters to the device of our choice (GPU if available). For the tiny neural network we have, communicating the data to the GPU actually takes much more time than we could save from running the operation on GPU. For large networks, the communication time is significantly smaller than the actual runtime making a GPU crucial in these cases. Still, to practice, we will push the data to GPU here.

[48]:
# Push model to device. Has to be only done once
model.to(device)
[48]:
SimpleClassifier(
  (linear1): Linear(in_features=2, out_features=4, bias=True)
  (act_fn): Tanh()
  (linear2): Linear(in_features=4, out_features=1, bias=True)
)

In addition, we set our model to training mode. This is done by calling model.train(). There exist certain modules that need to perform a different forward step during training than during testing (e.g. BatchNorm and Dropout), and we can switch between them using model.train() and model.eval().

[49]:
def train_model(model, optimizer, data_loader, loss_module, num_epochs=100):
    # Set model to train mode
    model.train()

    # Training loop
    for epoch in tqdm(range(num_epochs)):
        for data_inputs, data_labels in data_loader:

            # Step 1: Move input data to device (only strictly necessary if we use GPU)
            data_inputs = data_inputs.to(device)
            data_labels = data_labels.to(device)

            # Step 2: Run the model on the input data
            preds = model(data_inputs)
            preds = preds.squeeze(dim=1)  # Output is [Batch size, 1], but we want [Batch size]

            # Step 3: Calculate the loss
            loss = loss_module(preds, data_labels.float())

            # Step 4: Perform backpropagation
            # Before calculating the gradients, we need to ensure that they are all zero.
            # The gradients would not be overwritten, but actually added to the existing ones.
            optimizer.zero_grad()
            # Perform backpropagation
            loss.backward()

            # Step 5: Update the parameters
            optimizer.step()
[50]:
train_model(model, optimizer, train_data_loader, loss_module)
Saving a model

After finish training a model, we save the model to disk so that we can load the same weights at a later time. For this, we extract the so-called state_dict from the model which contains all learnable parameters. For our simple model, the state dict contains the following entries:

[51]:
state_dict = model.state_dict()
print(state_dict)
OrderedDict([('linear1.weight', tensor([[-2.0209, -2.3101],
        [ 1.3066, -1.8463],
        [-1.5089, -0.6550],
        [-0.7824, -0.9385]], device='cuda:0')), ('linear1.bias', tensor([ 0.7382, -0.9136,  1.4607, -0.1769], device='cuda:0')), ('linear2.weight', tensor([[-2.5543,  1.9722,  2.1591, -0.4553]], device='cuda:0')), ('linear2.bias', tensor([-1.0225], device='cuda:0'))])

To save the state dictionary, we can use torch.save:

[52]:
# torch.save(object, filename). For the filename, any extension can be used
torch.save(state_dict, "our_model.tar")

To load a model from a state dict, we use the function torch.load to load the state dict from the disk, and the module function load_state_dict to overwrite our parameters with the new values:

[53]:
# Load state dict from the disk (make sure it is the same name as above)
state_dict = torch.load("our_model.tar")

# Create a new model and load the state
new_model = SimpleClassifier(num_inputs=2, num_hidden=4, num_outputs=1)
new_model.load_state_dict(state_dict)

# Verify that the parameters are the same
print("Original model\n", model.state_dict())
print("\nLoaded model\n", new_model.state_dict())
Original model
 OrderedDict([('linear1.weight', tensor([[-2.0209, -2.3101],
        [ 1.3066, -1.8463],
        [-1.5089, -0.6550],
        [-0.7824, -0.9385]], device='cuda:0')), ('linear1.bias', tensor([ 0.7382, -0.9136,  1.4607, -0.1769], device='cuda:0')), ('linear2.weight', tensor([[-2.5543,  1.9722,  2.1591, -0.4553]], device='cuda:0')), ('linear2.bias', tensor([-1.0225], device='cuda:0'))])

Loaded model
 OrderedDict([('linear1.weight', tensor([[-2.0209, -2.3101],
        [ 1.3066, -1.8463],
        [-1.5089, -0.6550],
        [-0.7824, -0.9385]])), ('linear1.bias', tensor([ 0.7382, -0.9136,  1.4607, -0.1769])), ('linear2.weight', tensor([[-2.5543,  1.9722,  2.1591, -0.4553]])), ('linear2.bias', tensor([-1.0225]))])

A detailed tutorial on saving and loading models in PyTorch can be found here.

Evaluation

Once we have trained a model, it is time to evaluate it on a held-out test set. As our dataset consist of randomly generated data points, we need to first create a test set with a corresponding data loader.

[54]:
test_dataset = XORDataset(size=500)
# drop_last -> Don't drop the last batch although it is smaller than 128
test_data_loader = data.DataLoader(test_dataset, batch_size=128, shuffle=False, drop_last=False)

As metric, we will use accuracy which is calculated as follows:

acc = \frac{\#\text{correct predictions}}{\#\text{all predictions}} = \frac{TP+TN}{TP+TN+FP+FN}

where TP are the true positives, TN true negatives, FP false positives, and FN the fale negatives.

When evaluating the model, we don’t need to keep track of the computation graph as we don’t intend to calculate the gradients. This reduces the required memory and speed up the model. In PyTorch, we can deactivate the computation graph using with torch.no_grad(): .... Remember to additionally set the model to eval mode.

[55]:
def eval_model(model, data_loader):
    model.eval()  # Set model to eval mode
    true_preds, num_preds = 0.0, 0.0

    with torch.no_grad():  # Deactivate gradients for the following code
        for data_inputs, data_labels in data_loader:

            # Determine prediction of model on dev set
            data_inputs, data_labels = data_inputs.to(device), data_labels.to(device)
            preds = model(data_inputs)
            preds = preds.squeeze(dim=1)
            preds = torch.sigmoid(preds)  # Sigmoid to map predictions between 0 and 1
            pred_labels = (preds >= 0.5).long()  # Binarize predictions to 0 and 1

            # Keep records of predictions for the accuracy metric (true_preds=TP+TN, num_preds=TP+TN+FP+FN)
            true_preds += (pred_labels == data_labels).sum()
            num_preds += data_labels.shape[0]

    acc = true_preds / num_preds
    print(f"Accuracy of the model: {100.0*acc:4.2f}%")
[56]:
eval_model(model, test_data_loader)
Accuracy of the model: 100.00%

If we trained our model correctly, we should see a score close to 100% accuracy. However, this is only possible because of our simple task, and unfortunately, we usually don’t get such high scores on test sets of more complex tasks.

Visualizing classification boundaries

To visualize what our model has learned, we can perform a prediction for every data point in a range of [-0.5, 1.5], and visualize the predicted class as in the sample figure at the beginning of this section. This shows where the model has created decision boundaries, and which points would be classified as 0, and which as 1. We therefore get a background image out of blue (class 0) and orange (class 1). The spots where the model is uncertain we will see a blurry overlap. The specific code is less relevant compared to the output figure which should hopefully show us a clear separation of classes:

[57]:
@torch.no_grad()  # Decorator, same effect as "with torch.no_grad(): ..." over the whole function.
def visualize_classification(model, data, label):
    if isinstance(data, Tensor):
        data = data.cpu().numpy()
    if isinstance(label, Tensor):
        label = label.cpu().numpy()
    data_0 = data[label == 0]
    data_1 = data[label == 1]

    plt.figure(figsize=(4, 4))
    plt.scatter(data_0[:, 0], data_0[:, 1], edgecolor="#333", label="Class 0")
    plt.scatter(data_1[:, 0], data_1[:, 1], edgecolor="#333", label="Class 1")
    plt.title("Dataset samples")
    plt.ylabel(r"$x_2$")
    plt.xlabel(r"$x_1$")
    plt.legend()

    # Let's make use of a lot of operations we have learned above
    model.to(device)
    c0 = Tensor(to_rgba("C0")).to(device)
    c1 = Tensor(to_rgba("C1")).to(device)
    x1 = torch.arange(-0.5, 1.5, step=0.01, device=device)
    x2 = torch.arange(-0.5, 1.5, step=0.01, device=device)
    xx1, xx2 = torch.meshgrid(x1, x2)  # Meshgrid function as in numpy
    model_inputs = torch.stack([xx1, xx2], dim=-1)
    preds = model(model_inputs)
    preds = torch.sigmoid(preds)
    # Specifying "None" in a dimension creates a new one
    output_image = (1 - preds) * c0[None, None] + preds * c1[None, None]
    output_image = (
        output_image.cpu().numpy()
    )  # Convert to numpy array. This only works for tensors on CPU, hence first push to CPU
    plt.imshow(output_image, origin="lower", extent=(-0.5, 1.5, -0.5, 1.5))
    plt.grid(False)


visualize_classification(model, dataset.data, dataset.label)
plt.show()
_images/notebooks_course_UvA-DL_01-introduction-to-pytorch_116_0.svg

The decision boundaries might not look exactly as in the figure in the preamble of this section which can be caused by running it on CPU or a different GPU architecture. Nevertheless, the result on the accuracy metric should be the approximately the same.

Additional features we didn’t get to discuss yet

Finally, you are all set to start with your own PyTorch project! In summary, we have looked at how we can build neural networks in PyTorch, and train and test them on data. However, there is still much more to PyTorch we haven’t discussed yet. In the comming series of Jupyter notebooks, we will discover more and more functionalities of PyTorch, so that you also get familiar to PyTorch concepts beyond the basics. If you are already interested in learning more of PyTorch, we recommend the official tutorial website that contains many tutorials on various topics. Especially logging with Tensorboard (tutorial here) is a good practice that we will explore from Tutorial 5 on.

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Tutorial 2: Activation Functions

  • Author: Phillip Lippe

  • License: CC BY-SA

  • Generated: 2021-09-16T14:32:18.973374

In this tutorial, we will take a closer look at (popular) activation functions and investigate their effect on optimization properties in neural networks. Activation functions are a crucial part of deep learning models as they add the non-linearity to neural networks. There is a great variety of activation functions in the literature, and some are more beneficial than others. The goal of this tutorial is to show the importance of choosing a good activation function (and how to do so), and what problems might occur if we don’t. This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
# ! pip install --quiet "torchmetrics>=0.3" "torch>=1.6, <1.9" "pytorch-lightning>=1.3" "torchvision" "seaborn" "matplotlib"

Before we start, we import our standard libraries and set up basic functions:

[2]:
import json
import math
import os
import urllib.request
import warnings
from urllib.error import HTTPError

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision

# %matplotlib inline
from IPython.display import set_matplotlib_formats
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from tqdm.notebook import tqdm

set_matplotlib_formats("svg", "pdf")  # For export
sns.set()
/tmp/ipykernel_749/3776275675.py:24: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
  set_matplotlib_formats("svg", "pdf")  # For export

We will define a function to set a seed on all libraries we might interact with in this tutorial (here numpy and torch). This allows us to make our training reproducible. However, note that in contrast to the CPU, the same seed on different GPU architectures can give different results. All models here have been trained on an NVIDIA GTX1080Ti.

Additionally, the following cell defines two paths: DATASET_PATH and CHECKPOINT_PATH. The dataset path is the directory where we will download datasets used in the notebooks. It is recommended to store all datasets from PyTorch in one joined directory to prevent duplicate downloads. The checkpoint path is the directory where we will store trained model weights and additional files. The needed files will be automatically downloaded. In case you are on Google Colab, it is recommended to change the directories to start from the current directory (i.e. remove ../ for both dataset and checkpoint path).

[3]:
# Path to the folder where the datasets are/should be downloaded (e.g. MNIST)
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/Activation_Functions/")


# Function for setting the seed
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():  # GPU operation have separate seed
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


set_seed(42)

# Additionally, some operations on a GPU are implemented stochastic for efficiency
# We want to ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

# Fetching the device that will be used throughout this notebook
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
print("Using device", device)
Using device cuda:0

The following cell downloads all pretrained models we will use in this notebook. The files are stored on a separate repository to reduce the size of the notebook repository, especially for building the documentation on ReadTheDocs. In case the download below fails, you can download the models from a Google Drive folder. Please let me (Phillip) know if an error occurs so it can be fixed for all students.

[4]:
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/"
# Files to download
pretrained_files = [
    "FashionMNIST_elu.config",
    "FashionMNIST_elu.tar",
    "FashionMNIST_leakyrelu.config",
    "FashionMNIST_leakyrelu.tar",
    "FashionMNIST_relu.config",
    "FashionMNIST_relu.tar",
    "FashionMNIST_sigmoid.config",
    "FashionMNIST_sigmoid.tar",
    "FashionMNIST_swish.config",
    "FashionMNIST_swish.tar",
    "FashionMNIST_tanh.config",
    "FashionMNIST_tanh.tar",
]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n",
                e,
            )
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/FashionMNIST_elu.config...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/FashionMNIST_elu.tar...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/FashionMNIST_leakyrelu.config...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/FashionMNIST_leakyrelu.tar...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/FashionMNIST_relu.config...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/FashionMNIST_relu.tar...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/FashionMNIST_sigmoid.config...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/FashionMNIST_sigmoid.tar...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/FashionMNIST_swish.config...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/FashionMNIST_swish.tar...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/FashionMNIST_tanh.config...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial3/FashionMNIST_tanh.tar...

Common activation functions

As a first step, we will implement some common activation functions by ourselves. Of course, most of them can also be found in the torch.nn package (see the documentation for an overview). However, we’ll write our own functions here for a better understanding and insights.

For an easier time of comparing various activation functions, we start with defining a base class from which all our future modules will inherit:

[5]:
class ActivationFunction(nn.Module):
    def __init__(self):
        super().__init__()
        self.name = self.__class__.__name__
        self.config = {"name": self.name}

Every activation function will be an nn.Module so that we can integrate them nicely in a network. We will use the config dictionary to store adjustable parameters for some activation functions.

Next, we implement two of the “oldest” activation functions that are still commonly used for various tasks: sigmoid and tanh. Both the sigmoid and tanh activation can be also found as PyTorch functions (torch.sigmoid, torch.tanh) or as modules (nn.Sigmoid, nn.Tanh). Here, we implement them by hand:

[6]:
class Sigmoid(ActivationFunction):
    def forward(self, x):
        return 1 / (1 + torch.exp(-x))


class Tanh(ActivationFunction):
    def forward(self, x):
        x_exp, neg_x_exp = torch.exp(x), torch.exp(-x)
        return (x_exp - neg_x_exp) / (x_exp + neg_x_exp)

Another popular activation function that has allowed the training of deeper networks, is the Rectified Linear Unit (ReLU). Despite its simplicity of being a piecewise linear function, ReLU has one major benefit compared to sigmoid and tanh: a strong, stable gradient for a large range of values. Based on this idea, a lot of variations of ReLU have been proposed, of which we will implement the following three: LeakyReLU, ELU, and Swish. LeakyReLU replaces the zero settings in the negative part with a smaller slope to allow gradients to flow also in this part of the input. Similarly, ELU replaces the negative part with an exponential decay. The third, most recently proposed activation function is Swish, which is actually the result of a large experiment with the purpose of finding the “optimal” activation function. Compared to the other activation functions, Swish is both smooth and non-monotonic (i.e. contains a change of sign in the gradient). This has been shown to prevent dead neurons as in standard ReLU activation, especially for deep networks. If interested, a more detailed discussion of the benefits of Swish can be found in this paper [1].

Let’s implement the four activation functions below:

[7]:
class ReLU(ActivationFunction):
    def forward(self, x):
        return x * (x > 0).float()


class LeakyReLU(ActivationFunction):
    def __init__(self, alpha=0.1):
        super().__init__()
        self.config["alpha"] = alpha

    def forward(self, x):
        return torch.where(x > 0, x, self.config["alpha"] * x)


class ELU(ActivationFunction):
    def forward(self, x):
        return torch.where(x > 0, x, torch.exp(x) - 1)


class Swish(ActivationFunction):
    def forward(self, x):
        return x * torch.sigmoid(x)

For later usage, we summarize all our activation functions in a dictionary mapping the name to the class object. In case you implement a new activation function by yourself, add it here to include it in future comparisons as well:

[8]:
act_fn_by_name = {"sigmoid": Sigmoid, "tanh": Tanh, "relu": ReLU, "leakyrelu": LeakyReLU, "elu": ELU, "swish": Swish}
Visualizing activation functions

To get an idea of what each activation function actually does, we will visualize them in the following. Next to the actual activation value, the gradient of the function is an important aspect as it is crucial for optimizing the neural network. PyTorch allows us to compute the gradients simply by calling the backward function:

[9]:
def get_grads(act_fn, x):
    """Computes the gradients of an activation function at specified positions.

    Args:
        act_fn: An object of the class "ActivationFunction" with an implemented forward pass.
        x: 1D input tensor.
    Returns:
        A tensor with the same size of x containing the gradients of act_fn at x.
    """
    x = x.clone().requires_grad_()  # Mark the input as tensor for which we want to store gradients
    out = act_fn(x)
    out.sum().backward()  # Summing results in an equal gradient flow to each element in x
    return x.grad  # Accessing the gradients of x by "x.grad"

Now we can visualize all our activation functions including their gradients:

[10]:
def vis_act_fn(act_fn, ax, x):
    # Run activation function
    y = act_fn(x)
    y_grads = get_grads(act_fn, x)
    # Push x, y and gradients back to cpu for plotting
    x, y, y_grads = x.cpu().numpy(), y.cpu().numpy(), y_grads.cpu().numpy()
    # Plotting
    ax.plot(x, y, linewidth=2, label="ActFn")
    ax.plot(x, y_grads, linewidth=2, label="Gradient")
    ax.set_title(act_fn.name)
    ax.legend()
    ax.set_ylim(-1.5, x.max())


# Add activation functions if wanted
act_fns = [act_fn() for act_fn in act_fn_by_name.values()]
x = torch.linspace(-5, 5, 1000)  # Range on which we want to visualize the activation functions
# Plotting
cols = 2
rows = math.ceil(len(act_fns) / float(cols))
fig, ax = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
for i, act_fn in enumerate(act_fns):
    vis_act_fn(act_fn, ax[divmod(i, cols)], x)
fig.subplots_adjust(hspace=0.3)
plt.show()
_images/notebooks_course_UvA-DL_02-activation-functions_21_0.svg

Analysing the effect of activation functions

After implementing and visualizing the activation functions, we are aiming to gain insights into their effect. We do this by using a simple neural network trained on FashionMNIST and examine various aspects of the model, including the performance and gradient flow.

Setup

Firstly, let’s set up a neural network. The chosen network views the images as 1D tensors and pushes them through a sequence of linear layers and a specified activation function. Feel free to experiment with other network architectures.

[11]:
class BaseNetwork(nn.Module):
    def __init__(self, act_fn, input_size=784, num_classes=10, hidden_sizes=[512, 256, 256, 128]):
        """
        Args:
            act_fn: Object of the activation function that should be used as non-linearity in the network.
            input_size: Size of the input images in pixels
            num_classes: Number of classes we want to predict
            hidden_sizes: A list of integers specifying the hidden layer sizes in the NN
        """
        super().__init__()

        # Create the network based on the specified hidden sizes
        layers = []
        layer_sizes = [input_size] + hidden_sizes
        layer_size_last = layer_sizes[0]
        for layer_size in layer_sizes[1:]:
            layers += [nn.Linear(layer_size_last, layer_size), act_fn]
            layer_size_last = layer_size
        layers += [nn.Linear(layer_sizes[-1], num_classes)]
        # nn.Sequential summarizes a list of modules into a single module, applying them in sequence
        self.layers = nn.Sequential(*layers)

        # We store all hyperparameters in a dictionary for saving and loading of the model
        self.config = {
            "act_fn": act_fn.config,
            "input_size": input_size,
            "num_classes": num_classes,
            "hidden_sizes": hidden_sizes,
        }

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Reshape images to a flat vector
        out = self.layers(x)
        return out

We also add functions for loading and saving the model. The hyperparameters are stored in a configuration file (simple json file):

[12]:
def _get_config_file(model_path, model_name):
    # Name of the file for storing hyperparameter details
    return os.path.join(model_path, model_name + ".config")


def _get_model_file(model_path, model_name):
    # Name of the file for storing network parameters
    return os.path.join(model_path, model_name + ".tar")


def load_model(model_path, model_name, net=None):
    """Loads a saved model from disk.

    Args:
        model_path: Path of the checkpoint directory
        model_name: Name of the model (str)
        net: (Optional) If given, the state dict is loaded into this model. Otherwise, a new model is created.
    """
    config_file, model_file = _get_config_file(model_path, model_name), _get_model_file(model_path, model_name)
    assert os.path.isfile(
        config_file
    ), f'Could not find the config file "{config_file}". Are you sure this is the correct path and you have your model config stored here?'
    assert os.path.isfile(
        model_file
    ), f'Could not find the model file "{model_file}". Are you sure this is the correct path and you have your model stored here?'
    with open(config_file) as f:
        config_dict = json.load(f)
    if net is None:
        act_fn_name = config_dict["act_fn"].pop("name").lower()
        act_fn = act_fn_by_name[act_fn_name](**config_dict.pop("act_fn"))
        net = BaseNetwork(act_fn=act_fn, **config_dict)
    net.load_state_dict(torch.load(model_file, map_location=device))
    return net


def save_model(model, model_path, model_name):
    """Given a model, we save the state_dict and hyperparameters.

    Args:
        model: Network object to save parameters from
        model_path: Path of the checkpoint directory
        model_name: Name of the model (str)
    """
    config_dict = model.config
    os.makedirs(model_path, exist_ok=True)
    config_file, model_file = _get_config_file(model_path, model_name), _get_model_file(model_path, model_name)
    with open(config_file, "w") as f:
        json.dump(config_dict, f)
    torch.save(model.state_dict(), model_file)

We also set up the dataset we want to train it on, namely FashionMNIST. FashionMNIST is a more complex version of MNIST and contains black-and-white images of clothes instead of digits. The 10 classes include trousers, coats, shoes, bags and more. To load this dataset, we will make use of yet another PyTorch package, namely torchvision (documentation). The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision. We will use the package for many of the notebooks in this course to simplify our dataset handling.

Let’s load the dataset below, and visualize a few images to get an impression of the data.

[13]:

# Transformations applied on each image => first make them a tensor, then normalize them in the range -1 to 1 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # Loading the training dataset. We need to split it into a training and validation part train_dataset = FashionMNIST(root=DATASET_PATH, train=True, transform=transform, download=True) train_set, val_set = torch.utils.data.random_split(train_dataset, [50000, 10000]) # Loading the test set test_set = FashionMNIST(root=DATASET_PATH, train=False, transform=transform, download=True)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /__w/2/s/.datasets/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting /__w/2/s/.datasets/FashionMNIST/raw/train-images-idx3-ubyte.gz to /__w/2/s/.datasets/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /__w/2/s/.datasets/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting /__w/2/s/.datasets/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /__w/2/s/.datasets/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /__w/2/s/.datasets/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /__w/2/s/.datasets/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /__w/2/s/.datasets/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /__w/2/s/.datasets/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /__w/2/s/.datasets/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /__w/2/s/.datasets/FashionMNIST/raw

Processing...
Done!
/usr/local/lib/python3.9/dist-packages/torchvision/datasets/mnist.py:502: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:143.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

We define a set of data loaders that we can use for various purposes later. Note that for actually training a model, we will use different data loaders with a lower batch size.

[14]:
train_loader = data.DataLoader(train_set, batch_size=1024, shuffle=True, drop_last=False)
val_loader = data.DataLoader(val_set, batch_size=1024, shuffle=False, drop_last=False)
test_loader = data.DataLoader(test_set, batch_size=1024, shuffle=False, drop_last=False)
[15]:
exmp_imgs = [train_set[i][0] for i in range(16)]
# Organize the images into a grid for nicer visualization
img_grid = torchvision.utils.make_grid(torch.stack(exmp_imgs, dim=0), nrow=4, normalize=True, pad_value=0.5)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8, 8))
plt.title("FashionMNIST examples")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()
_images/notebooks_course_UvA-DL_02-activation-functions_33_0.svg
Visualizing the gradient flow after initialization

As mentioned previously, one important aspect of activation functions is how they propagate gradients through the network. Imagine we have a very deep neural network with more than 50 layers. The gradients for the input layer, i.e. the very first layer, have passed >50 times the activation function, but we still want them to be of a reasonable size. If the gradient through the activation function is (in expectation) considerably smaller than 1, our gradients will vanish until they reach the input layer. If the gradient through the activation function is larger than 1, the gradients exponentially increase and might explode.

To get a feeling of how every activation function influences the gradients, we can look at a freshly initialized network and measure the gradients for each parameter for a batch of 256 images:

[16]:
def visualize_gradients(net, color="C0"):
    """
    Args:
        net: Object of class BaseNetwork
        color: Color in which we want to visualize the histogram (for easier separation of activation functions)
    """
    net.eval()
    small_loader = data.DataLoader(train_set, batch_size=256, shuffle=False)
    imgs, labels = next(iter(small_loader))
    imgs, labels = imgs.to(device), labels.to(device)

    # Pass one batch through the network, and calculate the gradients for the weights
    net.zero_grad()
    preds = net(imgs)
    loss = F.cross_entropy(preds, labels)
    loss.backward()
    # We limit our visualization to the weight parameters and exclude the bias to reduce the number of plots
    grads = {
        name: params.grad.data.view(-1).cpu().clone().numpy()
        for name, params in net.named_parameters()
        if "weight" in name
    }
    net.zero_grad()

    # Plotting
    columns = len(grads)
    fig, ax = plt.subplots(1, columns, figsize=(columns * 3.5, 2.5))
    fig_index = 0
    for key in grads:
        key_ax = ax[fig_index % columns]
        sns.histplot(data=grads[key], bins=30, ax=key_ax, color=color, kde=True)
        key_ax.set_title(str(key))
        key_ax.set_xlabel("Grad magnitude")
        fig_index += 1
    fig.suptitle(
        f"Gradient magnitude distribution for activation function {net.config['act_fn']['name']}", fontsize=14, y=1.05
    )
    fig.subplots_adjust(wspace=0.45)
    plt.show()
    plt.close()
[17]:
# Seaborn prints warnings if histogram has small values. We can ignore them for now
warnings.filterwarnings("ignore")
# Create a plot for every activation function
for i, act_fn_name in enumerate(act_fn_by_name):
    # Setting the seed ensures that we have the same weight initialization for each activation function
    set_seed(42)
    act_fn = act_fn_by_name[act_fn_name]()
    net_actfn = BaseNetwork(act_fn=act_fn).to(device)
    visualize_gradients(net_actfn, color=f"C{i}")
_images/notebooks_course_UvA-DL_02-activation-functions_36_0.svg
_images/notebooks_course_UvA-DL_02-activation-functions_36_1.svg
_images/notebooks_course_UvA-DL_02-activation-functions_36_2.svg
_images/notebooks_course_UvA-DL_02-activation-functions_36_3.svg
_images/notebooks_course_UvA-DL_02-activation-functions_36_4.svg
_images/notebooks_course_UvA-DL_02-activation-functions_36_5.svg

The sigmoid activation function shows a clearly undesirable behavior. While the gradients for the output layer are very large with up to 0.1, the input layer has the lowest gradient norm across all activation functions with only 1e-5. This is due to its small maximum gradient of 1/4, and finding a suitable learning rate across all layers is not possible in this setup. All the other activation functions show to have similar gradient norms across all layers. Interestingly, the ReLU activation has a spike around 0 which is caused by its zero-part on the left, and dead neurons (we will take a closer look at this later on).

Note that additionally to the activation, the initialization of the weight parameters can be crucial. By default, PyTorch uses the Kaiming initialization for linear layers optimized for Tanh activations. In Tutorial 4, we will take a closer look at initialization, but assume for now that the Kaiming initialization works for all activation functions reasonably well.

Training a model

Next, we want to train our model with different activation functions on FashionMNIST and compare the gained performance. All in all, our final goal is to achieve the best possible performance on a dataset of our choice. Therefore, we write a training loop in the next cell including a validation after every epoch and a final test on the best model:

[18]:
def train_model(net, model_name, max_epochs=50, patience=7, batch_size=256, overwrite=False):
    """Train a model on the training set of FashionMNIST.

    Args:
        net: Object of BaseNetwork
        model_name: (str) Name of the model, used for creating the checkpoint names
        max_epochs: Number of epochs we want to (maximally) train for
        patience: If the performance on the validation set has not improved for #patience epochs, we stop training early
        batch_size: Size of batches used in training
        overwrite: Determines how to handle the case when there already exists a checkpoint. If True, it will be overwritten. Otherwise, we skip training.
    """
    file_exists = os.path.isfile(_get_model_file(CHECKPOINT_PATH, model_name))
    if file_exists and not overwrite:
        print("Model file already exists. Skipping training...")
    else:
        if file_exists:
            print("Model file exists, but will be overwritten...")

        # Defining optimizer, loss and data loader
        optimizer = optim.SGD(net.parameters(), lr=1e-2, momentum=0.9)  # Default parameters, feel free to change
        loss_module = nn.CrossEntropyLoss()
        train_loader_local = data.DataLoader(
            train_set, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True
        )

        val_scores = []
        best_val_epoch = -1
        for epoch in range(max_epochs):
            ############
            # Training #
            ############
            net.train()
            true_preds, count = 0.0, 0
            for imgs, labels in tqdm(train_loader_local, desc=f"Epoch {epoch+1}", leave=False):
                imgs, labels = imgs.to(device), labels.to(device)  # To GPU
                optimizer.zero_grad()  # Zero-grad can be placed anywhere before "loss.backward()"
                preds = net(imgs)
                loss = loss_module(preds, labels)
                loss.backward()
                optimizer.step()
                # Record statistics during training
                true_preds += (preds.argmax(dim=-1) == labels).sum()
                count += labels.shape[0]
            train_acc = true_preds / count

            ##############
            # Validation #
            ##############
            val_acc = test_model(net, val_loader)
            val_scores.append(val_acc)
            print(
                f"[Epoch {epoch+1:2i}] Training accuracy: {train_acc*100.0:05.2f}%, Validation accuracy: {val_acc*100.0:05.2f}%"
            )

            if len(val_scores) == 1 or val_acc > val_scores[best_val_epoch]:
                print("\t   (New best performance, saving model...)")
                save_model(net, CHECKPOINT_PATH, model_name)
                best_val_epoch = epoch
            elif best_val_epoch <= epoch - patience:
                print(f"Early stopping due to no improvement over the last {patience} epochs")
                break

        # Plot a curve of the validation accuracy
        plt.plot([i for i in range(1, len(val_scores) + 1)], val_scores)
        plt.xlabel("Epochs")
        plt.ylabel("Validation accuracy")
        plt.title(f"Validation performance of {model_name}")
        plt.show()
        plt.close()

    load_model(CHECKPOINT_PATH, model_name, net=net)
    test_acc = test_model(net, test_loader)
    print((f" Test accuracy: {test_acc*100.0:4.2f}% ").center(50, "=") + "\n")
    return test_acc


def test_model(net, data_loader):
    """Test a model on a specified dataset.

    Args:
        net: Trained model of type BaseNetwork
        data_loader: DataLoader object of the dataset to test on (validation or test)
    """
    net.eval()
    true_preds, count = 0.0, 0
    for imgs, labels in data_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        with torch.no_grad():
            preds = net(imgs).argmax(dim=-1)
            true_preds += (preds == labels).sum().item()
            count += labels.shape[0]
    test_acc = true_preds / count
    return test_acc

We train one model for each activation function. We recommend using the pretrained models to save time if you are running this notebook on CPU.

[19]:
for act_fn_name in act_fn_by_name:
    print(f"Training BaseNetwork with {act_fn_name} activation...")
    set_seed(42)
    act_fn = act_fn_by_name[act_fn_name]()
    net_actfn = BaseNetwork(act_fn=act_fn).to(device)
    train_model(net_actfn, f"FashionMNIST_{act_fn_name}", overwrite=False)
Training BaseNetwork with sigmoid activation...
Model file already exists. Skipping training...
============= Test accuracy: 10.00% ==============

Training BaseNetwork with tanh activation...
Model file already exists. Skipping training...
============= Test accuracy: 87.59% ==============

Training BaseNetwork with relu activation...
Model file already exists. Skipping training...
============= Test accuracy: 88.62% ==============

Training BaseNetwork with leakyrelu activation...
Model file already exists. Skipping training...
============= Test accuracy: 88.92% ==============

Training BaseNetwork with elu activation...
Model file already exists. Skipping training...
============= Test accuracy: 87.27% ==============

Training BaseNetwork with swish activation...
Model file already exists. Skipping training...
============= Test accuracy: 88.73% ==============

Not surprisingly, the model using the sigmoid activation function shows to fail and does not improve upon random performance (10 classes => 1/10 for random chance).

All the other activation functions gain similar performance. To have a more accurate conclusion, we would have to train the models for multiple seeds and look at the averages. However, the “optimal” activation function also depends on many other factors (hidden sizes, number of layers, type of layers, task, dataset, optimizer, learning rate, etc.) so that a thorough grid search would not be useful in our case. In the literature, activation functions that have shown to work well with deep networks are all types of ReLU functions we experiment with here, with small gains for specific activation functions in specific networks.

Visualizing the activation distribution

After we have trained the models, we can look at the actual activation values that find inside the model. For instance, how many neurons are set to zero in ReLU? Where do we find most values in Tanh? To answer these questions, we can write a simple function which takes a trained model, applies it to a batch of images, and plots the histogram of the activations inside the network:

[20]:
def visualize_activations(net, color="C0"):
    activations = {}

    net.eval()
    small_loader = data.DataLoader(train_set, batch_size=1024)
    imgs, labels = next(iter(small_loader))
    with torch.no_grad():
        layer_index = 0
        imgs = imgs.to(device)
        imgs = imgs.view(imgs.size(0), -1)
        # We need to manually loop through the layers to save all activations
        for layer_index, layer in enumerate(net.layers[:-1]):
            imgs = layer(imgs)
            activations[layer_index] = imgs.view(-1).cpu().numpy()

    # Plotting
    columns = 4
    rows = math.ceil(len(activations) / columns)
    fig, ax = plt.subplots(rows, columns, figsize=(columns * 2.7, rows * 2.5))
    fig_index = 0
    for key in activations:
        key_ax = ax[fig_index // columns][fig_index % columns]
        sns.histplot(data=activations[key], bins=50, ax=key_ax, color=color, kde=True, stat="density")
        key_ax.set_title(f"Layer {key} - {net.layers[key].__class__.__name__}")
        fig_index += 1
    fig.suptitle(f"Activation distribution for activation function {net.config['act_fn']['name']}", fontsize=14)
    fig.subplots_adjust(hspace=0.4, wspace=0.4)
    plt.show()
    plt.close()
[21]:
for i, act_fn_name in enumerate(act_fn_by_name):
    net_actfn = load_model(model_path=CHECKPOINT_PATH, model_name=f"FashionMNIST_{act_fn_name}").to(device)
    visualize_activations(net_actfn, color=f"C{i}")
_images/notebooks_course_UvA-DL_02-activation-functions_46_0.svg
_images/notebooks_course_UvA-DL_02-activation-functions_46_1.svg
_images/notebooks_course_UvA-DL_02-activation-functions_46_2.svg
_images/notebooks_course_UvA-DL_02-activation-functions_46_3.svg
_images/notebooks_course_UvA-DL_02-activation-functions_46_4.svg
_images/notebooks_course_UvA-DL_02-activation-functions_46_5.svg

As the model with sigmoid activation was not able to train properly, the activations are also less informative and all gathered around 0.5 (the activation at input 0).

The tanh shows a more diverse behavior. While for the input layer we experience a larger amount of neurons to be close to -1 and 1, where the gradients are close to zero, the activations in the two consecutive layers are closer to zero. This is probably because the input layers look for specific features in the input image, and the consecutive layers combine those together. The activations for the last layer are again more biased to the extreme points because the classification layer can be seen as a weighted average of those values (the gradients push the activations to those extremes).

The ReLU has a strong peak at 0, as we initially expected. The effect of having no gradients for negative values is that the network does not have a Gaussian-like distribution after the linear layers, but a longer tail towards the positive values. The LeakyReLU shows a very similar behavior while ELU follows again a more Gaussian-like distribution. The Swish activation seems to lie in between, although it is worth noting that Swish uses significantly higher values than other activation functions (up to 20).

As all activation functions show slightly different behavior although obtaining similar performance for our simple network, it becomes apparent that the selection of the “optimal” activation function really depends on many factors, and is not the same for all possible networks.

Finding dead neurons in ReLU networks

One known drawback of the ReLU activation is the occurrence of “dead neurons”, i.e. neurons with no gradient for any training input. The issue of dead neurons is that as no gradient is provided for the layer, we cannot train the parameters of this neuron in the previous layer to obtain output values besides zero. For dead neurons to happen, the output value of a specific neuron of the linear layer before the ReLU has to be negative for all input images. Considering the large number of neurons we have in a neural network, it is not unlikely for this to happen.

To get a better understanding of how much of a problem this is, and when we need to be careful, we will measure how many dead neurons different networks have. For this, we implement a function which runs the network on the whole training set and records whether a neuron is exactly 0 for all data points or not:

[22]:
@torch.no_grad()
def measure_number_dead_neurons(net):
    """Function to measure the number of dead neurons in a trained neural network.

    For each neuron, we create a boolean variable initially set to 1. If it has an activation unequals 0 at any time, we
    set this variable to 0. After running through the whole training set, only dead neurons will have a 1.
    """
    neurons_dead = [
        torch.ones(layer.weight.shape[0], device=device, dtype=torch.bool)
        for layer in net.layers[:-1]
        if isinstance(layer, nn.Linear)
    ]  # Same shapes as hidden size in BaseNetwork

    net.eval()
    for imgs, labels in tqdm(train_loader, leave=False):  # Run through whole training set
        layer_index = 0
        imgs = imgs.to(device)
        imgs = imgs.view(imgs.size(0), -1)
        for layer in net.layers[:-1]:
            imgs = layer(imgs)
            if isinstance(layer, ActivationFunction):
                # Are all activations == 0 in the batch, and we did not record the opposite in the last batches?
                neurons_dead[layer_index] = torch.logical_and(neurons_dead[layer_index], (imgs == 0).all(dim=0))
                layer_index += 1
    number_neurons_dead = [t.sum().item() for t in neurons_dead]
    print("Number of dead neurons:", number_neurons_dead)
    print(
        "In percentage:",
        ", ".join(
            [f"{(100.0 * num_dead / tens.shape[0]):4.2f}%" for tens, num_dead in zip(neurons_dead, number_neurons_dead)]
        ),
    )

First, we can measure the number of dead neurons for an untrained network:

[23]:
set_seed(42)
net_relu = BaseNetwork(act_fn=ReLU()).to(device)
measure_number_dead_neurons(net_relu)
Number of dead neurons: [0, 0, 3, 10]
In percentage: 0.00%, 0.00%, 1.17%, 7.81%

We see that only a minor amount of neurons are dead, but that they increase with the depth of the layer. However, this is not a problem for the small number of dead neurons we have as the input to later layers is changed due to updates to the weights of previous layers. Therefore, dead neurons in later layers can potentially become “alive”/active again.

How does this look like for a trained network (with the same initialization)?

[24]:
net_relu = load_model(model_path=CHECKPOINT_PATH, model_name="FashionMNIST_relu").to(device)
measure_number_dead_neurons(net_relu)
Number of dead neurons: [0, 0, 0, 3]
In percentage: 0.00%, 0.00%, 0.00%, 2.34%

The number of dead neurons indeed decreased in the later layers. However, it should be noted that dead neurons are especially problematic in the input layer. As the input does not change over epochs (the training set is kept as it is), training the network cannot turn those neurons back active. Still, the input data has usually a sufficiently high standard deviation to reduce the risk of dead neurons.

Finally, we check how the number of dead neurons behaves with increasing layer depth. For instance, let’s take the following 10-layer neural network:

[25]:
set_seed(42)
net_relu = BaseNetwork(
    act_fn=ReLU(),
    hidden_sizes=[256, 256, 256, 256, 256, 128, 128, 128, 128, 128],
).to(device)
measure_number_dead_neurons(net_relu)
Number of dead neurons: [0, 0, 7, 27, 89, 60, 58, 61, 72, 56]
In percentage: 0.00%, 0.00%, 2.73%, 10.55%, 34.77%, 46.88%, 45.31%, 47.66%, 56.25%, 43.75%

The number of dead neurons is significantly higher than before which harms the gradient flow especially in the first iterations. For instance, more than 56% of the neurons in the pre-last layer are dead which creates a considerable bottleneck. Hence, it is advisible to use other nonlinearities like Swish for very deep networks.

Conclusion

In this notebook, we have reviewed a set of six activation functions (sigmoid, tanh, ReLU, LeakyReLU, ELU, and Swish) in neural networks, and discussed how they influence the gradient distribution across layers. Sigmoid tends to fail deep neural networks as the highest gradient it provides is 0.25 leading to vanishing gradients in early layers. All ReLU-based activation functions have shown to perform well, and besides the original ReLU, do not have the issue of dead neurons. When implementing your own neural network, it is recommended to start with a ReLU-based network and select the specific activation function based on the properties of the network.

References

[1] Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. “Searching for activation functions.” arXiv preprint arXiv:1710.05941 (2017). Paper link

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Tutorial 3: Initialization and Optimization

  • Author: Phillip Lippe

  • License: CC BY-SA

  • Generated: 2021-12-04T16:52:46.401516

In this tutorial, we will review techniques for optimization and initialization of neural networks. When increasing the depth of neural networks, there are various challenges we face. Most importantly, we need to have a stable gradient flow through the network, as otherwise, we might encounter vanishing or exploding gradients. This is why we will take a closer look at the following concepts: initialization and optimization. This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "seaborn" "torchmetrics>=0.3" "torchvision" "pytorch-lightning>=1.3" "torch>=1.6, <1.9" "matplotlib"

In the first half of the notebook, we will review different initialization techniques, and go step by step from the simplest initialization to methods that are nowadays used in very deep networks. In the second half, we focus on optimization comparing the optimizers SGD, SGD with Momentum, and Adam.

Let’s start with importing our standard libraries:

[2]:
import copy
import json
import math
import os
import urllib.request
from urllib.error import HTTPError

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

%matplotlib inline
from IPython.display import set_matplotlib_formats
from matplotlib import cm
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from tqdm.notebook import tqdm

set_matplotlib_formats("svg", "pdf")  # For export
sns.set()
/tmp/ipykernel_875/1682095326.py:24: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
  set_matplotlib_formats("svg", "pdf")  # For export

Instead of the set_seed function as in Tutorial 3, we can use PyTorch Lightning’s build-in function pl.seed_everything. We will reuse the path variables DATASET_PATH and CHECKPOINT_PATH as in Tutorial 3. Adjust the paths if necessary.

[3]:
# Path to the folder where the datasets are/should be downloaded (e.g. MNIST)
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/InitOptim/")

# Seed everything
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

# Fetching the device that will be used throughout this notebook
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
print("Using device", device)
Global seed set to 42
Using device cuda:0

In the last part of the notebook, we will train models using three different optimizers. The pretrained models for those are downloaded below.

[4]:
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/"
# Files to download
pretrained_files = [
    "FashionMNIST_SGD.config",
    "FashionMNIST_SGD_results.json",
    "FashionMNIST_SGD.tar",
    "FashionMNIST_SGDMom.config",
    "FashionMNIST_SGDMom_results.json",
    "FashionMNIST_SGDMom.tar",
    "FashionMNIST_Adam.config",
    "FashionMNIST_Adam_results.json",
    "FashionMNIST_Adam.tar",
]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n",
                e,
            )
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_SGD.config...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_SGD_results.json...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_SGD.tar...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_SGDMom.config...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_SGDMom_results.json...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_SGDMom.tar...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_Adam.config...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_Adam_results.json...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial4/FashionMNIST_Adam.tar...

Preparation

Throughout this notebook, we will use a deep fully connected network, similar to our previous tutorial. We will also again apply the network to FashionMNIST, so you can relate to the results of Tutorial 3. We start by loading the FashionMNIST dataset:

[5]:

# Transformations applied on each image => first make them a tensor, then normalize them with mean 0 and std 1 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.2861,), (0.3530,))]) # Loading the training dataset. We need to split it into a training and validation part train_dataset = FashionMNIST(root=DATASET_PATH, train=True, transform=transform, download=True) train_set, val_set = torch.utils.data.random_split(train_dataset, [50000, 10000]) # Loading the test set test_set = FashionMNIST(root=DATASET_PATH, train=False, transform=transform, download=True)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /__w/1/s/.datasets/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting /__w/1/s/.datasets/FashionMNIST/raw/train-images-idx3-ubyte.gz to /__w/1/s/.datasets/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /__w/1/s/.datasets/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting /__w/1/s/.datasets/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /__w/1/s/.datasets/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /__w/1/s/.datasets/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /__w/1/s/.datasets/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /__w/1/s/.datasets/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /__w/1/s/.datasets/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /__w/1/s/.datasets/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /__w/1/s/.datasets/FashionMNIST/raw

Processing...
Done!
/usr/local/lib/python3.9/dist-packages/torchvision/datasets/mnist.py:502: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:143.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

We define a set of data loaders that we can use for various purposes later. Note that for actually training a model, we will use different data loaders with a lower batch size.

[6]:
train_loader = data.DataLoader(train_set, batch_size=1024, shuffle=True, drop_last=False)
val_loader = data.DataLoader(val_set, batch_size=1024, shuffle=False, drop_last=False)
test_loader = data.DataLoader(test_set, batch_size=1024, shuffle=False, drop_last=False)

In comparison to the previous tutorial, we have changed the parameters of the normalization transformation transforms.Normalize. The normalization is now designed to give us an expected mean of 0 and a standard deviation of 1 across pixels. This will be particularly relevant for the discussion about initialization we will look at below, and hence we change it here. It should be noted that in most classification tasks, both normalization techniques (between -1 and 1 or mean 0 and stddev 1) have shown to work well. We can calculate the normalization parameters by determining the mean and standard deviation on the original images:

[7]:
print("Mean", (train_dataset.data.float() / 255.0).mean().item())
print("Std", (train_dataset.data.float() / 255.0).std().item())
Mean 0.28604060411453247
Std 0.3530242443084717

We can verify the transformation by looking at the statistics of a single batch:

[8]:
imgs, _ = next(iter(train_loader))
print(f"Mean: {imgs.mean().item():5.3f}")
print(f"Standard deviation: {imgs.std().item():5.3f}")
print(f"Maximum: {imgs.max().item():5.3f}")
print(f"Minimum: {imgs.min().item():5.3f}")
Mean: 0.009
Standard deviation: 1.012
Maximum: 2.022
Minimum: -0.810

Note that the maximum and minimum are not 1 and -1 anymore, but shifted towards the positive values. This is because FashionMNIST contains a lot of black pixels, similar to MNIST.

Next, we create a linear neural network. We use the same setup as in the previous tutorial.

[9]:
class BaseNetwork(nn.Module):
    def __init__(self, act_fn, input_size=784, num_classes=10, hidden_sizes=[512, 256, 256, 128]):
        """
        Args:
            act_fn: Object of the activation function that should be used as non-linearity in the network.
            input_size: Size of the input images in pixels
            num_classes: Number of classes we want to predict
            hidden_sizes: A list of integers specifying the hidden layer sizes in the NN
        """
        super().__init__()

        # Create the network based on the specified hidden sizes
        layers = []
        layer_sizes = [input_size] + hidden_sizes
        for layer_index in range(1, len(layer_sizes)):
            layers += [nn.Linear(layer_sizes[layer_index - 1], layer_sizes[layer_index]), act_fn]
        layers += [nn.Linear(layer_sizes[-1], num_classes)]
        # A module list registers a list of modules as submodules (e.g. for parameters)
        self.layers = nn.ModuleList(layers)

        self.config = {
            "act_fn": act_fn.__class__.__name__,
            "input_size": input_size,
            "num_classes": num_classes,
            "hidden_sizes": hidden_sizes,
        }

    def forward(self, x):
        x = x.view(x.size(0), -1)
        for layer in self.layers:
            x = layer(x)
        return x

For the activation functions, we make use of PyTorch’s torch.nn library instead of implementing ourselves. However, we also define an Identity activation function. Although this activation function would significantly limit the network’s modeling capabilities, we will use it in the first steps of our discussion about initialization (for simplicity).

[10]:
class Identity(nn.Module):
    def forward(self, x):
        return x


act_fn_by_name = {"tanh": nn.Tanh, "relu": nn.ReLU, "identity": Identity}

Finally, we define a few plotting functions that we will use for our discussions. These functions help us to (1) visualize the weight/parameter distribution inside a network, (2) visualize the gradients that the parameters at different layers receive, and (3) the activations, i.e. the output of the linear layers. The detailed code is not important, but feel free to take a closer look if interested.

[11]:
##############################################################


def plot_dists(val_dict, color="C0", xlabel=None, stat="count", use_kde=True):
    columns = len(val_dict)
    fig, ax = plt.subplots(1, columns, figsize=(columns * 3, 2.5))
    fig_index = 0
    for key in sorted(val_dict.keys()):
        key_ax = ax[fig_index % columns]
        sns.histplot(
            val_dict[key],
            ax=key_ax,
            color=color,
            bins=50,
            stat=stat,
            kde=use_kde and ((val_dict[key].max() - val_dict[key].min()) > 1e-8),
        )  # Only plot kde if there is variance
        hidden_dim_str = (
            r"(%i $\to$ %i)" % (val_dict[key].shape[1], val_dict[key].shape[0]) if len(val_dict[key].shape) > 1 else ""
        )
        key_ax.set_title(f"{key} {hidden_dim_str}")
        if xlabel is not None:
            key_ax.set_xlabel(xlabel)
        fig_index += 1
    fig.subplots_adjust(wspace=0.4)
    return fig


##############################################################


def visualize_weight_distribution(model, color="C0"):
    weights = {}
    for name, param in model.named_parameters():
        if name.endswith(".bias"):
            continue
        key_name = f"Layer {name.split('.')[1]}"
        weights[key_name] = param.detach().view(-1).cpu().numpy()

    # Plotting
    fig = plot_dists(weights, color=color, xlabel="Weight vals")
    fig.suptitle("Weight distribution", fontsize=14, y=1.05)
    plt.show()
    plt.close()


##############################################################


def visualize_gradients(model, color="C0", print_variance=False):
    """
    Args:
        net: Object of class BaseNetwork
        color: Color in which we want to visualize the histogram (for easier separation of activation functions)
    """
    model.eval()
    small_loader = data.DataLoader(train_set, batch_size=1024, shuffle=False)
    imgs, labels = next(iter(small_loader))
    imgs, labels = imgs.to(device), labels.to(device)

    # Pass one batch through the network, and calculate the gradients for the weights
    model.zero_grad()
    preds = model(imgs)
    loss = F.cross_entropy(preds, labels)  # Same as nn.CrossEntropyLoss, but as a function instead of module
    loss.backward()
    # We limit our visualization to the weight parameters and exclude the bias to reduce the number of plots
    grads = {
        name: params.grad.view(-1).cpu().clone().numpy()
        for name, params in model.named_parameters()
        if "weight" in name
    }
    model.zero_grad()

    # Plotting
    fig = plot_dists(grads, color=color, xlabel="Grad magnitude")
    fig.suptitle("Gradient distribution", fontsize=14, y=1.05)
    plt.show()
    plt.close()

    if print_variance:
        for key in sorted(grads.keys()):
            print(f"{key} - Variance: {np.var(grads[key])}")


##############################################################


def visualize_activations(model, color="C0", print_variance=False):
    model.eval()
    small_loader = data.DataLoader(train_set, batch_size=1024, shuffle=False)
    imgs, labels = next(iter(small_loader))
    imgs, labels = imgs.to(device), labels.to(device)

    # Pass one batch through the network, and calculate the gradients for the weights
    feats = imgs.view(imgs.shape[0], -1)
    activations = {}
    with torch.no_grad():
        for layer_index, layer in enumerate(model.layers):
            feats = layer(feats)
            if isinstance(layer, nn.Linear):
                activations[f"Layer {layer_index}"] = feats.view(-1).detach().cpu().numpy()

    # Plotting
    fig = plot_dists(activations, color=color, stat="density", xlabel="Activation vals")
    fig.suptitle("Activation distribution", fontsize=14, y=1.05)
    plt.show()
    plt.close()

    if print_variance:
        for key in sorted(activations.keys()):
            print(f"{key} - Variance: {np.var(activations[key])}")


##############################################################

Initialization

Before starting our discussion about initialization, it should be noted that there exist many very good blog posts about the topic of neural network initialization (for example deeplearning.ai, or a more math-focused blog post). In case something remains unclear after this tutorial, we recommend skimming through these blog posts as well.

When initializing a neural network, there are a few properties we would like to have. First, the variance of the input should be propagated through the model to the last layer, so that we have a similar standard deviation for the output neurons. If the variance would vanish the deeper we go in our model, it becomes much harder to optimize the model as the input to the next layer is basically a single constant value. Similarly, if the variance increases, it is likely to explode (i.e. head to infinity) the deeper we design our model. The second property we look out for in initialization techniques is a gradient distribution with equal variance across layers. If the first layer receives much smaller gradients than the last layer, we will have difficulties in choosing an appropriate learning rate.

As a starting point for finding a good method, we will analyze different initialization based on our linear neural network with no activation function (i.e. an identity). We do this because initializations depend on the specific activation function used in the network, and we can adjust the initialization schemes later on for our specific choice.

[12]:
model = BaseNetwork(act_fn=Identity()).to(device)
Constant initialization

The first initialization we can consider is to initialize all weights with the same constant value. Intuitively, setting all weights to zero is not a good idea as the propagated gradient will be zero. However, what happens if we set all weights to a value slightly larger or smaller than 0? To find out, we can implement a function for setting all parameters below and visualize the gradients.

[13]:
def const_init(model, fill=0.0):
    for name, param in model.named_parameters():
        param.data.fill_(fill)


const_init(model, fill=0.005)
visualize_gradients(model)
visualize_activations(model, print_variance=True)
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_27_0.svg
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_27_1.svg
Layer 0 - Variance: 2.0582761764526367
Layer 2 - Variance: 13.489120483398438
Layer 4 - Variance: 22.100574493408203
Layer 6 - Variance: 36.20957946777344
Layer 8 - Variance: 14.831440925598145

As we can see, only the first and the last layer have diverse gradient distributions while the other three layers have the same gradient for all weights (note that this value is unequal 0, but often very close to it). Having the same gradient for parameters that have been initialized with the same values means that we will always have the same value for those parameters. This would make our layer useless and reduce our effective number of parameters to 1. Thus, we cannot use a constant initialization to train our networks.

Constant variance

From the experiment above, we have seen that a constant value is not working. So instead, how about we initialize the parameters by randomly sampling from a distribution like a Gaussian? The most intuitive way would be to choose one variance that is used for all layers in the network. Let’s implement it below, and visualize the activation distribution across layers.

[14]:
def var_init(model, std=0.01):
    for name, param in model.named_parameters():
        param.data.normal_(mean=0.0, std=std)


var_init(model, std=0.01)
visualize_activations(model, print_variance=True)
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_30_0.svg
Layer 0 - Variance: 0.0768686905503273
Layer 2 - Variance: 0.00374085595831275
Layer 4 - Variance: 0.00021300435764715075
Layer 6 - Variance: 0.000116668117698282
Layer 8 - Variance: 8.082647400442511e-05

The variance of the activation becomes smaller and smaller across layers, and almost vanishes in the last layer. Alternatively, we could use a higher standard deviation:

[15]:
var_init(model, std=0.1)
visualize_activations(model, print_variance=True)
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_32_0.svg
Layer 0 - Variance: 8.08608341217041
Layer 2 - Variance: 41.400367736816406
Layer 4 - Variance: 104.29255676269531
Layer 6 - Variance: 270.63995361328125
Layer 8 - Variance: 288.26495361328125

With a higher standard deviation, the activations are likely to explode. You can play around with the specific standard deviation values, but it will be hard to find one that gives us a good activation distribution across layers and is very specific to our model. If we would change the hidden sizes or number of layers, you would have to search all over again, which is neither efficient nor recommended.

How to find appropriate initialization values

From our experiments above, we have seen that we need to sample the weights from a distribution, but are not sure which one exactly. As a next step, we will try to find the optimal initialization from the perspective of the activation distribution. For this, we state two requirements:

  1. The mean of the activations should be zero

  2. The variance of the activations should stay the same across every layer

Suppose we want to design an initialization for the following layer: y=Wx+b with y\in\mathbb{R}^{d_y}, x\in\mathbb{R}^{d_x}. Our goal is that the variance of each element of y is the same as the input, i.e. \text{Var}(y_i)=\text{Var}(x_i)=\sigma_x^{2}, and that the mean is zero. We assume x to also have a mean of zero, because, in deep neural networks, y would be the input of another layer. This requires the bias and weight to have an expectation of 0. Actually, as b is a single element per output neuron and is constant across different inputs, we set it to 0 overall.

Next, we need to calculate the variance with which we need to initialize the weight parameters. Along the calculation, we will need to following variance rule: given two independent variables, the variance of their product is \text{Var}(X\cdot Y) = \mathbb{E}(Y)^2\text{Var}(X) + \mathbb{E}(X)^2\text{Var}(Y) + \text{Var}(X)\text{Var}(Y) = \mathbb{E}(Y^2)\mathbb{E}(X^2)-\mathbb{E}(Y)^2\mathbb{E}(X)^2 (X and Y are not refering to x and y, but any random variable).

The needed variance of the weights, \text{Var}(w_{ij}), is calculated as follows:

\begin{split}
    y_i & = \sum_{j} w_{ij}x_{j}\hspace{10mm}\text{Calculation of a single output neuron without bias}\\
    \text{Var}(y_i) = \sigma_x^{2} & = \text{Var}\left(\sum_{j} w_{ij}x_{j}\right)\\
    & = \sum_{j} \text{Var}(w_{ij}x_{j}) \hspace{10mm}\text{Inputs and weights are independent of each other}\\
    & = \sum_{j} \text{Var}(w_{ij})\cdot\text{Var}(x_{j}) \hspace{10mm}\text{Variance rule (see above) with expectations being zero}\\
    & = d_x \cdot \text{Var}(w_{ij})\cdot\text{Var}(x_{j}) \hspace{10mm}\text{Variance equal for all $d_x$ elements}\\
    & = \sigma_x^{2} \cdot d_x \cdot \text{Var}(w_{ij})\\
    \Rightarrow \text{Var}(w_{ij}) = \sigma_{W}^2 & = \frac{1}{d_x}\\
\end{split}

Thus, we should initialize the weight distribution with a variance of the inverse of the input dimension d_x. Let’s implement it below and check whether this holds:

[16]:
def equal_var_init(model):
    for name, param in model.named_parameters():
        if name.endswith(".bias"):
            param.data.fill_(0)
        else:
            param.data.normal_(std=1.0 / math.sqrt(param.shape[1]))


equal_var_init(model)
visualize_weight_distribution(model)
visualize_activations(model, print_variance=True)
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_35_0.svg
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_35_1.svg
Layer 0 - Variance: 1.0374585390090942
Layer 2 - Variance: 1.0698715448379517
Layer 4 - Variance: 1.1412183046340942
Layer 6 - Variance: 1.0962424278259277
Layer 8 - Variance: 1.0699418783187866

As we expected, the variance stays indeed constant across layers. Note that our initialization does not restrict us to a normal distribution, but allows any other distribution with a mean of 0 and variance of 1/d_x. You often see that a uniform distribution is used for initialization. A small benefit of using a uniform instead of a normal distribution is that we can exclude the chance of initializing very large or small weights.

Besides the variance of the activations, another variance we would like to stabilize is the one of the gradients. This ensures a stable optimization for deep networks. It turns out that we can do the same calculation as above starting from \Delta x=W\Delta y, and come to the conclusion that we should initialize our layers with 1/d_y where d_y is the number of output neurons. You can do the calculation as a practice, or check a thorough explanation in this blog post. As a compromise between both constraints, Glorot and Bengio (2010) proposed to use the harmonic mean of both values. This leads us to the well-known Xavier initialization:

W\sim \mathcal{N}\left(0,\frac{2}{d_x+d_y}\right)

If we use a uniform distribution, we would initialize the weights with:

W\sim U\left[-\frac{\sqrt{6}}{\sqrt{d_x+d_y}}, \frac{\sqrt{6}}{\sqrt{d_x+d_y}}\right]

Let’s shortly implement it and validate its effectiveness:

[17]:
def xavier_init(model):
    for name, param in model.named_parameters():
        if name.endswith(".bias"):
            param.data.fill_(0)
        else:
            bound = math.sqrt(6) / math.sqrt(param.shape[0] + param.shape[1])
            param.data.uniform_(-bound, bound)


xavier_init(model)
visualize_gradients(model, print_variance=True)
visualize_activations(model, print_variance=True)
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_37_0.svg
layers.0.weight - Variance: 0.000457785528851673
layers.2.weight - Variance: 0.0006751694600097835
layers.4.weight - Variance: 0.0008508111932314932
layers.6.weight - Variance: 0.001484374050050974
layers.8.weight - Variance: 0.011529149487614632
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_37_2.svg
Layer 0 - Variance: 1.1692266464233398
Layer 2 - Variance: 1.520001769065857
Layer 4 - Variance: 1.585775375366211
Layer 6 - Variance: 1.9146416187286377
Layer 8 - Variance: 3.2868599891662598

We see that the Xavier initialization balances the variance of gradients and activations. Note that the significantly higher variance for the output layer is due to the large difference of input and output dimension (128 vs 10). However, we currently assumed the activation function to be linear. So what happens if we add a non-linearity? In a tanh-based network, a common assumption is that for small values during the initial steps in training, the \tanh works as a linear function such that we don’t have to adjust our calculation. We can check if that is the case for us as well:

[18]:
model = BaseNetwork(act_fn=nn.Tanh()).to(device)
xavier_init(model)
visualize_gradients(model, print_variance=True)
visualize_activations(model, print_variance=True)
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_39_0.svg
layers.0.weight - Variance: 2.4351327738258988e-05
layers.2.weight - Variance: 3.7693978811148554e-05
layers.4.weight - Variance: 5.152593075763434e-05
layers.6.weight - Variance: 6.856555410195142e-05
layers.8.weight - Variance: 0.0004877124447375536
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_39_2.svg
Layer 0 - Variance: 1.2570290565490723
Layer 2 - Variance: 0.5786585807800293
Layer 4 - Variance: 0.2740468978881836
Layer 6 - Variance: 0.2201044261455536
Layer 8 - Variance: 0.3423171937465668

Although the variance decreases over depth, it is apparent that the activation distribution becomes more focused on the low values. Therefore, our variance will stabilize around 0.25 if we would go even deeper. Hence, we can conclude that the Xavier initialization works well for Tanh networks. But what about ReLU networks? Here, we cannot take the previous assumption of the non-linearity becoming linear for small values. The ReLU activation function sets (in expectation) half of the inputs to 0 so that also the expectation of the input is not zero. However, as long as the expectation of W is zero and b=0, the expectation of the output is zero. The part where the calculation of the ReLU initialization differs from the identity is when determining \text{Var}(w_{ij}x_{j}):

\text{Var}(w_{ij}x_{j})=\underbrace{\mathbb{E}[w_{ij}^2]}_{=\text{Var}(w_{ij})}\mathbb{E}[x_{j}^2]-\underbrace{\mathbb{E}[w_{ij}]^2}_{=0}\mathbb{E}[x_{j}]^2=\text{Var}(w_{ij})\mathbb{E}[x_{j}^2]

If we assume now that x is the output of a ReLU activation (from a previous layer, x=max(0,\tilde{y})), we can calculate the expectation as follows:

\begin{split}
    \mathbb{E}[x^2] & =\mathbb{E}[\max(0,\tilde{y})^2]\\
                    & =\frac{1}{2}\mathbb{E}[{\tilde{y}}^2]\hspace{2cm}\tilde{y}\text{ is zero-centered and symmetric}\\
                    & =\frac{1}{2}\text{Var}(\tilde{y})
\end{split}

Thus, we see that we have an additional factor of 1/2 in the equation, so that our desired weight variance becomes 2/d_x. This gives us the Kaiming initialization (see He, K. et al. (2015)). Note that the Kaiming initialization does not use the harmonic mean between input and output size. In their paper (Section 2.2, Backward Propagation, last paragraph), they argue that using d_x or d_y both lead to stable gradients throughout the network, and only depend on the overall input and output size of the network. Hence, we can use here only the input d_x:

[19]:
def kaiming_init(model):
    for name, param in model.named_parameters():
        if name.endswith(".bias"):
            param.data.fill_(0)
        elif name.startswith("layers.0"):  # The first layer does not have ReLU applied on its input
            param.data.normal_(0, 1 / math.sqrt(param.shape[1]))
        else:
            param.data.normal_(0, math.sqrt(2) / math.sqrt(param.shape[1]))


model = BaseNetwork(act_fn=nn.ReLU()).to(device)
kaiming_init(model)
visualize_gradients(model, print_variance=True)
visualize_activations(model, print_variance=True)
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_41_0.svg
layers.0.weight - Variance: 4.737732888315804e-05
layers.2.weight - Variance: 5.9308793424861506e-05
layers.4.weight - Variance: 7.343693141592667e-05
layers.6.weight - Variance: 0.00016474377480335534
layers.8.weight - Variance: 0.0029673215467482805
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_41_2.svg
Layer 0 - Variance: 1.037200689315796
Layer 2 - Variance: 1.0582876205444336
Layer 4 - Variance: 1.0638010501861572
Layer 6 - Variance: 1.3167966604232788
Layer 8 - Variance: 0.6909096837043762

The variance stays stable across layers. We can conclude that the Kaiming initialization indeed works well for ReLU-based networks. Note that for Leaky-ReLU etc., we have to slightly adjust the factor of 2 in the variance as half of the values are not set to zero anymore. PyTorch provides a function to calculate this factor for many activation function, see torch.nn.init.calculate_gain (link).

Optimization

Besides initialization, selecting a suitable optimization algorithm can be an important choice for deep neural networks. Before taking a closer look at them, we should define code for training the models. Most of the following code is copied from the previous tutorial, and only slightly altered to fit our needs.

[20]:
def _get_config_file(model_path, model_name):
    return os.path.join(model_path, model_name + ".config")


def _get_model_file(model_path, model_name):
    return os.path.join(model_path, model_name + ".tar")


def _get_result_file(model_path, model_name):
    return os.path.join(model_path, model_name + "_results.json")


def load_model(model_path, model_name, net=None):
    config_file = _get_config_file(model_path, model_name)
    model_file = _get_model_file(model_path, model_name)
    assert os.path.isfile(
        config_file
    ), f'Could not find the config file "{config_file}". Are you sure this is the correct path and you have your model config stored here?'
    assert os.path.isfile(
        model_file
    ), f'Could not find the model file "{model_file}". Are you sure this is the correct path and you have your model stored here?'
    with open(config_file) as f:
        config_dict = json.load(f)
    if net is None:
        act_fn_name = config_dict["act_fn"].pop("name").lower()
        assert (
            act_fn_name in act_fn_by_name
        ), f'Unknown activation function "{act_fn_name}". Please add it to the "act_fn_by_name" dict.'
        act_fn = act_fn_by_name[act_fn_name]()
        net = BaseNetwork(act_fn=act_fn, **config_dict)
    net.load_state_dict(torch.load(model_file))
    return net


def save_model(model, model_path, model_name):
    config_dict = model.config
    os.makedirs(model_path, exist_ok=True)
    config_file = _get_config_file(model_path, model_name)
    model_file = _get_model_file(model_path, model_name)
    with open(config_file, "w") as f:
        json.dump(config_dict, f)
    torch.save(model.state_dict(), model_file)


def train_model(net, model_name, optim_func, max_epochs=50, batch_size=256, overwrite=False):
    """Train a model on the training set of FashionMNIST.

    Args:
        net: Object of BaseNetwork
        model_name: (str) Name of the model, used for creating the checkpoint names
        max_epochs: Number of epochs we want to (maximally) train for
        patience: If the performance on the validation set has not improved for #patience epochs, we stop training early
        batch_size: Size of batches used in training
        overwrite: Determines how to handle the case when there already exists a checkpoint. If True, it will be overwritten. Otherwise, we skip training.
    """
    file_exists = os.path.isfile(_get_model_file(CHECKPOINT_PATH, model_name))
    if file_exists and not overwrite:
        print(f'Model file of "{model_name}" already exists. Skipping training...')
        with open(_get_result_file(CHECKPOINT_PATH, model_name)) as f:
            results = json.load(f)
    else:
        if file_exists:
            print("Model file exists, but will be overwritten...")

        # Defining optimizer, loss and data loader
        optimizer = optim_func(net.parameters())
        loss_module = nn.CrossEntropyLoss()
        train_loader_local = data.DataLoader(
            train_set, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True
        )

        results = None
        val_scores = []
        train_losses, train_scores = [], []
        best_val_epoch = -1
        for epoch in range(max_epochs):
            train_acc, val_acc, epoch_losses = epoch_iteration(
                net, loss_module, optimizer, train_loader_local, val_loader, epoch
            )
            train_scores.append(train_acc)
            val_scores.append(val_acc)
            train_losses += epoch_losses

            if len(val_scores) == 1 or val_acc > val_scores[best_val_epoch]:
                print("\t   (New best performance, saving model...)")
                save_model(net, CHECKPOINT_PATH, model_name)
                best_val_epoch = epoch

    if results is None:
        load_model(CHECKPOINT_PATH, model_name, net=net)
        test_acc = test_model(net, test_loader)
        results = {
            "test_acc": test_acc,
            "val_scores": val_scores,
            "train_losses": train_losses,
            "train_scores": train_scores,
        }
        with open(_get_result_file(CHECKPOINT_PATH, model_name), "w") as f:
            json.dump(results, f)

    # Plot a curve of the validation accuracy
    sns.set()
    plt.plot([i for i in range(1, len(results["train_scores"]) + 1)], results["train_scores"], label="Train")
    plt.plot([i for i in range(1, len(results["val_scores"]) + 1)], results["val_scores"], label="Val")
    plt.xlabel("Epochs")
    plt.ylabel("Validation accuracy")
    plt.ylim(min(results["val_scores"]), max(results["train_scores"]) * 1.01)
    plt.title(f"Validation performance of {model_name}")
    plt.legend()
    plt.show()
    plt.close()

    print((f" Test accuracy: {results['test_acc']*100.0:4.2f}% ").center(50, "=") + "\n")
    return results


def epoch_iteration(net, loss_module, optimizer, train_loader_local, val_loader, epoch):
    ############
    # Training #
    ############
    net.train()
    true_preds, count = 0.0, 0
    epoch_losses = []
    t = tqdm(train_loader_local, leave=False)
    for imgs, labels in t:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        preds = net(imgs)
        loss = loss_module(preds, labels)
        loss.backward()
        optimizer.step()
        # Record statistics during training
        true_preds += (preds.argmax(dim=-1) == labels).sum().item()
        count += labels.shape[0]
        t.set_description(f"Epoch {epoch+1}: loss={loss.item():4.2f}")
        epoch_losses.append(loss.item())
    train_acc = true_preds / count

    ##############
    # Validation #
    ##############
    val_acc = test_model(net, val_loader)
    print(
        f"[Epoch {epoch+1:2i}] Training accuracy: {train_acc*100.0:05.2f}%, Validation accuracy: {val_acc*100.0:05.2f}%"
    )
    return train_acc, val_acc, epoch_losses


def test_model(net, data_loader):
    """Test a model on a specified dataset.

    Args:
        net: Trained model of type BaseNetwork
        data_loader: DataLoader object of the dataset to test on (validation or test)
    """
    net.eval()
    true_preds, count = 0.0, 0
    for imgs, labels in data_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        with torch.no_grad():
            preds = net(imgs).argmax(dim=-1)
            true_preds += (preds == labels).sum().item()
            count += labels.shape[0]
    test_acc = true_preds / count
    return test_acc

First, we need to understand what an optimizer actually does. The optimizer is responsible to update the network’s parameters given the gradients. Hence, we effectively implement a function w^{t} = f(w^{t-1}, g^{t}, ...) with w being the parameters, and g^{t} = \nabla_{w^{(t-1)}} \mathcal{L}^{(t)} the gradients at time step t. A common, additional parameter to this function is the learning rate, here denoted by \eta. Usually, the learning rate can be seen as the “step size” of the update. A higher learning rate means that we change the weights more in the direction of the gradients, a smaller means we take shorter steps.

As most optimizers only differ in the implementation of f, we can define a template for an optimizer in PyTorch below. We take as input the parameters of a model and a learning rate. The function zero_grad sets the gradients of all parameters to zero, which we have to do before calling loss.backward(). Finally, the step() function tells the optimizer to update all weights based on their gradients. The template is setup below:

[21]:
class OptimizerTemplate:
    def __init__(self, params, lr):
        self.params = list(params)
        self.lr = lr

    def zero_grad(self):
        # Set gradients of all parameters to zero
        for p in self.params:
            if p.grad is not None:
                p.grad.detach_()  # For second-order optimizers important
                p.grad.zero_()

    @torch.no_grad()
    def step(self):
        # Apply update step to all parameters
        for p in self.params:
            if p.grad is None:  # We skip parameters without any gradients
                continue
            self.update_param(p)

    def update_param(self, p):
        # To be implemented in optimizer-specific classes
        raise NotImplementedError

The first optimizer we are going to implement is the standard Stochastic Gradient Descent (SGD). SGD updates the parameters using the following equation:

\begin{split}
    w^{(t)} & = w^{(t-1)} - \eta \cdot g^{(t)}
\end{split}

As simple as the equation is also our implementation of SGD:

[22]:
class SGD(OptimizerTemplate):
    def __init__(self, params, lr):
        super().__init__(params, lr)

    def update_param(self, p):
        p_update = -self.lr * p.grad
        p.add_(p_update)  # In-place update => saves memory and does not create computation graph

In the lecture, we also have discussed the concept of momentum which replaces the gradient in the update by an exponential average of all past gradients including the current one:

\begin{split}
    m^{(t)} & = \beta_1 m^{(t-1)} + (1 - \beta_1)\cdot g^{(t)}\\
    w^{(t)} & = w^{(t-1)} - \eta \cdot m^{(t)}\\
\end{split}

Let’s also implement it below:

[23]:
class SGDMomentum(OptimizerTemplate):
    def __init__(self, params, lr, momentum=0.0):
        super().__init__(params, lr)
        self.momentum = momentum  # Corresponds to beta_1 in the equation above
        self.param_momentum = {p: torch.zeros_like(p.data) for p in self.params}  # Dict to store m_t

    def update_param(self, p):
        self.param_momentum[p] = (1 - self.momentum) * p.grad + self.momentum * self.param_momentum[p]
        p_update = -self.lr * self.param_momentum[p]
        p.add_(p_update)

Finally, we arrive at Adam. Adam combines the idea of momentum with an adaptive learning rate, which is based on an exponential average of the squared gradients, i.e. the gradients norm. Furthermore, we add a bias correction for the momentum and adaptive learning rate for the first iterations:

\begin{split}
    m^{(t)} & = \beta_1 m^{(t-1)} + (1 - \beta_1)\cdot g^{(t)}\\
    v^{(t)} & = \beta_2 v^{(t-1)} + (1 - \beta_2)\cdot \left(g^{(t)}\right)^2\\
    \hat{m}^{(t)} & = \frac{m^{(t)}}{1-\beta^{t}_1}, \hat{v}^{(t)} = \frac{v^{(t)}}{1-\beta^{t}_2}\\
    w^{(t)} & = w^{(t-1)} - \frac{\eta}{\sqrt{v^{(t)}} + \epsilon}\circ \hat{m}^{(t)}\\
\end{split}

Epsilon is a small constant used to improve numerical stability for very small gradient norms. Remember that the adaptive learning rate does not replace the learning rate hyperparameter \eta, but rather acts as an extra factor and ensures that the gradients of various parameters have a similar norm.

[24]:
class Adam(OptimizerTemplate):
    def __init__(self, params, lr, beta1=0.9, beta2=0.999, eps=1e-8):
        super().__init__(params, lr)
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.param_step = {p: 0 for p in self.params}  # Remembers "t" for each parameter for bias correction
        self.param_momentum = {p: torch.zeros_like(p.data) for p in self.params}
        self.param_2nd_momentum = {p: torch.zeros_like(p.data) for p in self.params}

    def update_param(self, p):
        self.param_step[p] += 1

        self.param_momentum[p] = (1 - self.beta1) * p.grad + self.beta1 * self.param_momentum[p]
        self.param_2nd_momentum[p] = (1 - self.beta2) * (p.grad) ** 2 + self.beta2 * self.param_2nd_momentum[p]

        bias_correction_1 = 1 - self.beta1 ** self.param_step[p]
        bias_correction_2 = 1 - self.beta2 ** self.param_step[p]

        p_2nd_mom = self.param_2nd_momentum[p] / bias_correction_2
        p_mom = self.param_momentum[p] / bias_correction_1
        p_lr = self.lr / (torch.sqrt(p_2nd_mom) + self.eps)
        p_update = -p_lr * p_mom

        p.add_(p_update)
Comparing optimizers on model training

After we have implemented three optimizers (SGD, SGD with momentum, and Adam), we can start to analyze and compare them. First, we test them on how well they can optimize a neural network on the FashionMNIST dataset. We use again our linear network, this time with a ReLU activation and the kaiming initialization, which we have found before to work well for ReLU-based networks. Note that the model is over-parameterized for this task, and we can achieve similar performance with a much smaller network (for example 100,100,100). However, our main interest is in how well the optimizer can train deep neural networks, hence the over-parameterization.

[25]:
base_model = BaseNetwork(act_fn=nn.ReLU(), hidden_sizes=[512, 256, 256, 128])
kaiming_init(base_model)

For a fair comparison, we train the exact same model with the same seed with the three optimizers below. Feel free to change the hyperparameters if you want (however, you have to train your own model then).

[26]:
SGD_model = copy.deepcopy(base_model).to(device)
SGD_results = train_model(
    SGD_model, "FashionMNIST_SGD", lambda params: SGD(params, lr=1e-1), max_epochs=40, batch_size=256
)
Model file of "FashionMNIST_SGD" already exists. Skipping training...
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_56_1.svg
============= Test accuracy: 89.09% ==============

[27]:
SGDMom_model = copy.deepcopy(base_model).to(device)
SGDMom_results = train_model(
    SGDMom_model,
    "FashionMNIST_SGDMom",
    lambda params: SGDMomentum(params, lr=1e-1, momentum=0.9),
    max_epochs=40,
    batch_size=256,
)
Model file of "FashionMNIST_SGDMom" already exists. Skipping training...
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_57_1.svg
============= Test accuracy: 88.83% ==============

[28]:
Adam_model = copy.deepcopy(base_model).to(device)
Adam_results = train_model(
    Adam_model, "FashionMNIST_Adam", lambda params: Adam(params, lr=1e-3), max_epochs=40, batch_size=256
)
Model file of "FashionMNIST_Adam" already exists. Skipping training...
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_58_1.svg
============= Test accuracy: 89.46% ==============

The result is that all optimizers perform similarly well with the given model. The differences are too small to find any significant conclusion. However, keep in mind that this can also be attributed to the initialization we chose. When changing the initialization to worse (e.g. constant initialization), Adam usually shows to be more robust because of its adaptive learning rate. To show the specific benefits of the optimizers, we will continue to look at some possible loss surfaces in which momentum and adaptive learning rate are crucial.

Pathological curvatures

A pathological curvature is a type of surface that is similar to ravines and is particularly tricky for plain SGD optimization. In words, pathological curvatures typically have a steep gradient in one direction with an optimum at the center, while in a second direction we have a slower gradient towards a (global) optimum. Let’s first create an example surface of this and visualize it:

[29]:
def pathological_curve_loss(w1, w2):
    # Example of a pathological curvature. There are many more possible, feel free to experiment here!
    x1_loss = torch.tanh(w1) ** 2 + 0.01 * torch.abs(w1)
    x2_loss = torch.sigmoid(w2)
    return x1_loss + x2_loss
[30]:
def plot_curve(
    curve_fn, x_range=(-5, 5), y_range=(-5, 5), plot_3d=False, cmap=cm.viridis, title="Pathological curvature"
):
    fig = plt.figure()
    ax = fig.gca(projection="3d") if plot_3d else fig.gca()

    x = torch.arange(x_range[0], x_range[1], (x_range[1] - x_range[0]) / 100.0)
    y = torch.arange(y_range[0], y_range[1], (y_range[1] - y_range[0]) / 100.0)
    x, y = torch.meshgrid([x, y])
    z = curve_fn(x, y)
    x, y, z = x.numpy(), y.numpy(), z.numpy()

    if plot_3d:
        ax.plot_surface(x, y, z, cmap=cmap, linewidth=1, color="#000", antialiased=False)
        ax.set_zlabel("loss")
    else:
        ax.imshow(z.T[::-1], cmap=cmap, extent=(x_range[0], x_range[1], y_range[0], y_range[1]))
    plt.title(title)
    ax.set_xlabel(r"$w_1$")
    ax.set_ylabel(r"$w_2$")
    plt.tight_layout()
    return ax


sns.reset_orig()
_ = plot_curve(pathological_curve_loss, plot_3d=True)
plt.show()
/tmp/ipykernel_875/1102210584.py:5: MatplotlibDeprecationWarning: Calling gca() with keyword arguments was deprecated in Matplotlib 3.4. Starting two minor releases later, gca() will take no keyword arguments. The gca() function should only be used to get the current axes, or if no axes exist, create new axes with default keyword arguments. To create a new axes with non-default arguments, use plt.axes() or plt.subplot().
  ax = fig.gca(projection="3d") if plot_3d else fig.gca()
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_62_1.svg

In terms of optimization, you can image that w_1 and w_2 are weight parameters, and the curvature represents the loss surface over the space of w_1 and w_2. Note that in typical networks, we have many, many more parameters than two, and such curvatures can occur in multi-dimensional spaces as well.

Ideally, our optimization algorithm would find the center of the ravine and focuses on optimizing the parameters towards the direction of w_2. However, if we encounter a point along the ridges, the gradient is much greater in w_1 than w_2, and we might end up jumping from one side to the other. Due to the large gradients, we would have to reduce our learning rate slowing down learning significantly.

To test our algorithms, we can implement a simple function to train two parameters on such a surface:

[31]:
def train_curve(optimizer_func, curve_func=pathological_curve_loss, num_updates=100, init=[5, 5]):
    """
    Args:
        optimizer_func: Constructor of the optimizer to use. Should only take a parameter list
        curve_func: Loss function (e.g. pathological curvature)
        num_updates: Number of updates/steps to take when optimizing
        init: Initial values of parameters. Must be a list/tuple with two elements representing w_1 and w_2
    Returns:
        Numpy array of shape [num_updates, 3] with [t,:2] being the parameter values at step t, and [t,2] the loss at t.
    """
    weights = nn.Parameter(torch.FloatTensor(init), requires_grad=True)
    optim = optimizer_func([weights])

    list_points = []
    for _ in range(num_updates):
        loss = curve_func(weights[0], weights[1])
        list_points.append(torch.cat([weights.data.detach(), loss.unsqueeze(dim=0).detach()], dim=0))
        optim.zero_grad()
        loss.backward()
        optim.step()
    points = torch.stack(list_points, dim=0).numpy()
    return points

Next, let’s apply the different optimizers on our curvature. Note that we set a much higher learning rate for the optimization algorithms as you would in a standard neural network. This is because we only have 2 parameters instead of tens of thousands or even millions.

[32]:
SGD_points = train_curve(lambda params: SGD(params, lr=10))
SGDMom_points = train_curve(lambda params: SGDMomentum(params, lr=10, momentum=0.9))
Adam_points = train_curve(lambda params: Adam(params, lr=1))

To understand best how the different algorithms worked, we visualize the update step as a line plot through the loss surface. We will stick with a 2D representation for readability.

[33]:
all_points = np.concatenate([SGD_points, SGDMom_points, Adam_points], axis=0)
ax = plot_curve(
    pathological_curve_loss,
    x_range=(-np.absolute(all_points[:, 0]).max(), np.absolute(all_points[:, 0]).max()),
    y_range=(all_points[:, 1].min(), all_points[:, 1].max()),
    plot_3d=False,
)
ax.plot(SGD_points[:, 0], SGD_points[:, 1], color="red", marker="o", zorder=1, label="SGD")
ax.plot(SGDMom_points[:, 0], SGDMom_points[:, 1], color="blue", marker="o", zorder=2, label="SGDMom")
ax.plot(Adam_points[:, 0], Adam_points[:, 1], color="grey", marker="o", zorder=3, label="Adam")
plt.legend()
plt.show()
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_68_0.svg

We can clearly see that SGD is not able to find the center of the optimization curve and has a problem converging due to the steep gradients in w_1. In contrast, Adam and SGD with momentum nicely converge as the changing direction of w_1 is canceling itself out. On such surfaces, it is crucial to use momentum.

Steep optima

A second type of challenging loss surfaces are steep optima. In those, we have a larger part of the surface having very small gradients while around the optimum, we have very large gradients. For instance, take the following loss surfaces:

[34]:
def bivar_gaussian(w1, w2, x_mean=0.0, y_mean=0.0, x_sig=1.0, y_sig=1.0):
    norm = 1 / (2 * np.pi * x_sig * y_sig)
    x_exp = (-1 * (w1 - x_mean) ** 2) / (2 * x_sig ** 2)
    y_exp = (-1 * (w2 - y_mean) ** 2) / (2 * y_sig ** 2)
    return norm * torch.exp(x_exp + y_exp)


def comb_func(w1, w2):
    z = -bivar_gaussian(w1, w2, x_mean=1.0, y_mean=-0.5, x_sig=0.2, y_sig=0.2)
    z -= bivar_gaussian(w1, w2, x_mean=-1.0, y_mean=0.5, x_sig=0.2, y_sig=0.2)
    z -= bivar_gaussian(w1, w2, x_mean=-0.5, y_mean=-0.8, x_sig=0.2, y_sig=0.2)
    return z


_ = plot_curve(comb_func, x_range=(-2, 2), y_range=(-2, 2), plot_3d=True, title="Steep optima")
/tmp/ipykernel_875/1102210584.py:5: MatplotlibDeprecationWarning: Calling gca() with keyword arguments was deprecated in Matplotlib 3.4. Starting two minor releases later, gca() will take no keyword arguments. The gca() function should only be used to get the current axes, or if no axes exist, create new axes with default keyword arguments. To create a new axes with non-default arguments, use plt.axes() or plt.subplot().
  ax = fig.gca(projection="3d") if plot_3d else fig.gca()
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_71_1.svg

Most of the loss surface has very little to no gradients. However, close to the optima, we have very steep gradients. To reach the minimum when starting in a region with lower gradients, we expect an adaptive learning rate to be crucial. To verify this hypothesis, we can run our three optimizers on the surface:

[35]:
SGD_points = train_curve(lambda params: SGD(params, lr=0.5), comb_func, init=[0, 0])
SGDMom_points = train_curve(lambda params: SGDMomentum(params, lr=1, momentum=0.9), comb_func, init=[0, 0])
Adam_points = train_curve(lambda params: Adam(params, lr=0.2), comb_func, init=[0, 0])

all_points = np.concatenate([SGD_points, SGDMom_points, Adam_points], axis=0)
ax = plot_curve(comb_func, x_range=(-2, 2), y_range=(-2, 2), plot_3d=False, title="Steep optima")
ax.plot(SGD_points[:, 0], SGD_points[:, 1], color="red", marker="o", zorder=3, label="SGD", alpha=0.7)
ax.plot(SGDMom_points[:, 0], SGDMom_points[:, 1], color="blue", marker="o", zorder=2, label="SGDMom", alpha=0.7)
ax.plot(Adam_points[:, 0], Adam_points[:, 1], color="grey", marker="o", zorder=1, label="Adam", alpha=0.7)
ax.set_xlim(-2, 2)
ax.set_ylim(-2, 2)
plt.legend()
plt.show()
_images/notebooks_course_UvA-DL_03-initialization-and-optimization_73_0.svg

SGD first takes very small steps until it touches the border of the optimum. First reaching a point around (-0.75,-0.5), the gradient direction has changed and pushes the parameters to (0.8,0.5) from which SGD cannot recover anymore (only with many, many steps). A similar problem has SGD with momentum, only that it continues the direction of the touch of the optimum. The gradients from this time step are so much larger than any other point that the momentum m_t is overpowered by it. Finally, Adam is able to converge in the optimum showing the importance of adaptive learning rates.

What optimizer to take

After seeing the results on optimization, what is our conclusion? Should we always use Adam and never look at SGD anymore? The short answer: no. There are many papers saying that in certain situations, SGD (with momentum) generalizes better where Adam often tends to overfit [5,6]. This is related to the idea of finding wider optima. For instance, see the illustration of different optima below (credit: Keskar et al., 2017):

9e505569cb8f4b84a5671c84abf72a4a

The black line represents the training loss surface, while the dotted red line is the test loss. Finding sharp, narrow minima can be helpful for finding the minimal training loss. However, this doesn’t mean that it also minimizes the test loss as especially flat minima have shown to generalize better. You can imagine that the test dataset has a slightly shifted loss surface due to the different examples than in the training set. A small change can have a significant influence for sharp minima, while flat minima are generally more robust to this change.

In the next tutorial, we will see that some network types can still be better optimized with SGD and learning rate scheduling than Adam. Nevertheless, Adam is the most commonly used optimizer in Deep Learning as it usually performs better than other optimizers, especially for deep networks.

Conclusion

In this tutorial, we have looked at initialization and optimization techniques for neural networks. We have seen that a good initialization has to balance the preservation of the gradient variance as well as the activation variance. This can be achieved with the Xavier initialization for tanh-based networks, and the Kaiming initialization for ReLU-based networks. In optimization, concepts like momentum and adaptive learning rate can help with challenging loss surfaces but don’t guarantee an increase in performance for neural networks.

References

[1] Glorot, Xavier, and Yoshua Bengio. “Understanding the difficulty of training deep feedforward neural networks.” Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010. link

[2] He, Kaiming, et al. “Delving deep into rectifiers: Surpassing human-level performance on imagenet classification.” Proceedings of the IEEE international conference on computer vision. 2015. link

[3] Kingma, Diederik P. & Ba, Jimmy. “Adam: A Method for Stochastic Optimization.” Proceedings of the third international conference for learning representations (ICLR). 2015. link

[4] Keskar, Nitish Shirish, et al. “On large-batch training for deep learning: Generalization gap and sharp minima.” Proceedings of the fifth international conference for learning representations (ICLR). 2017. link

[5] Wilson, Ashia C., et al. “The Marginal Value of Adaptive Gradient Methods in Machine Learning.” Advances in neural information processing systems. 2017. link

[6] Ruder, Sebastian. “An overview of gradient descent optimization algorithms.” arXiv preprint. 2017. link

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Tutorial 4: Inception, ResNet and DenseNet

  • Author: Phillip Lippe

  • License: CC BY-SA

  • Generated: 2022-05-12T13:44:17.079573

In this tutorial, we will implement and discuss variants of modern CNN architectures. There have been many different architectures been proposed over the past few years. Some of the most impactful ones, and still relevant today, are the following: GoogleNet/Inception architecture (winner of ILSVRC 2014), ResNet (winner of ILSVRC 2015), and DenseNet (best paper award CVPR 2017). All of them were state-of-the-art models when being proposed, and the core ideas of these networks are the foundations for most current state-of-the-art architectures. Thus, it is important to understand these architectures in detail and learn how to implement them. This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "setuptools==59.5.0" "seaborn" "ipython[notebook]" "tabulate" "matplotlib" "torchvision" "torchmetrics>=0.7" "pytorch-lightning>=1.4" "torch>=1.8"

Let’s start with importing our standard libraries here.

[2]:
import os
import urllib.request
from types import SimpleNamespace
from urllib.error import HTTPError

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import seaborn as sns
import tabulate
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision

%matplotlib inline
from IPython.display import HTML, display, set_matplotlib_formats
from PIL import Image
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torchvision import transforms
from torchvision.datasets import CIFAR10

set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.reset_orig()

# PyTorch
# Torchvision
/usr/lib/python3.8/site-packages/apex/pyprof/__init__.py:5: FutureWarning: pyprof will be removed by the end of June, 2022
  warnings.warn("pyprof will be removed by the end of June, 2022", FutureWarning)
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  return self._float_to_str(self.smallest_subnormal)
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  return self._float_to_str(self.smallest_subnormal)
WARNING:root:Bagua cannot detect bundled NCCL library, Bagua will try to use system NCCL instead. If you encounter any error, please run `import bagua_core; bagua_core.install_deps()` or the `bagua_install_deps.py` script to install bundled libraries.
/tmp/ipykernel_2857/1100401100.py:25: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
  set_matplotlib_formats("svg", "pdf")  # For export

We will use the same set_seed function as in the previous tutorials, as well as the path variables DATASET_PATH and CHECKPOINT_PATH. Adjust the paths if necessary.

[3]:
# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/ConvNets")


# Function for setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
Global seed set to 42

We also have pretrained models and Tensorboards (more on this later) for this tutorial, and download them below.

[4]:
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial5/"
# Files to download
pretrained_files = [
    "GoogleNet.ckpt",
    "ResNet.ckpt",
    "ResNetPreAct.ckpt",
    "DenseNet.ckpt",
    "tensorboards/GoogleNet/events.out.tfevents.googlenet",
    "tensorboards/ResNet/events.out.tfevents.resnet",
    "tensorboards/ResNetPreAct/events.out.tfevents.resnetpreact",
    "tensorboards/DenseNet/events.out.tfevents.densenet",
]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n",
                e,
            )
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial5/GoogleNet.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial5/ResNet.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial5/ResNetPreAct.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial5/DenseNet.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial5/tensorboards/GoogleNet/events.out.tfevents.googlenet...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial5/tensorboards/ResNet/events.out.tfevents.resnet...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial5/tensorboards/ResNetPreAct/events.out.tfevents.resnetpreact...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial5/tensorboards/DenseNet/events.out.tfevents.densenet...

Throughout this tutorial, we will train and evaluate the models on the CIFAR10 dataset. This allows you to compare the results obtained here with the model you have implemented in the first assignment. As we have learned from the previous tutorial about initialization, it is important to have the data preprocessed with a zero mean. Therefore, as a first step, we will calculate the mean and standard deviation of the CIFAR dataset:

[5]:
train_dataset = CIFAR10(root=DATASET_PATH, train=True, download=True)
DATA_MEANS = (train_dataset.data / 255.0).mean(axis=(0, 1, 2))
DATA_STD = (train_dataset.data / 255.0).std(axis=(0, 1, 2))
print("Data mean", DATA_MEANS)
print("Data std", DATA_STD)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /__w/1/s/.datasets/cifar-10-python.tar.gz
Extracting /__w/1/s/.datasets/cifar-10-python.tar.gz to /__w/1/s/.datasets
Data mean [0.49139968 0.48215841 0.44653091]
Data std [0.24703223 0.24348513 0.26158784]

We will use this information to define a transforms.Normalize module which will normalize our data accordingly. Additionally, we will use data augmentation during training. This reduces the risk of overfitting and helps CNNs to generalize better. Specifically, we will apply two random augmentations.

First, we will flip each image horizontally by a chance of 50% (transforms.RandomHorizontalFlip). The object class usually does not change when flipping an image, and we don’t expect any image information to be dependent on the horizontal orientation. This would be however different if we would try to detect digits or letters in an image, as those have a certain orientation.

The second augmentation we use is called transforms.RandomResizedCrop. This transformation scales the image in a small range, while eventually changing the aspect ratio, and crops it afterward in the previous size. Therefore, the actual pixel values change while the content or overall semantics of the image stays the same.

We will randomly split the training dataset into a training and a validation set. The validation set will be used for determining early stopping. After finishing the training, we test the models on the CIFAR test set.

[6]:
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(DATA_MEANS, DATA_STD)])
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(DATA_MEANS, DATA_STD),
    ]
)
# Loading the training dataset. We need to split it into a training and validation part
# We need to do a little trick because the validation set should not use the augmentation.
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)
val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)
pl.seed_everything(42)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
pl.seed_everything(42)
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])

# Loading the test set
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
Files already downloaded and verified
Files already downloaded and verified
Global seed set to 42
Global seed set to 42
Files already downloaded and verified

To verify that our normalization works, we can print out the mean and standard deviation of the single batch. The mean should be close to 0 and the standard deviation close to 1 for each channel:

[7]:
imgs, _ = next(iter(train_loader))
print("Batch mean", imgs.mean(dim=[0, 2, 3]))
print("Batch std", imgs.std(dim=[0, 2, 3]))
Batch mean tensor([0.0231, 0.0006, 0.0005])
Batch std tensor([0.9865, 0.9849, 0.9868])

Finally, let’s visualize a few images from the training set, and how they look like after random data augmentation:

[8]:
NUM_IMAGES = 4
images = [train_dataset[idx][0] for idx in range(NUM_IMAGES)]
orig_images = [Image.fromarray(train_dataset.data[idx]) for idx in range(NUM_IMAGES)]
orig_images = [test_transform(img) for img in orig_images]

img_grid = torchvision.utils.make_grid(torch.stack(images + orig_images, dim=0), nrow=4, normalize=True, pad_value=0.5)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8, 8))
plt.title("Augmentation examples on CIFAR10")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()
_images/notebooks_course_UvA-DL_04-inception-resnet-densenet_16_0.svg

PyTorch Lightning

In this notebook and in many following ones, we will make use of the library PyTorch Lightning. PyTorch Lightning is a framework that simplifies your code needed to train, evaluate, and test a model in PyTorch. It also handles logging into TensorBoard, a visualization toolkit for ML experiments, and saving model checkpoints automatically with minimal code overhead from our side. This is extremely helpful for us as we want to focus on implementing different model architectures and spend little time on other code overhead. Note that at the time of writing/teaching, the framework has been released in version 1.3. Future versions might have a slightly changed interface and thus might not work perfectly with the code (we will try to keep it up-to-date as much as possible).

Now, we will take the first step in PyTorch Lightning, and continue to explore the framework in our other tutorials. PyTorch Lightning comes with a lot of useful functions, such as one for setting the seed as we have seen before:

[9]:
# Setting the seed
pl.seed_everything(42)
Global seed set to 42
[9]:
42

Thus, in the future, we don’t have to define our own set_seed function anymore.

In PyTorch Lightning, we define pl.LightningModule’s (inheriting from Module) that organize our code into 5 main sections:

  1. Initialization (__init__), where we create all necessary parameters/models

  2. Optimizers (configure_optimizers) where we create the optimizers, learning rate scheduler, etc.

  3. Training loop (training_step) where we only have to define the loss calculation for a single batch (the loop of optimizer.zero_grad(), loss.backward() and optimizer.step(), as well as any logging/saving operation, is done in the background)

  4. Validation loop (validation_step) where similarly to the training, we only have to define what should happen per step

  5. Test loop (test_step) which is the same as validation, only on a test set.

Therefore, we don’t abstract the PyTorch code, but rather organize it and define some default operations that are commonly used. If you need to change something else in your training/validation/test loop, there are many possible functions you can overwrite (see the docs for details).

Now we can look at an example of how a Lightning Module for training a CNN looks like:

[10]:
class CIFARModule(pl.LightningModule):
    def __init__(self, model_name, model_hparams, optimizer_name, optimizer_hparams):
        """
        Inputs:
            model_name - Name of the model/CNN to run. Used for creating the model (see function below)
            model_hparams - Hyperparameters for the model, as dictionary.
            optimizer_name - Name of the optimizer to use. Currently supported: Adam, SGD
            optimizer_hparams - Hyperparameters for the optimizer, as dictionary. This includes learning rate, weight decay, etc.
        """
        super().__init__()
        # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
        self.save_hyperparameters()
        # Create model
        self.model = create_model(model_name, model_hparams)
        # Create loss module
        self.loss_module = nn.CrossEntropyLoss()
        # Example input for visualizing the graph in Tensorboard
        self.example_input_array = torch.zeros((1, 3, 32, 32), dtype=torch.float32)

    def forward(self, imgs):
        # Forward function that is run when visualizing the graph
        return self.model(imgs)

    def configure_optimizers(self):
        # We will support Adam or SGD as optimizers.
        if self.hparams.optimizer_name == "Adam":
            # AdamW is Adam with a correct implementation of weight decay (see here
            # for details: https://arxiv.org/pdf/1711.05101.pdf)
            optimizer = optim.AdamW(self.parameters(), **self.hparams.optimizer_hparams)
        elif self.hparams.optimizer_name == "SGD":
            optimizer = optim.SGD(self.parameters(), **self.hparams.optimizer_hparams)
        else:
            assert False, f'Unknown optimizer: "{self.hparams.optimizer_name}"'

        # We will reduce the learning rate by 0.1 after 100 and 150 epochs
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        # "batch" is the output of the training data loader.
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss_module(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        # Logs the accuracy per epoch to tensorboard (weighted average over batches)
        self.log("train_acc", acc, on_step=False, on_epoch=True)
        self.log("train_loss", loss)
        return loss  # Return tensor to call ".backward" on

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs).argmax(dim=-1)
        acc = (labels == preds).float().mean()
        # By default logs it per epoch (weighted average over batches)
        self.log("val_acc", acc)

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs).argmax(dim=-1)
        acc = (labels == preds).float().mean()
        # By default logs it per epoch (weighted average over batches), and returns it afterwards
        self.log("test_acc", acc)

We see that the code is organized and clear, which helps if someone else tries to understand your code.

Another important part of PyTorch Lightning is the concept of callbacks. Callbacks are self-contained functions that contain the non-essential logic of your Lightning Module. They are usually called after finishing a training epoch, but can also influence other parts of your training loop. For instance, we will use the following two pre-defined callbacks: LearningRateMonitor and ModelCheckpoint. The learning rate monitor adds the current learning rate to our TensorBoard, which helps to verify that our learning rate scheduler works correctly. The model checkpoint callback allows you to customize the saving routine of your checkpoints. For instance, how many checkpoints to keep, when to save, which metric to look out for, etc. We import them below:

[11]:
# Callbacks

To allow running multiple different models with the same Lightning module, we define a function below that maps a model name to the model class. At this stage, the dictionary model_dict is empty, but we will fill it throughout the notebook with our new models.

[12]:
model_dict = {}


def create_model(model_name, model_hparams):
    if model_name in model_dict:
        return model_dict[model_name](**model_hparams)
    else:
        assert False, f'Unknown model name "{model_name}". Available models are: {str(model_dict.keys())}'

Similarly, to use the activation function as another hyperparameter in our model, we define a “name to function” dict below:

[13]:
act_fn_by_name = {"tanh": nn.Tanh, "relu": nn.ReLU, "leakyrelu": nn.LeakyReLU, "gelu": nn.GELU}

If we pass the classes or objects directly as an argument to the Lightning module, we couldn’t take advantage of PyTorch Lightning’s automatically hyperparameter saving and loading.

Besides the Lightning module, the second most important module in PyTorch Lightning is the Trainer. The trainer is responsible to execute the training steps defined in the Lightning module and completes the framework. Similar to the Lightning module, you can override any key part that you don’t want to be automated, but the default settings are often the best practice to do. For a full overview, see the documentation. The most important functions we use below are:

  • trainer.fit: Takes as input a lightning module, a training dataset, and an (optional) validation dataset. This function trains the given module on the training dataset with occasional validation (default once per epoch, can be changed)

  • trainer.test: Takes as input a model and a dataset on which we want to test. It returns the test metric on the dataset.

For training and testing, we don’t have to worry about things like setting the model to eval mode (model.eval()) as this is all done automatically. See below how we define a training function for our models:

[14]:
def train_model(model_name, save_name=None, **kwargs):
    """
    Inputs:
        model_name - Name of the model you want to run. Is used to look up the class in "model_dict"
        save_name (optional) - If specified, this name will be used for creating the checkpoint and logging directory.
    """
    if save_name is None:
        save_name = model_name

    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, save_name),  # Where to save models
        # We run on a single GPU (if possible)
        gpus=1 if str(device) == "cuda:0" else 0,
        # How many epochs to train for if no patience is set
        max_epochs=180,
        callbacks=[
            ModelCheckpoint(
                save_weights_only=True, mode="max", monitor="val_acc"
            ),  # Save the best checkpoint based on the maximum val_acc recorded. Saves only weights and not optimizer
            LearningRateMonitor("epoch"),
        ],  # Log learning rate every epoch
        progress_bar_refresh_rate=1,
    )  # In case your notebook crashes due to the progress bar, consider increasing the refresh rate
    trainer.logger._log_graph = True  # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, save_name + ".ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        # Automatically loads the model with the saved hyperparameters
        model = CIFARModule.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)  # To be reproducable
        model = CIFARModule(model_name=model_name, **kwargs)
        trainer.fit(model, train_loader, val_loader)
        model = CIFARModule.load_from_checkpoint(
            trainer.checkpoint_callback.best_model_path
        )  # Load best checkpoint after training

    # Test best model on validation and test set
    val_result = trainer.test(model, dataloaders=val_loader, verbose=False)
    test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
    result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}

    return model, result

Finally, we can focus on the Convolutional Neural Networks we want to implement today: GoogleNet, ResNet, and DenseNet.

Inception

The GoogleNet, proposed in 2014, won the ImageNet Challenge because of its usage of the Inception modules. In general, we will mainly focus on the concept of Inception in this tutorial instead of the specifics of the GoogleNet, as based on Inception, there have been many follow-up works (Inception-v2, Inception-v3, Inception-v4, Inception-ResNet,…). The follow-up works mainly focus on increasing efficiency and enabling very deep Inception networks. However, for a fundamental understanding, it is sufficient to look at the original Inception block.

An Inception block applies four convolution blocks separately on the same feature map: a 1x1, 3x3, and 5x5 convolution, and a max pool operation. This allows the network to look at the same data with different receptive fields. Of course, learning only 5x5 convolution would be theoretically more powerful. However, this is not only more computation and memory heavy but also tends to overfit much easier. The overall inception block looks like below (figure credit - Szegedy et al. ):

f96a42de937d4d15af11855e32bac6dd

The additional 1x1 convolutions before the 3x3 and 5x5 convolutions are used for dimensionality reduction. This is especially crucial as the feature maps of all branches are merged afterward, and we don’t want any explosion of feature size. As 5x5 convolutions are 25 times more expensive than 1x1 convolutions, we can save a lot of computation and parameters by reducing the dimensionality before the large convolutions.

We can now try to implement the Inception Block ourselves:

[15]:
class InceptionBlock(nn.Module):
    def __init__(self, c_in, c_red: dict, c_out: dict, act_fn):
        """
        Inputs:
            c_in - Number of input feature maps from the previous layers
            c_red - Dictionary with keys "3x3" and "5x5" specifying the output of the dimensionality reducing 1x1 convolutions
            c_out - Dictionary with keys "1x1", "3x3", "5x5", and "max"
            act_fn - Activation class constructor (e.g. nn.ReLU)
        """
        super().__init__()

        # 1x1 convolution branch
        self.conv_1x1 = nn.Sequential(
            nn.Conv2d(c_in, c_out["1x1"], kernel_size=1), nn.BatchNorm2d(c_out["1x1"]), act_fn()
        )

        # 3x3 convolution branch
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(c_in, c_red["3x3"], kernel_size=1),
            nn.BatchNorm2d(c_red["3x3"]),
            act_fn(),
            nn.Conv2d(c_red["3x3"], c_out["3x3"], kernel_size=3, padding=1),
            nn.BatchNorm2d(c_out["3x3"]),
            act_fn(),
        )

        # 5x5 convolution branch
        self.conv_5x5 = nn.Sequential(
            nn.Conv2d(c_in, c_red["5x5"], kernel_size=1),
            nn.BatchNorm2d(c_red["5x5"]),
            act_fn(),
            nn.Conv2d(c_red["5x5"], c_out["5x5"], kernel_size=5, padding=2),
            nn.BatchNorm2d(c_out["5x5"]),
            act_fn(),
        )

        # Max-pool branch
        self.max_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
            nn.Conv2d(c_in, c_out["max"], kernel_size=1),
            nn.BatchNorm2d(c_out["max"]),
            act_fn(),
        )

    def forward(self, x):
        x_1x1 = self.conv_1x1(x)
        x_3x3 = self.conv_3x3(x)
        x_5x5 = self.conv_5x5(x)
        x_max = self.max_pool(x)
        x_out = torch.cat([x_1x1, x_3x3, x_5x5, x_max], dim=1)
        return x_out

The GoogleNet architecture consists of stacking multiple Inception blocks with occasional max pooling to reduce the height and width of the feature maps. The original GoogleNet was designed for image sizes of ImageNet (224x224 pixels) and had almost 7 million parameters. As we train on CIFAR10 with image sizes of 32x32, we don’t require such a heavy architecture, and instead, apply a reduced version. The number of channels for dimensionality reduction and output per filter (1x1, 3x3, 5x5, and max pooling) need to be manually specified and can be changed if interested. The general intuition is to have the most filters for the 3x3 convolutions, as they are powerful enough to take the context into account while requiring almost a third of the parameters of the 5x5 convolution.

[16]:
class GoogleNet(nn.Module):
    def __init__(self, num_classes=10, act_fn_name="relu", **kwargs):
        super().__init__()
        self.hparams = SimpleNamespace(
            num_classes=num_classes, act_fn_name=act_fn_name, act_fn=act_fn_by_name[act_fn_name]
        )
        self._create_network()
        self._init_params()

    def _create_network(self):
        # A first convolution on the original image to scale up the channel size
        self.input_net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), self.hparams.act_fn()
        )
        # Stacking inception blocks
        self.inception_blocks = nn.Sequential(
            InceptionBlock(
                64,
                c_red={"3x3": 32, "5x5": 16},
                c_out={"1x1": 16, "3x3": 32, "5x5": 8, "max": 8},
                act_fn=self.hparams.act_fn,
            ),
            InceptionBlock(
                64,
                c_red={"3x3": 32, "5x5": 16},
                c_out={"1x1": 24, "3x3": 48, "5x5": 12, "max": 12},
                act_fn=self.hparams.act_fn,
            ),
            nn.MaxPool2d(3, stride=2, padding=1),  # 32x32 => 16x16
            InceptionBlock(
                96,
                c_red={"3x3": 32, "5x5": 16},
                c_out={"1x1": 24, "3x3": 48, "5x5": 12, "max": 12},
                act_fn=self.hparams.act_fn,
            ),
            InceptionBlock(
                96,
                c_red={"3x3": 32, "5x5": 16},
                c_out={"1x1": 16, "3x3": 48, "5x5": 16, "max": 16},
                act_fn=self.hparams.act_fn,
            ),
            InceptionBlock(
                96,
                c_red={"3x3": 32, "5x5": 16},
                c_out={"1x1": 16, "3x3": 48, "5x5": 16, "max": 16},
                act_fn=self.hparams.act_fn,
            ),
            InceptionBlock(
                96,
                c_red={"3x3": 32, "5x5": 16},
                c_out={"1x1": 32, "3x3": 48, "5x5": 24, "max": 24},
                act_fn=self.hparams.act_fn,
            ),
            nn.MaxPool2d(3, stride=2, padding=1),  # 16x16 => 8x8
            InceptionBlock(
                128,
                c_red={"3x3": 48, "5x5": 16},
                c_out={"1x1": 32, "3x3": 64, "5x5": 16, "max": 16},
                act_fn=self.hparams.act_fn,
            ),
            InceptionBlock(
                128,
                c_red={"3x3": 48, "5x5": 16},
                c_out={"1x1": 32, "3x3": 64, "5x5": 16, "max": 16},
                act_fn=self.hparams.act_fn,
            ),
        )
        # Mapping to classification output
        self.output_net = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(128, self.hparams.num_classes)
        )

    def _init_params(self):
        # Based on our discussion in Tutorial 4, we should initialize the
        # convolutions according to the activation function
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity=self.hparams.act_fn_name)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.input_net(x)
        x = self.inception_blocks(x)
        x = self.output_net(x)
        return x

Now, we can integrate our model to the model dictionary we defined above:

[17]:
model_dict["GoogleNet"] = GoogleNet

The training of the model is handled by PyTorch Lightning, and we just have to define the command to start. Note that we train for almost 200 epochs, which takes about an hour on Lisa’s default GPUs (GTX1080Ti). We would recommend using the saved models and train your own model if you are interested.

[18]:
googlenet_model, googlenet_results = train_model(
    model_name="GoogleNet",
    model_hparams={"num_classes": 10, "act_fn_name": "relu"},
    optimizer_name="Adam",
    optimizer_hparams={"lr": 1e-3, "weight_decay": 1e-4},
)
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:96: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=1)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/ConvNets/GoogleNet/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model at saved_models/ConvNets/GoogleNet.ckpt, loading...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

We will compare the results later in the notebooks, but we can already print them here for a first glance:

[19]:
print("GoogleNet Results", googlenet_results)
GoogleNet Results {'test': 0.8970000147819519, 'val': 0.9039999842643738}
Tensorboard log

A nice extra of PyTorch Lightning is the automatic logging into TensorBoard. To give you a better intuition of what TensorBoard can be used, we can look at the board that PyTorch Lightning has been generated when training the GoogleNet. TensorBoard provides an inline functionality for Jupyter notebooks, and we use it here:

[20]:
# Import tensorboard
%load_ext tensorboard
[21]:
# Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH!
%tensorboard --logdir ../saved_models/tutorial5/tensorboards/GoogleNet/

5862ac055d8646bb809b23cab64ab303

TensorBoard is organized in multiple tabs. The main tab is the scalar tab where we can log the development of single numbers. For example, we have plotted the training loss, accuracy, learning rate, etc. If we look at the training or validation accuracy, we can really see the impact of using a learning rate scheduler. Reducing the learning rate gives our model a nice increase in training performance. Similarly, when looking at the training loss, we see a sudden decrease at this point. However, the high numbers on the training set compared to validation indicate that our model was overfitting which is inevitable for such large networks.

Another interesting tab in TensorBoard is the graph tab. It shows us the network architecture organized by building blocks from the input to the output. It basically shows the operations taken in the forward step of CIFARModule. Double-click on a module to open it. Feel free to explore the architecture from a different perspective. The graph visualization can often help you to validate that your model is actually doing what it is supposed to do, and you don’t miss any layers in the computation graph.

ResNet

The ResNet paper is one of the most cited AI papers, and has been the foundation for neural networks with more than 1,000 layers. Despite its simplicity, the idea of residual connections is highly effective as it supports stable gradient propagation through the network. Instead of modeling x_{l+1}=F(x_{l}), we model x_{l+1}=x_{l}+F(x_{l}) where F is a non-linear mapping (usually a sequence of NN modules likes convolutions, activation functions, and normalizations). If we do backpropagation on such residual connections, we obtain:

\frac{\partial x_{l+1}}{\partial x_{l}} = \mathbf{I} + \frac{\partial F(x_{l})}{\partial x_{l}}

The bias towards the identity matrix guarantees a stable gradient propagation being less effected by F itself. There have been many variants of ResNet proposed, which mostly concern the function F, or operations applied on the sum. In this tutorial, we look at two of them: the original ResNet block, and the Pre-Activation ResNet block. We visually compare the blocks below (figure credit - He et al. ):

df2c038a436e4f2dbef0583b84543eb3

The original ResNet block applies a non-linear activation function, usually ReLU, after the skip connection. In contrast, the pre-activation ResNet block applies the non-linearity at the beginning of F. Both have their advantages and disadvantages. For very deep network, however, the pre-activation ResNet has shown to perform better as the gradient flow is guaranteed to have the identity matrix as calculated above, and is not harmed by any non-linear activation applied to it. For comparison, in this notebook, we implement both ResNet types as shallow networks.

Let’s start with the original ResNet block. The visualization above already shows what layers are included in F. One special case we have to handle is when we want to reduce the image dimensions in terms of width and height. The basic ResNet block requires F(x_{l}) to be of the same shape as x_{l}. Thus, we need to change the dimensionality of x_{l} as well before adding to F(x_{l}). The original implementation used an identity mapping with stride 2 and padded additional feature dimensions with 0. However, the more common implementation is to use a 1x1 convolution with stride 2 as it allows us to change the feature dimensionality while being efficient in parameter and computation cost. The code for the ResNet block is relatively simple, and shown below:

[22]:


class ResNetBlock(nn.Module): def __init__(self, c_in, act_fn, subsample=False, c_out=-1): """ Inputs: c_in - Number of input features act_fn - Activation class constructor (e.g. nn.ReLU) subsample - If True, we want to apply a stride inside the block and reduce the output shape by 2 in height and width c_out - Number of output features. Note that this is only relevant if subsample is True, as otherwise, c_out = c_in """ super().__init__() if not subsample: c_out = c_in # Network representing F self.net = nn.Sequential( nn.Conv2d( c_in, c_out, kernel_size=3, padding=1, stride=1 if not subsample else 2, bias=False ), # No bias needed as the Batch Norm handles it nn.BatchNorm2d(c_out), act_fn(), nn.Conv2d(c_out, c_out, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(c_out), ) # 1x1 convolution with stride 2 means we take the upper left value, and transform it to new output size self.downsample = nn.Conv2d(c_in, c_out, kernel_size=1, stride=2) if subsample else None self.act_fn = act_fn() def forward(self, x): z = self.net(x) if self.downsample is not None: x = self.downsample(x) out = z + x out = self.act_fn(out) return out

The second block we implement is the pre-activation ResNet block. For this, we have to change the order of layer in self.net, and do not apply an activation function on the output. Additionally, the downsampling operation has to apply a non-linearity as well as the input, x_l, has not been processed by a non-linearity yet. Hence, the block looks as follows:

[23]:
class PreActResNetBlock(nn.Module):
    def __init__(self, c_in, act_fn, subsample=False, c_out=-1):
        """
        Inputs:
            c_in - Number of input features
            act_fn - Activation class constructor (e.g. nn.ReLU)
            subsample - If True, we want to apply a stride inside the block and reduce the output shape by 2 in height and width
            c_out - Number of output features. Note that this is only relevant if subsample is True, as otherwise, c_out = c_in
        """
        super().__init__()
        if not subsample:
            c_out = c_in

        # Network representing F
        self.net = nn.Sequential(
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=1 if not subsample else 2, bias=False),
            nn.BatchNorm2d(c_out),
            act_fn(),
            nn.Conv2d(c_out, c_out, kernel_size=3, padding=1, bias=False),
        )

        # 1x1 convolution needs to apply non-linearity as well as not done on skip connection
        self.downsample = (
            nn.Sequential(nn.BatchNorm2d(c_in), act_fn(), nn.Conv2d(c_in, c_out, kernel_size=1, stride=2, bias=False))
            if subsample
            else None
        )

    def forward(self, x):
        z = self.net(x)
        if self.downsample is not None:
            x = self.downsample(x)
        out = z + x
        return out

Similarly to the model selection, we define a dictionary to create a mapping from string to block class. We will use the string name as hyperparameter value in our model to choose between the ResNet blocks. Feel free to implement any other ResNet block type and add it here as well.

[24]:
resnet_blocks_by_name = {"ResNetBlock": ResNetBlock, "PreActResNetBlock": PreActResNetBlock}

The overall ResNet architecture consists of stacking multiple ResNet blocks, of which some are downsampling the input. When talking about ResNet blocks in the whole network, we usually group them by the same output shape. Hence, if we say the ResNet has [3,3,3] blocks, it means that we have 3 times a group of 3 ResNet blocks, where a subsampling is taking place in the fourth and seventh block. The ResNet with [3,3,3] blocks on CIFAR10 is visualized below.

bedb3032cb50425c82831e0c37eb4837

The three groups operate on the resolutions 32\times32, 16\times16 and 8\times8 respectively. The blocks in orange denote ResNet blocks with downsampling. The same notation is used by many other implementations such as in the torchvision library from PyTorch. Thus, our code looks as follows:

[25]:
class ResNet(nn.Module):
    def __init__(
        self,
        num_classes=10,
        num_blocks=[3, 3, 3],
        c_hidden=[16, 32, 64],
        act_fn_name="relu",
        block_name="ResNetBlock",
        **kwargs,
    ):
        """
        Inputs:
            num_classes - Number of classification outputs (10 for CIFAR10)
            num_blocks - List with the number of ResNet blocks to use. The first block of each group uses downsampling, except the first.
            c_hidden - List with the hidden dimensionalities in the different blocks. Usually multiplied by 2 the deeper we go.
            act_fn_name - Name of the activation function to use, looked up in "act_fn_by_name"
            block_name - Name of the ResNet block, looked up in "resnet_blocks_by_name"
        """
        super().__init__()
        assert block_name in resnet_blocks_by_name
        self.hparams = SimpleNamespace(
            num_classes=num_classes,
            c_hidden=c_hidden,
            num_blocks=num_blocks,
            act_fn_name=act_fn_name,
            act_fn=act_fn_by_name[act_fn_name],
            block_class=resnet_blocks_by_name[block_name],
        )
        self._create_network()
        self._init_params()

    def _create_network(self):
        c_hidden = self.hparams.c_hidden

        # A first convolution on the original image to scale up the channel size
        if self.hparams.block_class == PreActResNetBlock:  # => Don't apply non-linearity on output
            self.input_net = nn.Sequential(nn.Conv2d(3, c_hidden[0], kernel_size=3, padding=1, bias=False))
        else:
            self.input_net = nn.Sequential(
                nn.Conv2d(3, c_hidden[0], kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(c_hidden[0]),
                self.hparams.act_fn(),
            )

        # Creating the ResNet blocks
        blocks = []
        for block_idx, block_count in enumerate(self.hparams.num_blocks):
            for bc in range(block_count):
                # Subsample the first block of each group, except the very first one.
                subsample = bc == 0 and block_idx > 0
                blocks.append(
                    self.hparams.block_class(
                        c_in=c_hidden[block_idx if not subsample else (block_idx - 1)],
                        act_fn=self.hparams.act_fn,
                        subsample=subsample,
                        c_out=c_hidden[block_idx],
                    )
                )
        self.blocks = nn.Sequential(*blocks)

        # Mapping to classification output
        self.output_net = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(c_hidden[-1], self.hparams.num_classes)
        )

    def _init_params(self):
        # Based on our discussion in Tutorial 4, we should initialize the convolutions according to the activation function
        # Fan-out focuses on the gradient distribution, and is commonly used in ResNets
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity=self.hparams.act_fn_name)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.input_net(x)
        x = self.blocks(x)
        x = self.output_net(x)
        return x

We also need to add the new ResNet class to our model dictionary:

[26]:
model_dict["ResNet"] = ResNet

Finally, we can train our ResNet models. One difference to the GoogleNet training is that we explicitly use SGD with Momentum as optimizer instead of Adam. Adam often leads to a slightly worse accuracy on plain, shallow ResNets. It is not 100% clear why Adam performs worse in this context, but one possible explanation is related to ResNet’s loss surface. ResNet has been shown to produce smoother loss surfaces than networks without skip connection (see Li et al., 2018 for details). A possible visualization of the loss surface with/out skip connections is below (figure credit - Li et al. ):

1df76dc0c49544248934ef05ff052fa7

The x and y axis shows a projection of the parameter space, and the z axis shows the loss values achieved by different parameter values. On smooth surfaces like the one on the right, we might not require an adaptive learning rate as Adam provides. Instead, Adam can get stuck in local optima while SGD finds the wider minima that tend to generalize better. However, to answer this question in detail, we would need an extra tutorial because it is not easy to answer. For now, we conclude: for ResNet architectures, consider the optimizer to be an important hyperparameter, and try training with both Adam and SGD. Let’s train the model below with SGD:

[27]:
resnet_model, resnet_results = train_model(
    model_name="ResNet",
    model_hparams={"num_classes": 10, "c_hidden": [16, 32, 64], "num_blocks": [3, 3, 3], "act_fn_name": "relu"},
    optimizer_name="SGD",
    optimizer_hparams={"lr": 0.1, "momentum": 0.9, "weight_decay": 1e-4},
)
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:96: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=1)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/ConvNets/ResNet/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model at saved_models/ConvNets/ResNet.ckpt, loading...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

Let’s also train the pre-activation ResNet as comparison:

[28]:
resnetpreact_model, resnetpreact_results = train_model(
    model_name="ResNet",
    model_hparams={
        "num_classes": 10,
        "c_hidden": [16, 32, 64],
        "num_blocks": [3, 3, 3],
        "act_fn_name": "relu",
        "block_name": "PreActResNetBlock",
    },
    optimizer_name="SGD",
    optimizer_hparams={"lr": 0.1, "momentum": 0.9, "weight_decay": 1e-4},
    save_name="ResNetPreAct",
)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/ConvNets/ResNetPreAct/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model at saved_models/ConvNets/ResNetPreAct.ckpt, loading...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Tensorboard log

Similarly to our GoogleNet model, we also have a TensorBoard log for the ResNet model. We can open it below.

[29]:
# Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH! Feel free to change "ResNet" to "ResNetPreAct"
%tensorboard --logdir ../saved_models/tutorial5/tensorboards/ResNet/

9a2209da92974a70809a46cc72172a45

Feel free to explore the TensorBoard yourself, including the computation graph. In general, we can see that with SGD, the ResNet has a higher training loss than the GoogleNet in the first stage of the training. After reducing the learning rate however, the model achieves even higher validation accuracies. We compare the precise scores at the end of the notebook.

DenseNet

DenseNet is another architecture for enabling very deep neural networks and takes a slightly different perspective on residual connections. Instead of modeling the difference between layers, DenseNet considers residual connections as a possible way to reuse features across layers, removing any necessity to learn redundant feature maps. If we go deeper into the network, the model learns abstract features to recognize patterns. However, some complex patterns consist of a combination of abstract features (e.g. hand, face, etc. ), and low-level features (e.g. edges, basic color, etc.). To find these low-level features in the deep layers, standard CNNs have to learn copy such feature maps, which wastes a lot of parameter complexity. DenseNet provides an efficient way of reusing features by having each convolution depends on all previous input features, but add only a small amount of filters to it. See the figure below for an illustration (figure credit - Hu et al. ):

4cc16d331c384f9ca4f63f1047f455fb

The last layer, called the transition layer, is responsible for reducing the dimensionality of the feature maps in height, width, and channel size. Although those technically break the identity backpropagation, there are only a few in a network so that it doesn’t affect the gradient flow much.

We split the implementation of the layers in DenseNet into three parts: a DenseLayer, and a DenseBlock, and a TransitionLayer. The module DenseLayer implements a single layer inside a dense block. It applies a 1x1 convolution for dimensionality reduction with a subsequential 3x3 convolution. The output channels are concatenated to the originals and returned. Note that we apply the Batch Normalization as the first layer of each block. This allows slightly different activations for the same features to different layers, depending on what is needed. Overall, we can implement it as follows:

[30]:
class DenseLayer(nn.Module):
    def __init__(self, c_in, bn_size, growth_rate, act_fn):
        """
        Inputs:
            c_in - Number of input channels
            bn_size - Bottleneck size (factor of growth rate) for the output of the 1x1 convolution. Typically between 2 and 4.
            growth_rate - Number of output channels of the 3x3 convolution
            act_fn - Activation class constructor (e.g. nn.ReLU)
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, bn_size * growth_rate, kernel_size=1, bias=False),
            nn.BatchNorm2d(bn_size * growth_rate),
            act_fn(),
            nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False),
        )

    def forward(self, x):
        out = self.net(x)
        out = torch.cat([out, x], dim=1)
        return out

The module DenseBlock summarizes multiple dense layers applied in sequence. Each dense layer takes as input the original input concatenated with all previous layers’ feature maps:

[31]:
class DenseBlock(nn.Module):
    def __init__(self, c_in, num_layers, bn_size, growth_rate, act_fn):
        """
        Inputs:
            c_in - Number of input channels
            num_layers - Number of dense layers to apply in the block
            bn_size - Bottleneck size to use in the dense layers
            growth_rate - Growth rate to use in the dense layers
            act_fn - Activation function to use in the dense layers
        """
        super().__init__()
        layers = []
        for layer_idx in range(num_layers):
            # Input channels are original plus the feature maps from previous layers
            layer_c_in = c_in + layer_idx * growth_rate
            layers.append(DenseLayer(c_in=layer_c_in, bn_size=bn_size, growth_rate=growth_rate, act_fn=act_fn))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        out = self.block(x)
        return out

Finally, the TransitionLayer takes as input the final output of a dense block and reduces its channel dimensionality using a 1x1 convolution. To reduce the height and width dimension, we take a slightly different approach than in ResNet and apply an average pooling with kernel size 2 and stride 2. This is because we don’t have an additional connection to the output that would consider the full 2x2 patch instead of a single value. Besides, it is more parameter efficient than using a 3x3 convolution with stride 2. Thus, the layer is implemented as follows:

[32]:
class TransitionLayer(nn.Module):
    def __init__(self, c_in, c_out, act_fn):
        super().__init__()
        self.transition = nn.Sequential(
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, c_out, kernel_size=1, bias=False),
            nn.AvgPool2d(kernel_size=2, stride=2),  # Average the output for each 2x2 pixel group
        )

    def forward(self, x):
        return self.transition(x)

Now we can put everything together and create our DenseNet. To specify the number of layers, we use a similar notation as in ResNets and pass on a list of ints representing the number of layers per block. After each dense block except the last one, we apply a transition layer to reduce the dimensionality by 2.

[33]:
class DenseNet(nn.Module):
    def __init__(
        self, num_classes=10, num_layers=[6, 6, 6, 6], bn_size=2, growth_rate=16, act_fn_name="relu", **kwargs
    ):
        super().__init__()
        self.hparams = SimpleNamespace(
            num_classes=num_classes,
            num_layers=num_layers,
            bn_size=bn_size,
            growth_rate=growth_rate,
            act_fn_name=act_fn_name,
            act_fn=act_fn_by_name[act_fn_name],
        )
        self._create_network()
        self._init_params()

    def _create_network(self):
        c_hidden = self.hparams.growth_rate * self.hparams.bn_size  # The start number of hidden channels

        # A first convolution on the original image to scale up the channel size
        self.input_net = nn.Sequential(
            # No batch norm or activation function as done inside the Dense layers
            nn.Conv2d(3, c_hidden, kernel_size=3, padding=1)
        )

        # Creating the dense blocks, eventually including transition layers
        blocks = []
        for block_idx, num_layers in enumerate(self.hparams.num_layers):
            blocks.append(
                DenseBlock(
                    c_in=c_hidden,
                    num_layers=num_layers,
                    bn_size=self.hparams.bn_size,
                    growth_rate=self.hparams.growth_rate,
                    act_fn=self.hparams.act_fn,
                )
            )
            c_hidden = c_hidden + num_layers * self.hparams.growth_rate  # Overall output of the dense block
            if block_idx < len(self.hparams.num_layers) - 1:  # Don't apply transition layer on last block
                blocks.append(TransitionLayer(c_in=c_hidden, c_out=c_hidden // 2, act_fn=self.hparams.act_fn))
                c_hidden = c_hidden // 2

        self.blocks = nn.Sequential(*blocks)

        # Mapping to classification output
        self.output_net = nn.Sequential(
            nn.BatchNorm2d(c_hidden),  # The features have not passed a non-linearity until here.
            self.hparams.act_fn(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(c_hidden, self.hparams.num_classes),
        )

    def _init_params(self):
        # Based on our discussion in Tutorial 4, we should initialize the
        # convolutions according to the activation function
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity=self.hparams.act_fn_name)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.input_net(x)
        x = self.blocks(x)
        x = self.output_net(x)
        return x

Let’s also add the DenseNet to our model dictionary:

[34]:
model_dict["DenseNet"] = DenseNet

Lastly, we train our network. In contrast to ResNet, DenseNet does not show any issues with Adam, and hence we train it with this optimizer. The other hyperparameters are chosen to result in a network with a similar parameter size as the ResNet and GoogleNet. Commonly, when designing very deep networks, DenseNet is more parameter efficient than ResNet while achieving a similar or even better performance.

[35]:
densenet_model, densenet_results = train_model(
    model_name="DenseNet",
    model_hparams={
        "num_classes": 10,
        "num_layers": [6, 6, 6, 6],
        "bn_size": 2,
        "growth_rate": 16,
        "act_fn_name": "relu",
    },
    optimizer_name="Adam",
    optimizer_hparams={"lr": 1e-3, "weight_decay": 1e-4},
)
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:96: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=1)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/ConvNets/DenseNet/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model at saved_models/ConvNets/DenseNet.ckpt, loading...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Tensorboard log

Finally, we also have another TensorBoard for the DenseNet training. We take a look at it below:

[36]:
# Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH! Feel free to change "ResNet" to "ResNetPreAct"
%tensorboard --logdir ../saved_models/tutorial5/tensorboards/DenseNet/

5694f1c2aa1f43a69d397caddd8b84f9

The overall course of the validation accuracy and training loss resemble the training of GoogleNet, which is also related to training the network with Adam. Feel free to explore the training metrics yourself.

Conclusion and Comparison

After discussing each model separately, and training all of them, we can finally compare them. First, let’s organize the results of all models in a table:

[37]:
%%html
<!-- Some HTML code to increase font size in the following table -->
<style>
th {font-size: 120%;}
td {font-size: 120%;}
</style>
[38]:
all_models = [
    ("GoogleNet", googlenet_results, googlenet_model),
    ("ResNet", resnet_results, resnet_model),
    ("ResNetPreAct", resnetpreact_results, resnetpreact_model),
    ("DenseNet", densenet_results, densenet_model),
]
table = [
    [
        model_name,
        f"{100.0*model_results['val']:4.2f}%",
        f"{100.0*model_results['test']:4.2f}%",
        f"{sum(np.prod(p.shape) for p in model.parameters()):,}",
    ]
    for model_name, model_results, model in all_models
]
display(
    HTML(
        tabulate.tabulate(table, tablefmt="html", headers=["Model", "Val Accuracy", "Test Accuracy", "Num Parameters"])
    )
)
Model Val Accuracy Test Accuracy Num Parameters
GoogleNet 90.40% 89.70% 260,650
ResNet 91.84% 91.06% 272,378
ResNetPreAct91.80% 91.07% 272,250
DenseNet 90.72% 90.23% 239,146

First of all, we see that all models are performing reasonably well. Simple models as you have implemented them in the practical achieve considerably lower performance, which is beside the lower number of parameters also attributed to the architecture design choice. GoogleNet is the model to obtain the lowest performance on the validation and test set, although it is very close to DenseNet. A proper hyperparameter search over all the channel sizes in GoogleNet would likely improve the accuracy of the model to a similar level, but this is also expensive given a large number of hyperparameters. ResNet outperforms both DenseNet and GoogleNet by more than 1% on the validation set, while there is a minor difference between both versions, original and pre-activation. We can conclude that for shallow networks, the place of the activation function does not seem to be crucial, although papers have reported the contrary for very deep networks (e.g. He et al. ).

In general, we can conclude that ResNet is a simple, but powerful architecture. If we would apply the models on more complex tasks with larger images and more layers inside the networks, we would likely see a bigger gap between GoogleNet and skip-connection architectures like ResNet and DenseNet. A comparison with deeper models on CIFAR10 can be for example found here. Interestingly, DenseNet outperforms the original ResNet on their setup but comes closely behind the Pre-Activation ResNet. The best model, a Dual Path Network (Chen et. al), is actually a combination of ResNet and DenseNet showing that both offer different advantages.

Which model should I choose for my task?

We have reviewed four different models. So, which one should we choose if have given a new task? Usually, starting with a ResNet is a good idea given the superior performance of the CIFAR dataset and its simple implementation. Besides, for the parameter number we have chosen here, ResNet is the fastest as DenseNet and GoogleNet have many more layers that are applied in sequence in our primitive implementation. However, if you have a really difficult task, such as semantic segmentation on HD images, more complex variants of ResNet and DenseNet are recommended.

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Tutorial 5: Transformers and Multi-Head Attention

  • Author: Phillip Lippe

  • License: CC BY-SA

  • Generated: 2022-04-09T16:34:55.714521

In this tutorial, we will discuss one of the most impactful architectures of the last 2 years: the Transformer model. Since the paper Attention Is All You Need by Vaswani et al. had been published in 2017, the Transformer architecture has continued to beat benchmarks in many domains, most importantly in Natural Language Processing. Transformers with an incredible amount of parameters can generate long, convincing essays, and opened up new application fields of AI. As the hype of the Transformer architecture seems not to come to an end in the next years, it is important to understand how it works, and have implemented it yourself, which we will do in this notebook. This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "pytorch-lightning>=1.3" "torchvision" "seaborn" "torchmetrics>=0.3" "matplotlib" "torch>=1.6, <1.9" "ipython[notebook]"

Despite the huge success of Transformers in NLP, we will not include the NLP domain in our notebook here. There are many courses at the University of Amsterdam that focus on Natural Language Processing and take a closer look at the application of the Transformer architecture in NLP (NLP2, Advanced Topics in Computational Semantics). Furthermore, and most importantly, there is so much more to the Transformer architecture. NLP is the domain the Transformer architecture has been originally proposed for and had the greatest impact on, but it also accelerated research in other domains, recently even Computer Vision. Thus, we focus here on what makes the Transformer and self-attention so powerful in general. In a second notebook, we will look at Vision Transformers, i.e. Transformers for image classification (link to notebook).

Below, we import our standard libraries.

[2]:
# Standard libraries
import math
import os
import urllib.request
from functools import partial
from urllib.error import HTTPError

# Plotting
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

# PyTorch Lightning
import pytorch_lightning as pl
import seaborn as sns

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

# Torchvision
import torchvision
from IPython.display import set_matplotlib_formats
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision import transforms
from torchvision.datasets import CIFAR100
from tqdm.notebook import tqdm

plt.set_cmap("cividis")
%matplotlib inline
set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.reset_orig()

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/Transformers/")

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
/usr/lib/python3.9/site-packages/apex/pyprof/__init__.py:5: FutureWarning: pyprof will be removed by the end of June, 2022
  warnings.warn("pyprof will be removed by the end of June, 2022", FutureWarning)
/tmp/ipykernel_1570/2689201066.py:34: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
  set_matplotlib_formats("svg", "pdf")  # For export
Global seed set to 42
Device: cuda:0

Two pre-trained models are downloaded below. Make sure to have adjusted your CHECKPOINT_PATH before running this code if not already done.

[3]:
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial6/"
# Files to download
pretrained_files = ["ReverseTask.ckpt", "SetAnomalyTask.ckpt"]

# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print("Downloading %s..." % file_url)
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file manually,"
                " or contact the author with the full output including the following error:\n",
                e,
            )
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial6/ReverseTask.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial6/SetAnomalyTask.ckpt...

The Transformer architecture

In the first part of this notebook, we will implement the Transformer architecture by hand. As the architecture is so popular, there already exists a Pytorch module nn.Transformer (documentation) and a tutorial on how to use it for next token prediction. However, we will implement it here ourselves, to get through to the smallest details.

There are of course many more tutorials out there about attention and Transformers. Below, we list a few that are worth exploring if you are interested in the topic and might want yet another perspective on the topic after this one:

What is Attention?

The attention mechanism describes a recent new group of layers in neural networks that has attracted a lot of interest in the past few years, especially in sequence tasks. There are a lot of different possible definitions of “attention” in the literature, but the one we will use here is the following: the attention mechanism describes a weighted average of (sequence) elements with the weights dynamically computed based on an input query and elements’ keys. So what does this exactly mean? The goal is to take an average over the features of multiple elements. However, instead of weighting each element equally, we want to weight them depending on their actual values. In other words, we want to dynamically decide on which inputs we want to “attend” more than others. In particular, an attention mechanism has usually four parts we need to specify:

  • Query: The query is a feature vector that describes what we are looking for in the sequence, i.e. what would we maybe want to pay attention to.

  • Keys: For each input element, we have a key which is again a feature vector. This feature vector roughly describes what the element is “offering”, or when it might be important. The keys should be designed such that we can identify the elements we want to pay attention to based on the query.

  • Values: For each input element, we also have a value vector. This feature vector is the one we want to average over.

  • Score function: To rate which elements we want to pay attention to, we need to specify a score function f_{attn}. The score function takes the query and a key as input, and output the score/attention weight of the query-key pair. It is usually implemented by simple similarity metrics like a dot product, or a small MLP.

The weights of the average are calculated by a softmax over all score function outputs. Hence, we assign those value vectors a higher weight whose corresponding key is most similar to the query. If we try to describe it with pseudo-math, we can write:

\alpha_i = \frac{\exp\left(f_{attn}\left(\text{key}_i, \text{query}\right)\right)}{\sum_j \exp\left(f_{attn}\left(\text{key}_j, \text{query}\right)\right)}, \hspace{5mm} \text{out} = \sum_i \alpha_i \cdot \text{value}_i

Visually, we can show the attention over a sequence of words as follows:

d80004df82dc453098e2bc79323deeb6

For every word, we have one key and one value vector. The query is compared to all keys with a score function (in this case the dot product) to determine the weights. The softmax is not visualized for simplicity. Finally, the value vectors of all words are averaged using the attention weights.

Most attention mechanisms differ in terms of what queries they use, how the key and value vectors are defined, and what score function is used. The attention applied inside the Transformer architecture is called self-attention. In self-attention, each sequence element provides a key, value, and query. For each element, we perform an attention layer where based on its query, we check the similarity of the all sequence elements’ keys, and returned a different, averaged value vector for each element. We will now go into a bit more detail by first looking at the specific implementation of the attention mechanism which is in the Transformer case the scaled dot product attention.

Scaled Dot Product Attention

The core concept behind self-attention is the scaled dot product attention. Our goal is to have an attention mechanism with which any element in a sequence can attend to any other while still being efficient to compute. The dot product attention takes as input a set of queries Q\in\mathbb{R}^{T\times d_k}, keys K\in\mathbb{R}^{T\times d_k} and values V\in\mathbb{R}^{T\times d_v} where T is the sequence length, and d_k and d_v are the hidden dimensionality for queries/keys and values respectively. For simplicity, we neglect the batch dimension for now. The attention value from element i to j is based on its similarity of the query Q_i and key K_j, using the dot product as the similarity metric. In math, we calculate the dot product attention as follows:

\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

The matrix multiplication QK^T performs the dot product for every possible pair of queries and keys, resulting in a matrix of the shape T\times T. Each row represents the attention logits for a specific element i to all other elements in the sequence. On these, we apply a softmax and multiply with the value vector to obtain a weighted mean (the weights being determined by the attention). Another perspective on this attention mechanism offers the computation graph which is visualized below (figure credit - Vaswani et al., 2017).

49e679323a324fa499a3561383d1f25b

One aspect we haven’t discussed yet is the scaling factor of 1/\sqrt{d_k}. This scaling factor is crucial to maintain an appropriate variance of attention values after initialization. Remember that we intialize our layers with the intention of having equal variance throughout the model, and hence, Q and K might also have a variance close to 1. However, performing a dot product over two vectors with a variance \sigma results in a scalar having d_k-times higher variance:

q_i \sim \mathcal{N}(0,\sigma), k_i \sim \mathcal{N}(0,\sigma) \to \text{Var}\left(\sum_{i=1}^{d_k} q_i\cdot k_i\right) = \sigma\cdot d_k

If we do not scale down the variance back to \sigma, the softmax over the logits will already saturate to 1 for one random element and 0 for all others. The gradients through the softmax will be close to zero so that we can’t learn the parameters appropriately.

The block Mask (opt. ) in the diagram above represents the optional masking of specific entries in the attention matrix. This is for instance used if we stack multiple sequences with different lengths into a batch. To still benefit from parallelization in PyTorch, we pad the sentences to the same length and mask out the padding tokens during the calculation of the attention values. This is usually done by setting the respective attention logits to a very low value.

After we have discussed the details of the scaled dot product attention block, we can write a function below which computes the output features given the triple of queries, keys, and values:

[4]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

Note that our code above supports any additional dimensionality in front of the sequence length so that we can also use it for batches. However, for a better understanding, let’s generate a few random queries, keys, and value vectors, and calculate the attention outputs:

[5]:
seq_len, d_k = 3, 2
pl.seed_everything(42)
q = torch.randn(seq_len, d_k)
k = torch.randn(seq_len, d_k)
v = torch.randn(seq_len, d_k)
values, attention = scaled_dot_product(q, k, v)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("Values\n", values)
print("Attention\n", attention)
Global seed set to 42
Q
 tensor([[ 0.3367,  0.1288],
        [ 0.2345,  0.2303],
        [-1.1229, -0.1863]])
K
 tensor([[ 2.2082, -0.6380],
        [ 0.4617,  0.2674],
        [ 0.5349,  0.8094]])
V
 tensor([[ 1.1103, -1.6898],
        [-0.9890,  0.9580],
        [ 1.3221,  0.8172]])
Values
 tensor([[ 0.5698, -0.1520],
        [ 0.5379, -0.0265],
        [ 0.2246,  0.5556]])
Attention
 tensor([[0.4028, 0.2886, 0.3086],
        [0.3538, 0.3069, 0.3393],
        [0.1303, 0.4630, 0.4067]])

Before continuing, make sure you can follow the calculation of the specific values here, and also check it by hand. It is important to fully understand how the scaled dot product attention is calculated.

Multi-Head Attention

The scaled dot product attention allows a network to attend over a sequence. However, often there are multiple different aspects a sequence element wants to attend to, and a single weighted average is not a good option for it. This is why we extend the attention mechanisms to multiple heads, i.e. multiple different query-key-value triplets on the same features. Specifically, given a query, key, and value matrix, we transform those into h sub-queries, sub-keys, and sub-values, which we pass through the scaled dot product attention independently. Afterward, we concatenate the heads and combine them with a final weight matrix. Mathematically, we can express this operation as:

\begin{split}
    \text{Multihead}(Q,K,V) & = \text{Concat}(\text{head}_1,...,\text{head}_h)W^{O}\\
    \text{where } \text{head}_i & = \text{Attention}(QW_i^Q,KW_i^K, VW_i^V)
\end{split}

We refer to this as Multi-Head Attention layer with the learnable parameters W_{1...h}^{Q}\in\mathbb{R}^{D\times d_k}, W_{1...h}^{K}\in\mathbb{R}^{D\times d_k}, W_{1...h}^{V}\in\mathbb{R}^{D\times d_v}, and W^{O}\in\mathbb{R}^{h\cdot d_k\times d_{out}} (D being the input dimensionality). Expressed in a computational graph, we can visualize it as below (figure credit - Vaswani et al., 2017).

ffb61d9dfb344b5c95a1c9682297e942

How are we applying a Multi-Head Attention layer in a neural network, where we don’t have an arbitrary query, key, and value vector as input? Looking at the computation graph above, a simple but effective implementation is to set the current feature map in a NN, X\in\mathbb{R}^{B\times T\times d_{\text{model}}}, as Q, K and V (B being the batch size, T the sequence length, d_{\text{model}} the hidden dimensionality of X). The consecutive weight matrices W^{Q}, W^{K}, and W^{V} can transform X to the corresponding feature vectors that represent the queries, keys, and values of the input. Using this approach, we can implement the Multi-Head Attention module below.

[6]:
class MultiheadAttention(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3 * embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, embed_dim = x.size()
        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)  # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3)  # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, embed_dim)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o

One crucial characteristic of the multi-head attention is that it is permutation-equivariant with respect to its inputs. This means that if we switch two input elements in the sequence, e.g. X_1\leftrightarrow X_2 (neglecting the batch dimension for now), the output is exactly the same besides the elements 1 and 2 switched. Hence, the multi-head attention is actually looking at the input not as a sequence, but as a set of elements. This property makes the multi-head attention block and the Transformer architecture so powerful and widely applicable! But what if the order of the input is actually important for solving the task, like language modeling? The answer is to encode the position in the input features, which we will take a closer look at later (topic Positional encodings below).

Before moving on to creating the Transformer architecture, we can compare the self-attention operation with our other common layer competitors for sequence data: convolutions and recurrent neural networks. Below you can find a table by Vaswani et al. (2017) on the complexity per layer, the number of sequential operations, and maximum path length. The complexity is measured by the upper bound of the number of operations to perform, while the maximum path length represents the maximum number of steps a forward or backward signal has to traverse to reach any other position. The lower this length, the better gradient signals can backpropagate for long-range dependencies. Let’s take a look at the table below:

9847f3f67425453d847130030c8a315c

n is the sequence length, d is the representation dimension and k is the kernel size of convolutions. In contrast to recurrent networks, the self-attention layer can parallelize all its operations making it much faster to execute for smaller sequence lengths. However, when the sequence length exceeds the hidden dimensionality, self-attention becomes more expensive than RNNs. One way of reducing the computational cost for long sequences is by restricting the self-attention to a neighborhood of inputs to attend over, denoted by r. Nevertheless, there has been recently a lot of work on more efficient Transformer architectures that still allow long dependencies, of which you can find an overview in the paper by Tay et al. (2020) if interested.

Transformer Encoder

Next, we will look at how to apply the multi-head attention blog inside the Transformer architecture. Originally, the Transformer model was designed for machine translation. Hence, it got an encoder-decoder structure where the encoder takes as input the sentence in the original language and generates an attention-based representation. On the other hand, the decoder attends over the encoded information and generates the translated sentence in an autoregressive manner, as in a standard RNN. While this structure is extremely useful for Sequence-to-Sequence tasks with the necessity of autoregressive decoding, we will focus here on the encoder part. Many advances in NLP have been made using pure encoder-based Transformer models (if interested, models include the BERT-family, the Vision Transformer, and more), and in our tutorial, we will also mainly focus on the encoder part. If you have understood the encoder architecture, the decoder is a very small step to implement as well. The full Transformer architecture looks as follows (figure credit - Vaswani et al., 2017). :

89b98b05674841a98421b38b468d32d0

The encoder consists of N identical blocks that are applied in sequence. Taking as input x, it is first passed through a Multi-Head Attention block as we have implemented above. The output is added to the original input using a residual connection, and we apply a consecutive Layer Normalization on the sum. Overall, it calculates \text{LayerNorm}(x+\text{Multihead}(x,x,x)) (x being Q, K and V input to the attention layer). The residual connection is crucial in the Transformer architecture for two reasons:

  1. Similar to ResNets, Transformers are designed to be very deep. Some models contain more than 24 blocks in the encoder. Hence, the residual connections are crucial for enabling a smooth gradient flow through the model.

  2. Without the residual connection, the information about the original sequence is lost. Remember that the Multi-Head Attention layer ignores the position of elements in a sequence, and can only learn it based on the input features. Removing the residual connections would mean that this information is lost after the first attention layer (after initialization), and with a randomly initialized query and key vector, the output vectors for position i has no relation to its original input. All outputs of the attention are likely to represent similar/same information, and there is no chance for the model to distinguish which information came from which input element. An alternative option to residual connection would be to fix at least one head to focus on its original input, but this is very inefficient and does not have the benefit of the improved gradient flow.

The Layer Normalization also plays an important role in the Transformer architecture as it enables faster training and provides small regularization. Additionally, it ensures that the features are in a similar magnitude among the elements in the sequence. We are not using Batch Normalization because it depends on the batch size which is often small with Transformers (they require a lot of GPU memory), and BatchNorm has shown to perform particularly bad in language as the features of words tend to have a much higher variance (there are many, very rare words which need to be considered for a good distribution estimate).

Additionally to the Multi-Head Attention, a small fully connected feed-forward network is added to the model, which is applied to each position separately and identically. Specifically, the model uses a Linear\toReLU\toLinear MLP. The full transformation including the residual connection can be expressed as:

\begin{split}
    \text{FFN}(x) & = \max(0, xW_1+b_1)W_2 + b_2\\
    x & = \text{LayerNorm}(x + \text{FFN}(x))
\end{split}

This MLP adds extra complexity to the model and allows transformations on each sequence element separately. You can imagine as this allows the model to “post-process” the new information added by the previous Multi-Head Attention, and prepare it for the next attention block. Usually, the inner dimensionality of the MLP is 2-8\times larger than d_{\text{model}}, i.e. the dimensionality of the original input x. The general advantage of a wider layer instead of a narrow, multi-layer MLP is the faster, parallelizable execution.

Finally, after looking at all parts of the encoder architecture, we can start implementing it below. We first start by implementing a single encoder block. Additionally to the layers described above, we will add dropout layers in the MLP and on the output of the MLP and Multi-Head Attention for regularization.

[7]:
class EncoderBlock(nn.Module):
    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0):
        """
        Args:
            input_dim: Dimensionality of the input
            num_heads: Number of heads to use in the attention block
            dim_feedforward: Dimensionality of the hidden layer in the MLP
            dropout: Dropout probability to use in the dropout layers
        """
        super().__init__()

        # Attention layer
        self.self_attn = MultiheadAttention(input_dim, input_dim, num_heads)

        # Two-layer MLP
        self.linear_net = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, input_dim),
        )

        # Layers to apply in between the main layers
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Attention part
        attn_out = self.self_attn(x, mask=mask)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)

        # MLP part
        linear_out = self.linear_net(x)
        x = x + self.dropout(linear_out)
        x = self.norm2(x)

        return x

Based on this block, we can implement a module for the full Transformer encoder. Additionally to a forward function that iterates through the sequence of encoder blocks, we also provide a function called get_attention_maps. The idea of this function is to return the attention probabilities for all Multi-Head Attention blocks in the encoder. This helps us in understanding, and in a sense, explaining the model. However, the attention probabilities should be interpreted with a grain of salt as it does not necessarily reflect the true interpretation of the model (there is a series of papers about this, including Attention is not Explanation and Attention is not not Explanation).

[8]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, **block_args):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask=mask)
        return x

    def get_attention_maps(self, x, mask=None):
        attention_maps = []
        for layer in self.layers:
            _, attn_map = layer.self_attn(x, mask=mask, return_attention=True)
            attention_maps.append(attn_map)
            x = layer(x)
        return attention_maps
Positional encoding

We have discussed before that the Multi-Head Attention block is permutation-equivariant, and cannot distinguish whether an input comes before another one in the sequence or not. In tasks like language understanding, however, the position is important for interpreting the input words. The position information can therefore be added via the input features. We could learn a embedding for every possible position, but this would not generalize to a dynamical input sequence length. Hence, the better option is to use feature patterns that the network can identify from the features and potentially generalize to larger sequences. The specific pattern chosen by Vaswani et al. are sine and cosine functions of different frequencies, as follows:

PE_{(pos,i)} = \begin{cases}
    \sin\left(\frac{pos}{10000^{i/d_{\text{model}}}}\right) & \text{if}\hspace{3mm} i \text{ mod } 2=0\\
    \cos\left(\frac{pos}{10000^{(i-1)/d_{\text{model}}}}\right) & \text{otherwise}\\
\end{cases}

PE_{(pos,i)} represents the position encoding at position pos in the sequence, and hidden dimensionality i. These values, concatenated for all hidden dimensions, are added to the original input features (in the Transformer visualization above, see “Positional encoding”), and constitute the position information. We distinguish between even (i \text{ mod } 2=0) and uneven (i \text{ mod } 2=1) hidden dimensionalities where we apply a sine/cosine respectively. The intuition behind this encoding is that you can represent PE_{(pos+k,:)} as a linear function of PE_{(pos,:)}, which might allow the model to easily attend to relative positions. The wavelengths in different dimensions range from 2\pi to 10000\cdot 2\pi.

The positional encoding is implemented below. The code is taken from the PyTorch tutorial about Transformers on NLP and adjusted for our purposes.

[9]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        """
        Args
            d_model: Hidden dimensionality of the input.
            max_len: Maximum length of a sequence to expect.
        """
        super().__init__()

        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer("pe", pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)]
        return x

To understand the positional encoding, we can visualize it below. We will generate an image of the positional encoding over hidden dimensionality and position in a sequence. Each pixel, therefore, represents the change of the input feature we perform to encode the specific position. Let’s do it below.

[10]:
encod_block = PositionalEncoding(d_model=48, max_len=96)
pe = encod_block.pe.squeeze().T.cpu().numpy()

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 3))
pos = ax.imshow(pe, cmap="RdGy", extent=(1, pe.shape[1] + 1, pe.shape[0] + 1, 1))
fig.colorbar(pos, ax=ax)
ax.set_xlabel("Position in sequence")
ax.set_ylabel("Hidden dimension")
ax.set_title("Positional encoding over hidden dimensions")
ax.set_xticks([1] + [i * 10 for i in range(1, 1 + pe.shape[1] // 10)])
ax.set_yticks([1] + [i * 10 for i in range(1, 1 + pe.shape[0] // 10)])
plt.show()
_images/notebooks_course_UvA-DL_05-transformers-and-MH-attention_24_0.svg

You can clearly see the sine and cosine waves with different wavelengths that encode the position in the hidden dimensions. Specifically, we can look at the sine/cosine wave for each hidden dimension separately, to get a better intuition of the pattern. Below we visualize the positional encoding for the hidden dimensions 1, 2, 3 and 4.

[11]:
sns.set_theme()
fig, ax = plt.subplots(2, 2, figsize=(12, 4))
ax = [a for a_list in ax for a in a_list]
for i in range(len(ax)):
    ax[i].plot(np.arange(1, 17), pe[i, :16], color="C%i" % i, marker="o", markersize=6, markeredgecolor="black")
    ax[i].set_title("Encoding in hidden dimension %i" % (i + 1))
    ax[i].set_xlabel("Position in sequence", fontsize=10)
    ax[i].set_ylabel("Positional encoding", fontsize=10)
    ax[i].set_xticks(np.arange(1, 17))
    ax[i].tick_params(axis="both", which="major", labelsize=10)
    ax[i].tick_params(axis="both", which="minor", labelsize=8)
    ax[i].set_ylim(-1.2, 1.2)
fig.subplots_adjust(hspace=0.8)
sns.reset_orig()
plt.show()
_images/notebooks_course_UvA-DL_05-transformers-and-MH-attention_26_0.svg

As we can see, the patterns between the hidden dimension 1 and 2 only differ in the starting angle. The wavelength is 2\pi, hence the repetition after position 6. The hidden dimensions 2 and 3 have about twice the wavelength.

Learning rate warm-up

One commonly used technique for training a Transformer is learning rate warm-up. This means that we gradually increase the learning rate from 0 on to our originally specified learning rate in the first few iterations. Thus, we slowly start learning instead of taking very large steps from the beginning. In fact, training a deep Transformer without learning rate warm-up can make the model diverge and achieve a much worse performance on training and testing. Take for instance the following plot by Liu et al. (2019) comparing Adam-vanilla (i.e. Adam without warm-up) vs Adam with a warm-up:

809bdb29c55d49ac9170c28c080f8a3a

Clearly, the warm-up is a crucial hyperparameter in the Transformer architecture. Why is it so important? There are currently two common explanations. Firstly, Adam uses the bias correction factors which however can lead to a higher variance in the adaptive learning rate during the first iterations. Improved optimizers like RAdam have been shown to overcome this issue, not requiring warm-up for training Transformers. Secondly, the iteratively applied Layer Normalization across layers can lead to very high gradients during the first iterations, which can be solved by using Pre-Layer Normalization (similar to Pre-Activation ResNet), or replacing Layer Normalization by other techniques (Adaptive Normalization, Power Normalization).

Nevertheless, many applications and papers still use the original Transformer architecture with Adam, because warm-up is a simple, yet effective way of solving the gradient problem in the first iterations. There are many different schedulers we could use. For instance, the original Transformer paper used an exponential decay scheduler with a warm-up. However, the currently most popular scheduler is the cosine warm-up scheduler, which combines warm-up with a cosine-shaped learning rate decay. We can implement it below, and visualize the learning rate factor over epochs.

[12]:
class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup, max_iters):
        self.warmup = warmup
        self.max_num_iters = max_iters
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor
[13]:
# Needed for initializing the lr scheduler
p = nn.Parameter(torch.empty(4, 4))
optimizer = optim.Adam([p], lr=1e-3)
lr_scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=100, max_iters=2000)

# Plotting
epochs = list(range(2000))
sns.set()
plt.figure(figsize=(8, 3))
plt.plot(epochs, [lr_scheduler.get_lr_factor(e) for e in epochs])
plt.ylabel("Learning rate factor")
plt.xlabel("Iterations (in batches)")
plt.title("Cosine Warm-up Learning Rate Scheduler")
plt.show()
sns.reset_orig()
_images/notebooks_course_UvA-DL_05-transformers-and-MH-attention_30_0.svg

In the first 100 iterations, we increase the learning rate factor from 0 to 1, whereas for all later iterations, we decay it using the cosine wave. Pre-implementations of this scheduler can be found in the popular NLP Transformer library huggingface.

PyTorch Lightning Module

Finally, we can embed the Transformer architecture into a PyTorch lightning module. From Tutorial 5, you know that PyTorch Lightning simplifies our training and test code, as well as structures the code nicely in separate functions. We will implement a template for a classifier based on the Transformer encoder. Thereby, we have a prediction output per sequence element. If we would need a classifier over the whole sequence, the common approach is to add an additional [CLS] token to the sequence, representing the classifier token. However, here we focus on tasks where we have an output per element.

Additionally to the Transformer architecture, we add a small input network (maps input dimensions to model dimensions), the positional encoding, and an output network (transforms output encodings to predictions). We also add the learning rate scheduler, which takes a step each iteration instead of once per epoch. This is needed for the warmup and the smooth cosine decay. The training, validation, and test step is left empty for now and will be filled for our task-specific models.

[14]:
class TransformerPredictor(pl.LightningModule):
    def __init__(
        self,
        input_dim,
        model_dim,
        num_classes,
        num_heads,
        num_layers,
        lr,
        warmup,
        max_iters,
        dropout=0.0,
        input_dropout=0.0,
    ):
        """
        Args:
            input_dim: Hidden dimensionality of the input
            model_dim: Hidden dimensionality to use inside the Transformer
            num_classes: Number of classes to predict per sequence element
            num_heads: Number of heads to use in the Multi-Head Attention blocks
            num_layers: Number of encoder blocks to use.
            lr: Learning rate in the optimizer
            warmup: Number of warmup steps. Usually between 50 and 500
            max_iters: Number of maximum iterations the model is trained for. This is needed for the CosineWarmup scheduler
            dropout: Dropout to apply inside the model
            input_dropout: Dropout to apply on the input features
        """
        super().__init__()
        self.save_hyperparameters()
        self._create_model()

    def _create_model(self):
        # Input dim -> Model dim
        self.input_net = nn.Sequential(
            nn.Dropout(self.hparams.input_dropout), nn.Linear(self.hparams.input_dim, self.hparams.model_dim)
        )
        # Positional encoding for sequences
        self.positional_encoding = PositionalEncoding(d_model=self.hparams.model_dim)
        # Transformer
        self.transformer = TransformerEncoder(
            num_layers=self.hparams.num_layers,
            input_dim=self.hparams.model_dim,
            dim_feedforward=2 * self.hparams.model_dim,
            num_heads=self.hparams.num_heads,
            dropout=self.hparams.dropout,
        )
        # Output classifier per sequence lement
        self.output_net = nn.Sequential(
            nn.Linear(self.hparams.model_dim, self.hparams.model_dim),
            nn.LayerNorm(self.hparams.model_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(self.hparams.dropout),
            nn.Linear(self.hparams.model_dim, self.hparams.num_classes),
        )

    def forward(self, x, mask=None, add_positional_encoding=True):
        """
        Args:
            x: Input features of shape [Batch, SeqLen, input_dim]
            mask: Mask to apply on the attention outputs (optional)
            add_positional_encoding: If True, we add the positional encoding to the input.
                                      Might not be desired for some tasks.
        """
        x = self.input_net(x)
        if add_positional_encoding:
            x = self.positional_encoding(x)
        x = self.transformer(x, mask=mask)
        x = self.output_net(x)
        return x

    @torch.no_grad()
    def get_attention_maps(self, x, mask=None, add_positional_encoding=True):
        """Function for extracting the attention matrices of the whole Transformer for a single batch.

        Input arguments same as the forward pass.
        """
        x = self.input_net(x)
        if add_positional_encoding:
            x = self.positional_encoding(x)
        attention_maps = self.transformer.get_attention_maps(x, mask=mask)
        return attention_maps

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)

        # We don't return the lr scheduler because we need to apply it per iteration, not per epoch
        self.lr_scheduler = CosineWarmupScheduler(
            optimizer, warmup=self.hparams.warmup, max_iters=self.hparams.max_iters
        )
        return optimizer

    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        self.lr_scheduler.step()  # Step per iteration

    def training_step(self, batch, batch_idx):
        raise NotImplementedError

    def validation_step(self, batch, batch_idx):
        raise NotImplementedError

    def test_step(self, batch, batch_idx):
        raise NotImplementedError

Experiments

After having finished the implementation of the Transformer architecture, we can start experimenting and apply it to various tasks. In this notebook, we will focus on two tasks: parallel Sequence-to-Sequence, and set anomaly detection. The two tasks focus on different properties of the Transformer architecture, and we go through them below.

Sequence to Sequence

A Sequence-to-Sequence task represents a task where the input and the output is a sequence, not necessarily of the same length. Popular tasks in this domain include machine translation and summarization. For this, we usually have a Transformer encoder for interpreting the input sequence, and a decoder for generating the output in an autoregressive manner. Here, however, we will go back to a much simpler example task and use only the encoder. Given a sequence of N numbers between 0 and M, the task is to reverse the input sequence. In Numpy notation, if our input is x, the output should be x[::-1]. Although this task sounds very simple, RNNs can have issues with such because the task requires long-term dependencies. Transformers are built to support such, and hence, we expect it to perform very well.

First, let’s create a dataset class below.

[15]:
class ReverseDataset(data.Dataset):
    def __init__(self, num_categories, seq_len, size):
        super().__init__()
        self.num_categories = num_categories
        self.seq_len = seq_len
        self.size = size

        self.data = torch.randint(self.num_categories, size=(self.size, self.seq_len))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        inp_data = self.data[idx]
        labels = torch.flip(inp_data, dims=(0,))
        return inp_data, labels

We create an arbitrary number of random sequences of numbers between 0 and num_categories-1. The label is simply the tensor flipped over the sequence dimension. We can create the corresponding data loaders below.

[16]:
dataset = partial(ReverseDataset, 10, 16)
train_loader = data.DataLoader(dataset(50000), batch_size=128, shuffle=True, drop_last=True, pin_memory=True)
val_loader = data.DataLoader(dataset(1000), batch_size=128)
test_loader = data.DataLoader(dataset(10000), batch_size=128)

Let’s look at an arbitrary sample of the dataset:

[17]:
inp_data, labels = train_loader.dataset[0]
print("Input data:", inp_data)
print("Labels:    ", labels)
Input data: tensor([9, 6, 2, 0, 6, 2, 7, 9, 7, 3, 3, 4, 3, 7, 0, 9])
Labels:     tensor([9, 0, 7, 3, 4, 3, 3, 7, 9, 7, 2, 6, 0, 2, 6, 9])

During training, we pass the input sequence through the Transformer encoder and predict the output for each input token. We use the standard Cross-Entropy loss to perform this. Every number is represented as a one-hot vector. Remember that representing the categories as single scalars decreases the expressiveness of the model extremely as 0 and 1 are not closer related than 0 and 9 in our example. An alternative to a one-hot vector is using a learned embedding vector as it is provided by the PyTorch module nn.Embedding. However, using a one-hot vector with an additional linear layer as in our case has the same effect as an embedding layer (self.input_net maps one-hot vector to a dense vector, where each row of the weight matrix represents the embedding for a specific category).

To implement the training dynamic, we create a new class inheriting from TransformerPredictor and overwriting the training, validation and test step functions.

[18]:
class ReversePredictor(TransformerPredictor):
    def _calculate_loss(self, batch, mode="train"):
        # Fetch data and transform categories to one-hot vectors
        inp_data, labels = batch
        inp_data = F.one_hot(inp_data, num_classes=self.hparams.num_classes).float()

        # Perform prediction and calculate loss and accuracy
        preds = self.forward(inp_data, add_positional_encoding=True)
        loss = F.cross_entropy(preds.view(-1, preds.size(-1)), labels.view(-1))
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        # Logging
        self.log("%s_loss" % mode, loss)
        self.log("%s_acc" % mode, acc)
        return loss, acc

    def training_step(self, batch, batch_idx):
        loss, _ = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="test")

Finally, we can create a training function similar to the one we have seen in Tutorial 5 for PyTorch Lightning. We create a pl.Trainer object, running for N epochs, logging in TensorBoard, and saving our best model based on the validation. Afterward, we test our models on the test set. An additional parameter we pass to the trainer here is gradient_clip_val. This clips the norm of the gradients for all parameters before taking an optimizer step and prevents the model from diverging if we obtain very high gradients at, for instance, sharp loss surfaces (see many good blog posts on gradient clipping, like DeepAI glossary). For Transformers, gradient clipping can help to further stabilize the training during the first few iterations, and also afterward. In plain PyTorch, you can apply gradient clipping via torch.nn.utils.clip_grad_norm_(...) (see documentation). The clip value is usually between 0.5 and 10, depending on how harsh you want to clip large gradients. After having explained this, let’s implement the training function:

[19]:
def train_reverse(**kwargs):
    # Create a PyTorch Lightning trainer with the generation callback
    root_dir = os.path.join(CHECKPOINT_PATH, "ReverseTask")
    os.makedirs(root_dir, exist_ok=True)
    trainer = pl.Trainer(
        default_root_dir=root_dir,
        callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc")],
        gpus=1 if str(device).startswith("cuda") else 0,
        max_epochs=10,
        gradient_clip_val=5,
        progress_bar_refresh_rate=1,
    )
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "ReverseTask.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = ReversePredictor.load_from_checkpoint(pretrained_filename)
    else:
        model = ReversePredictor(max_iters=trainer.max_epochs * len(train_loader), **kwargs)
        trainer.fit(model, train_loader, val_loader)

    # Test best model on validation and test set
    val_result = trainer.test(model, dataloaders=val_loader, verbose=False)
    test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
    result = {"test_acc": test_result[0]["test_acc"], "val_acc": val_result[0]["test_acc"]}

    model = model.to(device)
    return model, result

Finally, we can train the model. In this setup, we will use a single encoder block and a single head in the Multi-Head Attention. This is chosen because of the simplicity of the task, and in this case, the attention can actually be interpreted as an “explanation” of the predictions (compared to the other papers above dealing with deep Transformers).

[20]:
reverse_model, reverse_result = train_reverse(
    input_dim=train_loader.dataset.num_categories,
    model_dim=32,
    num_heads=1,
    num_classes=train_loader.dataset.num_categories,
    num_layers=1,
    dropout=0.0,
    lr=5e-4,
    warmup=50,
)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:96: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=1)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/Transformers/ReverseTask/lightning_logs
Found pretrained model, loading...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

The warning of PyTorch Lightning regarding the number of workers can be ignored for now. As the data set is so simple and the __getitem__ finishes a neglectable time, we don’t need subprocesses to provide us the data (in fact, more workers can slow down the training as we have communication overhead among processes/threads). First, let’s print the results:

[21]:
print("Val accuracy:  %4.2f%%" % (100.0 * reverse_result["val_acc"]))
print("Test accuracy: %4.2f%%" % (100.0 * reverse_result["test_acc"]))
Val accuracy:  100.00%
Test accuracy: 100.00%

As we would have expected, the Transformer can correctly solve the task. However, how does the attention in the Multi-Head Attention block looks like for an arbitrary input? Let’s try to visualize it below.

[22]:
data_input, labels = next(iter(val_loader))
inp_data = F.one_hot(data_input, num_classes=reverse_model.hparams.num_classes).float()
inp_data = inp_data.to(device)
attention_maps = reverse_model.get_attention_maps(inp_data)

The object attention_maps is a list of length N where N is the number of layers. Each element is a tensor of shape [Batch, Heads, SeqLen, SeqLen], which we can verify below.

[23]:
attention_maps[0].shape
[23]:
torch.Size([128, 1, 16, 16])

Next, we will write a plotting function that takes as input the sequences, attention maps, and an index indicating for which batch element we want to visualize the attention map. We will create a plot where over rows, we have different layers, while over columns, we show the different heads. Remember that the softmax has been applied for each row separately.

[24]:
def plot_attention_maps(input_data, attn_maps, idx=0):
    if input_data is not None:
        input_data = input_data[idx].detach().cpu().numpy()
    else:
        input_data = np.arange(attn_maps[0][idx].shape[-1])
    attn_maps = [m[idx].detach().cpu().numpy() for m in attn_maps]

    num_heads = attn_maps[0].shape[0]
    num_layers = len(attn_maps)
    seq_len = input_data.shape[0]
    fig_size = 4 if num_heads == 1 else 3
    fig, ax = plt.subplots(num_layers, num_heads, figsize=(num_heads * fig_size, num_layers * fig_size))
    if num_layers == 1:
        ax = [ax]
    if num_heads == 1:
        ax = [[a] for a in ax]
    for row in range(num_layers):
        for column in range(num_heads):
            ax[row][column].imshow(attn_maps[row][column], origin="lower", vmin=0)
            ax[row][column].set_xticks(list(range(seq_len)))
            ax[row][column].set_xticklabels(input_data.tolist())
            ax[row][column].set_yticks(list(range(seq_len)))
            ax[row][column].set_yticklabels(input_data.tolist())
            ax[row][column].set_title("Layer %i, Head %i" % (row + 1, column + 1))
    fig.subplots_adjust(hspace=0.5)
    plt.show()

Finally, we can plot the attention map of our trained Transformer on the reverse task:

[25]:
plot_attention_maps(data_input, attention_maps, idx=0)
_images/notebooks_course_UvA-DL_05-transformers-and-MH-attention_55_0.svg

The model has learned to attend to the token that is on the flipped index of itself. Hence, it actually does what we intended it to do. We see that it however also pays some attention to values close to the flipped index. This is because the model doesn’t need the perfect, hard attention to solve this problem, but is fine with this approximate, noisy attention map. The close-by indices are caused by the similarity of the positional encoding, which we also intended with the positional encoding.

Set Anomaly Detection

Besides sequences, sets are another data structure that is relevant for many applications. In contrast to sequences, elements are unordered in a set. RNNs can only be applied on sets by assuming an order in the data, which however biases the model towards a non-existing order in the data. Vinyals et al. (2015) and other papers have shown that the assumed order can have a significant impact on the model’s performance, and hence, we should try to not use RNNs on sets. Ideally, our model should be permutation-equivariant/invariant such that the output is the same no matter how we sort the elements in a set.

Transformers offer the perfect architecture for this as the Multi-Head Attention is permutation-equivariant, and thus, outputs the same values no matter in what order we enter the inputs (inputs and outputs are permuted equally). The task we are looking at for sets is Set Anomaly Detection which means that we try to find the element(s) in a set that does not fit the others. In the research community, the common application of anomaly detection is performed on a set of images, where N-1 images belong to the same category/have the same high-level features while one belongs to another category. Note that category does not necessarily have to relate to a class in a standard classification problem, but could be the combination of multiple features. For instance, on a face dataset, this could be people with glasses, male, beard, etc. An example of distinguishing different animals can be seen below. The first four images show foxes, while the last represents a different animal. We want to recognize that the last image shows a different animal, but it is not relevant which class of animal it is.

20bd249ad14243049e95668f8aed0f4b

In this tutorial, we will use the CIFAR100 dataset. CIFAR100 has 600 images for 100 classes each with a resolution of 32x32, similar to CIFAR10. The larger amount of classes requires the model to attend to specific features in the images instead of coarse features as in CIFAR10, therefore making the task harder. We will show the model a set of 9 images of one class, and 1 image from another class. The task is to find the image that is from a different class than the other images. Using the raw images directly as input to the Transformer is not a good idea, because it is not translation invariant as a CNN, and would need to learn to detect image features from high-dimensional input first of all. Instead, we will use a pre-trained ResNet34 model from the torchvision package to obtain high-level, low-dimensional features of the images. The ResNet model has been pre-trained on the ImageNet dataset which contains 1 million images of 1k classes and varying resolutions. However, during training and testing, the images are usually scaled to a resolution of 224x224, and hence we rescale our CIFAR images to this resolution as well. Below, we will load the dataset, and prepare the data for being processed by the ResNet model.

[26]:
# ImageNet statistics
DATA_MEANS = np.array([0.485, 0.456, 0.406])
DATA_STD = np.array([0.229, 0.224, 0.225])
# As torch tensors for later preprocessing
TORCH_DATA_MEANS = torch.from_numpy(DATA_MEANS).view(1, 3, 1, 1)
TORCH_DATA_STD = torch.from_numpy(DATA_STD).view(1, 3, 1, 1)

# Resize to 224x224, and normalize to ImageNet statistic
transform = transforms.Compose(
    [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(DATA_MEANS, DATA_STD)]
)
# Loading the training dataset.
train_set = CIFAR100(root=DATASET_PATH, train=True, transform=transform, download=True)

# Loading the test set
test_set = CIFAR100(root=DATASET_PATH, train=False, transform=transform, download=True)
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /__w/1/s/.datasets/cifar-100-python.tar.gz
Extracting /__w/1/s/.datasets/cifar-100-python.tar.gz to /__w/1/s/.datasets
Files already downloaded and verified

Next, we want to run the pre-trained ResNet model on the images, and extract the features before the classification layer. These are the most high-level features, and should sufficiently describe the images. CIFAR100 has some similarity to ImageNet, and thus we are not retraining the ResNet model in any form. However, if you would want to get the best performance and have a very large dataset, it would be better to add the ResNet to the computation graph during training and finetune its parameters as well. As we don’t have a large enough dataset and want to train our model efficiently, we will extract the features beforehand. Let’s load and prepare the model below.

[27]:
os.environ["TORCH_HOME"] = CHECKPOINT_PATH
pretrained_model = torchvision.models.resnet34(pretrained=True)
# Remove classification layer
# In some models, it is called "fc", others have "classifier"
# Setting both to an empty sequential represents an identity map of the final features.
pretrained_model.fc = nn.Sequential()
pretrained_model.classifier = nn.Sequential()
# To GPU
pretrained_model = pretrained_model.to(device)

# Only eval, no gradient required
pretrained_model.eval()
for p in pretrained_model.parameters():
    p.requires_grad = False
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to saved_models/Transformers/hub/checkpoints/resnet34-333f7ec4.pth

We will now write a extraction function for the features below. This cell requires access to a GPU, as the model is rather deep and the images relatively large. The GPUs on GoogleColab are sufficient, but running this cell can take 2-3 minutes. Once it is run, the features are exported on disk so they don’t have to be recalculated every time you run the notebook. However, this requires >150MB free disk space. So it is recommended to run this only on a local computer if you have enough free disk and a GPU (GoogleColab is fine for this). If you do not have a GPU, you can download the features from the GoogleDrive folder.

[28]:
@torch.no_grad()
def extract_features(dataset, save_file):
    if not os.path.isfile(save_file):
        data_loader = data.DataLoader(dataset, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
        extracted_features = []
        for imgs, _ in tqdm(data_loader):
            imgs = imgs.to(device)
            feats = pretrained_model(imgs)
            extracted_features.append(feats)
        extracted_features = torch.cat(extracted_features, dim=0)
        extracted_features = extracted_features.detach().cpu()
        torch.save(extracted_features, save_file)
    else:
        extracted_features = torch.load(save_file)
    return extracted_features


train_feat_file = os.path.join(CHECKPOINT_PATH, "train_set_features.tar")
train_set_feats = extract_features(train_set, train_feat_file)

test_feat_file = os.path.join(CHECKPOINT_PATH, "test_set_features.tar")
test_feats = extract_features(test_set, test_feat_file)

Let’s verify the feature shapes below. The training should have 50k elements, and the test 10k images. The feature dimension is 512 for the ResNet34. If you experiment with other models, you likely see a different feature dimension.

[29]:
print("Train:", train_set_feats.shape)
print("Test: ", test_feats.shape)
Train: torch.Size([50000, 512])
Test:  torch.Size([10000, 512])

As usual, we want to create a validation set to detect when we should stop training. In this case, we will split the training set into 90% training, 10% validation. However, the difficulty is here that we need to ensure that the validation set has the same number of images for all 100 labels. Otherwise, we have a class imbalance which is not good for creating the image sets. Hence, we take 10% of the images for each class, and move them into the validation set. The code below does exactly this.

[30]:
# Split train into train+val
# Get labels from train set
labels = train_set.targets

# Get indices of images per class
labels = torch.LongTensor(labels)
num_labels = labels.max() + 1
sorted_indices = torch.argsort(labels).reshape(num_labels, -1)  # [classes, num_imgs per class]

# Determine number of validation images per class
num_val_exmps = sorted_indices.shape[1] // 10

# Get image indices for validation and training
val_indices = sorted_indices[:, :num_val_exmps].reshape(-1)
train_indices = sorted_indices[:, num_val_exmps:].reshape(-1)

# Group corresponding image features and labels
train_feats, train_labels = train_set_feats[train_indices], labels[train_indices]
val_feats, val_labels = train_set_feats[val_indices], labels[val_indices]

Now we can prepare a dataset class for the set anomaly task. We define an epoch to be the sequence in which each image has been exactly once as an “anomaly”. Hence, the length of the dataset is the number of images in it. For the training set, each time we access an item with __getitem__, we sample a random, different class than the image at the corresponding index idx has. In a second step, we sample N-1 images of this sampled class. The set of 10 images is finally returned. The randomness in the __getitem__ allows us to see a slightly different set during each iteration. However, we can’t use the same strategy for the test set as we want the test dataset to be the same every time we iterate over it. Hence, we sample the sets in the __init__ method, and return those in __getitem__. The code below implements exactly this dynamic.

[31]:
class SetAnomalyDataset(data.Dataset):
    def __init__(self, img_feats, labels, set_size=10, train=True):
        """
        Args:
            img_feats: Tensor of shape [num_imgs, img_dim]. Represents the high-level features.
            labels: Tensor of shape [num_imgs], containing the class labels for the images
            set_size: Number of elements in a set. N-1 are sampled from one class, and one from another one.
            train: If True, a new set will be sampled every time __getitem__ is called.
        """
        super().__init__()
        self.img_feats = img_feats
        self.labels = labels
        self.set_size = set_size - 1  # The set size is here the size of correct images
        self.train = train

        # Tensors with indices of the images per class
        self.num_labels = labels.max() + 1
        self.img_idx_by_label = torch.argsort(self.labels).reshape(self.num_labels, -1)

        if not train:
            self.test_sets = self._create_test_sets()

    def _create_test_sets(self):
        # Pre-generates the sets for each image for the test set
        test_sets = []
        num_imgs = self.img_feats.shape[0]
        np.random.seed(42)
        test_sets = [self.sample_img_set(self.labels[idx]) for idx in range(num_imgs)]
        test_sets = torch.stack(test_sets, dim=0)
        return test_sets

    def sample_img_set(self, anomaly_label):
        """Samples a new set of images, given the label of the anomaly.

        The sampled images come from a different class than anomaly_label
        """
        # Sample class from 0,...,num_classes-1 while skipping anomaly_label as class
        set_label = np.random.randint(self.num_labels - 1)
        if set_label >= anomaly_label:
            set_label += 1

        # Sample images from the class determined above
        img_indices = np.random.choice(self.img_idx_by_label.shape[1], size=self.set_size, replace=False)
        img_indices = self.img_idx_by_label[set_label, img_indices]
        return img_indices

    def __len__(self):
        return self.img_feats.shape[0]

    def __getitem__(self, idx):
        anomaly = self.img_feats[idx]
        if self.train:  # If train => sample
            img_indices = self.sample_img_set(self.labels[idx])
        else:  # If test => use pre-generated ones
            img_indices = self.test_sets[idx]

        # Concatenate images. The anomaly is always the last image for simplicity
        img_set = torch.cat([self.img_feats[img_indices], anomaly[None]], dim=0)
        indices = torch.cat([img_indices, torch.LongTensor([idx])], dim=0)
        label = img_set.shape[0] - 1

        # We return the indices of the images for visualization purpose. "Label" is the index of the anomaly
        return img_set, indices, label

Next, we can setup our datasets and data loaders below. Here, we will use a set size of 10, i.e. 9 images from one category + 1 anomaly. Feel free to change it if you want to experiment with the sizes.

[32]:
SET_SIZE = 10
test_labels = torch.LongTensor(test_set.targets)

train_anom_dataset = SetAnomalyDataset(train_feats, train_labels, set_size=SET_SIZE, train=True)
val_anom_dataset = SetAnomalyDataset(val_feats, val_labels, set_size=SET_SIZE, train=False)
test_anom_dataset = SetAnomalyDataset(test_feats, test_labels, set_size=SET_SIZE, train=False)

train_anom_loader = data.DataLoader(
    train_anom_dataset, batch_size=64, shuffle=True, drop_last=True, num_workers=4, pin_memory=True
)
val_anom_loader = data.DataLoader(val_anom_dataset, batch_size=64, shuffle=False, drop_last=False, num_workers=4)
test_anom_loader = data.DataLoader(test_anom_dataset, batch_size=64, shuffle=False, drop_last=False, num_workers=4)

To understand the dataset a little better, we can plot below a few sets from the test dataset. Each row shows a different input set, where the first 9 are from the same class.

[33]:
def visualize_exmp(indices, orig_dataset):
    images = [orig_dataset[idx][0] for idx in indices.reshape(-1)]
    images = torch.stack(images, dim=0)
    images = images * TORCH_DATA_STD + TORCH_DATA_MEANS

    img_grid = torchvision.utils.make_grid(images, nrow=SET_SIZE, normalize=True, pad_value=0.5, padding=16)
    img_grid = img_grid.permute(1, 2, 0)

    plt.figure(figsize=(12, 8))
    plt.title("Anomaly examples on CIFAR100")
    plt.imshow(img_grid)
    plt.axis("off")
    plt.show()
    plt.close()


_, indices, _ = next(iter(test_anom_loader))
visualize_exmp(indices[:4], test_set)
_images/notebooks_course_UvA-DL_05-transformers-and-MH-attention_72_0.svg

We can already see that for some sets the task might be easier than for others. Difficulties can especially arise if the anomaly is in a different, but yet visually similar class (e.g. train vs bus, flour vs worm, etc. ).

After having prepared the data, we can look closer at the model. Here, we have a classification of the whole set. For the prediction to be permutation-equivariant, we will output one logit for each image. Over these logits, we apply a softmax and train the anomaly image to have the highest score/probability. This is a bit different than a standard classification layer as the softmax is applied over images, not over output classes in the classical sense. However, if we swap two images in their position, we effectively swap their position in the output softmax. Hence, the prediction is equivariant with respect to the input. We implement this idea below in the subclass of the Transformer Lightning module.

[34]:
class AnomalyPredictor(TransformerPredictor):
    def _calculate_loss(self, batch, mode="train"):
        img_sets, _, labels = batch
        # No positional encodings as it is a set, not a sequence!
        preds = self.forward(img_sets, add_positional_encoding=False)
        preds = preds.squeeze(dim=-1)  # Shape: [Batch_size, set_size]
        loss = F.cross_entropy(preds, labels)  # Softmax/CE over set dimension
        acc = (preds.argmax(dim=-1) == labels).float().mean()
        self.log("%s_loss" % mode, loss)
        self.log("%s_acc" % mode, acc, on_step=False, on_epoch=True)
        return loss, acc

    def training_step(self, batch, batch_idx):
        loss, _ = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        _ = self._calculate_loss(batch, mode="test")

Finally, we write our train function below. It has the exact same structure as the reverse task one, hence not much of an explanation is needed here.

[35]:
def train_anomaly(**kwargs):
    # Create a PyTorch Lightning trainer with the generation callback
    root_dir = os.path.join(CHECKPOINT_PATH, "SetAnomalyTask")
    os.makedirs(root_dir, exist_ok=True)
    trainer = pl.Trainer(
        default_root_dir=root_dir,
        callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc")],
        gpus=1 if str(device).startswith("cuda") else 0,
        max_epochs=100,
        gradient_clip_val=2,
        progress_bar_refresh_rate=1,
    )
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "SetAnomalyTask.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = AnomalyPredictor.load_from_checkpoint(pretrained_filename)
    else:
        model = AnomalyPredictor(max_iters=trainer.max_epochs * len(train_anom_loader), **kwargs)
        trainer.fit(model, train_anom_loader, val_anom_loader)
        model = AnomalyPredictor.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    # Test best model on validation and test set
    train_result = trainer.test(model, dataloaders=train_anom_loader, verbose=False)
    val_result = trainer.test(model, dataloaders=val_anom_loader, verbose=False)
    test_result = trainer.test(model, dataloaders=test_anom_loader, verbose=False)
    result = {
        "test_acc": test_result[0]["test_acc"],
        "val_acc": val_result[0]["test_acc"],
        "train_acc": train_result[0]["test_acc"],
    }

    model = model.to(device)
    return model, result

Let’s finally train our model. We will use 4 layers with 4 attention heads each. The hidden dimensionality of the model is 256, and we use a dropout of 0.1 throughout the model for good regularization. Note that we also apply the dropout on the input features, as this makes the model more robust against image noise and generalizes better. Again, we use warmup to slowly start our model training.

[36]:
anomaly_model, anomaly_result = train_anomaly(
    input_dim=train_anom_dataset.img_feats.shape[-1],
    model_dim=256,
    num_heads=4,
    num_classes=1,
    num_layers=4,
    dropout=0.1,
    input_dropout=0.1,
    lr=5e-4,
    warmup=100,
)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:96: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=1)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/Transformers/SetAnomalyTask/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model, loading...
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:486: PossibleUserWarning: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test/predict dataloaders.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

We can print the achieved accuracy below.

[37]:
print("Train accuracy: %4.2f%%" % (100.0 * anomaly_result["train_acc"]))
print("Val accuracy:   %4.2f%%" % (100.0 * anomaly_result["val_acc"]))
print("Test accuracy:  %4.2f%%" % (100.0 * anomaly_result["test_acc"]))
Train accuracy: 96.33%
Val accuracy:   95.92%
Test accuracy:  94.41%

With ~94% validation and test accuracy, the model generalizes quite well. It should be noted that you might see slightly different scores depending on what computer/device you are running this notebook. This is because despite setting the seed before generating the test dataset, it is not the same across platforms and numpy versions. Nevertheless, we can conclude that the model performs quite well and can solve the task for most sets. Before trying to interpret the model, let’s verify that our model is permutation-equivariant, and assigns the same predictions for different permutations of the input set. For this, we sample a batch from the test set and run it through the model to obtain the probabilities.

[38]:
inp_data, indices, labels = next(iter(test_anom_loader))
inp_data = inp_data.to(device)

anomaly_model.eval()

with torch.no_grad():
    preds = anomaly_model.forward(inp_data, add_positional_encoding=False)
    preds = F.softmax(preds.squeeze(dim=-1), dim=-1)

    # Permut input data
    permut = np.random.permutation(inp_data.shape[1])
    perm_inp_data = inp_data[:, permut]
    perm_preds = anomaly_model.forward(perm_inp_data, add_positional_encoding=False)
    perm_preds = F.softmax(perm_preds.squeeze(dim=-1), dim=-1)

assert (preds[:, permut] - perm_preds).abs().max() < 1e-5, "Predictions are not permutation equivariant"

print("Preds\n", preds[0, permut].cpu().numpy())
print("Permuted preds\n", perm_preds[0].cpu().numpy())
Preds
 [2.7690839e-05 1.8979506e-05 1.7386024e-05 2.7842490e-05 1.6142623e-05
 1.7020535e-05 5.7293695e-05 9.9977750e-01 2.1364667e-05 1.8681461e-05]
Permuted preds
 [2.7690839e-05 1.8979506e-05 1.7386024e-05 2.7842490e-05 1.6142623e-05
 1.7020551e-05 5.7293695e-05 9.9977750e-01 2.1364667e-05 1.8681461e-05]

You can see that the predictions are almost exactly the same, and only differ because of slight numerical differences inside the network operation.

To interpret the model a little more, we can plot the attention maps inside the model. This will give us an idea of what information the model is sharing/communicating between images, and what each head might represent. First, we need to extract the attention maps for the test batch above, and determine the discrete predictions for simplicity.

[39]:
attention_maps = anomaly_model.get_attention_maps(inp_data, add_positional_encoding=False)
predictions = preds.argmax(dim=-1)

Below we write a plot function which plots the images in the input set, the prediction of the model, and the attention maps of the different heads on layers of the transformer. Feel free to explore the attention maps for different input examples as well.

[40]:
def visualize_prediction(idx):
    visualize_exmp(indices[idx : idx + 1], test_set)
    print("Prediction:", predictions[idx].item())
    plot_attention_maps(input_data=None, attn_maps=attention_maps, idx=idx)


visualize_prediction(0)
_images/notebooks_course_UvA-DL_05-transformers-and-MH-attention_86_0.svg
Prediction: 9
_images/notebooks_course_UvA-DL_05-transformers-and-MH-attention_86_2.svg

Depending on the random seed, you might see a slightly different input set. For the version on the website, we compare 9 tree images with a volcano. We see that multiple heads, for instance, Layer 2 Head 1, Layer 2 Head 3, and Layer 3 Head 1 focus on the last image. Additionally, the heads in Layer 4 all seem to ignore the last image and assign a very low attention probability to it. This shows that the model has indeed recognized that the image doesn’t fit the setting, and hence predicted it to be the anomaly. Layer 3 Head 2-4 seems to take a slightly weighted average of all images. That might indicate that the model extracts the “average” information of all images, to compare it to the image features itself.

Let’s try to find where the model actually makes a mistake. We can do this by identifying the sets where the model predicts something else than 9, as in the dataset, we ensured that the anomaly is always at the last position in the set.

[41]:
mistakes = torch.where(predictions != 9)[0].cpu().numpy()
print("Indices with mistake:", mistakes)
Indices with mistake: [49]

As our model achieves ~94% accuracy, we only have very little number of mistakes in a batch of 64 sets. Still, let’s visualize one of them, for example the last one:

[42]:
visualize_prediction(mistakes[-1])
print("Probabilities:")
for i, p in enumerate(preds[mistakes[-1]].cpu().numpy()):
    print("Image %i: %4.2f%%" % (i, 100.0 * p))
_images/notebooks_course_UvA-DL_05-transformers-and-MH-attention_90_0.svg
Prediction: 7
_images/notebooks_course_UvA-DL_05-transformers-and-MH-attention_90_2.svg
Probabilities:
Image 0: 0.07%
Image 1: 0.11%
Image 2: 0.07%
Image 3: 0.11%
Image 4: 0.17%
Image 5: 23.27%
Image 6: 0.16%
Image 7: 48.91%
Image 8: 0.10%
Image 9: 27.03%

In this example, the model confuses a palm tree with a building, giving a probability of ~90% to image 2, and 8% to the actual anomaly. However, the difficulty here is that the picture of the building has been taken at a similar angle as the palms. Meanwhile, image 2 shows a rather unusual palm with a different color palette, which is why the model fails here. Nevertheless, in general, the model performs quite well.

Conclusion

In this tutorial, we took a closer look at the Multi-Head Attention layer which uses a scaled dot product between queries and keys to find correlations and similarities between input elements. The Transformer architecture is based on the Multi-Head Attention layer and applies multiple of them in a ResNet-like block. The Transformer is a very important, recent architecture that can be applied to many tasks and datasets. Although it is best known for its success in NLP, there is so much more to it. We have seen its application on sequence-to-sequence tasks and set anomaly detection. Its property of being permutation-equivariant if we do not provide any positional encodings, allows it to generalize to many settings. Hence, it is important to know the architecture, but also its possible issues such as the gradient problem during the first iterations solved by learning rate warm-up. If you are interested in continuing with the study of the Transformer architecture, please have a look at the blog posts listed at the beginning of the tutorial notebook.

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Tutorial 6: Basics of Graph Neural Networks

  • Author: Phillip Lippe

  • License: CC BY-SA

  • Generated: 2022-05-12T13:44:19.646158

In this tutorial, we will discuss the application of neural networks on graphs. Graph Neural Networks (GNNs) have recently gained increasing popularity in both applications and research, including domains such as social networks, knowledge graphs, recommender systems, and bioinformatics. While the theory and math behind GNNs might first seem complicated, the implementation of those models is quite simple and helps in understanding the methodology. Therefore, we will discuss the implementation of basic network layers of a GNN, namely graph convolutions, and attention layers. Finally, we will apply a GNN on semi-supervised node classification and molecule categorization. This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "torch-scatter" "ipython[notebook]" "torchmetrics>=0.7" "torch>=1.8" "torch-cluster" "torch-spline-conv" "setuptools==59.5.0" "torch-sparse<0.6.13" "torch-geometric==2.0.2" "pytorch-lightning>=1.4"

We start by importing our standard libraries below.

[2]:
# Standard libraries
import os

# For downloading pre-trained models
import urllib.request
from urllib.error import HTTPError

# PyTorch Lightning
import pytorch_lightning as pl

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# PyTorch geometric
import torch_geometric
import torch_geometric.data as geom_data
import torch_geometric.nn as geom_nn

# PL callbacks
from pytorch_lightning.callbacks import ModelCheckpoint
from torch import Tensor

AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64
# Path to the folder where the datasets are/should be downloaded
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/GNNs/")

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False
/usr/lib/python3.8/site-packages/apex/pyprof/__init__.py:5: FutureWarning: pyprof will be removed by the end of June, 2022
  warnings.warn("pyprof will be removed by the end of June, 2022", FutureWarning)
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  return self._float_to_str(self.smallest_subnormal)
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  return self._float_to_str(self.smallest_subnormal)
WARNING:root:Bagua cannot detect bundled NCCL library, Bagua will try to use system NCCL instead. If you encounter any error, please run `import bagua_core; bagua_core.install_deps()` or the `bagua_install_deps.py` script to install bundled libraries.
Global seed set to 42

We also have a few pre-trained models we download below.

[3]:
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/"
# Files to download
pretrained_files = ["NodeLevelMLP.ckpt", "NodeLevelGNN.ckpt", "GraphLevelGraphConv.ckpt"]

# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print("Downloading %s..." % file_url)
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file from the GDrive folder,"
                " or contact the author with the full output including the following error:\n",
                e,
            )
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/NodeLevelMLP.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/NodeLevelGNN.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/GraphLevelGraphConv.ckpt...

Graph Neural Networks

Graph representation

Before starting the discussion of specific neural network operations on graphs, we should consider how to represent a graph. Mathematically, a graph \mathcal{G} is defined as a tuple of a set of nodes/vertices V, and a set of edges/links E: \mathcal{G}=(V,E). Each edge is a pair of two vertices, and represents a connection between them. For instance, let’s look at the following graph:

a262a37fff7b49d8abc7fcf01e194fd8

The vertices are V=\{1,2,3,4\}, and edges E=\{(1,2), (2,3), (2,4), (3,4)\}. Note that for simplicity, we assume the graph to be undirected and hence don’t add mirrored pairs like (2,1). In application, vertices and edge can often have specific attributes, and edges can even be directed. The question is how we could represent this diversity in an efficient way for matrix operations. Usually, for the edges, we decide between two variants: an adjacency matrix, or a list of paired vertex indices.

The adjacency matrix A is a square matrix whose elements indicate whether pairs of vertices are adjacent, i.e. connected, or not. In the simplest case, A_{ij} is 1 if there is a connection from node i to j, and otherwise 0. If we have edge attributes or different categories of edges in a graph, this information can be added to the matrix as well. For an undirected graph, keep in mind that A is a symmetric matrix (A_{ij}=A_{ji}). For the example graph above, we have the following adjacency matrix:

A = \begin{bmatrix}
    0 & 1 & 0 & 0\\
    1 & 0 & 1 & 1\\
    0 & 1 & 0 & 1\\
    0 & 1 & 1 & 0
\end{bmatrix}

While expressing a graph as a list of edges is more efficient in terms of memory and (possibly) computation, using an adjacency matrix is more intuitive and simpler to implement. In our implementations below, we will rely on the adjacency matrix to keep the code simple. However, common libraries use edge lists, which we will discuss later more. Alternatively, we could also use the list of edges to define a sparse adjacency matrix with which we can work as if it was a dense matrix, but allows more memory-efficient operations. PyTorch supports this with the sub-package torch.sparse (documentation) which is however still in a beta-stage (API might change in future).

Graph Convolutions

Graph Convolutional Networks have been introduced by Kipf et al.  in 2016 at the University of Amsterdam. He also wrote a great blog post about this topic, which is recommended if you want to read about GCNs from a different perspective. GCNs are similar to convolutions in images in the sense that the “filter” parameters are typically shared over all locations in the graph. At the same time, GCNs rely on message passing methods, which means that vertices exchange information with the neighbors, and send “messages” to each other. Before looking at the math, we can try to visually understand how GCNs work. The first step is that each node creates a feature vector that represents the message it wants to send to all its neighbors. In the second step, the messages are sent to the neighbors, so that a node receives one message per adjacent node. Below we have visualized the two steps for our example graph.

8bf69da731624fc3804e66bab6fa4598

If we want to formulate that in more mathematical terms, we need to first decide how to combine all the messages a node receives. As the number of messages vary across nodes, we need an operation that works for any number. Hence, the usual way to go is to sum or take the mean. Given the previous features of nodes H^{(l)}, the GCN layer is defined as follows:

H^{(l+1)} = \sigma\left(\hat{D}^{-1/2}\hat{A}\hat{D}^{-1/2}H^{(l)}W^{(l)}\right)

W^{(l)} is the weight parameters with which we transform the input features into messages (H^{(l)}W^{(l)}). To the adjacency matrix A we add the identity matrix so that each node sends its own message also to itself: \hat{A}=A+I. Finally, to take the average instead of summing, we calculate the matrix \hat{D} which is a diagonal matrix with D_{ii} denoting the number of neighbors node i has. \sigma represents an arbitrary activation function, and not necessarily the sigmoid (usually a ReLU-based activation function is used in GNNs).

When implementing the GCN layer in PyTorch, we can take advantage of the flexible operations on tensors. Instead of defining a matrix \hat{D}, we can simply divide the summed messages by the number of neighbors afterward. Additionally, we replace the weight matrix with a linear layer, which additionally allows us to add a bias. Written as a PyTorch module, the GCN layer is defined as follows:

[4]:
class GCNLayer(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.projection = nn.Linear(c_in, c_out)

    def forward(self, node_feats, adj_matrix):
        """
        Args:
            node_feats: Tensor with node features of shape [batch_size, num_nodes, c_in]
            adj_matrix: Batch of adjacency matrices of the graph. If there is an edge from i to j,
                         adj_matrix[b,i,j]=1 else 0. Supports directed edges by non-symmetric matrices.
                         Assumes to already have added the identity connections.
                         Shape: [batch_size, num_nodes, num_nodes]
        """
        # Num neighbours = number of incoming edges
        num_neighbours = adj_matrix.sum(dim=-1, keepdims=True)
        node_feats = self.projection(node_feats)
        node_feats = torch.bmm(adj_matrix, node_feats)
        node_feats = node_feats / num_neighbours
        return node_feats

To further understand the GCN layer, we can apply it to our example graph above. First, let’s specify some node features and the adjacency matrix with added self-connections:

[5]:
node_feats = torch.arange(8, dtype=torch.float32).view(1, 4, 2)
adj_matrix = Tensor([[[1, 1, 0, 0], [1, 1, 1, 1], [0, 1, 1, 1], [0, 1, 1, 1]]])

print("Node features:\n", node_feats)
print("\nAdjacency matrix:\n", adj_matrix)
Node features:
 tensor([[[0., 1.],
         [2., 3.],
         [4., 5.],
         [6., 7.]]])

Adjacency matrix:
 tensor([[[1., 1., 0., 0.],
         [1., 1., 1., 1.],
         [0., 1., 1., 1.],
         [0., 1., 1., 1.]]])

Next, let’s apply a GCN layer to it. For simplicity, we initialize the linear weight matrix as an identity matrix so that the input features are equal to the messages. This makes it easier for us to verify the message passing operation.

[6]:
layer = GCNLayer(c_in=2, c_out=2)
layer.projection.weight.data = Tensor([[1.0, 0.0], [0.0, 1.0]])
layer.projection.bias.data = Tensor([0.0, 0.0])

with torch.no_grad():
    out_feats = layer(node_feats, adj_matrix)

print("Adjacency matrix", adj_matrix)
print("Input features", node_feats)
print("Output features", out_feats)
Adjacency matrix tensor([[[1., 1., 0., 0.],
         [1., 1., 1., 1.],
         [0., 1., 1., 1.],
         [0., 1., 1., 1.]]])
Input features tensor([[[0., 1.],
         [2., 3.],
         [4., 5.],
         [6., 7.]]])
Output features tensor([[[1., 2.],
         [3., 4.],
         [4., 5.],
         [4., 5.]]])

As we can see, the first node’s output values are the average of itself and the second node. Similarly, we can verify all other nodes. However, in a GNN, we would also want to allow feature exchange between nodes beyond its neighbors. This can be achieved by applying multiple GCN layers, which gives us the final layout of a GNN. The GNN can be build up by a sequence of GCN layers and non-linearities such as ReLU. For a visualization, see below (figure credit - Thomas Kipf, 2016).

feddde85dc0445f987400dec35f583ca

However, one issue we can see from looking at the example above is that the output features for nodes 3 and 4 are the same because they have the same adjacent nodes (including itself). Therefore, GCN layers can make the network forget node-specific information if we just take a mean over all messages. Multiple possible improvements have been proposed. While the simplest option might be using residual connections, the more common approach is to either weigh the self-connections higher or define a separate weight matrix for the self-connections. Alternatively, we can use a well-known concept: attention.

Graph Attention

Attention describes a weighted average of multiple elements with the weights dynamically computed based on an input query and elements’ keys (if you don’t know what attention is, it is recommended to at least go through the very first section called What is Attention?). This concept can be similarly applied to graphs, one of such is the Graph Attention Network (called GAT, proposed by Velickovic et al., 2017). Similarly to the GCN, the graph attention layer creates a message for each node using a linear layer/weight matrix. For the attention part, it uses the message from the node itself as a query, and the messages to average as both keys and values (note that this also includes the message to itself). The score function f_{attn} is implemented as a one-layer MLP which maps the query and key to a single value. The MLP looks as follows (figure credit - Velickovic et al. ):

e57775561eaf448d8e4370f00f09276b

h_i and h_j are the original features from node i and j respectively, and represent the messages of the layer with \mathbf{W} as weight matrix. \mathbf{a} is the weight matrix of the MLP, which has the shape [1,2\times d_{\text{message}}], and \alpha_{ij} the final attention weight from node i to j. The calculation can be described as follows:

\alpha_{ij} = \frac{\exp\left(\text{LeakyReLU}\left(\mathbf{a}\left[\mathbf{W}h_i||\mathbf{W}h_j\right]\right)\right)}{\sum_{k\in\mathcal{N}_i} \exp\left(\text{LeakyReLU}\left(\mathbf{a}\left[\mathbf{W}h_i||\mathbf{W}h_k\right]\right)\right)}

The operator || represents the concatenation, and \mathcal{N}_i the indices of the neighbors of node i. Note that in contrast to usual practice, we apply a non-linearity (here LeakyReLU) before the softmax over elements. Although it seems like a minor change at first, it is crucial for the attention to depend on the original input. Specifically, let’s remove the non-linearity for a second, and try to simplify the expression:

\begin{split}
    \alpha_{ij} & = \frac{\exp\left(\mathbf{a}\left[\mathbf{W}h_i||\mathbf{W}h_j\right]\right)}{\sum_{k\in\mathcal{N}_i} \exp\left(\mathbf{a}\left[\mathbf{W}h_i||\mathbf{W}h_k\right]\right)}\\[5pt]
    & = \frac{\exp\left(\mathbf{a}_{:,:d/2}\mathbf{W}h_i+\mathbf{a}_{:,d/2:}\mathbf{W}h_j\right)}{\sum_{k\in\mathcal{N}_i} \exp\left(\mathbf{a}_{:,:d/2}\mathbf{W}h_i+\mathbf{a}_{:,d/2:}\mathbf{W}h_k\right)}\\[5pt]
    & = \frac{\exp\left(\mathbf{a}_{:,:d/2}\mathbf{W}h_i\right)\cdot\exp\left(\mathbf{a}_{:,d/2:}\mathbf{W}h_j\right)}{\sum_{k\in\mathcal{N}_i} \exp\left(\mathbf{a}_{:,:d/2}\mathbf{W}h_i\right)\cdot\exp\left(\mathbf{a}_{:,d/2:}\mathbf{W}h_k\right)}\\[5pt]
    & = \frac{\exp\left(\mathbf{a}_{:,d/2:}\mathbf{W}h_j\right)}{\sum_{k\in\mathcal{N}_i} \exp\left(\mathbf{a}_{:,d/2:}\mathbf{W}h_k\right)}\\
\end{split}

We can see that without the non-linearity, the attention term with h_i actually cancels itself out, resulting in the attention being independent of the node itself. Hence, we would have the same issue as the GCN of creating the same output features for nodes with the same neighbors. This is why the LeakyReLU is crucial and adds some dependency on h_i to the attention.

Once we obtain all attention factors, we can calculate the output features for each node by performing the weighted average:

h_i'=\sigma\left(\sum_{j\in\mathcal{N}_i}\alpha_{ij}\mathbf{W}h_j\right)

\sigma is yet another non-linearity, as in the GCN layer. Visually, we can represent the full message passing in an attention layer as follows (figure credit - Velickovic et al. ):

684b22027a144c77b113ce122a1cb131

To increase the expressiveness of the graph attention network, Velickovic et al.  proposed to extend it to multiple heads similar to the Multi-Head Attention block in Transformers. This results in N attention layers being applied in parallel. In the image above, it is visualized as three different colors of arrows (green, blue, and purple) that are afterward concatenated. The average is only applied for the very final prediction layer in a network.

After having discussed the graph attention layer in detail, we can implement it below:

[7]:
class GATLayer(nn.Module):
    def __init__(self, c_in, c_out, num_heads=1, concat_heads=True, alpha=0.2):
        """
        Args:
            c_in: Dimensionality of input features
            c_out: Dimensionality of output features
            num_heads: Number of heads, i.e. attention mechanisms to apply in parallel. The
                        output features are equally split up over the heads if concat_heads=True.
            concat_heads: If True, the output of the different heads is concatenated instead of averaged.
            alpha: Negative slope of the LeakyReLU activation.
        """
        super().__init__()
        self.num_heads = num_heads
        self.concat_heads = concat_heads
        if self.concat_heads:
            assert c_out % num_heads == 0, "Number of output features must be a multiple of the count of heads."
            c_out = c_out // num_heads

        # Sub-modules and parameters needed in the layer
        self.projection = nn.Linear(c_in, c_out * num_heads)
        self.a = nn.Parameter(Tensor(num_heads, 2 * c_out))  # One per head
        self.leakyrelu = nn.LeakyReLU(alpha)

        # Initialization from the original implementation
        nn.init.xavier_uniform_(self.projection.weight.data, gain=1.414)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

    def forward(self, node_feats, adj_matrix, print_attn_probs=False):
        """
        Args:
            node_feats: Input features of the node. Shape: [batch_size, c_in]
            adj_matrix: Adjacency matrix including self-connections. Shape: [batch_size, num_nodes, num_nodes]
            print_attn_probs: If True, the attention weights are printed during the forward pass
                               (for debugging purposes)
        """
        batch_size, num_nodes = node_feats.size(0), node_feats.size(1)

        # Apply linear layer and sort nodes by head
        node_feats = self.projection(node_feats)
        node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)

        # We need to calculate the attention logits for every edge in the adjacency matrix
        # Doing this on all possible combinations of nodes is very expensive
        # => Create a tensor of [W*h_i||W*h_j] with i and j being the indices of all edges
        # Returns indices where the adjacency matrix is not 0 => edges
        edges = adj_matrix.nonzero(as_tuple=False)
        node_feats_flat = node_feats.view(batch_size * num_nodes, self.num_heads, -1)
        edge_indices_row = edges[:, 0] * num_nodes + edges[:, 1]
        edge_indices_col = edges[:, 0] * num_nodes + edges[:, 2]
        a_input = torch.cat(
            [
                torch.index_select(input=node_feats_flat, index=edge_indices_row, dim=0),
                torch.index_select(input=node_feats_flat, index=edge_indices_col, dim=0),
            ],
            dim=-1,
        )  # Index select returns a tensor with node_feats_flat being indexed at the desired positions

        # Calculate attention MLP output (independent for each head)
        attn_logits = torch.einsum("bhc,hc->bh", a_input, self.a)
        attn_logits = self.leakyrelu(attn_logits)

        # Map list of attention values back into a matrix
        attn_matrix = attn_logits.new_zeros(adj_matrix.shape + (self.num_heads,)).fill_(-9e15)
        attn_matrix[adj_matrix[..., None].repeat(1, 1, 1, self.num_heads) == 1] = attn_logits.reshape(-1)

        # Weighted average of attention
        attn_probs = F.softmax(attn_matrix, dim=2)
        if print_attn_probs:
            print("Attention probs\n", attn_probs.permute(0, 3, 1, 2))
        node_feats = torch.einsum("bijh,bjhc->bihc", attn_probs, node_feats)

        # If heads should be concatenated, we can do this by reshaping. Otherwise, take mean
        if self.concat_heads:
            node_feats = node_feats.reshape(batch_size, num_nodes, -1)
        else:
            node_feats = node_feats.mean(dim=2)

        return node_feats

Again, we can apply the graph attention layer on our example graph above to understand the dynamics better. As before, the input layer is initialized as an identity matrix, but we set \mathbf{a} to be a vector of arbitrary numbers to obtain different attention values. We use two heads to show the parallel, independent attention mechanisms working in the layer.

[8]:
layer = GATLayer(2, 2, num_heads=2)
layer.projection.weight.data = Tensor([[1.0, 0.0], [0.0, 1.0]])
layer.projection.bias.data = Tensor([0.0, 0.0])
layer.a.data = Tensor([[-0.2, 0.3], [0.1, -0.1]])

with torch.no_grad():
    out_feats = layer(node_feats, adj_matrix, print_attn_probs=True)

print("Adjacency matrix", adj_matrix)
print("Input features", node_feats)
print("Output features", out_feats)
Attention probs
 tensor([[[[0.3543, 0.6457, 0.0000, 0.0000],
          [0.1096, 0.1450, 0.2642, 0.4813],
          [0.0000, 0.1858, 0.2885, 0.5257],
          [0.0000, 0.2391, 0.2696, 0.4913]],

         [[0.5100, 0.4900, 0.0000, 0.0000],
          [0.2975, 0.2436, 0.2340, 0.2249],
          [0.0000, 0.3838, 0.3142, 0.3019],
          [0.0000, 0.4018, 0.3289, 0.2693]]]])
Adjacency matrix tensor([[[1., 1., 0., 0.],
         [1., 1., 1., 1.],
         [0., 1., 1., 1.],
         [0., 1., 1., 1.]]])
Input features tensor([[[0., 1.],
         [2., 3.],
         [4., 5.],
         [6., 7.]]])
Output features tensor([[[1.2913, 1.9800],
         [4.2344, 3.7725],
         [4.6798, 4.8362],
         [4.5043, 4.7351]]])

We recommend that you try to calculate the attention matrix at least for one head and one node for yourself. The entries are 0 where there does not exist an edge between i and j. For the others, we see a diverse set of attention probabilities. Moreover, the output features of node 3 and 4 are now different although they have the same neighbors.

PyTorch Geometric

We had mentioned before that implementing graph networks with adjacency matrix is simple and straight-forward but can be computationally expensive for large graphs. Many real-world graphs can reach over 200k nodes, for which adjacency matrix-based implementations fail. There are a lot of optimizations possible when implementing GNNs, and luckily, there exist packages that provide such layers. The most popular packages for PyTorch are PyTorch Geometric and the Deep Graph Library (the latter being actually framework agnostic). Which one to use depends on the project you are planning to do and personal taste. In this tutorial, we will look at PyTorch Geometric as part of the PyTorch family.

PyTorch Geometric provides us a set of common graph layers, including the GCN and GAT layer we implemented above. Additionally, similar to PyTorch’s torchvision, it provides the common graph datasets and transformations on those to simplify training. Compared to our implementation above, PyTorch Geometric uses a list of index pairs to represent the edges. The details of this library will be explored further in our experiments.

In our tasks below, we want to allow us to pick from a multitude of graph layers. Thus, we define again below a dictionary to access those using a string:

[9]:
gnn_layer_by_name = {"GCN": geom_nn.GCNConv, "GAT": geom_nn.GATConv, "GraphConv": geom_nn.GraphConv}

Additionally to GCN and GAT, we added the layer geom_nn.GraphConv (documentation). GraphConv is a GCN with a separate weight matrix for the self-connections. Mathematically, this would be:

\mathbf{x}_i^{(l+1)} = \mathbf{W}^{(l + 1)}_1 \mathbf{x}_i^{(l)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{j \in \mathcal{N}_i} \mathbf{x}_j^{(l)}

In this formula, the neighbor’s messages are added instead of averaged. However, PyTorch Geometric provides the argument aggr to switch between summing, averaging, and max pooling.

Experiments on graph structures

Tasks on graph-structured data can be grouped into three groups: node-level, edge-level and graph-level. The different levels describe on which level we want to perform classification/regression. We will discuss all three types in more detail below.

Node-level tasks: Semi-supervised node classification

Node-level tasks have the goal to classify nodes in a graph. Usually, we have given a single, large graph with >1000 nodes of which a certain amount of nodes are labeled. We learn to classify those labeled examples during training and try to generalize to the unlabeled nodes.

A popular example that we will use in this tutorial is the Cora dataset, a citation network among papers. The Cora consists of 2708 scientific publications with links between each other representing the citation of one paper by another. The task is to classify each publication into one of seven classes. Each publication is represented by a bag-of-words vector. This means that we have a vector of 1433 elements for each publication, where a 1 at feature i indicates that the i-th word of a pre-defined dictionary is in the article. Binary bag-of-words representations are commonly used when we need very simple encodings, and already have an intuition of what words to expect in a network. There exist much better approaches, but we will leave this to the NLP courses to discuss.

We will load the dataset below:

[10]:
cora_dataset = torch_geometric.datasets.Planetoid(root=DATASET_PATH, name="Cora")
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!

Let’s look at how PyTorch Geometric represents the graph data. Note that although we have a single graph, PyTorch Geometric returns a dataset for compatibility to other datasets.

[11]:
cora_dataset[0]
[11]:
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

The graph is represented by a Data object (documentation) which we can access as a standard Python namespace. The edge index tensor is the list of edges in the graph and contains the mirrored version of each edge for undirected graphs. The train_mask, val_mask, and test_mask are boolean masks that indicate which nodes we should use for training, validation, and testing. The x tensor is the feature tensor of our 2708 publications, and y the labels for all nodes.

After having seen the data, we can implement a simple graph neural network. The GNN applies a sequence of graph layers (GCN, GAT, or GraphConv), ReLU as activation function, and dropout for regularization. See below for the specific implementation.

[12]:
class GNNModel(nn.Module):
    def __init__(
        self,
        c_in,
        c_hidden,
        c_out,
        num_layers=2,
        layer_name="GCN",
        dp_rate=0.1,
        **kwargs,
    ):
        """
        Args:
            c_in: Dimension of input features
            c_hidden: Dimension of hidden features
            c_out: Dimension of the output features. Usually number of classes in classification
            num_layers: Number of "hidden" graph layers
            layer_name: String of the graph layer to use
            dp_rate: Dropout rate to apply throughout the network
            kwargs: Additional arguments for the graph layer (e.g. number of heads for GAT)
        """
        super().__init__()
        gnn_layer = gnn_layer_by_name[layer_name]

        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers - 1):
            layers += [
                gnn_layer(in_channels=in_channels, out_channels=out_channels, **kwargs),
                nn.ReLU(inplace=True),
                nn.Dropout(dp_rate),
            ]
            in_channels = c_hidden
        layers += [gnn_layer(in_channels=in_channels, out_channels=c_out, **kwargs)]
        self.layers = nn.ModuleList(layers)

    def forward(self, x, edge_index):
        """
        Args:
            x: Input features per node
            edge_index: List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
        """
        for layer in self.layers:
            # For graph layers, we need to add the "edge_index" tensor as additional input
            # All PyTorch Geometric graph layer inherit the class "MessagePassing", hence
            # we can simply check the class type.
            if isinstance(layer, geom_nn.MessagePassing):
                x = layer(x, edge_index)
            else:
                x = layer(x)
        return x

Good practice in node-level tasks is to create an MLP baseline that is applied to each node independently. This way we can verify whether adding the graph information to the model indeed improves the prediction, or not. It might also be that the features per node are already expressive enough to clearly point towards a specific class. To check this, we implement a simple MLP below.

[13]:
class MLPModel(nn.Module):
    def __init__(self, c_in, c_hidden, c_out, num_layers=2, dp_rate=0.1):
        """
        Args:
            c_in: Dimension of input features
            c_hidden: Dimension of hidden features
            c_out: Dimension of the output features. Usually number of classes in classification
            num_layers: Number of hidden layers
            dp_rate: Dropout rate to apply throughout the network
        """
        super().__init__()
        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers - 1):
            layers += [nn.Linear(in_channels, out_channels), nn.ReLU(inplace=True), nn.Dropout(dp_rate)]
            in_channels = c_hidden
        layers += [nn.Linear(in_channels, c_out)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x, *args, **kwargs):
        """
        Args:
            x: Input features per node
        """
        return self.layers(x)

Finally, we can merge the models into a PyTorch Lightning module which handles the training, validation, and testing for us.

[14]:
class NodeLevelGNN(pl.LightningModule):
    def __init__(self, model_name, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()

        if model_name == "MLP":
            self.model = MLPModel(**model_kwargs)
        else:
            self.model = GNNModel(**model_kwargs)
        self.loss_module = nn.CrossEntropyLoss()

    def forward(self, data, mode="train"):
        x, edge_index = data.x, data.edge_index
        x = self.model(x, edge_index)

        # Only calculate the loss on the nodes corresponding to the mask
        if mode == "train":
            mask = data.train_mask
        elif mode == "val":
            mask = data.val_mask
        elif mode == "test":
            mask = data.test_mask
        else:
            assert False, "Unknown forward mode: %s" % mode

        loss = self.loss_module(x[mask], data.y[mask])
        acc = (x[mask].argmax(dim=-1) == data.y[mask]).sum().float() / mask.sum()
        return loss, acc

    def configure_optimizers(self):
        # We use SGD here, but Adam works as well
        optimizer = optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=2e-3)
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, acc = self.forward(batch, mode="train")
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode="val")
        self.log("val_acc", acc)

    def test_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode="test")
        self.log("test_acc", acc)

Additionally to the Lightning module, we define a training function below. As we have a single graph, we use a batch size of 1 for the data loader and share the same data loader for the train, validation, and test set (the mask is picked inside the Lightning module). Besides, we set the argument progress_bar_refresh_rate to zero as it usually shows the progress per epoch, but an epoch only consists of a single step. If you have downloaded the pre-trained models in the beginning of the tutorial, we load those instead of training from scratch. Finally, we test the model and return the results.

[15]:
def train_node_classifier(model_name, dataset, **model_kwargs):
    pl.seed_everything(42)
    node_data_loader = geom_data.DataLoader(dataset, batch_size=1)

    # Create a PyTorch Lightning trainer
    root_dir = os.path.join(CHECKPOINT_PATH, "NodeLevel" + model_name)
    os.makedirs(root_dir, exist_ok=True)
    trainer = pl.Trainer(
        default_root_dir=root_dir,
        callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc")],
        gpus=AVAIL_GPUS,
        max_epochs=200,
        progress_bar_refresh_rate=0,
    )  # 0 because epoch size is 1
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "NodeLevel%s.ckpt" % model_name)
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = NodeLevelGNN.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything()
        model = NodeLevelGNN(
            model_name=model_name, c_in=dataset.num_node_features, c_out=dataset.num_classes, **model_kwargs
        )
        trainer.fit(model, node_data_loader, node_data_loader)
        model = NodeLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    # Test best model on the test set
    test_result = trainer.test(model, dataloaders=node_data_loader, verbose=False)
    batch = next(iter(node_data_loader))
    batch = batch.to(model.device)
    _, train_acc = model.forward(batch, mode="train")
    _, val_acc = model.forward(batch, mode="val")
    result = {"train": train_acc, "val": val_acc, "test": test_result[0]["test_acc"]}
    return model, result

Now, we can train our models. First, let’s train the simple MLP:

[16]:
# Small function for printing the test scores
def print_results(result_dict):
    if "train" in result_dict:
        print("Train accuracy: %4.2f%%" % (100.0 * result_dict["train"]))
    if "val" in result_dict:
        print("Val accuracy:   %4.2f%%" % (100.0 * result_dict["val"]))
    print("Test accuracy:  %4.2f%%" % (100.0 * result_dict["test"]))
[17]:
node_mlp_model, node_mlp_result = train_node_classifier(
    model_name="MLP", dataset=cora_dataset, c_hidden=16, num_layers=2, dp_rate=0.1
)

print_results(node_mlp_result)
Global seed set to 42
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead
  warnings.warn(out)
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:96: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=0)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/GNNs/NodeLevelMLP/lightning_logs
Found pretrained model, loading...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Train accuracy: 97.14%
Val accuracy:   54.60%
Test accuracy:  60.60%
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:72: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2708. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
  warning_cache.warn(

Although the MLP can overfit on the training dataset because of the high-dimensional input features, it does not perform too well on the test set. Let’s see if we can beat this score with our graph networks:

[18]:
node_gnn_model, node_gnn_result = train_node_classifier(
    model_name="GNN", layer_name="GCN", dataset=cora_dataset, c_hidden=16, num_layers=2, dp_rate=0.1
)
print_results(node_gnn_result)
Global seed set to 42
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/GNNs/NodeLevelGNN/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model, loading...
Train accuracy: 100.00%
Val accuracy:   78.60%
Test accuracy:  82.40%

As we would have hoped for, the GNN model outperforms the MLP by quite a margin. This shows that using the graph information indeed improves our predictions and lets us generalizes better.

The hyperparameters in the model have been chosen to create a relatively small network. This is because the first layer with an input dimension of 1433 can be relatively expensive to perform for large graphs. In general, GNNs can become relatively expensive for very big graphs. This is why such GNNs either have a small hidden size or use a special batching strategy where we sample a connected subgraph of the big, original graph.

Graph-level tasks: Graph classification

Finally, in this part of the tutorial, we will have a closer look at how to apply GNNs to the task of graph classification. The goal is to classify an entire graph instead of single nodes or edges. Therefore, we are also given a dataset of multiple graphs that we need to classify based on some structural graph properties. The most common task for graph classification is molecular property prediction, in which molecules are represented as graphs. Each atom is linked to a node, and edges in the graph are the bonds between atoms. For example, look at the figure below.

9570f92a1b6549b2bd71619dfb9b577e

On the left, we have an arbitrary, small molecule with different atoms, whereas the right part of the image shows the graph representation. The atom types are abstracted as node features (e.g. a one-hot vector), and the different bond types are used as edge features. For simplicity, we will neglect the edge attributes in this tutorial, but you can include by using methods like the Relational Graph Convolution that uses a different weight matrix for each edge type.

The dataset we will use below is called the MUTAG dataset. It is a common small benchmark for graph classification algorithms, and contain 188 graphs with 18 nodes and 20 edges on average for each graph. The graph nodes have 7 different labels/atom types, and the binary graph labels represent “their mutagenic effect on a specific gram negative bacterium” (the specific meaning of the labels are not too important here). The dataset is part of a large collection of different graph classification datasets, known as the TUDatasets, which is directly accessible via torch_geometric.datasets.TUDataset (documentation) in PyTorch Geometric. We can load the dataset below.

[19]:
tu_dataset = torch_geometric.datasets.TUDataset(root=DATASET_PATH, name="MUTAG")
Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip
Extracting /__w/1/s/.datasets/MUTAG/MUTAG.zip
Processing...
Done!

Let’s look at some statistics for the dataset:

[20]:
print("Data object:", tu_dataset.data)
print("Length:", len(tu_dataset))
print("Average label: %4.2f" % (tu_dataset.data.y.float().mean().item()))
Data object: Data(x=[3371, 7], edge_index=[2, 7442], edge_attr=[7442, 4], y=[188])
Length: 188
Average label: 0.66

The first line shows how the dataset stores different graphs. The nodes, edges, and labels of each graph are concatenated to one tensor, and the dataset stores the indices where to split the tensors correspondingly. The length of the dataset is the number of graphs we have, and the “average label” denotes the percentage of the graph with label 1. As long as the percentage is in the range of 0.5, we have a relatively balanced dataset. It happens quite often that graph datasets are very imbalanced, hence checking the class balance is always a good thing to do.

Next, we will split our dataset into a training and test part. Note that we do not use a validation set this time because of the small size of the dataset. Therefore, our model might overfit slightly on the validation set due to the noise of the evaluation, but we still get an estimate of the performance on untrained data.

[21]:
torch.manual_seed(42)
tu_dataset.shuffle()
train_dataset = tu_dataset[:150]
test_dataset = tu_dataset[150:]

When using a data loader, we encounter a problem with batching N graphs. Each graph in the batch can have a different number of nodes and edges, and hence we would require a lot of padding to obtain a single tensor. Torch geometric uses a different, more efficient approach: we can view the N graphs in a batch as a single large graph with concatenated node and edge list. As there is no edge between the N graphs, running GNN layers on the large graph gives us the same output as running the GNN on each graph separately. Visually, this batching strategy is visualized below (figure credit - PyTorch Geometric team, tutorial here).

a493e65d28ac4d2abf0935515aec8da5

The adjacency matrix is zero for any nodes that come from two different graphs, and otherwise according to the adjacency matrix of the individual graph. Luckily, this strategy is already implemented in torch geometric, and hence we can use the corresponding data loader:

[22]:
graph_train_loader = geom_data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
graph_val_loader = geom_data.DataLoader(test_dataset, batch_size=BATCH_SIZE)  # Additional loader for a larger datasets
graph_test_loader = geom_data.DataLoader(test_dataset, batch_size=BATCH_SIZE)

Let’s load a batch below to see the batching in action:

[23]:
batch = next(iter(graph_test_loader))
print("Batch:", batch)
print("Labels:", batch.y[:10])
print("Batch indices:", batch.batch[:40])
Batch: DataBatch(edge_index=[2, 1512], x=[687, 7], edge_attr=[1512, 4], y=[38], batch=[687], ptr=[39])
Labels: tensor([1, 1, 1, 0, 0, 0, 1, 1, 1, 0])
Batch indices: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2])

We have 38 graphs stacked together for the test dataset. The batch indices, stored in batch, show that the first 12 nodes belong to the first graph, the next 22 to the second graph, and so on. These indices are important for performing the final prediction. To perform a prediction over a whole graph, we usually perform a pooling operation over all nodes after running the GNN model. In this case, we will use the average pooling. Hence, we need to know which nodes should be included in which average pool. Using this pooling, we can already create our graph network below. Specifically, we re-use our class GNNModel from before, and simply add an average pool and single linear layer for the graph prediction task.

[24]:
class GraphGNNModel(nn.Module):
    def __init__(self, c_in, c_hidden, c_out, dp_rate_linear=0.5, **kwargs):
        """
        Args:
            c_in: Dimension of input features
            c_hidden: Dimension of hidden features
            c_out: Dimension of output features (usually number of classes)
            dp_rate_linear: Dropout rate before the linear layer (usually much higher than inside the GNN)
            kwargs: Additional arguments for the GNNModel object
        """
        super().__init__()
        self.GNN = GNNModel(c_in=c_in, c_hidden=c_hidden, c_out=c_hidden, **kwargs)  # Not our prediction output yet!
        self.head = nn.Sequential(nn.Dropout(dp_rate_linear), nn.Linear(c_hidden, c_out))

    def forward(self, x, edge_index, batch_idx):
        """
        Args:
            x: Input features per node
            edge_index: List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
            batch_idx: Index of batch element for each node
        """
        x = self.GNN(x, edge_index)
        x = geom_nn.global_mean_pool(x, batch_idx)  # Average pooling
        x = self.head(x)
        return x

Finally, we can create a PyTorch Lightning module to handle the training. It is similar to the modules we have seen before and does nothing surprising in terms of training. As we have a binary classification task, we use the Binary Cross Entropy loss.

[25]:
class GraphLevelGNN(pl.LightningModule):
    def __init__(self, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()

        self.model = GraphGNNModel(**model_kwargs)
        self.loss_module = nn.BCEWithLogitsLoss() if self.hparams.c_out == 1 else nn.CrossEntropyLoss()

    def forward(self, data, mode="train"):
        x, edge_index, batch_idx = data.x, data.edge_index, data.batch
        x = self.model(x, edge_index, batch_idx)
        x = x.squeeze(dim=-1)

        if self.hparams.c_out == 1:
            preds = (x > 0).float()
            data.y = data.y.float()
        else:
            preds = x.argmax(dim=-1)
        loss = self.loss_module(x, data.y)
        acc = (preds == data.y).sum().float() / preds.shape[0]
        return loss, acc

    def configure_optimizers(self):
        # High lr because of small dataset and small model
        optimizer = optim.AdamW(self.parameters(), lr=1e-2, weight_decay=0.0)
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, acc = self.forward(batch, mode="train")
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode="val")
        self.log("val_acc", acc)

    def test_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode="test")
        self.log("test_acc", acc)

Below we train the model on our dataset. It resembles the typical training functions we have seen so far.

[26]:
def train_graph_classifier(model_name, **model_kwargs):
    pl.seed_everything(42)

    # Create a PyTorch Lightning trainer with the generation callback
    root_dir = os.path.join(CHECKPOINT_PATH, "GraphLevel" + model_name)
    os.makedirs(root_dir, exist_ok=True)
    trainer = pl.Trainer(
        default_root_dir=root_dir,
        callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc")],
        gpus=AVAIL_GPUS,
        max_epochs=500,
        progress_bar_refresh_rate=0,
    )
    trainer.logger._default_hp_metric = None

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "GraphLevel%s.ckpt" % model_name)
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = GraphLevelGNN.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)
        model = GraphLevelGNN(
            c_in=tu_dataset.num_node_features,
            c_out=1 if tu_dataset.num_classes == 2 else tu_dataset.num_classes,
            **model_kwargs,
        )
        trainer.fit(model, graph_train_loader, graph_val_loader)
        model = GraphLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    # Test best model on validation and test set
    train_result = trainer.test(model, dataloaders=graph_train_loader, verbose=False)
    test_result = trainer.test(model, dataloaders=graph_test_loader, verbose=False)
    result = {"test": test_result[0]["test_acc"], "train": train_result[0]["test_acc"]}
    return model, result

Finally, let’s perform the training and testing. Feel free to experiment with different GNN layers, hyperparameters, etc.

[27]:
model, result = train_graph_classifier(
    model_name="GraphConv", c_hidden=256, layer_name="GraphConv", num_layers=3, dp_rate_linear=0.5, dp_rate=0.0
)
Global seed set to 42
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/GNNs/GraphLevelGraphConv/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:487: PossibleUserWarning: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test/predict dataloaders.
  rank_zero_warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:72: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
  warning_cache.warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model, loading...
[28]:
print("Train performance: %4.2f%%" % (100.0 * result["train"]))
print("Test performance:  %4.2f%%" % (100.0 * result["test"]))
Train performance: 92.67%
Test performance:  92.11%

The test performance shows that we obtain quite good scores on an unseen part of the dataset. It should be noted that as we have been using the test set for validation as well, we might have overfitted slightly to this set. Nevertheless, the experiment shows us that GNNs can be indeed powerful to predict the properties of graphs and/or molecules.

Conclusion

In this tutorial, we have seen the application of neural networks to graph structures. We looked at how a graph can be represented (adjacency matrix or edge list), and discussed the implementation of common graph layers: GCN and GAT. The implementations showed the practical side of the layers, which is often easier than the theory. Finally, we experimented with different tasks, on node-, edge- and graph-level. Overall, we have seen that including graph information in the predictions can be crucial for achieving high performance. There are a lot of applications that benefit from GNNs, and the importance of these networks will likely increase over the next years.

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Tutorial 7: Deep Energy-Based Generative Models

  • Author: Phillip Lippe

  • License: CC BY-SA

  • Generated: 2021-09-16T14:32:29.871712

In this tutorial, we will look at energy-based deep learning models, and focus on their application as generative models. Energy models have been a popular tool before the huge deep learning hype around 2012 hit. However, in recent years, energy-based models have gained increasing attention because of improved training methods and tricks being proposed. Although they are still in a research stage, they have shown to outperform strong Generative Adversarial Networks in certain cases which have been the state of the art of generating images (blog postabout strong energy-based models, blog post about the power of GANs). Hence, it is important to be aware of energy-based models, and as the theory can be abstract sometimes, we will show the idea of energy-based models with a lot of examples. This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
# ! pip install --quiet "torchvision" "torch>=1.6, <1.9" "tensorboard" "matplotlib" "pytorch-lightning>=1.3" "torchmetrics>=0.3"

First, let’s import our standard libraries below.

[2]:
# Standard libraries
import os
import random
import urllib.request
from urllib.error import HTTPError

# Plotting
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

# PyTorch Lightning
import pytorch_lightning as pl

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

# Torchvision
import torchvision

# %matplotlib inline
from IPython.display import set_matplotlib_formats
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torchvision import transforms
from torchvision.datasets import MNIST

set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = os.environ.get("PATH_DATASETS", "data")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/tutorial8")

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
/tmp/ipykernel_1940/3480345581.py:30: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
  set_matplotlib_formats("svg", "pdf")  # For export
Global seed set to 42

We also have pre-trained models that we download below.

[3]:
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial8/"
# Files to download
pretrained_files = ["MNIST.ckpt", "tensorboards/events.out.tfevents.MNIST"]

# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print("Downloading %s..." % file_url)
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the files manually,"
                " or contact the author with the full output including the following error:\n",
                e,
            )
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial8/MNIST.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial8/tensorboards/events.out.tfevents.MNIST...

Energy Models

In the first part of this tutorial, we will review the theory of the energy-based models (the same theory has been discussed in Lecture 8). While most of the previous models had the goal of classification or regression, energy-based models are motivated from a different perspective: density estimation. Given a dataset with a lot of elements, we want to estimate the probability distribution over the whole data space. As an example, if we model images from CIFAR10, our goal would be to have a probability distribution over all possible images of size 32\times32\times3 where those images have a high likelihood that look realistic and are one of the 10 CIFAR classes. Simple methods like interpolation between images don’t work because images are extremely high-dimensional (especially for large HD images). Hence, we turn to deep learning methods that have performed well on complex data.

However, how do we predict a probability distribution p(\mathbf{x}) over so many dimensions using a simple neural network? The problem is that we cannot just predict a score between 0 and 1, because a probability distribution over data needs to fulfill two properties:

  1. The probability distribution needs to assign any possible value of \mathbf{x} a non-negative value: p(\mathbf{x}) \geq 0.

  2. The probability density must sum/integrate to 1 over all possible inputs: \int_{\mathbf{x}} p(\mathbf{x}) d\mathbf{x} = 1.

Luckily, there are actually many approaches for this, and one of them are energy-based models. The fundamental idea of energy-based models is that you can turn any function that predicts values larger than zero into a probability distribution by dviding by its volume. Imagine we have a neural network, which has as output a single neuron, like in regression. We can call this network E_{\theta}(\mathbf{x}), where \theta are our parameters of the network, and \mathbf{x} the input data (e.g. an image). The output of E_{\theta} is a scalar value between -\infty and \infty. Now, we can use basic probability theory to normalize the scores of all possible inputs:

q_{\theta}(\mathbf{x}) = \frac{\exp\left(-E_{\theta}(\mathbf{x})\right)}{Z_{\theta}} \hspace{5mm}\text{where}\hspace{5mm}
Z_{\theta} = \begin{cases}
    \int_{\mathbf{x}}\exp\left(-E_{\theta}(\mathbf{x})\right) d\mathbf{x} & \text{if }x\text{ is continuous}\\
    \sum_{\mathbf{x}}\exp\left(-E_{\theta}(\mathbf{x})\right) & \text{if }x\text{ is discrete}
\end{cases}

The \exp-function ensures that we assign a probability greater than zero to any possible input. We use a negative sign in front of E because we call E_{\theta} to be the energy function: data points with high likelihood have a low energy, while data points with low likelihood have a high energy. Z_{\theta} is our normalization terms that ensures that the density integrates/sums to 1. We can show this by integrating over q_{\theta}(\mathbf{x}):

\int_{\mathbf{x}}q_{\theta}(\mathbf{x})d\mathbf{x} =
\int_{\mathbf{x}}\frac{\exp\left(-E_{\theta}(\mathbf{x})\right)}{\int_{\mathbf{\tilde{x}}}\exp\left(-E_{\theta}(\mathbf{\tilde{x}})\right) d\mathbf{\tilde{x}}}d\mathbf{x} =
\frac{\int_{\mathbf{x}}\exp\left(-E_{\theta}(\mathbf{x})\right)d\mathbf{x}}{\int_{\mathbf{\tilde{x}}}\exp\left(-E_{\theta}(\mathbf{\tilde{x}})\right) d\mathbf{\tilde{x}}} = 1

Note that we call the probability distribution q_{\theta}(\mathbf{x}) because this is the learned distribution by the model, and is trained to be as close as possible to the true, unknown distribution p(\mathbf{x}).

The main benefit of this formulation of the probability distribution is its great flexibility as we can choose E_{\theta} in whatever way we like, without any constraints. Nevertheless, when looking at the equation above, we can see a fundamental issue: How do we calculate Z_{\theta}? There is no chance that we can calculate Z_{\theta} analytically for high-dimensional input and/or larger neural networks, but the task requires us to know Z_{\theta}. Although we can’t determine the exact likelihood of a point, there exist methods with which we can train energy-based models. Thus, we will look next at “Contrastive Divergence” for training the model.

Contrastive Divergence

When we train a model on generative modeling, it is usually done by maximum likelihood estimation. In other words, we try to maximize the likelihood of the examples in the training set. As the exact likelihood of a point cannot be determined due to the unknown normalization constant Z_{\theta}, we need to train energy-based models slightly different. We cannot just maximize the un-normalized probability \exp(-E_{\theta}(\mathbf{x}_{\text{train}})) because there is no guarantee that Z_{\theta} stays constant, or that \mathbf{x}_{\text{train}} is becoming more likely than the others. However, if we base our training on comparing the likelihood of points, we can create a stable objective. Namely, we can re-write our maximum likelihood objective where we maximize the probability of \mathbf{x}_{\text{train}} compared to a randomly sampled data point of our model:

\begin{split}
    \nabla_{\theta}\mathcal{L}_{\text{MLE}}(\mathbf{\theta};p) & = -\mathbb{E}_{p(\mathbf{x})}\left[\nabla_{\theta}\log q_{\theta}(\mathbf{x})\right]\\[5pt]
    & = \mathbb{E}_{p(\mathbf{x})}\left[\nabla_{\theta}E_{\theta}(\mathbf{x})\right] - \mathbb{E}_{q_{\theta}(\mathbf{x})}\left[\nabla_{\theta}E_{\theta}(\mathbf{x})\right]
\end{split}

Note that the loss is still an objective we want to minimize. Thus, we try to minimize the energy for data points from the dataset, while maximizing the energy for randomly sampled data points from our model (how we sample will be explained below). Although this objective sounds intuitive, how is it actually derived from our original distribution q_{\theta}(\mathbf{x})? The trick is that we approximate Z_{\theta} by a single Monte-Carlo sample. This gives us the exact same objective as written above.

Visually, we can look at the objective as follows (figure credit - Stefano Ermon and Aditya Grover):

593262fe19d14328b0af47a955163892

f_{\theta} represents \exp(-E_{\theta}(\mathbf{x})) in our case. The point on the right, called “correct answer”, represents a data point from the dataset (i.e. x_{\text{train}}), and the left point, “wrong answer”, a sample from our model (i.e. x_{\text{sample}}). Thus, we try to “pull up” the probability of the data points in the dataset, while “pushing down” randomly sampled points. The two forces for pulling and pushing are in balance iff q_{\theta}(\mathbf{x})=p(\mathbf{x}).

Sampling from Energy-Based Models

For sampling from an energy-based model, we can apply a Markov Chain Monte Carlo using Langevin Dynamics. The idea of the algorithm is to start from a random point, and slowly move towards the direction of higher probability using the gradients of E_{\theta}. Nevertheless, this is not enough to fully capture the probability distribution. We need to add noise \omega at each gradient step to the current sample. Under certain conditions such as that we perform the gradient steps an infinite amount of times, we would be able to create an exact sample from our modeled distribution. However, as this is not practically possible, we usually limit the chain to K steps (K a hyperparameter that needs to be finetuned). Overall, the sampling procedure can be summarized in the following algorithm:

2ede79c427824212a00120b11aa3cc3e

Applications of Energy-based models beyond generation

Modeling the probability distribution for sampling new data is not the only application of energy-based models. Any application which requires us to compare two elements is much simpler to learn because we just need to go for the higher energy. A couple of examples are shown below (figure credit - Stefano Ermon and Aditya Grover). A classification setup like object recognition or sequence labeling can be considered as an energy-based task as we just need to find the Y input that minimizes the output E(X, Y) (hence maximizes probability). Similarly, a popular application of energy-based models is denoising of images. Given an image X with a lot of noise, we try to minimize the energy by finding the true input image Y.

11c44e64c00448ce98d1ef5327a52abb

Nonetheless, we will focus on generative modeling here as in the next couple of lectures, we will discuss more generative deep learning approaches.

Image generation

As an example for energy-based models, we will train a model on image generation. Specifically, we will look at how we can generate MNIST digits with a very simple CNN model. However, it should be noted that energy models are not easy to train and often diverge if the hyperparameters are not well tuned. We will rely on training tricks proposed in the paper Implicit Generation and Generalization in Energy-Based Models by Yilun Du and Igor Mordatch (blog). The important part of this notebook is however to see how the theory above can actually be used in a model.

Dataset

First, we can load the MNIST dataset below. Note that we need to normalize the images between -1 and 1 instead of mean 0 and std 1 because during sampling, we have to limit the input space. Scaling between -1 and 1 makes it easier to implement it.

[4]:
# Transformations applied on each image => make them a tensor and normalize between -1 and 1
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Loading the training dataset. We need to split it into a training and validation part
train_set = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
# Note that for actually training a model, we will use different data loaders
# with a lower batch size.
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)
test_loader = data.DataLoader(test_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)
CNN Model

First, we implement our CNN model. The MNIST images are of size 28x28, hence we only need a small model. As an example, we will apply several convolutions with stride 2 that downscale the images. If you are interested, you can also use a deeper model such as a small ResNet, but for simplicity, we will stick with the tiny network.

It is a good practice to use a smooth activation function like Swish instead of ReLU in the energy model. This is because we will rely on the gradients we get back with respect to the input image, which should not be sparse.

[5]:
class CNNModel(nn.Module):
    def __init__(self, hidden_features=32, out_dim=1, **kwargs):
        super().__init__()
        # We increase the hidden dimension over layers. Here pre-calculated for simplicity.
        c_hid1 = hidden_features // 2
        c_hid2 = hidden_features
        c_hid3 = hidden_features * 2

        # Series of convolutions and Swish activation functions
        self.cnn_layers = nn.Sequential(
            nn.Conv2d(1, c_hid1, kernel_size=5, stride=2, padding=4),  # [16x16] - Larger padding to get 32x32 image
            nn.SiLU(),
            nn.Conv2d(c_hid1, c_hid2, kernel_size=3, stride=2, padding=1),  # [8x8]
            nn.SiLU(),
            nn.Conv2d(c_hid2, c_hid3, kernel_size=3, stride=2, padding=1),  # [4x4]
            nn.SiLU(),
            nn.Conv2d(c_hid3, c_hid3, kernel_size=3, stride=2, padding=1),  # [2x2]
            nn.SiLU(),
            nn.Flatten(),
            nn.Linear(c_hid3 * 4, c_hid3),
            nn.SiLU(),
            nn.Linear(c_hid3, out_dim),
        )

    def forward(self, x):
        x = self.cnn_layers(x).squeeze(dim=-1)
        return x

In the rest of the notebook, the output of the model will actually not represent E_{\theta}(\mathbf{x}), but -E_{\theta}(\mathbf{x}). This is a standard implementation practice for energy-based models, as some people also write the energy probability density as q_{\theta}(\mathbf{x}) = \frac{\exp\left(f_{\theta}(\mathbf{x})\right)}{Z_{\theta}}. In that case, the model would actually represent f_{\theta}(\mathbf{x}). In the training loss etc., we need to be careful to not switch up the signs.

Sampling buffer

In the next part, we look at the training with sampled elements. To use the contrastive divergence objective, we need to generate samples during training. Previous work has shown that due to the high dimensionality of images, we need a lot of iterations inside the MCMC sampling to obtain reasonable samples. However, there is a training trick that significantly reduces the sampling cost: using a sampling buffer. The idea is that we store the samples of the last couple of batches in a buffer, and re-use those as the starting point of the MCMC algorithm for the next batches. This reduces the sampling cost because the model requires a significantly lower number of steps to converge to reasonable samples. However, to not solely rely on previous samples and allow novel samples as well, we re-initialize 5% of our samples from scratch (random noise between -1 and 1).

Below, we implement the sampling buffer. The function sample_new_exmps returns a new batch of “fake” images. We refer to those as fake images because they have been generated, but are not actually part of the dataset. As mentioned before, we use initialize 5% randomly, and 95% are randomly picked from our buffer. On this initial batch, we perform MCMC for 60 iterations to improve the image quality and come closer to samples from q_{\theta}(\mathbf{x}). In the function generate_samples, we implemented the MCMC for images. Note that the hyperparameters of step_size, steps, the noise standard deviation \sigma are specifically set for MNIST, and need to be finetuned for a different dataset if you want to use such.

[6]:
class Sampler:
    def __init__(self, model, img_shape, sample_size, max_len=8192):
        """
        Args:
            model: Neural network to use for modeling E_theta
            img_shape: Shape of the images to model
            sample_size: Batch size of the samples
            max_len: Maximum number of data points to keep in the buffer
        """
        super().__init__()
        self.model = model
        self.img_shape = img_shape
        self.sample_size = sample_size
        self.max_len = max_len
        self.examples = [(torch.rand((1,) + img_shape) * 2 - 1) for _ in range(self.sample_size)]

    def sample_new_exmps(self, steps=60, step_size=10):
        """Function for getting a new batch of "fake" images.

        Args:
            steps: Number of iterations in the MCMC algorithm
            step_size: Learning rate nu in the algorithm above
        """
        # Choose 95% of the batch from the buffer, 5% generate from scratch
        n_new = np.random.binomial(self.sample_size, 0.05)
        rand_imgs = torch.rand((n_new,) + self.img_shape) * 2 - 1
        old_imgs = torch.cat(random.choices(self.examples, k=self.sample_size - n_new), dim=0)
        inp_imgs = torch.cat([rand_imgs, old_imgs], dim=0).detach().to(device)

        # Perform MCMC sampling
        inp_imgs = Sampler.generate_samples(self.model, inp_imgs, steps=steps, step_size=step_size)

        # Add new images to the buffer and remove old ones if needed
        self.examples = list(inp_imgs.to(torch.device("cpu")).chunk(self.sample_size, dim=0)) + self.examples
        self.examples = self.examples[: self.max_len]
        return inp_imgs

    @staticmethod
    def generate_samples(model, inp_imgs, steps=60, step_size=10, return_img_per_step=False):
        """Function for sampling images for a given model.

        Args:
            model: Neural network to use for modeling E_theta
            inp_imgs: Images to start from for sampling. If you want to generate new images, enter noise between -1 and 1.
            steps: Number of iterations in the MCMC algorithm.
            step_size: Learning rate nu in the algorithm above
            return_img_per_step: If True, we return the sample at every iteration of the MCMC
        """
        # Before MCMC: set model parameters to "required_grad=False"
        # because we are only interested in the gradients of the input.
        is_training = model.training
        model.eval()
        for p in model.parameters():
            p.requires_grad = False
        inp_imgs.requires_grad = True

        # Enable gradient calculation if not already the case
        had_gradients_enabled = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

        # We use a buffer tensor in which we generate noise each loop iteration.
        # More efficient than creating a new tensor every iteration.
        noise = torch.randn(inp_imgs.shape, device=inp_imgs.device)

        # List for storing generations at each step (for later analysis)
        imgs_per_step = []

        # Loop over K (steps)
        for _ in range(steps):
            # Part 1: Add noise to the input.
            noise.normal_(0, 0.005)
            inp_imgs.data.add_(noise.data)
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            # Part 2: calculate gradients for the current input.
            out_imgs = -model(inp_imgs)
            out_imgs.sum().backward()
            inp_imgs.grad.data.clamp_(-0.03, 0.03)  # For stabilizing and preventing too high gradients

            # Apply gradients to our current samples
            inp_imgs.data.add_(-step_size * inp_imgs.grad.data)
            inp_imgs.grad.detach_()
            inp_imgs.grad.zero_()
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            if return_img_per_step:
                imgs_per_step.append(inp_imgs.clone().detach())

        # Reactivate gradients for parameters for training
        for p in model.parameters():
            p.requires_grad = True
        model.train(is_training)

        # Reset gradient calculation to setting before this function
        torch.set_grad_enabled(had_gradients_enabled)

        if return_img_per_step:
            return torch.stack(imgs_per_step, dim=0)
        else:
            return inp_imgs

The idea of the buffer becomes a bit clearer in the following algorithm.

Training algorithm

With the sampling buffer being ready, we can complete our training algorithm. Below is shown a summary of the full training algorithm of an energy model on image modeling:

0bf3ab4728e744e39fd1628f1a5c0fd2

The first few statements in each training iteration concern the sampling of the real and fake data, as we have seen above with the sample buffer. Next, we calculate the contrastive divergence objective using our energy model E_{\theta}. However, one additional training trick we need is to add a regularization loss on the output of E_{\theta}. As the output of the network is not constrained and adding a large bias or not to the output doesn’t change the contrastive divergence loss, we need to ensure somehow else that the output values are in a reasonable range. Without the regularization loss, the output values will fluctuate in a very large range. With this, we ensure that the values for the real data are around 0, and the fake data likely slightly lower (for noise or outliers the score can be still significantly lower). As the regularization loss is less important than the Contrastive Divergence, we have a weight factor \alpha which is usually quite some smaller than 1. Finally, we perform an update step with an optimizer on the combined loss and add the new samples to the buffer.

Below, we put this training dynamic into a PyTorch Lightning module:

[7]:
class DeepEnergyModel(pl.LightningModule):
    def __init__(self, img_shape, batch_size, alpha=0.1, lr=1e-4, beta1=0.0, **CNN_args):
        super().__init__()
        self.save_hyperparameters()

        self.cnn = CNNModel(**CNN_args)
        self.sampler = Sampler(self.cnn, img_shape=img_shape, sample_size=batch_size)
        self.example_input_array = torch.zeros(1, *img_shape)

    def forward(self, x):
        z = self.cnn(x)
        return z

    def configure_optimizers(self):
        # Energy models can have issues with momentum as the loss surfaces changes with its parameters.
        # Hence, we set it to 0 by default.
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, betas=(self.hparams.beta1, 0.999))
        scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.97)  # Exponential decay over epochs
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        # We add minimal noise to the original images to prevent the model from focusing on purely "clean" inputs
        real_imgs, _ = batch
        small_noise = torch.randn_like(real_imgs) * 0.005
        real_imgs.add_(small_noise).clamp_(min=-1.0, max=1.0)

        # Obtain samples
        fake_imgs = self.sampler.sample_new_exmps(steps=60, step_size=10)

        # Predict energy score for all images
        inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
        real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0)

        # Calculate losses
        reg_loss = self.hparams.alpha * (real_out ** 2 + fake_out ** 2).mean()
        cdiv_loss = fake_out.mean() - real_out.mean()
        loss = reg_loss + cdiv_loss

        # Logging
        self.log("loss", loss)
        self.log("loss_regularization", reg_loss)
        self.log("loss_contrastive_divergence", cdiv_loss)
        self.log("metrics_avg_real", real_out.mean())
        self.log("metrics_avg_fake", fake_out.mean())
        return loss

    def validation_step(self, batch, batch_idx):
        # For validating, we calculate the contrastive divergence between purely random images and unseen examples
        # Note that the validation/test step of energy-based models depends on what we are interested in the model
        real_imgs, _ = batch
        fake_imgs = torch.rand_like(real_imgs) * 2 - 1

        inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
        real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0)

        cdiv = fake_out.mean() - real_out.mean()
        self.log("val_contrastive_divergence", cdiv)
        self.log("val_fake_out", fake_out.mean())
        self.log("val_real_out", real_out.mean())

We do not implement a test step because energy-based, generative models are usually not evaluated on a test set. The validation step however is used to get an idea of the difference between ennergy/likelihood of random images to unseen examples of the dataset.

Callbacks

To track the performance of our model during training, we will make extensive use of PyTorch Lightning’s callback framework. Remember that callbacks can be used for running small functions at any point of the training, for instance after finishing an epoch. Here, we will use three different callbacks we define ourselves.

The first callback, called GenerateCallback, is used for adding image generations to the model during training. After every N epochs (usually N=5 to reduce output to TensorBoard), we take a small batch of random images and perform many MCMC iterations until the model’s generation converges. Compared to the training that used 60 iterations, we use 256 here because (1) we only have to do it once compared to the training that has to do it every iteration, and (2) we do not start from a buffer here, but from scratch. It is implemented as follows:

[8]:
class GenerateCallback(pl.Callback):
    def __init__(self, batch_size=8, vis_steps=8, num_steps=256, every_n_epochs=5):
        super().__init__()
        self.batch_size = batch_size  # Number of images to generate
        self.vis_steps = vis_steps  # Number of steps within generation to visualize
        self.num_steps = num_steps  # Number of steps to take during generation
        # Only save those images every N epochs (otherwise tensorboard gets quite large)
        self.every_n_epochs = every_n_epochs

    def on_epoch_end(self, trainer, pl_module):
        # Skip for all other epochs
        if trainer.current_epoch % self.every_n_epochs == 0:
            # Generate images
            imgs_per_step = self.generate_imgs(pl_module)
            # Plot and add to tensorboard
            for i in range(imgs_per_step.shape[1]):
                step_size = self.num_steps // self.vis_steps
                imgs_to_plot = imgs_per_step[step_size - 1 :: step_size, i]
                grid = torchvision.utils.make_grid(
                    imgs_to_plot, nrow=imgs_to_plot.shape[0], normalize=True, range=(-1, 1)
                )
                trainer.logger.experiment.add_image("generation_%i" % i, grid, global_step=trainer.current_epoch)

    def generate_imgs(self, pl_module):
        pl_module.eval()
        start_imgs = torch.rand((self.batch_size,) + pl_module.hparams["img_shape"]).to(pl_module.device)
        start_imgs = start_imgs * 2 - 1
        imgs_per_step = Sampler.generate_samples(
            pl_module.cnn, start_imgs, steps=self.num_steps, step_size=10, return_img_per_step=True
        )
        pl_module.train()
        return imgs_per_step

The second callback is called SamplerCallback, and simply adds a randomly picked subset of images in the sampling buffer to the TensorBoard. This helps to understand what images are currently shown to the model as “fake”.

[9]:
class SamplerCallback(pl.Callback):
    def __init__(self, num_imgs=32, every_n_epochs=5):
        super().__init__()
        self.num_imgs = num_imgs  # Number of images to plot
        # Only save those images every N epochs (otherwise tensorboard gets quite large)
        self.every_n_epochs = every_n_epochs

    def on_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.every_n_epochs == 0:
            exmp_imgs = torch.cat(random.choices(pl_module.sampler.examples, k=self.num_imgs), dim=0)
            grid = torchvision.utils.make_grid(exmp_imgs, nrow=4, normalize=True, range=(-1, 1))
            trainer.logger.experiment.add_image("sampler", grid, global_step=trainer.current_epoch)

Finally, our last callback is OutlierCallback. This callback evaluates the model by recording the (negative) energy assigned to random noise. While our training loss is almost constant across iterations, this score is likely showing the progress of the model to detect “outliers”.

[10]:
class OutlierCallback(pl.Callback):
    def __init__(self, batch_size=1024):
        super().__init__()
        self.batch_size = batch_size

    def on_epoch_end(self, trainer, pl_module):
        with torch.no_grad():
            pl_module.eval()
            rand_imgs = torch.rand((self.batch_size,) + pl_module.hparams["img_shape"]).to(pl_module.device)
            rand_imgs = rand_imgs * 2 - 1.0
            rand_out = pl_module.cnn(rand_imgs).mean()
            pl_module.train()

        trainer.logger.experiment.add_scalar("rand_out", rand_out, global_step=trainer.current_epoch)
Running the model

Finally, we can add everything together to create our final training function. The function is very similar to any other PyTorch Lightning training function we have seen so far. However, there is the small difference of that we do not test the model on a test set because we will analyse the model afterward by checking its prediction and ability to perform outlier detection.

[11]:
def train_model(**kwargs):
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, "MNIST"),
        gpus=1 if str(device).startswith("cuda") else 0,
        max_epochs=60,
        gradient_clip_val=0.1,
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode="min", monitor="val_contrastive_divergence"),
            GenerateCallback(every_n_epochs=5),
            SamplerCallback(every_n_epochs=5),
            OutlierCallback(),
            LearningRateMonitor("epoch"),
        ],
        progress_bar_refresh_rate=1,
    )
    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "MNIST.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = DeepEnergyModel.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)
        model = DeepEnergyModel(**kwargs)
        trainer.fit(model, train_loader, test_loader)
        model = DeepEnergyModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    # No testing as we are more interested in other properties
    return model
[12]:
model = train_model(img_shape=(1, 28, 28), batch_size=train_loader.batch_size, lr=1e-4, beta1=0.0)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Found pretrained model, loading...

Analysis

In the last part of the notebook, we will try to take the trained energy-based generative model, and analyse its properties.

TensorBoard

The first thing we can look at is the TensorBoard generate during training. This can help us to understand the training dynamic even better, and shows potential issues. Let’s load the TensorBoard below:

[13]:
# Uncomment the following two lines to open a tensorboard in the notebook.
# Adjust the path to your CHECKPOINT_PATH if needed.
# %load_ext tensorboard
# %tensorboard --logdir ../saved_models/tutorial8/tensorboards/

eb78073aa45e40889ca7e478248188db

We see that the contrastive divergence as well as the regularization converge quickly to 0. However, the training continues although the loss is always close to zero. This is because our “training” data changes with the model by sampling. The progress of training can be best measured by looking at the samples across iterations, and the score for random images that decreases constantly over time.

Image Generation

Another way of evaluating generative models is by sampling a few generated images. Generative models need to be good at generating realistic images as this truely shows that they have modeled the true data distribution. Thus, let’s sample a few images of the model below:

[14]:
model.to(device)
pl.seed_everything(43)
callback = GenerateCallback(batch_size=4, vis_steps=8, num_steps=256)
imgs_per_step = callback.generate_imgs(model)
imgs_per_step = imgs_per_step.cpu()
Global seed set to 43

The characteristic of sampling with energy-based models is that they require the iterative MCMC algorithm. To gain an insight in how the images change over iterations, we plot a few intermediate samples in the MCMC as well:

[15]:
for i in range(imgs_per_step.shape[1]):
    step_size = callback.num_steps // callback.vis_steps
    imgs_to_plot = imgs_per_step[step_size - 1 :: step_size, i]
    imgs_to_plot = torch.cat([imgs_per_step[0:1, i], imgs_to_plot], dim=0)
    grid = torchvision.utils.make_grid(
        imgs_to_plot, nrow=imgs_to_plot.shape[0], normalize=True, range=(-1, 1), pad_value=0.5, padding=2
    )
    grid = grid.permute(1, 2, 0)
    plt.figure(figsize=(8, 8))
    plt.imshow(grid)
    plt.xlabel("Generation iteration")
    plt.xticks(
        [(imgs_per_step.shape[-1] + 2) * (0.5 + j) for j in range(callback.vis_steps + 1)],
        labels=[1] + list(range(step_size, imgs_per_step.shape[0] + 1, step_size)),
    )
    plt.yticks([])
    plt.show()
/usr/local/lib/python3.9/dist-packages/torchvision/utils.py:50: UserWarning: range will be deprecated, please use value_range instead.
  warnings.warn(warning)
_images/notebooks_course_UvA-DL_07-deep-energy-based-generative-models_39_1.svg
/usr/local/lib/python3.9/dist-packages/torchvision/utils.py:50: UserWarning: range will be deprecated, please use value_range instead.
  warnings.warn(warning)
_images/notebooks_course_UvA-DL_07-deep-energy-based-generative-models_39_3.svg
/usr/local/lib/python3.9/dist-packages/torchvision/utils.py:50: UserWarning: range will be deprecated, please use value_range instead.
  warnings.warn(warning)
_images/notebooks_course_UvA-DL_07-deep-energy-based-generative-models_39_5.svg
/usr/local/lib/python3.9/dist-packages/torchvision/utils.py:50: UserWarning: range will be deprecated, please use value_range instead.
  warnings.warn(warning)
_images/notebooks_course_UvA-DL_07-deep-energy-based-generative-models_39_7.svg

We see that although starting from noise in the very first step, the sampling algorithm obtains reasonable shapes after only 32 steps. Over the next 200 steps, the shapes become clearer and changed towards realistic digits. The specific samples can differ when you run the code on Colab, hence the following description is specific to the plots shown on the website. The first row shows an 8, where we remove unnecessary white parts over iterations. The transformation across iterations can be seen at best for the second sample, which creates a digit of 2. While the first sample after 32 iterations looks a bit like a digit, but not really, the sample is transformed more and more to a typical image of the digit 2.

Out-of-distribution detection

A very common and strong application of energy-based models is out-of-distribution detection (sometimes referred to as “anomaly” detection). As more and more deep learning models are applied in production and applications, a crucial aspect of these models is to know what the models don’t know. Deep learning models are usually overconfident, meaning that they classify even random images sometimes with 100% probability. Clearly, this is not something that we want to see in applications. Energy-based models can help with this problem because they are trained to detect images that do not fit the training dataset distribution. Thus, in those applications, you could train an energy-based model along with the classifier, and only output predictions if the energy-based models assign a (unnormalized) probability higher than \delta to the image. You can actually combine classifiers and energy-based objectives in a single model, as proposed in this paper.

In this part of the analysis, we want to test the out-of-distribution capability of our energy-based model. Remember that a lower output of the model denotes a low probability. Thus, we hope to see low scores if we enter random noise to the model:

[16]:
with torch.no_grad():
    rand_imgs = torch.rand((128,) + model.hparams.img_shape).to(model.device)
    rand_imgs = rand_imgs * 2 - 1.0
    rand_out = model.cnn(rand_imgs).mean()
    print("Average score for random images: %4.2f" % (rand_out.item()))
Average score for random images: -17.88

As we hoped, the model assigns very low probability to those noisy images. As another reference, let’s look at predictions for a batch of images from the training set:

[17]:
with torch.no_grad():
    train_imgs, _ = next(iter(train_loader))
    train_imgs = train_imgs.to(model.device)
    train_out = model.cnn(train_imgs).mean()
    print("Average score for training images: %4.2f" % (train_out.item()))
Average score for training images: -0.00

The scores are close to 0 because of the regularization objective that was added to the training. So clearly, the model can distinguish between noise and real digits. However, what happens if we change the training images a little, and see which ones gets a very low score?

[18]:
@torch.no_grad()
def compare_images(img1, img2):
    imgs = torch.stack([img1, img2], dim=0).to(model.device)
    score1, score2 = model.cnn(imgs).cpu().chunk(2, dim=0)
    grid = torchvision.utils.make_grid(
        [img1.cpu(), img2.cpu()], nrow=2, normalize=True, range=(-1, 1), pad_value=0.5, padding=2
    )
    grid = grid.permute(1, 2, 0)
    plt.figure(figsize=(4, 4))
    plt.imshow(grid)
    plt.xticks([(img1.shape[2] + 2) * (0.5 + j) for j in range(2)], labels=["Original image", "Transformed image"])
    plt.yticks([])
    plt.show()
    print("Score original image: %4.2f" % score1)
    print("Score transformed image: %4.2f" % score2)

We use a random test image for this. Feel free to change it to experiment with the model yourself.

[19]:
test_imgs, _ = next(iter(test_loader))
exmp_img = test_imgs[0].to(model.device)

The first transformation is to add some random noise to the image:

[20]:
img_noisy = exmp_img + torch.randn_like(exmp_img) * 0.3
img_noisy.clamp_(min=-1.0, max=1.0)
compare_images(exmp_img, img_noisy)
/usr/local/lib/python3.9/dist-packages/torchvision/utils.py:50: UserWarning: range will be deprecated, please use value_range instead.
  warnings.warn(warning)
_images/notebooks_course_UvA-DL_07-deep-energy-based-generative-models_50_1.svg
Score original image: 0.03
Score transformed image: -0.07

We can see that the score considerably drops. Hence, the model can detect random Gaussian noise on the image. This is also to expect as initially, the “fake” samples are pure noise images.

Next, we flip an image and check how this influences the score:

[21]:
img_flipped = exmp_img.flip(dims=(1, 2))
compare_images(exmp_img, img_flipped)
/usr/local/lib/python3.9/dist-packages/torchvision/utils.py:50: UserWarning: range will be deprecated, please use value_range instead.
  warnings.warn(warning)
_images/notebooks_course_UvA-DL_07-deep-energy-based-generative-models_52_1.svg
Score original image: 0.03
Score transformed image: -0.00

If the digit can only be read in this way, for example, the 7, then we can see that the score drops. However, the score only drops slightly. This is likely because of the small size of our model. Keep in mind that generative modeling is a much harder task than classification, as we do not only need to distinguish between classes but learn all details/characteristics of the digits. With a deeper model, this could eventually be captured better (but at the cost of greater training instability).

Finally, we check what happens if we reduce the digit significantly in size:

[22]:
img_tiny = torch.zeros_like(exmp_img) - 1
img_tiny[:, exmp_img.shape[1] // 2 :, exmp_img.shape[2] // 2 :] = exmp_img[:, ::2, ::2]
compare_images(exmp_img, img_tiny)
/usr/local/lib/python3.9/dist-packages/torchvision/utils.py:50: UserWarning: range will be deprecated, please use value_range instead.
  warnings.warn(warning)
_images/notebooks_course_UvA-DL_07-deep-energy-based-generative-models_54_1.svg
Score original image: 0.03
Score transformed image: -0.02

The score again drops but not by a large margin, although digits in the MNIST dataset usually are much larger.

Overall, we can conclude that our model is good for detecting Gaussian noise and smaller transformations to existing digits. Nonetheless, to obtain a very good out-of-distribution model, we would need to train deeper models and for more iterations.

Instability

Finally, we should discuss the possible instabilities of energy-based models, in particular for the example of image generation that we have implemented in this notebook. In the process of hyperparameter search for this notebook, there have been several models that diverged. Divergence in energy-based models means that the models assign a high probability to examples of the training set which is a good thing. However, at the same time, the sampling algorithm fails and only generates noise images that obtain minimal probability scores. This happens because the model has created many local maxima in which the generated noise images fall. The energy surface over which we calculate the gradients to reach data points with high probability has “diverged” and is not useful for our MCMC sampling.

Besides finding the optimal hyperparameters, a common trick in energy-based models is to reload stable checkpoints. If we detect that the model is diverging, we stop the training, load the model from one epoch ago where it did not diverge yet. Afterward, we continue training and hope that with a different seed the model is not diverging again. Nevertheless, this should be considered as the “last hope” for stabilizing the models, and careful hyperparameter tuning is the better way to do so. Sensitive hyperparameters include step_size, steps and the noise standard deviation in the sampler, and the learning rate and feature dimensionality in the CNN model.

Conclusion

In this tutorial, we have discussed energy-based models for generative modeling. The concept relies on the idea that any strictly positive function can be turned into a probability distribution by normalizing over the whole dataset. As this is not reasonable to calculate for high dimensional data like images, we train the model using contrastive divergence and sampling via MCMC. While the idea allows us to turn any neural network into an energy-based model, we have seen that there are multiple training tricks needed to stabilize the training. Furthermore, the training time of these models is relatively long as, during every training iteration, we need to sample new “fake” images, even with a sampling buffer. In the next lectures and assignment, we will see different generative models (e.g. VAE, GAN, NF) that allow us to do generative modeling more stably, but with the cost of more parameters.

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Tutorial 8: Deep Autoencoders

  • Author: Phillip Lippe

  • License: CC BY-SA

  • Generated: 2023-01-05T11:32:28.944067

In this tutorial, we will take a closer look at autoencoders (AE). Autoencoders are trained on encoding input data such as images into a smaller feature vector, and afterward, reconstruct it by a second neural network, called a decoder. The feature vector is called the “bottleneck” of the network as we aim to compress the input data into a smaller amount of features. This property is useful in many applications, in particular in compressing data or comparing images on a metric beyond pixel-level comparisons. Besides learning about the autoencoder framework, we will also see the “deconvolution” (or transposed convolution) operator in action for scaling up feature maps in height and width. Such deconvolution networks are necessary wherever we start from a small feature vector and need to output an image of full size (e.g. in VAE, GANs, or super-resolution applications). This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "pytorch-lightning>=1.4, <1.9" "torchmetrics>=0.7, <0.12" "torchvision" "matplotlib" "seaborn" "torch>=1.8.1, <1.14.0" "ipython[notebook]>=8.0.0, <8.9.0" "setuptools==65.6.3"
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[2]:
import os
import urllib.request
from urllib.error import HTTPError

import matplotlib
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from IPython.display import set_matplotlib_formats
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import CIFAR10
from tqdm.notebook import tqdm

%matplotlib inline
set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.reset_orig()
sns.set()

# Tensorboard extension (for visualization purposes later)
%load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = os.environ.get("PATH_DATASETS", "data")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/tutorial9")

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
/tmp/ipykernel_3063/1482893342.py:23: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
  set_matplotlib_formats("svg", "pdf")  # For export
Global seed set to 42
Device: cuda:0

We have 4 pretrained models that we have to download. Remember the adjust the variables DATASET_PATH and CHECKPOINT_PATH if needed.

[3]:
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial9/"
# Files to download
pretrained_files = ["cifar10_64.ckpt", "cifar10_128.ckpt", "cifar10_256.ckpt", "cifar10_384.ckpt"]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print("Downloading %s..." % file_url)
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the files manually,"
                " or contact the author with the full output including the following error:\n",
                e,
            )
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial9/cifar10_64.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial9/cifar10_128.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial9/cifar10_256.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial9/cifar10_384.ckpt...

In this tutorial, we work with the CIFAR10 dataset. In CIFAR10, each image has 3 color channels and is 32x32 pixels large. As autoencoders do not have the constrain of modeling images probabilistic, we can work on more complex image data (i.e. 3 color channels instead of black-and-white) much easier than for VAEs. In case you have downloaded CIFAR10 already in a different directory, make sure to set DATASET_PATH accordingly to prevent another download.

In contrast to previous tutorials on CIFAR10 like Tutorial 5 (CNN classification), we do not normalize the data explicitly with a mean of 0 and std of 1, but roughly estimate it scaling the data between -1 and 1. This is because limiting the range will make our task of predicting/reconstructing images easier.

[4]:
# Transformations applied on each image => only make them a tensor
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Loading the training dataset. We need to split it into a training and validation part
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=transform, download=True)
pl.seed_everything(42)
train_set, val_set = torch.utils.data.random_split(train_dataset, [45000, 5000])

# Loading the test set
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
train_loader = data.DataLoader(train_set, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)


def get_train_images(num):
    return torch.stack([train_dataset[i][0] for i in range(num)], dim=0)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /__w/8/s/.datasets/cifar-10-python.tar.gz
Extracting /__w/8/s/.datasets/cifar-10-python.tar.gz to /__w/8/s/.datasets
Global seed set to 42
Files already downloaded and verified

Building the autoencoder

In general, an autoencoder consists of an encoder that maps the input x to a lower-dimensional feature vector z, and a decoder that reconstructs the input \hat{x} from z. We train the model by comparing x to \hat{x} and optimizing the parameters to increase the similarity between x and \hat{x}. See below for a small illustration of the autoencoder framework.

7f635549ba9842549cf62def7e79741e

We first start by implementing the encoder. The encoder effectively consists of a deep convolutional network, where we scale down the image layer-by-layer using strided convolutions. After downscaling the image three times, we flatten the features and apply linear layers. The latent representation z is therefore a vector of size d which can be flexibly selected.

[5]:
class Encoder(nn.Module):
    def __init__(self, num_input_channels: int, base_channel_size: int, latent_dim: int, act_fn: object = nn.GELU):
        """
        Args:
           num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
           base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
           latent_dim : Dimensionality of latent representation z
           act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2),  # 32x32 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 16x16 => 8x8
            act_fn(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 8x8 => 4x4
            act_fn(),
            nn.Flatten(),  # Image grid to single feature vector
            nn.Linear(2 * 16 * c_hid, latent_dim),
        )

    def forward(self, x):
        return self.net(x)

Note that we do not apply Batch Normalization here. This is because we want the encoding of each image to be independent of all the other images. Otherwise, we might introduce correlations into the encoding or decoding that we do not want to have. In some implementations, you still can see Batch Normalization being used, because it can also serve as a form of regularization. Nevertheless, the better practice is to go with other normalization techniques if necessary like Instance Normalization or Layer Normalization. Given the small size of the model, we can neglect normalization for now.

The decoder is a mirrored, flipped version of the encoder. The only difference is that we replace strided convolutions by transposed convolutions (i.e. deconvolutions) to upscale the features. Transposed convolutions can be imagined as adding the stride to the input instead of the output, and can thus upscale the input. For an illustration of a nn.ConvTranspose2d layer with kernel size 3, stride 2, and padding 1, see below (figure credit - Vincent Dumoulin and Francesco Visin):

42db2b4685ba4c61b5a42bc11592c5fd

You see that for an input of size 3\times3, we obtain an output of 5\times5. However, to truly have a reverse operation of the convolution, we need to ensure that the layer scales the input shape by a factor of 2 (e.g. 4\times4\to8\times8). For this, we can specify the parameter output_padding which adds additional values to the output shape. Note that we do not perform zero-padding with this, but rather increase the output shape for calculation.

Overall, the decoder can be implemented as follows:

[6]:
class Decoder(nn.Module):
    def __init__(self, num_input_channels: int, base_channel_size: int, latent_dim: int, act_fn: object = nn.GELU):
        """
        Args:
           num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3
           base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
           latent_dim : Dimensionality of latent representation z
           act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(nn.Linear(latent_dim, 2 * 16 * c_hid), act_fn())
        self.net = nn.Sequential(
            nn.ConvTranspose2d(
                2 * c_hid, 2 * c_hid, kernel_size=3, output_padding=1, padding=1, stride=2
            ),  # 4x4 => 8x8
            act_fn(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),  # 8x8 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(
                c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2
            ),  # 16x16 => 32x32
            nn.Tanh(),  # The input images is scaled between -1 and 1, hence the output has to be bounded as well
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x

The encoder and decoder networks we chose here are relatively simple. Usually, more complex networks are applied, especially when using a ResNet-based architecture. For example, see VQ-VAE and NVAE (although the papers discuss architectures for VAEs, they can equally be applied to standard autoencoders).

In a final step, we add the encoder and decoder together into the autoencoder architecture. We define the autoencoder as PyTorch Lightning Module to simplify the needed training code:

[7]:
class Autoencoder(pl.LightningModule):
    def __init__(
        self,
        base_channel_size: int,
        latent_dim: int,
        encoder_class: object = Encoder,
        decoder_class: object = Decoder,
        num_input_channels: int = 3,
        width: int = 32,
        height: int = 32,
    ):
        super().__init__()
        # Saving hyperparameters of autoencoder
        self.save_hyperparameters()
        # Creating encoder and decoder
        self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
        self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
        # Example input array needed for visualizing the graph of the network
        self.example_input_array = torch.zeros(2, num_input_channels, width, height)

    def forward(self, x):
        """The forward function takes in an image and returns the reconstructed image."""
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

    def _get_reconstruction_loss(self, batch):
        """Given a batch of images, this function returns the reconstruction loss (MSE in our case)"""
        x, _ = batch  # We do not need the labels
        x_hat = self.forward(x)
        loss = F.mse_loss(x, x_hat, reduction="none")
        loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0])
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        # Using a scheduler is optional but can be helpful.
        # The scheduler reduces the LR if the validation performance hasn't improved for the last N epochs
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("test_loss", loss)

For the loss function, we use the mean squared error (MSE). The mean squared error pushes the network to pay special attention to those pixel values its estimate is far away. Predicting 127 instead of 128 is not important when reconstructing, but confusing 0 with 128 is much worse. Note that in contrast to VAEs, we do not predict the probability per pixel value, but instead use a distance measure. This saves a lot of parameters and simplifies training. To get a better intuition per pixel, we report the summed squared error averaged over the batch dimension (any other mean/sum leads to the same result/parameters).

However, MSE has also some considerable disadvantages. Usually, MSE leads to blurry images where small noise/high-frequent patterns are removed as those cause a very low error. To ensure realistic images to be reconstructed, one could combine Generative Adversarial Networks (lecture 10) with autoencoders as done in several works (e.g. see here, here or these slides). Additionally, comparing two images using MSE does not necessarily reflect their visual similarity. For instance, suppose the autoencoder reconstructs an image shifted by one pixel to the right and bottom. Although the images are almost identical, we can get a higher loss than predicting a constant pixel value for half of the image (see code below). An example solution for this issue includes using a separate, pre-trained CNN, and use a distance of visual features in lower layers as a distance measure instead of the original pixel-level comparison.

[8]:
def compare_imgs(img1, img2, title_prefix=""):
    # Calculate MSE loss between both images
    loss = F.mse_loss(img1, img2, reduction="sum")
    # Plot images for visual comparison
    grid = torchvision.utils.make_grid(torch.stack([img1, img2], dim=0), nrow=2, normalize=True, range=(-1, 1))
    grid = grid.permute(1, 2, 0)
    plt.figure(figsize=(4, 2))
    plt.title(f"{title_prefix} Loss: {loss.item():4.2f}")
    plt.imshow(grid)
    plt.axis("off")
    plt.show()


for i in range(2):
    # Load example image
    img, _ = train_dataset[i]
    img_mean = img.mean(dim=[1, 2], keepdims=True)

    # Shift image by one pixel
    SHIFT = 1
    img_shifted = torch.roll(img, shifts=SHIFT, dims=1)
    img_shifted = torch.roll(img_shifted, shifts=SHIFT, dims=2)
    img_shifted[:, :1, :] = img_mean
    img_shifted[:, :, :1] = img_mean
    compare_imgs(img, img_shifted, "Shifted -")

    # Set half of the image to zero
    img_masked = img.clone()
    img_masked[:, : img_masked.shape[1] // 2, :] = img_mean
    compare_imgs(img, img_masked, "Masked -")
_images/notebooks_course_UvA-DL_08-deep-autoencoders_19_0.svg
_images/notebooks_course_UvA-DL_08-deep-autoencoders_19_1.svg
_images/notebooks_course_UvA-DL_08-deep-autoencoders_19_2.svg
_images/notebooks_course_UvA-DL_08-deep-autoencoders_19_3.svg
Training the model

During the training, we want to keep track of the learning progress by seeing reconstructions made by our model. For this, we implement a callback object in PyTorch Lightning which will add reconstructions every N epochs to our tensorboard:

[9]:
class GenerateCallback(pl.Callback):
    def __init__(self, input_imgs, every_n_epochs=1):
        super().__init__()
        self.input_imgs = input_imgs  # Images to reconstruct during training
        # Only save those images every N epochs (otherwise tensorboard gets quite large)
        self.every_n_epochs = every_n_epochs

    def on_train_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.every_n_epochs == 0:
            # Reconstruct images
            input_imgs = self.input_imgs.to(pl_module.device)
            with torch.no_grad():
                pl_module.eval()
                reconst_imgs = pl_module(input_imgs)
                pl_module.train()
            # Plot and add to tensorboard
            imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0, 1)
            grid = torchvision.utils.make_grid(imgs, nrow=2, normalize=True, range=(-1, 1))
            trainer.logger.experiment.add_image("Reconstructions", grid, global_step=trainer.global_step)

We will now write a training function that allows us to train the autoencoder with different latent dimensionality and returns both the test and validation score. We provide pre-trained models and recommend you using those, especially when you work on a computer without GPU. Of course, feel free to train your own models on Lisa.

[10]:
def train_cifar(latent_dim):
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, "cifar10_%i" % latent_dim),
        gpus=1 if str(device).startswith("cuda") else 0,
        max_epochs=500,
        callbacks=[
            ModelCheckpoint(save_weights_only=True),
            GenerateCallback(get_train_images(8), every_n_epochs=10),
            LearningRateMonitor("epoch"),
        ],
    )
    trainer.logger._log_graph = True  # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "cifar10_%i.ckpt" % latent_dim)
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = Autoencoder.load_from_checkpoint(pretrained_filename)
    else:
        model = Autoencoder(base_channel_size=32, latent_dim=latent_dim)
        trainer.fit(model, train_loader, val_loader)
    # Test best model on validation and test set
    val_result = trainer.test(model, dataloaders=val_loader, verbose=False)
    test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
    result = {"test": test_result, "val": val_result}
    return model, result
Comparing latent dimensionality

When training an autoencoder, we need to choose a dimensionality for the latent representation z. The higher the latent dimensionality, the better we expect the reconstruction to be. However, the idea of autoencoders is to compress data. Hence, we are also interested in keeping the dimensionality low. To find the best tradeoff, we can train multiple models with different latent dimensionalities. The original input has 32\times 32\times 3 = 3072 pixels. Keeping this in mind, a reasonable choice for the latent dimensionality might be between 64 and 384:

[11]:
model_dict = {}
for latent_dim in [64, 128, 256, 384]:
    model_ld, result_ld = train_cifar(latent_dim)
    model_dict[latent_dim] = {"model": model_ld, "result": result_ld}
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:441: LightningDeprecationWarning: Setting `Trainer(gpus=1)` is deprecated in v1.7 and will be removed in v2.0. Please use `Trainer(accelerator='gpu', devices=1)` instead.
  rank_zero_deprecation(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/tutorial9/cifar10_64/lightning_logs
Found pretrained model, loading...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/tutorial9/cifar10_128/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model, loading...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/tutorial9/cifar10_256/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model, loading...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/tutorial9/cifar10_384/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model, loading...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

After training the models, we can plot the reconstruction loss over the latent dimensionality to get an intuition how these two properties are correlated:

[12]:
latent_dims = sorted(k for k in model_dict)
val_scores = [model_dict[k]["result"]["val"][0]["test_loss"] for k in latent_dims]

fig = plt.figure(figsize=(6, 4))
plt.plot(
    latent_dims, val_scores, "--", color="#000", marker="*", markeredgecolor="#000", markerfacecolor="y", markersize=16
)
plt.xscale("log")
plt.xticks(latent_dims, labels=latent_dims)
plt.title("Reconstruction error over latent dimensionality", fontsize=14)
plt.xlabel("Latent dimensionality")
plt.ylabel("Reconstruction error")
plt.minorticks_off()
plt.ylim(0, 100)
plt.show()
_images/notebooks_course_UvA-DL_08-deep-autoencoders_27_0.svg

As we initially expected, the reconstruction loss goes down with increasing latent dimensionality. For our model and setup, the two properties seem to be exponentially (or double exponentially) correlated. To understand what these differences in reconstruction error mean, we can visualize example reconstructions of the four models:

[13]:
def visualize_reconstructions(model, input_imgs):
    # Reconstruct images
    model.eval()
    with torch.no_grad():
        reconst_imgs = model(input_imgs.to(model.device))
    reconst_imgs = reconst_imgs.cpu()

    # Plotting
    imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0, 1)
    grid = torchvision.utils.make_grid(imgs, nrow=4, normalize=True, range=(-1, 1))
    grid = grid.permute(1, 2, 0)
    plt.figure(figsize=(7, 4.5))
    plt.title("Reconstructed from %i latents" % (model.hparams.latent_dim))
    plt.imshow(grid)
    plt.axis("off")
    plt.show()
[14]:
input_imgs = get_train_images(4)
for latent_dim in model_dict:
    visualize_reconstructions(model_dict[latent_dim]["model"], input_imgs)
_images/notebooks_course_UvA-DL_08-deep-autoencoders_30_0.svg
_images/notebooks_course_UvA-DL_08-deep-autoencoders_30_1.svg
_images/notebooks_course_UvA-DL_08-deep-autoencoders_30_2.svg
_images/notebooks_course_UvA-DL_08-deep-autoencoders_30_3.svg

Clearly, the smallest latent dimensionality can only save information about the rough shape and color of the object, but the reconstructed image is extremely blurry and it is hard to recognize the original object in the reconstruction. With 128 features, we can recognize some shapes again although the picture remains blurry. The models with the highest two dimensionalities reconstruct the images quite well. The difference between 256 and 384 is marginal at first sight but can be noticed when comparing, for instance, the backgrounds of the first image (the 384 features model more of the pattern than 256).

Out-of-distribution images

Before continuing with the applications of autoencoder, we can actually explore some limitations of our autoencoder. For example, what happens if we try to reconstruct an image that is clearly out of the distribution of our dataset? We expect the decoder to have learned some common patterns in the dataset, and thus might in particular fail to reconstruct images that do not follow these patterns.

The first experiment we can try is to reconstruct noise. We, therefore, create two images whose pixels are randomly sampled from a uniform distribution over pixel values, and visualize the reconstruction of the model (feel free to test different latent dimensionalities):

[15]:
rand_imgs = torch.rand(2, 3, 32, 32) * 2 - 1
visualize_reconstructions(model_dict[256]["model"], rand_imgs)
_images/notebooks_course_UvA-DL_08-deep-autoencoders_33_0.svg

The reconstruction of the noise is quite poor, and seems to introduce some rough patterns. As the input does not follow the patterns of the CIFAR dataset, the model has issues reconstructing it accurately.

We can also check how well the model can reconstruct other manually-coded patterns:

[16]:
plain_imgs = torch.zeros(4, 3, 32, 32)

# Single color channel
plain_imgs[1, 0] = 1
# Checkboard pattern
plain_imgs[2, :, :16, :16] = 1
plain_imgs[2, :, 16:, 16:] = -1
# Color progression
xx, yy = torch.meshgrid(torch.linspace(-1, 1, 32), torch.linspace(-1, 1, 32))
plain_imgs[3, 0, :, :] = xx
plain_imgs[3, 1, :, :] = yy

visualize_reconstructions(model_dict[256]["model"], plain_imgs)
/usr/local/lib/python3.9/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
_images/notebooks_course_UvA-DL_08-deep-autoencoders_35_1.svg

The plain, constant images are reconstructed relatively good although the single color channel contains some noticeable noise. The hard borders of the checkboard pattern are not as sharp as intended, as well as the color progression, both because such patterns never occur in the real-world pictures of CIFAR.

In general, autoencoders tend to fail reconstructing high-frequent noise (i.e. sudden, big changes across few pixels) due to the choice of MSE as loss function (see our previous discussion about loss functions in autoencoders). Small misalignments in the decoder can lead to huge losses so that the model settles for the expected value/mean in these regions. For low-frequent noise, a misalignment of a few pixels does not result in a big difference to the original image. However, the larger the latent dimensionality becomes, the more of this high-frequent noise can be accurately reconstructed.

Generating new images

Variational autoencoders are a generative version of the autoencoders because we regularize the latent space to follow a Gaussian distribution. However, in vanilla autoencoders, we do not have any restrictions on the latent vector. So what happens if we would actually input a randomly sampled latent vector into the decoder? Let’s find it out below:

[17]:
model = model_dict[256]["model"]
latent_vectors = torch.randn(8, model.hparams.latent_dim, device=model.device)
with torch.no_grad():
    imgs = model.decoder(latent_vectors)
    imgs = imgs.cpu()

grid = torchvision.utils.make_grid(imgs, nrow=4, normalize=True, range=(-1, 1), pad_value=0.5)
grid = grid.permute(1, 2, 0)
plt.figure(figsize=(8, 5))
plt.imshow(grid)
plt.axis("off")
plt.show()
_images/notebooks_course_UvA-DL_08-deep-autoencoders_38_0.svg

As we can see, the generated images more look like art than realistic images. As the autoencoder was allowed to structure the latent space in whichever way it suits the reconstruction best, there is no incentive to map every possible latent vector to realistic images. Furthermore, the distribution in latent space is unknown to us and doesn’t necessarily follow a multivariate normal distribution. Thus, we can conclude that vanilla autoencoders are indeed not generative.

Finding visually similar images

One application of autoencoders is to build an image-based search engine to retrieve visually similar images. This can be done by representing all images as their latent dimensionality, and find the closest K images in this domain. The first step to such a search engine is to encode all images into z. In the following, we will use the training set as a search corpus, and the test set as queries to the system.

(Warning: the following cells can be computationally heavy for a weak CPU-only system. If you do not have a strong computer and are not on Google Colab, you might want to skip the execution of the following cells and rely on the results shown in the filled notebook)

[18]:
# We use the following model throughout this section.
# If you want to try a different latent dimensionality, change it here!
model = model_dict[128]["model"]
[19]:
def embed_imgs(model, data_loader):
    # Encode all images in the data_laoder using model, and return both images and encodings
    img_list, embed_list = [], []
    model.eval()
    for imgs, _ in tqdm(data_loader, desc="Encoding images", leave=False):
        with torch.no_grad():
            z = model.encoder(imgs.to(model.device))
        img_list.append(imgs)
        embed_list.append(z)
    return (torch.cat(img_list, dim=0), torch.cat(embed_list, dim=0))


train_img_embeds = embed_imgs(model, train_loader)
test_img_embeds = embed_imgs(model, test_loader)

After encoding all images, we just need to write a function that finds the closest K images and returns (or plots) those:

[20]:
def find_similar_images(query_img, query_z, key_embeds, K=8):
    # Find closest K images. We use the euclidean distance here but other like cosine distance can also be used.
    dist = torch.cdist(query_z[None, :], key_embeds[1], p=2)
    dist = dist.squeeze(dim=0)
    dist, indices = torch.sort(dist)
    # Plot K closest images
    imgs_to_display = torch.cat([query_img[None], key_embeds[0][indices[:K]]], dim=0)
    grid = torchvision.utils.make_grid(imgs_to_display, nrow=K + 1, normalize=True, range=(-1, 1))
    grid = grid.permute(1, 2, 0)
    plt.figure(figsize=(12, 3))
    plt.imshow(grid)
    plt.axis("off")
    plt.show()
[21]:
# Plot the closest images for the first N test images as example
for i in range(8):
    find_similar_images(test_img_embeds[0][i], test_img_embeds[1][i], key_embeds=train_img_embeds)
_images/notebooks_course_UvA-DL_08-deep-autoencoders_45_0.svg
_images/notebooks_course_UvA-DL_08-deep-autoencoders_45_1.svg
_images/notebooks_course_UvA-DL_08-deep-autoencoders_45_2.svg
_images/notebooks_course_UvA-DL_08-deep-autoencoders_45_3.svg
_images/notebooks_course_UvA-DL_08-deep-autoencoders_45_4.svg
_images/notebooks_course_UvA-DL_08-deep-autoencoders_45_5.svg
_images/notebooks_course_UvA-DL_08-deep-autoencoders_45_6.svg
_images/notebooks_course_UvA-DL_08-deep-autoencoders_45_7.svg

Based on our autoencoder, we see that we are able to retrieve many similar images to the test input. In particular, in row 4, we can spot that some test images might not be that different from the training set as we thought (same poster, just different scaling/color scaling). We also see that although we haven’t given the model any labels, it can cluster different classes in different parts of the latent space (airplane + ship, animals, etc.). This is why autoencoders can also be used as a pre-training strategy for deep networks, especially when we have a large set of unlabeled images (often the case). However, it should be noted that the background still plays a big role in autoencoders while it doesn’t for classification. Hence, we don’t get “perfect” clusters and need to finetune such models for classification.

Tensorboard clustering

Another way of exploring the similarity of images in the latent space is by dimensionality-reduction methods like PCA or T-SNE. Luckily, Tensorboard provides a nice interface for this and we can make use of it in the following:

[22]:
# We use the following model throughout this section.
# If you want to try a different latent dimensionality, change it here!
model = model_dict[128]["model"]
[23]:
# Create a summary writer
writer = SummaryWriter("tensorboard/")

The function add_embedding allows us to add high-dimensional feature vectors to TensorBoard on which we can perform clustering. What we have to provide in the function are the feature vectors, additional metadata such as the labels, and the original images so that we can identify a specific image in the clustering.

[24]:
# In case you obtain the following error in the next cell, execute the import statements and last line in this cell
# AttributeError: module 'tensorflow._api.v2.io.gfile' has no attribute 'get_filesystem'

# import tensorflow as tf
# import tensorboard as tb
# tf.io.gfile = tb.compat.tensorflow_stub.io.gfile
[25]:
# Note: the embedding projector in tensorboard is computationally heavy.
# Reduce the image amount below if your computer struggles with visualizing all 10k points
NUM_IMGS = len(test_set)

writer.add_embedding(
    test_img_embeds[1][:NUM_IMGS],  # Encodings per image
    metadata=[test_set[i][1] for i in range(NUM_IMGS)],  # Adding the labels per image to the plot
    label_img=(test_img_embeds[0][:NUM_IMGS] + 1) / 2.0,
)  # Adding the original images to the plot

Finally, we can run tensorboard to explore similarities among images:

[26]:
# Uncomment the next line to start the tensorboard
%tensorboard --logdir tensorboard/

You should be able to see something similar as in the following image. In case the projector stays empty, try to start the TensorBoard outside of the Jupyter notebook.

84f8e5dc426e415bafb1435e3bdc6af5

Overall, we can see that the model indeed clustered images together that are visually similar. Especially the background color seems to be a crucial factor in the encoding. This correlates to the chosen loss function, here Mean Squared Error on pixel-level because the background is responsible for more than half of the pixels in an average image. Hence, the model learns to focus on it. Nevertheless, we can see that the encodings also separate a couple of classes in the latent space although it hasn’t seen any labels. This shows again that autoencoding can also be used as a “pre-training”/transfer learning task before classification.

[27]:
# Closing the summary writer
writer.close()

Conclusion

In this tutorial, we have implemented our own autoencoder on small RGB images and explored various properties of the model. In contrast to variational autoencoders, vanilla AEs are not generative and can work on MSE loss functions. This makes them often easier to train. Both versions of AE can be used for dimensionality reduction, as we have seen for finding visually similar images beyond pixel distances. Despite autoencoders gaining less interest in the research community due to their more “theoretically” challenging counterpart of VAEs, autoencoders still find usage in a lot of applications like denoising and compression. Hence, AEs are an essential tool that every Deep Learning engineer/researcher should be familiar with.

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Tutorial 9: Normalizing Flows for Image Modeling

  • Author: Phillip Lippe

  • License: CC BY-SA

  • Generated: 2022-05-12T13:44:24.674574

In this tutorial, we will take a closer look at complex, deep normalizing flows. The most popular, current application of deep normalizing flows is to model datasets of images. As for other generative models, images are a good domain to start working on because (1) CNNs are widely studied and strong models exist, (2) images are high-dimensional and complex, and (3) images are discrete integers. In this tutorial, we will review current advances in normalizing flows for image modeling, and get hands-on experience on coding normalizing flows. Note that normalizing flows are commonly parameter heavy and therefore computationally expensive. We will use relatively simple and shallow flows to save computational cost and allow you to run the notebook on CPU, but keep in mind that a simple way to improve the scores of the flows we study here is to make them deeper. This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "ipython[notebook]" "seaborn" "tabulate" "torchvision" "setuptools==59.5.0" "pytorch-lightning>=1.4" "matplotlib" "torchmetrics>=0.7" "torch>=1.8"

Throughout this notebook, we make use of PyTorch Lightning. The first cell imports our usual libraries.

[2]:
import math
import os
import time
import urllib.request
from urllib.error import HTTPError

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import seaborn as sns
import tabulate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from IPython.display import HTML, display, set_matplotlib_formats
from matplotlib.colors import to_rgb
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch import Tensor
from torchvision import transforms
from torchvision.datasets import MNIST
from tqdm.notebook import tqdm

%matplotlib inline
set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.reset_orig()

# Path to the folder where the datasets are/should be downloaded (e.g. MNIST)
DATASET_PATH = os.environ.get("PATH_DATASETS", "data")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/tutorial11")

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

# Fetching the device that will be used throughout this notebook
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
print("Using device", device)
/usr/lib/python3.8/site-packages/apex/pyprof/__init__.py:5: FutureWarning: pyprof will be removed by the end of June, 2022
  warnings.warn("pyprof will be removed by the end of June, 2022", FutureWarning)
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  return self._float_to_str(self.smallest_subnormal)
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  return self._float_to_str(self.smallest_subnormal)
WARNING:root:Bagua cannot detect bundled NCCL library, Bagua will try to use system NCCL instead. If you encounter any error, please run `import bagua_core; bagua_core.install_deps()` or the `bagua_install_deps.py` script to install bundled libraries.
/tmp/ipykernel_5302/3974796986.py:28: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
  set_matplotlib_formats("svg", "pdf")  # For export
Global seed set to 42
Using device cuda:0

Again, we have a few pretrained models. We download them below to the specified path above.

[3]:
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial11/"
# Files to download
pretrained_files = ["MNISTFlow_simple.ckpt", "MNISTFlow_vardeq.ckpt", "MNISTFlow_multiscale.ckpt"]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print("Downloading %s..." % file_url)
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n",
                e,
            )
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial11/MNISTFlow_simple.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial11/MNISTFlow_vardeq.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial11/MNISTFlow_multiscale.ckpt...

We will use the MNIST dataset in this notebook. MNIST constitutes, despite its simplicity, a challenge for small generative models as it requires the global understanding of an image. At the same time, we can easily judge whether generated images come from the same distribution as the dataset (i.e. represent real digits), or not.

To deal better with the discrete nature of the images, we transform them from a range of 0-1 to a range of 0-255 as integers.

[4]:
# Convert images from 0-1 to 0-255 (integers)
def discretize(sample):
    return (sample * 255).to(torch.int32)


# Transformations applied on each image => make them a tensor and discretize
transform = transforms.Compose([transforms.ToTensor(), discretize])

# Loading the training dataset. We need to split it into a training and validation part
train_dataset = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)
pl.seed_everything(42)
train_set, val_set = torch.utils.data.random_split(train_dataset, [50000, 10000])

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
# Note that for actually training a model, we will use different data loaders
# with a lower batch size.
train_loader = data.DataLoader(train_set, batch_size=256, shuffle=False, drop_last=False)
val_loader = data.DataLoader(val_set, batch_size=64, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=64, shuffle=False, drop_last=False, num_workers=4)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /__w/1/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /__w/1/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /__w/1/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /__w/1/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

/usr/local/lib/python3.8/dist-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:180.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Global seed set to 42

In addition, we will define below a function to simplify the visualization of images/samples. Some training examples of the MNIST dataset is shown below.

[5]:
def show_imgs(imgs, title=None, row_size=4):
    # Form a grid of pictures (we use max. 8 columns)
    num_imgs = imgs.shape[0] if isinstance(imgs, Tensor) else len(imgs)
    is_int = imgs.dtype == torch.int32 if isinstance(imgs, Tensor) else imgs[0].dtype == torch.int32
    nrow = min(num_imgs, row_size)
    ncol = int(math.ceil(num_imgs / nrow))
    imgs = torchvision.utils.make_grid(imgs, nrow=nrow, pad_value=128 if is_int else 0.5)
    np_imgs = imgs.cpu().numpy()
    # Plot the grid
    plt.figure(figsize=(1.5 * nrow, 1.5 * ncol))
    plt.imshow(np.transpose(np_imgs, (1, 2, 0)), interpolation="nearest")
    plt.axis("off")
    if title is not None:
        plt.title(title)
    plt.show()
    plt.close()


show_imgs([train_set[i][0] for i in range(8)])
_images/notebooks_course_UvA-DL_09-normalizing-flows_10_0.svg

Normalizing Flows as generative model

In the previous lectures, we have seen Energy-based models, Variational Autoencoders (VAEs) and Generative Adversarial Networks (GANs) as example of generative models. However, none of them explicitly learn the probability density function p(x) of the real input data. While VAEs model a lower bound, energy-based models only implicitly learn the probability density. GANs on the other hand provide us a sampling mechanism for generating new data, without offering a likelihood estimate. The generative model we will look at here, called Normalizing Flows, actually models the true data distribution p(x) and provides us with an exact likelihood estimate. Below, we can visually compare VAEs, GANs and Flows (figure credit - Lilian Weng):

97d5f7324f1a46a6a31d95c0bb792225

The major difference compared to VAEs is that flows use invertible functions f to map the input data x to a latent representation z. To realize this, z must be of the same shape as x. This is in contrast to VAEs where z is usually much lower dimensional than the original input data. However, an invertible mapping also means that for every data point x, we have a corresponding latent representation z which allows us to perform lossless reconstruction (z to x). In the visualization above, this means that x=x' for flows, no matter what invertible function f and input x we choose.

Nonetheless, how are normalizing flows modeling a probability density with an invertible function? The answer to this question is the rule for change of variables. Specifically, given a prior density p_z(z) (e.g. Gaussian) and an invertible function f, we can determine p_x(x) as follows:

\begin{split}
    \int p_x(x) dx & = \int p_z(z) dz = 1 \hspace{1cm}\text{(by definition of a probability distribution)}\\
    \Leftrightarrow p_x(x) & = p_z(z) \left|\frac{dz}{dx}\right| = p_z(f(x)) \left|\frac{df(x)}{dx}\right|
\end{split}

Hence, in order to determine the probability of x, we only need to determine its probability in latent space, and get the derivate of f. Note that this is for a univariate distribution, and f is required to be invertible and smooth. For a multivariate case, the derivative becomes a Jacobian of which we need to take the determinant. As we usually use the log-likelihood as objective, we write the multivariate term with logarithms below:

\log p_x(\mathbf{x}) = \log p_z(f(\mathbf{x})) + \log{} \left|\det \frac{df(\mathbf{x})}{d\mathbf{x}}\right|

Although we now know how a normalizing flow obtains its likelihood, it might not be clear what a normalizing flow does intuitively. For this, we should look from the inverse perspective of the flow starting with the prior probability density p_z(z). If we apply an invertible function on it, we effectively “transform” its probability density. For instance, if f^{-1}(z)=z+1, we shift the density by one while still remaining a valid probability distribution, and being invertible. We can also apply more complex transformations, like scaling: f^{-1}(z)=2z+1, but there you might see a difference. When you scale, you also change the volume of the probability density, as for example on uniform distributions (figure credit - Eric Jang):

58634ebc855343c6a6ba215a7b3582cd

You can see that the height of p(y) should be lower than p(x) after scaling. This change in volume represents \left|\frac{df(x)}{dx}\right| in our equation above, and ensures that even after scaling, we still have a valid probability distribution. We can go on with making our function f more complex. However, the more complex f becomes, the harder it will be to find the inverse f^{-1} of it, and to calculate the log-determinant of the Jacobian \log{} \left|\det \frac{df(\mathbf{x})}{d\mathbf{x}}\right|. An easier trick to stack multiple invertible functions f_{1,...,K} after each other, as all together, they still represent a single, invertible function. Using multiple, learnable invertible functions, a normalizing flow attempts to transform p_z(z) slowly into a more complex distribution which should finally be p_x(x). We visualize the idea below (figure credit - Lilian Weng):

9fed1d93b7cb409893515b08ef7e2616

Starting from z_0, which follows the prior Gaussian distribution, we sequentially apply the invertible functions f_1,f_2,...,f_K, until z_K represents x. Note that in the figure above, the functions f represent the inverted function from f we had above (here: f:Z\to X, above: f:X\to Z). This is just a different notation and has no impact on the actual flow design because all f need to be invertible anyways. When we estimate the log likelihood of a data point x as in the equations above, we run the flows in the opposite direction than visualized above. Multiple flow layers have been proposed that use a neural network as learnable parameters, such as the planar and radial flow. However, we will focus here on flows that are commonly used in image modeling, and will discuss them in the rest of the notebook along with the details of how to train a normalizing flow.

Normalizing Flows on images

To become familiar with normalizing flows, especially for the application of image modeling, it is best to discuss the different elements in a flow along with the implementation. As a general concept, we want to build a normalizing flow that maps an input image (here MNIST) to an equally sized latent space:

634cf3b64b204980b48a82a5e204f2a2

As a first step, we will implement a template of a normalizing flow in PyTorch Lightning. During training and validation, a normalizing flow performs density estimation in the forward direction. For this, we apply a series of flow transformations on the input x and estimate the probability of the input by determining the probability of the transformed point z given a prior, and the change of volume caused by the transformations. During inference, we can do both density estimation and sampling new points by inverting the flow transformations. Therefore, we define a function _get_likelihood which performs density estimation, and sample to generate new examples. The functions training_step, validation_step and test_step all make use of _get_likelihood.

The standard metric used in generative models, and in particular normalizing flows, is bits per dimensions (bpd). Bpd is motivated from an information theory perspective and describes how many bits we would need to encode a particular example in our modeled distribution. The less bits we need, the more likely the example in our distribution. When we test for the bits per dimension of our test dataset, we can judge whether our model generalizes to new samples of the dataset and didn’t memorize the training dataset. In order to calculate the bits per dimension score, we can rely on the negative log-likelihood and change the log base (as bits are binary while NLL is usually exponential):

\text{bpd} = \text{nll} \cdot \log_2\left(\exp(1)\right) \cdot \left(\prod d_i\right)^{-1}

where d_1,...,d_K are the dimensions of the input. For images, this would be the height, width and channel number. We divide the log likelihood by these extra dimensions to have a metric which we can compare for different image resolutions. In the original image space, MNIST examples have a bits per dimension score of 8 (we need 8 bits to encode each pixel as there are 256 possible values).

[6]:
class ImageFlow(pl.LightningModule):
    def __init__(self, flows, import_samples=8):
        """
        Args:
            flows: A list of flows (each a nn.Module) that should be applied on the images.
            import_samples: Number of importance samples to use during testing (see explanation below). Can be changed at any time
        """
        super().__init__()
        self.flows = nn.ModuleList(flows)
        self.import_samples = import_samples
        # Create prior distribution for final latent space
        self.prior = torch.distributions.normal.Normal(loc=0.0, scale=1.0)
        # Example input for visualizing the graph
        self.example_input_array = train_set[0][0].unsqueeze(dim=0)

    def forward(self, imgs):
        # The forward function is only used for visualizing the graph
        return self._get_likelihood(imgs)

    def encode(self, imgs):
        # Given a batch of images, return the latent representation z and ldj of the transformations
        z, ldj = imgs, torch.zeros(imgs.shape[0], device=self.device)
        for flow in self.flows:
            z, ldj = flow(z, ldj, reverse=False)
        return z, ldj

    def _get_likelihood(self, imgs, return_ll=False):
        """Given a batch of images, return the likelihood of those.

        If return_ll is True, this function returns the log likelihood of the input. Otherwise, the ouptut metric is
        bits per dimension (scaled negative log likelihood)
        """
        z, ldj = self.encode(imgs)
        log_pz = self.prior.log_prob(z).sum(dim=[1, 2, 3])
        log_px = ldj + log_pz
        nll = -log_px
        # Calculating bits per dimension
        bpd = nll * np.log2(np.exp(1)) / np.prod(imgs.shape[1:])
        return bpd.mean() if not return_ll else log_px

    @torch.no_grad()
    def sample(self, img_shape, z_init=None):
        """Sample a batch of images from the flow."""
        # Sample latent representation from prior
        if z_init is None:
            z = self.prior.sample(sample_shape=img_shape).to(device)
        else:
            z = z_init.to(device)

        # Transform z to x by inverting the flows
        ldj = torch.zeros(img_shape[0], device=device)
        for flow in reversed(self.flows):
            z, ldj = flow(z, ldj, reverse=True)
        return z

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        # An scheduler is optional, but can help in flows to get the last bpd improvement
        scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.99)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        # Normalizing flows are trained by maximum likelihood => return bpd
        loss = self._get_likelihood(batch[0])
        self.log("train_bpd", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_likelihood(batch[0])
        self.log("val_bpd", loss)

    def test_step(self, batch, batch_idx):
        # Perform importance sampling during testing => estimate likelihood M times for each image
        samples = []
        for _ in range(self.import_samples):
            img_ll = self._get_likelihood(batch[0], return_ll=True)
            samples.append(img_ll)
        img_ll = torch.stack(samples, dim=-1)

        # To average the probabilities, we need to go from log-space to exp, and back to log.
        # Logsumexp provides us a stable implementation for this
        img_ll = torch.logsumexp(img_ll, dim=-1) - np.log(self.import_samples)

        # Calculate final bpd
        bpd = -img_ll * np.log2(np.exp(1)) / np.prod(batch[0].shape[1:])
        bpd = bpd.mean()

        self.log("test_bpd", bpd)

The test_step function differs from the training and validation step in that it makes use of importance sampling. We will discuss the motiviation and details behind this after understanding how flows model discrete images in continuous space.

Dequantization

Normalizing flows rely on the rule of change of variables, which is naturally defined in continuous space. Applying flows directly on discrete data leads to undesired density models where arbitrarly high likelihood are placed on a few, particular values. See the illustration below:

98dff479292d40628e139e8b4ef071c8

The black points represent the discrete points, and the green volume the density modeled by a normalizing flow in continuous space. The flow would continue to increase the likelihood for x=0,1,2,3 while having no volume on any other point. Remember that in continuous space, we have the constraint that the overall volume of the probability density must be 1 (\int p(x)dx=1). Otherwise, we don’t model a probability distribution anymore. However, the discrete points x=0,1,2,3 represent delta peaks with no width in continuous space. This is why the flow can place an infinite high likelihood on these few points while still representing a distribution in continuous space. Nonetheless, the learned density does not tell us anything about the distribution among the discrete points, as in discrete space, the likelihoods of those four points would have to sum to 1, not to infinity.

To prevent such degenerated solutions, a common solution is to add a small amount of noise to each discrete value, which is also referred to as dequantization. Considering x as an integer (as it is the case for images), the dequantized representation v can be formulated as v=x+u where u\in[0,1)^D. Thus, the discrete value 1 is modeled by a distribution over the interval [1.0, 2.0), the value 2 by an volume over [2.0, 3.0), etc. Our objective of modeling p(x) becomes:

p(x) = \int p(x+u)du = \int \frac{q(u|x)}{q(u|x)}p(x+u)du = \mathbb{E}_{u\sim q(u|x)}\left[\frac{p(x+u)}{q(u|x)} \right]

with q(u|x) being the noise distribution. For now, we assume it to be uniform, which can also be written as p(x)=\mathbb{E}_{u\sim U(0,1)^D}\left[p(x+u) \right].

In the following, we will implement Dequantization as a flow transformation itself. After adding noise to the discrete values, we additionally transform the volume into a Gaussian-like shape. This is done by scaling x+u between 0 and 1, and applying the invert of the sigmoid function \sigma(z)^{-1} = \log z - \log 1-z. If we would not do this, we would face two problems:

  1. The input is scaled between 0 and 256 while the prior distribution is a Gaussian with mean 0 and standard deviation 1. In the first iterations after initializing the parameters of the flow, we would have extremely low likelihoods for large values like 256. This would cause the training to diverge instantaneously.

  2. As the output distribution is a Gaussian, it is beneficial for the flow to have a similarly shaped input distribution. This will reduce the modeling complexity that is required by the flow.

Overall, we can implement dequantization as follows:

[7]:
class Dequantization(nn.Module):
    def __init__(self, alpha=1e-5, quants=256):
        """
        Args:
            alpha: small constant that is used to scale the original input.
                    Prevents dealing with values very close to 0 and 1 when inverting the sigmoid
            quants: Number of possible discrete values (usually 256 for 8-bit image)
        """
        super().__init__()
        self.alpha = alpha
        self.quants = quants

    def forward(self, z, ldj, reverse=False):
        if not reverse:
            z, ldj = self.dequant(z, ldj)
            z, ldj = self.sigmoid(z, ldj, reverse=True)
        else:
            z, ldj = self.sigmoid(z, ldj, reverse=False)
            z = z * self.quants
            ldj += np.log(self.quants) * np.prod(z.shape[1:])
            z = torch.floor(z).clamp(min=0, max=self.quants - 1).to(torch.int32)
        return z, ldj

    def sigmoid(self, z, ldj, reverse=False):
        # Applies an invertible sigmoid transformation
        if not reverse:
            ldj += (-z - 2 * F.softplus(-z)).sum(dim=[1, 2, 3])
            z = torch.sigmoid(z)
        else:
            z = z * (1 - self.alpha) + 0.5 * self.alpha  # Scale to prevent boundaries 0 and 1
            ldj += np.log(1 - self.alpha) * np.prod(z.shape[1:])
            ldj += (-torch.log(z) - torch.log(1 - z)).sum(dim=[1, 2, 3])
            z = torch.log(z) - torch.log(1 - z)
        return z, ldj

    def dequant(self, z, ldj):
        # Transform discrete values to continuous volumes
        z = z.to(torch.float32)
        z = z + torch.rand_like(z).detach()
        z = z / self.quants
        ldj -= np.log(self.quants) * np.prod(z.shape[1:])
        return z, ldj

A good check whether a flow is correctly implemented or not, is to verify that it is invertible. Hence, we will dequantize a randomly chosen training image, and then quantize it again. We would expect that we would get the exact same image out:

[8]:
# Testing invertibility of dequantization layer
pl.seed_everything(42)
orig_img = train_set[0][0].unsqueeze(dim=0)
ldj = torch.zeros(
    1,
)
dequant_module = Dequantization()
deq_img, ldj = dequant_module(orig_img, ldj, reverse=False)
reconst_img, ldj = dequant_module(deq_img, ldj, reverse=True)

d1, d2 = torch.where(orig_img.squeeze() != reconst_img.squeeze())
if len(d1) != 0:
    print("Dequantization was not invertible.")
    for i in range(d1.shape[0]):
        print("Original value:", orig_img[0, 0, d1[i], d2[i]].item())
        print("Reconstructed value:", reconst_img[0, 0, d1[i], d2[i]].item())
else:
    print("Successfully inverted dequantization")

# Layer is not strictly invertible due to float precision constraints
# assert (orig_img == reconst_img).all().item()
Global seed set to 42
Dequantization was not invertible.
Original value: 0
Reconstructed value: 1

In contrast to our expectation, the test fails. However, this is no reason to doubt our implementation here as only one single value is not equal to the original. This is caused due to numerical inaccuracies in the sigmoid invert. While the input space to the inverted sigmoid is scaled between 0 and 1, the output space is between -\infty and \infty. And as we use 32 bits to represent the numbers (in addition to applying logs over and over again), such inaccuries can occur and should not be worrisome. Nevertheless, it is good to be aware of them, and can be improved by using a double tensor (float64).

Finally, we can take our dequantization and actually visualize the distribution it transforms the discrete values into:

[9]:


def visualize_dequantization(quants, prior=None): """Function for visualizing the dequantization values of discrete values in continuous space.""" # Prior over discrete values. If not given, a uniform is assumed if prior is None: prior = np.ones(quants, dtype=np.float32) / quants prior = prior / prior.sum() * quants # In the following, we assume 1 for each value means uniform distribution inp = torch.arange(-4, 4, 0.01).view(-1, 1, 1, 1) # Possible continuous values we want to consider ldj = torch.zeros(inp.shape[0]) dequant_module = Dequantization(quants=quants) # Invert dequantization on continuous values to find corresponding discrete value out, ldj = dequant_module.forward(inp, ldj, reverse=True) inp, out, prob = inp.squeeze().numpy(), out.squeeze().numpy(), ldj.exp().numpy() prob = prob * prior[out] # Probability scaled by categorical prior # Plot volumes and continuous distribution sns.set_style("white") _ = plt.figure(figsize=(6, 3)) x_ticks = [] for v in np.unique(out): indices = np.where(out == v) color = to_rgb("C%i" % v) plt.fill_between(inp[indices], prob[indices], np.zeros(indices[0].shape[0]), color=color + (0.5,), label=str(v)) plt.plot([inp[indices[0][0]]] * 2, [0, prob[indices[0][0]]], color=color) plt.plot([inp[indices[0][-1]]] * 2, [0, prob[indices[0][-1]]], color=color) x_ticks.append(inp[indices[0][0]]) x_ticks.append(inp.max()) plt.xticks(x_ticks, ["%.1f" % x for x in x_ticks]) plt.plot(inp, prob, color=(0.0, 0.0, 0.0)) # Set final plot properties plt.ylim(0, prob.max() * 1.1) plt.xlim(inp.min(), inp.max()) plt.xlabel("z") plt.ylabel("Probability") plt.title("Dequantization distribution for %i discrete values" % quants) plt.legend() plt.show() plt.close() visualize_dequantization(quants=8)
_images/notebooks_course_UvA-DL_09-normalizing-flows_20_0.svg

The visualized distribution show the sub-volumes that are assigned to the different discrete values. The value 0 has its volume between [-\infty, -1.9), the value 1 is represented by the interval [-1.9, -1.1), etc. The volume for each discrete value has the same probability mass. That’s why the volumes close to the center (e.g. 3 and 4) have a smaller area on the z-axis as others (z is being used to denote the output of the whole dequantization flow).

Effectively, the consecutive normalizing flow models discrete images by the following objective:

\log p(x) = \log \mathbb{E}_{u\sim q(u|x)}\left[\frac{p(x+u)}{q(u|x)} \right] \geq \mathbb{E}_{u}\left[\log \frac{p(x+u)}{q(u|x)} \right]

Although normalizing flows are exact in likelihood, we have a lower bound. Specifically, this is an example of the Jensen inequality because we need to move the log into the expectation so we can use Monte-carlo estimates. In general, this bound is considerably smaller than the ELBO in variational autoencoders. Actually, we can reduce the bound ourselves by estimating the expectation not by one, but by M samples. In other words, we can apply importance sampling which leads to the following inequality:

\log p(x) = \log \mathbb{E}_{u\sim q(u|x)}\left[\frac{p(x+u)}{q(u|x)} \right] \geq \mathbb{E}_{u}\left[\log \frac{1}{M} \sum_{m=1}^{M} \frac{p(x+u_m)}{q(u_m|x)} \right] \geq \mathbb{E}_{u}\left[\log \frac{p(x+u)}{q(u|x)} \right]

The importance sampling \frac{1}{M} \sum_{m=1}^{M} \frac{p(x+u_m)}{q(u_m|x)} becomes \mathbb{E}_{u\sim q(u|x)}\left[\frac{p(x+u)}{q(u|x)} \right] if M\to \infty, so that the more samples we use, the tighter the bound is. During testing, we can make use of this property and have it implemented in test_step in ImageFlow. In theory, we could also use this tighter bound during training. However, related work has shown that this does not necessarily lead to an improvement given the additional computational cost, and it is more efficient to stick with a single estimate [5].

Variational Dequantization

Dequantization uses a uniform distribution for the noise u which effectively leads to images being represented as hypercubes (cube in high dimensions) with sharp borders. However, modeling such sharp borders is not easy for a flow as it uses smooth transformations to convert it into a Gaussian distribution.

Another way of looking at it is if we change the prior distribution in the previous visualization. Imagine we have independent Gaussian noise on pixels which is commonly the case for any real-world taken picture. Therefore, the flow would have to model a distribution as above, but with the individual volumes scaled as follows:

[10]:
visualize_dequantization(quants=8, prior=np.array([0.075, 0.2, 0.4, 0.2, 0.075, 0.025, 0.0125, 0.0125]))
_images/notebooks_course_UvA-DL_09-normalizing-flows_23_0.svg

Transforming such a probability into a Gaussian is a difficult task, especially with such hard borders. Dequantization has therefore been extended to more sophisticated, learnable distributions beyond uniform in a variational framework. In particular, if we remember the learning objective \log p(x) = \log \mathbb{E}_{u}\left[\frac{p(x+u)}{q(u|x)} \right], the uniform distribution can be replaced by a learned distribution q_{\theta}(u|x) with support over u\in[0,1)^D. This approach is called Variational Dequantization and has been proposed by Ho et al. [3]. How can we learn such a distribution? We can use a second normalizing flow that takes x as external input and learns a flexible distribution over u. To ensure a support over [0,1)^D, we can apply a sigmoid activation function as final flow transformation.

Inheriting the original dequantization class, we can implement variational dequantization as follows:

[11]:
class VariationalDequantization(Dequantization):
    def __init__(self, var_flows, alpha=1e-5):
        """
        Args:
            var_flows: A list of flow transformations to use for modeling q(u|x)
            alpha: Small constant, see Dequantization for details
        """
        super().__init__(alpha=alpha)
        self.flows = nn.ModuleList(var_flows)

    def dequant(self, z, ldj):
        z = z.to(torch.float32)
        img = (z / 255.0) * 2 - 1  # We condition the flows on x, i.e. the original image

        # Prior of u is a uniform distribution as before
        # As most flow transformations are defined on [-infinity,+infinity], we apply an inverse sigmoid first.
        deq_noise = torch.rand_like(z).detach()
        deq_noise, ldj = self.sigmoid(deq_noise, ldj, reverse=True)
        for flow in self.flows:
            deq_noise, ldj = flow(deq_noise, ldj, reverse=False, orig_img=img)
        deq_noise, ldj = self.sigmoid(deq_noise, ldj, reverse=False)

        # After the flows, apply u as in standard dequantization
        z = (z + deq_noise) / 256.0
        ldj -= np.log(256.0) * np.prod(z.shape[1:])
        return z, ldj

Variational dequantization can be used as a substitute for dequantization. We will compare dequantization and variational dequantization in later experiments.

Coupling layers

Next, we look at possible transformations to apply inside the flow. A recent popular flow layer, which works well in combination with deep neural networks, is the coupling layer introduced by Dinh et al. [1]. The input z is arbitrarily split into two parts, z_{1:j} and z_{j+1:d}, of which the first remains unchanged by the flow. Yet, z_{1:j} is used to parameterize the transformation for the second part, z_{j+1:d}. Various transformations have been proposed in recent time [3,4], but here we will settle for the simplest and most efficient one: affine coupling. In this coupling layer, we apply an affine transformation by shifting the input by a bias \mu and scale it by \sigma. In other words, our transformation looks as follows:

z'_{j+1:d} = \mu_{\theta}(z_{1:j}) + \sigma_{\theta}(z_{1:j}) \odot z_{j+1:d}

The functions \mu and \sigma are implemented as a shared neural network, and the sum and multiplication are performed element-wise. The LDJ is thereby the sum of the logs of the scaling factors: \sum_i \left[\log \sigma_{\theta}(z_{1:j})\right]_i. Inverting the layer can as simply be done as subtracting the bias and dividing by the scale:

z_{j+1:d} = \left(z'_{j+1:d} - \mu_{\theta}(z_{1:j})\right) / \sigma_{\theta}(z_{1:j})

We can also visualize the coupling layer in form of a computation graph, where z_1 represents z_{1:j}, and z_2 represents z_{j+1:d}:

2855744839e54995b077b4d7e71c6f33

In our implementation, we will realize the splitting of variables as masking. The variables to be transformed, z_{j+1:d}, are masked when passing z to the shared network to predict the transformation parameters. When applying the transformation, we mask the parameters for z_{1:j} so that we have an identity operation for those variables:

[12]:
class CouplingLayer(nn.Module):
    def __init__(self, network, mask, c_in):
        """Coupling layer inside a normalizing flow.

        Args:
            network: A PyTorch nn.Module constituting the deep neural network for mu and sigma.
                      Output shape should be twice the channel size as the input.
            mask: Binary mask (0 or 1) where 0 denotes that the element should be transformed,
                   while 1 means the latent will be used as input to the NN.
            c_in: Number of input channels
        """
        super().__init__()
        self.network = network
        self.scaling_factor = nn.Parameter(torch.zeros(c_in))
        # Register mask as buffer as it is a tensor which is not a parameter,
        # but should be part of the modules state.
        self.register_buffer("mask", mask)

    def forward(self, z, ldj, reverse=False, orig_img=None):
        """
        Args:
            z: Latent input to the flow
            ldj: The current ldj of the previous flows.
                  The ldj of this layer will be added to this tensor.
            reverse: If True, we apply the inverse of the layer.
            orig_img (optional): Only needed in VarDeq. Allows external
                                  input to condition the flow on (e.g. original image)
        """
        # Apply network to masked input
        z_in = z * self.mask
        if orig_img is None:
            nn_out = self.network(z_in)
        else:
            nn_out = self.network(torch.cat([z_in, orig_img], dim=1))
        s, t = nn_out.chunk(2, dim=1)

        # Stabilize scaling output
        s_fac = self.scaling_factor.exp().view(1, -1, 1, 1)
        s = torch.tanh(s / s_fac) * s_fac

        # Mask outputs (only transform the second part)
        s = s * (1 - self.mask)
        t = t * (1 - self.mask)

        # Affine transformation
        if not reverse:
            # Whether we first shift and then scale, or the other way round,
            # is a design choice, and usually does not have a big impact
            z = (z + t) * torch.exp(s)
            ldj += s.sum(dim=[1, 2, 3])
        else:
            z = (z * torch.exp(-s)) - t
            ldj -= s.sum(dim=[1, 2, 3])

        return z, ldj

For stabilization purposes, we apply a \tanh activation function on the scaling output. This prevents sudden large output values for the scaling that can destabilize training. To still allow scaling factors smaller or larger than -1 and 1 respectively, we have a learnable parameter per dimension, called scaling_factor. This scales the tanh to different limits. Below, we visualize the effect of the scaling factor on the output activation of the scaling terms:

[13]:
with torch.no_grad():
    x = torch.arange(-5, 5, 0.01)
    scaling_factors = [0.5, 1, 2]
    sns.set()
    fig, ax = plt.subplots(1, 3, figsize=(12, 3))
    for i, scale in enumerate(scaling_factors):
        y = torch.tanh(x / scale) * scale
        ax[i].plot(x.numpy(), y.numpy())
        ax[i].set_title("Scaling factor: " + str(scale))
        ax[i].set_ylim(-3, 3)
    plt.subplots_adjust(wspace=0.4)
    sns.reset_orig()
    plt.show()
_images/notebooks_course_UvA-DL_09-normalizing-flows_30_0.svg

Coupling layers generalize to any masking technique we could think of. However, the most common approach for images is to split the input z in half, using a checkerboard mask or channel mask. A checkerboard mask splits the variables across the height and width dimensions and assigns each other pixel to z_{j+1:d}. Thereby, the mask is shared across channels. In contrast, the channel mask assigns half of the channels to z_{j+1:d}, and the other half to z_{1:j+1}. Note that when we apply multiple coupling layers, we invert the masking for each other layer so that each variable is transformed a similar amount of times.

Let’s implement a function that creates a checkerboard mask and a channel mask for us:

[14]:
def create_checkerboard_mask(h, w, invert=False):
    x, y = torch.arange(h, dtype=torch.int32), torch.arange(w, dtype=torch.int32)
    xx, yy = torch.meshgrid(x, y)
    mask = torch.fmod(xx + yy, 2)
    mask = mask.to(torch.float32).view(1, 1, h, w)
    if invert:
        mask = 1 - mask
    return mask


def create_channel_mask(c_in, invert=False):
    mask = torch.cat([torch.ones(c_in // 2, dtype=torch.float32), torch.zeros(c_in - c_in // 2, dtype=torch.float32)])
    mask = mask.view(1, c_in, 1, 1)
    if invert:
        mask = 1 - mask
    return mask

We can also visualize the corresponding masks for an image of size 8\times 8\times 2 (2 channels):

[15]:
checkerboard_mask = create_checkerboard_mask(h=8, w=8).expand(-1, 2, -1, -1)
channel_mask = create_channel_mask(c_in=2).expand(-1, -1, 8, 8)

show_imgs(checkerboard_mask.transpose(0, 1), "Checkerboard mask")
show_imgs(channel_mask.transpose(0, 1), "Channel mask")
_images/notebooks_course_UvA-DL_09-normalizing-flows_34_0.svg
_images/notebooks_course_UvA-DL_09-normalizing-flows_34_1.svg

As a last aspect of coupling layers, we need to decide for the deep neural network we want to apply in the coupling layers. The input to the layers is an image, and hence we stick with a CNN. Because the input to a transformation depends on all transformations before, it is crucial to ensure a good gradient flow through the CNN back to the input, which can be optimally achieved by a ResNet-like architecture. Specifically, we use a Gated ResNet that adds a \sigma-gate to the skip connection, similarly to the input gate in LSTMs. The details are not necessarily important here, and the network is strongly inspired from Flow++ [3] in case you are interested in building even stronger models.

[16]:
class ConcatELU(nn.Module):
    """Activation function that applies ELU in both direction (inverted and plain).

    Allows non-linearity while providing strong gradients for any input (important for final convolution)
    """

    def forward(self, x):
        return torch.cat([F.elu(x), F.elu(-x)], dim=1)


class LayerNormChannels(nn.Module):
    def __init__(self, c_in):
        """This module applies layer norm across channels in an image.

        Has been shown to work well with ResNet connections.
        Args:
            c_in: Number of channels of the input
        """
        super().__init__()
        self.layer_norm = nn.LayerNorm(c_in)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = self.layer_norm(x)
        x = x.permute(0, 3, 1, 2)
        return x


class GatedConv(nn.Module):
    def __init__(self, c_in, c_hidden):
        """
        This module applies a two-layer convolutional ResNet block with input gate
        Args:
            c_in: Number of channels of the input
            c_hidden: Number of hidden dimensions we want to model (usually similar to c_in)
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_hidden, kernel_size=3, padding=1),
            ConcatELU(),
            nn.Conv2d(2 * c_hidden, 2 * c_in, kernel_size=1),
        )

    def forward(self, x):
        out = self.net(x)
        val, gate = out.chunk(2, dim=1)
        return x + val * torch.sigmoid(gate)


class GatedConvNet(nn.Module):
    def __init__(self, c_in, c_hidden=32, c_out=-1, num_layers=3):
        """Module that summarizes the previous blocks to a full convolutional neural network.

        Args:
            c_in: Number of input channels
            c_hidden: Number of hidden dimensions to use within the network
            c_out: Number of output channels. If -1, 2 times the input channels are used (affine coupling)
            num_layers: Number of gated ResNet blocks to apply
        """
        super().__init__()
        c_out = c_out if c_out > 0 else 2 * c_in
        layers = []
        layers += [nn.Conv2d(c_in, c_hidden, kernel_size=3, padding=1)]
        for layer_index in range(num_layers):
            layers += [GatedConv(c_hidden, c_hidden), LayerNormChannels(c_hidden)]
        layers += [ConcatELU(), nn.Conv2d(2 * c_hidden, c_out, kernel_size=3, padding=1)]
        self.nn = nn.Sequential(*layers)

        self.nn[-1].weight.data.zero_()
        self.nn[-1].bias.data.zero_()

    def forward(self, x):
        return self.nn(x)
Training loop

Finally, we can add Dequantization, Variational Dequantization and Coupling Layers together to build our full normalizing flow on MNIST images. We apply 8 coupling layers in the main flow, and 4 for variational dequantization if applied. We apply a checkerboard mask throughout the network as with a single channel (black-white images), we cannot apply channel mask. The overall architecture is visualized below.

b2ad64cee0f748eebbb36cc742594817

[17]:
def create_simple_flow(use_vardeq=True):
    flow_layers = []
    if use_vardeq:
        vardeq_layers = [
            CouplingLayer(
                network=GatedConvNet(c_in=2, c_out=2, c_hidden=16),
                mask=create_checkerboard_mask(h=28, w=28, invert=(i % 2 == 1)),
                c_in=1,
            )
            for i in range(4)
        ]
        flow_layers += [VariationalDequantization(var_flows=vardeq_layers)]
    else:
        flow_layers += [Dequantization()]

    for i in range(8):
        flow_layers += [
            CouplingLayer(
                network=GatedConvNet(c_in=1, c_hidden=32),
                mask=create_checkerboard_mask(h=28, w=28, invert=(i % 2 == 1)),
                c_in=1,
            )
        ]

    flow_model = ImageFlow(flow_layers).to(device)
    return flow_model

For implementing the training loop, we use the framework of PyTorch Lightning and reduce the code overhead. If interested, you can take a look at the generated tensorboard file, in particularly the graph to see an overview of flow transformations that are applied. Note that we again provide pre-trained models (see later on in the notebook) as normalizing flows are particularly expensive to train. We have also run validation and testing as this can take some time as well with the added importance sampling.

[18]:
def train_flow(flow, model_name="MNISTFlow"):
    # Create a PyTorch Lightning trainer
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, model_name),
        gpus=1 if torch.cuda.is_available() else 0,
        max_epochs=200,
        gradient_clip_val=1.0,
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode="min", monitor="val_bpd"),
            LearningRateMonitor("epoch"),
        ],
    )
    trainer.logger._log_graph = True
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    train_data_loader = data.DataLoader(
        train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=8
    )
    result = None

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, model_name + ".ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        ckpt = torch.load(pretrained_filename, map_location=device)
        flow.load_state_dict(ckpt["state_dict"])
        result = ckpt.get("result", None)
    else:
        print("Start training", model_name)
        trainer.fit(flow, train_data_loader, val_loader)

    # Test best model on validation and test set if no result has been found
    # Testing can be expensive due to the importance sampling.
    if result is None:
        val_result = trainer.test(flow, dataloaders=val_loader, verbose=False)
        start_time = time.time()
        test_result = trainer.test(flow, dataloaders=test_loader, verbose=False)
        duration = time.time() - start_time
        result = {"test": test_result, "val": val_result, "time": duration / len(test_loader) / flow.import_samples}

    return flow, result

Multi-scale architecture

One disadvantage of normalizing flows is that they operate on the exact same dimensions as the input. If the input is high-dimensional, so is the latent space, which requires larger computational cost to learn suitable transformations. However, particularly in the image domain, many pixels contain less information in the sense that we could remove them without loosing the semantical information of the image.

Based on this intuition, deep normalizing flows on images commonly apply a multi-scale architecture [1]. After the first N flow transformations, we split off half of the latent dimensions and directly evaluate them on the prior. The other half is run through N more flow transformations, and depending on the size of the input, we split it again in half or stop overall at this position. The two operations involved in this setup is Squeeze and Split which we will review more closely and implement below.

Squeeze and Split

When we want to remove half of the pixels in an image, we have the problem of deciding which variables to cut, and how to rearrange the image. Thus, the squeezing operation is commonly used before split, which divides the image into subsquares of shape 2\times 2\times C, and reshapes them into 1\times 1\times 4C blocks. Effectively, we reduce the height and width of the image by a factor of 2 while scaling the number of channels by 4. Afterwards, we can perform the split operation over channels without the need of rearranging the pixels. The smaller scale also makes the overall architecture more efficient. Visually, the squeeze operation should transform the input as follows:

84b0e470c74e47c4944de9e5dfc43f82

The input of 4\times 4\times 1 is scaled to 2\times 2\times 4 following the idea of grouping the pixels in 2\times 2\times 1 subsquares. Next, let’s try to implement this layer:

[19]:
class SqueezeFlow(nn.Module):
    def forward(self, z, ldj, reverse=False):
        B, C, H, W = z.shape
        if not reverse:
            # Forward direction: H x W x C => H/2 x W/2 x 4C
            z = z.reshape(B, C, H // 2, 2, W // 2, 2)
            z = z.permute(0, 1, 3, 5, 2, 4)
            z = z.reshape(B, 4 * C, H // 2, W // 2)
        else:
            # Reverse direction: H/2 x W/2 x 4C => H x W x C
            z = z.reshape(B, C // 4, 2, 2, H, W)
            z = z.permute(0, 1, 4, 2, 5, 3)
            z = z.reshape(B, C // 4, H * 2, W * 2)
        return z, ldj

Before moving on, we can verify our implementation by comparing our output with the example figure above:

[20]:
sq_flow = SqueezeFlow()
rand_img = torch.arange(1, 17).view(1, 1, 4, 4)
print("Image (before)\n", rand_img)
forward_img, _ = sq_flow(rand_img, ldj=None, reverse=False)
print("\nImage (forward)\n", forward_img.permute(0, 2, 3, 1))  # Permute for readability
reconst_img, _ = sq_flow(forward_img, ldj=None, reverse=True)
print("\nImage (reverse)\n", reconst_img)
Image (before)
 tensor([[[[ 1,  2,  3,  4],
          [ 5,  6,  7,  8],
          [ 9, 10, 11, 12],
          [13, 14, 15, 16]]]])

Image (forward)
 tensor([[[[ 1,  2,  5,  6],
          [ 3,  4,  7,  8]],

         [[ 9, 10, 13, 14],
          [11, 12, 15, 16]]]])

Image (reverse)
 tensor([[[[ 1,  2,  3,  4],
          [ 5,  6,  7,  8],
          [ 9, 10, 11, 12],
          [13, 14, 15, 16]]]])

The split operation divides the input into two parts, and evaluates one part directly on the prior. So that our flow operation fits to the implementation of the previous layers, we will return the prior probability of the first part as the log determinant jacobian of the layer. It has the same effect as if we would combine all variable splits at the end of the flow, and evaluate them together on the prior.

[21]:
class SplitFlow(nn.Module):
    def __init__(self):
        super().__init__()
        self.prior = torch.distributions.normal.Normal(loc=0.0, scale=1.0)

    def forward(self, z, ldj, reverse=False):
        if not reverse:
            z, z_split = z.chunk(2, dim=1)
            ldj += self.prior.log_prob(z_split).sum(dim=[1, 2, 3])
        else:
            z_split = self.prior.sample(sample_shape=z.shape).to(device)
            z = torch.cat([z, z_split], dim=1)
            ldj -= self.prior.log_prob(z_split).sum(dim=[1, 2, 3])
        return z, ldj
Building a multi-scale flow

After defining the squeeze and split operation, we are finally able to build our own multi-scale flow. Deep normalizing flows such as Glow and Flow++ [2,3] often apply a split operation directly after squeezing. However, with shallow flows, we need to be more thoughtful about where to place the split operation as we need at least a minimum amount of transformations on each variable. Our setup is inspired by the original RealNVP architecture [1] which is shallower than other, more recent state-of-the-art architectures.

Hence, for the MNIST dataset, we will apply the first squeeze operation after two coupling layers, but don’t apply a split operation yet. Because we have only used two coupling layers and each the variable has been only transformed once, a split operation would be too early. We apply two more coupling layers before finally applying a split flow and squeeze again. The last four coupling layers operate on a scale of 7\times 7\times 8. The full flow architecture is shown below.

488a2754f4064231bef7d08f9fc09b7b

Note that while the feature maps inside the coupling layers reduce with the height and width of the input, the increased number of channels is not directly considered. To counteract this, we increase the hidden dimensions for the coupling layers on the squeezed input. The dimensions are often scaled by 2 as this approximately increases the computation cost by 4 canceling with the squeezing operation. However, we will choose the hidden dimensionalities 32, 48, 64 for the three scales respectively to keep the number of parameters reasonable and show the efficiency of multi-scale architectures.

[22]:
def create_multiscale_flow():
    flow_layers = []

    vardeq_layers = [
        CouplingLayer(
            network=GatedConvNet(c_in=2, c_out=2, c_hidden=16),
            mask=create_checkerboard_mask(h=28, w=28, invert=(i % 2 == 1)),
            c_in=1,
        )
        for i in range(4)
    ]
    flow_layers += [VariationalDequantization(vardeq_layers)]

    flow_layers += [
        CouplingLayer(
            network=GatedConvNet(c_in=1, c_hidden=32),
            mask=create_checkerboard_mask(h=28, w=28, invert=(i % 2 == 1)),
            c_in=1,
        )
        for i in range(2)
    ]
    flow_layers += [SqueezeFlow()]
    for i in range(2):
        flow_layers += [
            CouplingLayer(
                network=GatedConvNet(c_in=4, c_hidden=48), mask=create_channel_mask(c_in=4, invert=(i % 2 == 1)), c_in=4
            )
        ]
    flow_layers += [SplitFlow(), SqueezeFlow()]
    for i in range(4):
        flow_layers += [
            CouplingLayer(
                network=GatedConvNet(c_in=8, c_hidden=64), mask=create_channel_mask(c_in=8, invert=(i % 2 == 1)), c_in=8
            )
        ]

    flow_model = ImageFlow(flow_layers).to(device)
    return flow_model

We can show the difference in number of parameters below:

[23]:
def print_num_params(model):
    num_params = sum(np.prod(p.shape) for p in model.parameters())
    print(f"Number of parameters: {num_params:,}")


print_num_params(create_simple_flow(use_vardeq=False))
print_num_params(create_simple_flow(use_vardeq=True))
print_num_params(create_multiscale_flow())
Number of parameters: 335,128
Number of parameters: 379,556
Number of parameters: 1,062,090

Although the multi-scale flow has almost 3 times the parameters of the single scale flow, it is not necessarily more computationally expensive than its counterpart. We will compare the runtime in the following experiments as well.

Analysing the flows

In the last part of the notebook, we will train all the models we have implemented above, and try to analyze the effect of the multi-scale architecture and variational dequantization.

Training flow variants

Before we can analyse the flow models, we need to train them first. We provide pre-trained models that contain the validation and test performance, and run-time information. As flow models are computationally expensive, we advice you to rely on those pretrained models for a first run through the notebook.

[24]:
flow_dict = {"simple": {}, "vardeq": {}, "multiscale": {}}
flow_dict["simple"]["model"], flow_dict["simple"]["result"] = train_flow(
    create_simple_flow(use_vardeq=False), model_name="MNISTFlow_simple"
)
flow_dict["vardeq"]["model"], flow_dict["vardeq"]["result"] = train_flow(
    create_simple_flow(use_vardeq=True), model_name="MNISTFlow_vardeq"
)
flow_dict["multiscale"]["model"], flow_dict["multiscale"]["result"] = train_flow(
    create_multiscale_flow(), model_name="MNISTFlow_multiscale"
)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Found pretrained model, loading...
Found pretrained model, loading...
Found pretrained model, loading...
Density modeling and sampling

Firstly, we can compare the models on their quantitative results. The following table shows all important statistics. The inference time specifies the time needed to determine the probability for a batch of 64 images for each model, and the sampling time the duration it took to sample a batch of 64 images.

[25]:
%%html
<!-- Some HTML code to increase font size in the following table -->
<style>
th {font-size: 120%;}
td {font-size: 120%;}
</style>
[26]:

table = [ [ key, "%4.3f bpd" % flow_dict[key]["result"]["val"][0]["test_bpd"], "%4.3f bpd" % flow_dict[key]["result"]["test"][0]["test_bpd"], "%2.0f ms" % (1000 * flow_dict[key]["result"]["time"]), "%2.0f ms" % (1000 * flow_dict[key]["result"].get("samp_time", 0)), "{:,}".format(sum(np.prod(p.shape) for p in flow_dict[key]["model"].parameters())), ] for key in flow_dict ] display( HTML( tabulate.tabulate( table, tablefmt="html", headers=["Model", "Validation Bpd", "Test Bpd", "Inference time", "Sampling time", "Num Parameters"], ) ) )
Model Validation Bpd Test Bpd Inference time Sampling time Num Parameters
simple 1.109 bpd 1.107 bpd 51 ms 50 ms 335,128
vardeq 1.068 bpd 1.066 bpd 69 ms 50 ms 379,556
multiscale1.029 bpd 1.026 bpd 40 ms 22 ms 1,062,090

As we have intially expected, using variational dequantization improves upon standard dequantization in terms of bits per dimension. Although the difference with 0.04bpd doesn’t seem impressive first, it is a considerably step for generative models (most state-of-the-art models improve upon previous models in a range of 0.02-0.1bpd on CIFAR with three times as high bpd). While it takes longer to evaluate the probability of an image due to the variational dequantization, which also leads to a longer training time, it does not have an effect on the sampling time. This is because inverting variational dequantization is the same as dequantization: finding the next lower integer.

When we compare the two models to multi-scale architecture, we can see that the bits per dimension score again dropped by about 0.04bpd. Additionally, the inference time and sampling time improved notably despite having more parameters. Thus, we see that the multi-scale flow is not only stronger for density modeling, but also more efficient.

Next, we can test the sampling quality of the models. We should note that the samples for variational dequantization and standard dequantization are very similar, and hence we visualize here only the ones for variational dequantization and the multi-scale model. However, feel free to also test out the "simple" model. The seeds are set to obtain reproducable generations and are not cherry picked.

[27]:
pl.seed_everything(44)
samples = flow_dict["vardeq"]["model"].sample(img_shape=[16, 1, 28, 28])
show_imgs(samples.cpu())
Global seed set to 44
_images/notebooks_course_UvA-DL_09-normalizing-flows_59_1.svg
[28]:
pl.seed_everything(44)
samples = flow_dict["multiscale"]["model"].sample(img_shape=[16, 8, 7, 7])
show_imgs(samples.cpu())
Global seed set to 44
_images/notebooks_course_UvA-DL_09-normalizing-flows_60_1.svg

From the few samples, we can see a clear difference between the simple and the multi-scale model. The single-scale model has only learned local, small correlations while the multi-scale model was able to learn full, global relations that form digits. This show-cases another benefit of the multi-scale model. In contrast to VAEs, the outputs are sharp as normalizing flows can naturally model complex, multi-modal distributions while VAEs have the independent decoder output noise. Nevertheless, the samples from this flow are far from perfect as not all samples show true digits.

Interpolation in latent space

Another popular test for the smoothness of the latent space of generative models is to interpolate between two training examples. As normalizing flows are strictly invertible, we can guarantee that any image is represented in the latent space. We again compare the variational dequantization model with the multi-scale model below.

[29]:
@torch.no_grad()
def interpolate(model, img1, img2, num_steps=8):
    """
    Args:
        model: object of ImageFlow class that represents the (trained) flow model
        img1, img2: Image tensors of shape [1, 28, 28]. Images between which should be interpolated.
        num_steps: Number of interpolation steps. 8 interpolation steps mean 6 intermediate pictures besides img1 and img2
    """
    imgs = torch.stack([img1, img2], dim=0).to(model.device)
    z, _ = model.encode(imgs)
    alpha = torch.linspace(0, 1, steps=num_steps, device=z.device).view(-1, 1, 1, 1)
    interpolations = z[0:1] * alpha + z[1:2] * (1 - alpha)
    interp_imgs = model.sample(interpolations.shape[:1] + imgs.shape[1:], z_init=interpolations)
    show_imgs(interp_imgs, row_size=8)


exmp_imgs, _ = next(iter(train_loader))
[30]:
pl.seed_everything(42)
for i in range(2):
    interpolate(flow_dict["vardeq"]["model"], exmp_imgs[2 * i], exmp_imgs[2 * i + 1])
Global seed set to 42
_images/notebooks_course_UvA-DL_09-normalizing-flows_64_1.svg
_images/notebooks_course_UvA-DL_09-normalizing-flows_64_2.svg
[31]:
pl.seed_everything(42)
for i in range(2):
    interpolate(flow_dict["multiscale"]["model"], exmp_imgs[2 * i], exmp_imgs[2 * i + 1])
Global seed set to 42
_images/notebooks_course_UvA-DL_09-normalizing-flows_65_1.svg
_images/notebooks_course_UvA-DL_09-normalizing-flows_65_2.svg

The interpolations of the multi-scale model result in more realistic digits (first row 7\leftrightarrow 8\leftrightarrow 6, second row 9\leftrightarrow 4\leftrightarrow 6), while the variational dequantization model focuses on local patterns that globally do not form a digit. For the multi-scale model, we actually did not do the “true” interpolation between the two images as we did not consider the variables that were split along the flow (they have been sampled randomly for all samples). However, as we will see in the next experiment, the early variables do not effect the overall image much.

Visualization of latents in different levels of multi-scale

In the following we will focus more on the multi-scale flow. We want to analyse what information is being stored in the variables split at early layers, and what information for the final variables. For this, we sample 8 images where each of them share the same final latent variables, but differ in the other part of the latent variables. Below we visualize three examples of this:

[32]:
pl.seed_everything(44)
for _ in range(3):
    z_init = flow_dict["multiscale"]["model"].prior.sample(sample_shape=[1, 8, 7, 7])
    z_init = z_init.expand(8, -1, -1, -1)
    samples = flow_dict["multiscale"]["model"].sample(img_shape=z_init.shape, z_init=z_init)
    show_imgs(samples.cpu())
Global seed set to 44
_images/notebooks_course_UvA-DL_09-normalizing-flows_68_1.svg
_images/notebooks_course_UvA-DL_09-normalizing-flows_68_2.svg
_images/notebooks_course_UvA-DL_09-normalizing-flows_68_3.svg

We see that the early split variables indeed have a smaller effect on the image. Still, small differences can be spot when we look carefully at the borders of the digits. For instance, the hole at the top of the 8 changes for different samples although all of them represent the same coarse structure. This shows that the flow indeed learns to separate the higher-level information in the final variables, while the early split ones contain local noise patterns.

Visualizing Dequantization

As a final part of this notebook, we will look at the effect of variational dequantization. We have motivated variational dequantization by the issue of sharp edges/boarders being difficult to model, and a flow would rather prefer smooth, prior-like distributions. To check how what noise distribution q(u|x) the flows in the variational dequantization module have learned, we can plot a histogram of output values from the dequantization and variational dequantization module.

[33]:
def visualize_dequant_distribution(model: ImageFlow, imgs: Tensor, title: str = None):
    """
    Args:
        model: The flow of which we want to visualize the dequantization distribution
        imgs: Example training images of which we want to visualize the dequantization distribution
    """
    imgs = imgs.to(device)
    ldj = torch.zeros(imgs.shape[0], dtype=torch.float32).to(device)
    with torch.no_grad():
        dequant_vals = []
        for _ in tqdm(range(8), leave=False):
            d, _ = model.flows[0](imgs, ldj, reverse=False)
            dequant_vals.append(d)
        dequant_vals = torch.cat(dequant_vals, dim=0)
    dequant_vals = dequant_vals.view(-1).cpu().numpy()
    sns.set()
    plt.figure(figsize=(10, 3))
    plt.hist(dequant_vals, bins=256, color=to_rgb("C0") + (0.5,), edgecolor="C0", density=True)
    if title is not None:
        plt.title(title)
    plt.show()
    plt.close()


sample_imgs, _ = next(iter(train_loader))
[34]:
visualize_dequant_distribution(flow_dict["simple"]["model"], sample_imgs, title="Dequantization")
_images/notebooks_course_UvA-DL_09-normalizing-flows_72_1.svg
[35]:
visualize_dequant_distribution(flow_dict["vardeq"]["model"], sample_imgs, title="Variational dequantization")
_images/notebooks_course_UvA-DL_09-normalizing-flows_73_1.svg

The dequantization distribution in the first plot shows that the MNIST images have a strong bias towards 0 (black), and the distribution of them have a sharp border as mentioned before. The variational dequantization module has indeed learned a much smoother distribution with a Gaussian-like curve which can be modeled much better. For the other values, we would need to visualize the distribution q(u|x) on a deeper level, depending on x. However, as all u’s interact and depend on each other, we would need to visualize a distribution in 784 dimensions, which is not that intuitive anymore.

Conclusion

In conclusion, we have seen how to implement our own normalizing flow, and what difficulties arise if we want to apply them on images. Dequantization is a crucial step in mapping the discrete images into continuous space to prevent underisable delta-peak solutions. While dequantization creates hypercubes with hard border, variational dequantization allows us to fit a flow much better on the data. This allows us to obtain a lower bits per dimension score, while not affecting the sampling speed. The most common flow element, the coupling layer, is simple to implement, and yet effective. Furthermore, multi-scale architectures help to capture the global image context while allowing us to efficiently scale up the flow. Normalizing flows are an interesting alternative to VAEs as they allow an exact likelihood estimate in continuous space, and we have the guarantee that every possible input x has a corresponding latent vector z. However, even beyond continuous inputs and images, flows can be applied and allow us to exploit the data structure in latent space, as e.g. on graphs for the task of molecule generation [6]. Recent advances in Neural ODEs allow a flow with infinite number of layers, called Continuous Normalizing Flows, whose potential is yet to fully explore. Overall, normalizing flows are an exciting research area which will continue over the next couple of years.

References

[1] Dinh, L., Sohl-Dickstein, J., and Bengio, S. (2017). “Density estimation using Real NVP,” In: 5th International Conference on Learning Representations, ICLR 2017. Link

[2] Kingma, D. P., and Dhariwal, P. (2018). “Glow: Generative Flow with Invertible 1x1 Convolutions,” In: Advances in Neural Information Processing Systems, vol. 31, pp. 10215–10224. Link

[3] Ho, J., Chen, X., Srinivas, A., Duan, Y., and Abbeel, P. (2019). “Flow++: Improving Flow-Based Generative Models with Variational Dequantization and Architecture Design,” in Proceedings of the 36th International Conference on Machine Learning, vol. 97, pp. 2722–2730. Link

[4] Durkan, C., Bekasov, A., Murray, I., and Papamakarios, G. (2019). “Neural Spline Flows,” In: Advances in Neural Information Processing Systems, pp. 7509–7520. Link

[5] Hoogeboom, E., Cohen, T. S., and Tomczak, J. M. (2020). “Learning Discrete Distributions by Dequantization,” arXiv preprint arXiv2001.11235v1. Link

[6] Lippe, P., and Gavves, E. (2021). “Categorical Normalizing Flows via Continuous Transformations,” In: International Conference on Learning Representations, ICLR 2021. Link

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Tutorial 10: Autoregressive Image Modeling

  • Author: Phillip Lippe

  • License: CC BY-SA

  • Generated: 2022-05-03T02:43:16.674251

In this tutorial, we implement an autoregressive likelihood model for the task of image modeling. Autoregressive models are naturally strong generative models that constitute one of the current state-of-the-art architectures on likelihood-based image modeling, and are also the basis for large language generation models such as GPT3. We will focus on the PixelCNN architecture in this tutorial, and apply it to MNIST modeling. This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "ipython[notebook]" "torchmetrics>=0.7" "setuptools==59.5.0" "matplotlib" "seaborn" "torchvision" "pytorch-lightning>=1.4" "torch>=1.8"

Similar to the language generation you have seen in assignment 2, autoregressive models work on images by modeling the likelihood of a pixel given all previous ones. For instance, in the picture below, we model the pixel x_i as a conditional probability distribution based on all previous (here blue) pixels (figure credit - Aaron van den Oord et al. ):

6c68c59c7df640cb971ee567abfaccca

Generally, autoregressive model over high-dimensional data \mathbf{x} factor the joint distribution as the following product of conditionals:

p(\mathbf{x})=p(x_1, ..., x_n)=\prod_{i=1}^{n} p(x_i|x_1,...,x_{i-1})

Learning these conditionals is often much simpler than learning the joint distribution p(\mathbf{x}) all together. However, disadvantages of autoregressive models include slow sampling, especially for large images, as we need height-times-width forward passes through the model. In addition, for some applications, we require a latent space as modeled in VAEs and Normalizing Flows. For instance, in autoregressive models, we cannot interpolate between two images because of the lack of a latent representation. We will explore and discuss these benefits and drawbacks alongside with our implementation.

Our implementation will focus on the PixelCNN [2] model which has been discussed in detail in the lecture. Most current SOTA models use PixelCNN as their fundamental architecture, and various additions have been proposed to improve the performance (e.g. PixelCNN++ and PixelSNAIL). Hence, implementing PixelCNN is a good starting point for our short tutorial.

First of all, we need to import our standard libraries. Similarly as in the last couple of tutorials, we will use PyTorch Lightning here as well.

[2]:

import math import os import urllib.request from urllib.error import HTTPError # Imports for plotting import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl import seaborn as sns import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch.utils.data as data import torchvision from IPython.display import set_matplotlib_formats from matplotlib.colors import to_rgb from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from torch import Tensor from torchvision import transforms from torchvision.datasets import MNIST from tqdm.notebook import tqdm plt.set_cmap("cividis") %matplotlib inline set_matplotlib_formats("svg", "pdf") # For export # Path to the folder where the datasets are/should be downloaded (e.g. MNIST) DATASET_PATH = os.environ.get("PATH_DATASETS", "data") # Path to the folder where the pretrained models are saved CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/tutorial12") # Setting the seed pl.seed_everything(42) # Ensure that all operations are deterministic on GPU (if used) for reproducibility torch.backends.cudnn.determinstic = True torch.backends.cudnn.benchmark = False # Fetching the device that will be used throughout this notebook device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0") print("Using device", device)
/usr/lib/python3.8/site-packages/apex/pyprof/__init__.py:5: FutureWarning: pyprof will be removed by the end of June, 2022
  warnings.warn("pyprof will be removed by the end of June, 2022", FutureWarning)
WARNING:root:Bagua cannot detect bundled NCCL library, Bagua will try to use system NCCL instead. If you encounter any error, please run `import bagua_core; bagua_core.install_deps()` or the `bagua_install_deps.py` script to install bundled libraries.
/tmp/ipykernel_4215/1092408338.py:27: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
  set_matplotlib_formats("svg", "pdf")  # For export
Global seed set to 42
Using device cuda:0

We again provide a pretrained model, which is downloaded below:

[3]:
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial12/"
# Files to download
pretrained_files = ["PixelCNN.ckpt"]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print("Downloading %s..." % file_url)
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n",
                e,
            )
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial12/PixelCNN.ckpt...

Similar to the Normalizing Flows in Tutorial 11, we will work on the MNIST dataset and use 8-bits per pixel (values between 0 and 255). The dataset is loaded below:

[4]:
# Convert images from 0-1 to 0-255 (integers). We use the long datatype as we will use the images as labels as well
def discretize(sample):
    return (sample * 255).to(torch.long)


# Transformations applied on each image => only make them a tensor
transform = transforms.Compose([transforms.ToTensor(), discretize])

# Loading the training dataset. We need to split it into a training and validation part
train_dataset = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)
pl.seed_everything(42)
train_set, val_set = torch.utils.data.random_split(train_dataset, [50000, 10000])

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
/usr/local/lib/python3.8/dist-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:180.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Global seed set to 42

A good practice is to always visualize some data examples to get an intuition of the data:

[5]:
def show_imgs(imgs):
    num_imgs = imgs.shape[0] if isinstance(imgs, Tensor) else len(imgs)
    nrow = min(num_imgs, 4)
    ncol = int(math.ceil(num_imgs / nrow))
    imgs = torchvision.utils.make_grid(imgs, nrow=nrow, pad_value=128)
    imgs = imgs.clamp(min=0, max=255)
    np_imgs = imgs.cpu().numpy()
    plt.figure(figsize=(1.5 * nrow, 1.5 * ncol))
    plt.imshow(np.transpose(np_imgs, (1, 2, 0)), interpolation="nearest")
    plt.axis("off")
    plt.show()
    plt.close()


show_imgs([train_set[i][0] for i in range(8)])
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_10_0.svg

Masked autoregressive convolutions

The core module of PixelCNN is its masked convolutions. In contrast to language models, we don’t apply an LSTM on each pixel one-by-one. This would be inefficient because images are grids instead of sequences. Thus, it is better to rely on convolutions that have shown great success in deep CNN classification models.

Nevertheless, we cannot just apply standard convolutions without any changes. Remember that during training of autoregressive models, we want to use teacher forcing which both helps the model training, and significantly reduces the time needed for training. For image modeling, teacher forcing is implemented by using a training image as input to the model, and we want to obtain as output the prediction for each pixel based on only its predecessors. Thus, we need to ensure that the prediction for a specific pixel can only be influenced by its predecessors and not by its own value or any “future” pixels. For this, we apply convolutions with a mask.

Which mask we use depends on the ordering of pixels we decide on, i.e. which is the first pixel we predict, which is the second one, etc. The most commonly used ordering is to denote the upper left pixel as the start pixel, and sort the pixels row by row, as shown in the visualization at the top of the tutorial. Thus, the second pixel is on the right of the first one (first row, second column), and once we reach the end of the row, we start in the second row, first column. If we now want to apply this to our convolutions, we need to ensure that the prediction of pixel 1 is not influenced by its own “true” input, and all pixels on its right and in any lower row. In convolutions, this means that we want to set those entries of the weight matrix to zero that take pixels on the right and below into account. As an example for a 5x5 kernel, see a mask below (figure credit - Aaron van den Oord):

1a2c0dd8710a4cb79deeb8ed482b8cc3

Before looking into the application of masked convolutions in PixelCNN in detail, let’s first implement a module that allows us to apply an arbitrary mask to a convolution:

[6]:
class MaskedConvolution(nn.Module):
    def __init__(self, c_in, c_out, mask, **kwargs):
        """Implements a convolution with mask applied on its weights.

        Args:
            c_in: Number of input channels
            c_out: Number of output channels
            mask: Tensor of shape [kernel_size_H, kernel_size_W] with 0s where
                   the convolution should be masked, and 1s otherwise.
            kwargs: Additional arguments for the convolution
        """
        super().__init__()
        # For simplicity: calculate padding automatically
        kernel_size = (mask.shape[0], mask.shape[1])
        dilation = 1 if "dilation" not in kwargs else kwargs["dilation"]
        padding = tuple(dilation * (kernel_size[i] - 1) // 2 for i in range(2))
        # Actual convolution
        self.conv = nn.Conv2d(c_in, c_out, kernel_size, padding=padding, **kwargs)

        # Mask as buffer => it is no parameter but still a tensor of the module
        # (must be moved with the devices)
        self.register_buffer("mask", mask[None, None])

    def forward(self, x):
        self.conv.weight.data *= self.mask  # Ensures zero's at masked positions
        return self.conv(x)
Vertical and horizontal convolution stacks

To build our own autoregressive image model, we could simply stack a few masked convolutions on top of each other. This was actually the case for the original PixelCNN model, discussed in the paper Pixel Recurrent Neural Networks, but this leads to a considerable issue. When sequentially applying a couple of masked convolutions, the receptive field of a pixel show to have a “blind spot” on the right upper side, as shown in the figure below (figure credit - Aaron van den Oord et al. ):

4f84ae496679487d8552adf74c9ef2d6

Although a pixel should be able to take into account all other pixels above and left of it, a stack of masked convolutions does not allow us to look to the upper pixels on the right. This is because the features of the pixels above, which we use for convolution, do not contain any information of the pixels on the right of the same row. If they would, we would be “cheating” and actually looking into the future. To overcome this issue, van den Oord et. al [2] proposed to split the convolutions into a vertical and a horizontal stack. The vertical stack looks at all pixels above the current one, while the horizontal takes into account all on the left. While keeping both of them separate, we can actually look at the pixels on the right with the vertical stack without breaking any of our assumptions. The two convolutions are also shown in the figure above.

Let us implement them here as follows:

[7]:
class VerticalStackConvolution(MaskedConvolution):
    def __init__(self, c_in, c_out, kernel_size=3, mask_center=False, **kwargs):
        # Mask out all pixels below. For efficiency, we could also reduce the kernel
        # size in height, but for simplicity, we stick with masking here.
        mask = torch.ones(kernel_size, kernel_size)
        mask[kernel_size // 2 + 1 :, :] = 0

        # For the very first convolution, we will also mask the center row
        if mask_center:
            mask[kernel_size // 2, :] = 0

        super().__init__(c_in, c_out, mask, **kwargs)


class HorizontalStackConvolution(MaskedConvolution):
    def __init__(self, c_in, c_out, kernel_size=3, mask_center=False, **kwargs):
        # Mask out all pixels on the left. Note that our kernel has a size of 1
        # in height because we only look at the pixel in the same row.
        mask = torch.ones(1, kernel_size)
        mask[0, kernel_size // 2 + 1 :] = 0

        # For the very first convolution, we will also mask the center pixel
        if mask_center:
            mask[0, kernel_size // 2] = 0

        super().__init__(c_in, c_out, mask, **kwargs)

Note that we have an input argument called mask_center. Remember that the input to the model is the actual input image. Hence, the very first convolution we apply cannot use the center pixel as input, but must be masked. All consecutive convolutions, however, should use the center pixel as we otherwise lose the features of the previous layer. Hence, the input argument mask_center is True for the very first convolutions, and False for all others.

Visualizing the receptive field

To validate our implementation of masked convolutions, we can visualize the receptive field we obtain with such convolutions. We should see that with increasing number of convolutional layers, the receptive field grows in both vertical and horizontal direction, without the issue of a blind spot. The receptive field can be empirically measured by backpropagating an arbitrary loss for the output features of a speicifc pixel with respect to the input. We implement this idea below, and visualize the receptive field below.

[8]:
inp_img = torch.zeros(1, 1, 11, 11)
inp_img.requires_grad_()


def show_center_recep_field(img, out):
    """Calculates the gradients of the input with respect to the output center pixel, and visualizes the overall
    receptive field.

    Args:
        img: Input image for which we want to calculate the receptive field on.
        out: Output features/loss which is used for backpropagation, and should be
              the output of the network/computation graph.
    """
    # Determine gradients
    loss = out[0, :, img.shape[2] // 2, img.shape[3] // 2].sum()  # L1 loss for simplicity
    # Retain graph as we want to stack multiple layers and show the receptive field of all of them
    loss.backward(retain_graph=True)
    img_grads = img.grad.abs()
    img.grad.fill_(0)  # Reset grads

    # Plot receptive field
    img = img_grads.squeeze().cpu().numpy()
    fig, ax = plt.subplots(1, 2)
    _ = ax[0].imshow(img)
    ax[1].imshow(img > 0)
    # Mark the center pixel in red if it doesn't have any gradients (should be
    # the case for standard autoregressive models)
    show_center = img[img.shape[0] // 2, img.shape[1] // 2] == 0
    if show_center:
        center_pixel = np.zeros(img.shape + (4,))
        center_pixel[center_pixel.shape[0] // 2, center_pixel.shape[1] // 2, :] = np.array([1.0, 0.0, 0.0, 1.0])
    for i in range(2):
        ax[i].axis("off")
        if show_center:
            ax[i].imshow(center_pixel)
    ax[0].set_title("Weighted receptive field")
    ax[1].set_title("Binary receptive field")
    plt.show()
    plt.close()


show_center_recep_field(inp_img, inp_img)
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_17_0.svg

Let’s first visualize the receptive field of a horizontal convolution without the center pixel. We use a small, arbitrary input image (11\times 11 pixels), and calculate the loss for the center pixel. For simplicity, we initialize all weights with 1 and the bias with 0, and use a single channel. This is sufficient for our visualization purposes.

[9]:
horiz_conv = HorizontalStackConvolution(c_in=1, c_out=1, kernel_size=3, mask_center=True)
horiz_conv.conv.weight.data.fill_(1)
horiz_conv.conv.bias.data.fill_(0)
horiz_img = horiz_conv(inp_img)
show_center_recep_field(inp_img, horiz_img)
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_19_0.svg

The receptive field is shown in yellow, the center pixel in red, and all other pixels outside of the receptive field are dark blue. As expected, the receptive field of a single horizontal convolution with the center pixel masked and a 3\times3 kernel is only the pixel on the left. If we use a larger kernel size, more pixels would be taken into account on the left.

Next, let’s take a look at the vertical convolution:

[10]:
vert_conv = VerticalStackConvolution(c_in=1, c_out=1, kernel_size=3, mask_center=True)
vert_conv.conv.weight.data.fill_(1)
vert_conv.conv.bias.data.fill_(0)
vert_img = vert_conv(inp_img)
show_center_recep_field(inp_img, vert_img)
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_21_0.svg

The vertical convolution takes all pixels above into account. Combining these two, we get the L-shaped receptive field of the original masked convolution:

[11]:
horiz_img = vert_img + horiz_img
show_center_recep_field(inp_img, horiz_img)
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_23_0.svg

If we stack multiple horizontal and vertical convolutions, we need to take two aspects into account:

  1. The center should not be masked anymore for the following convolutions as the features at the pixel’s position are already independent of its actual value. If it is hard to imagine why we can do this, just change the value below to mask_center=True and see what happens.

  2. The vertical convolution is not allowed to work on features from the horizontal convolution. In the feature map of the horizontal convolutions, a pixel contains information about all of the “true” pixels on the left. If we apply a vertical convolution which also uses features from the right, we effectively expand our receptive field to the true input which we want to prevent. Thus, the feature maps can only be merged for the horizontal convolution.

Using this, we can stack the convolutions in the following way. We have two feature streams: one for the vertical stack, and one for the horizontal stack. The horizontal convolutions can operate on the joint features of the previous horizontals and vertical convolutions, while the vertical stack only takes its own previous features as input. For a quick implementation, we can therefore sum the horizontal and vertical output features at each layer, and use those as final output features to calculate the loss on. An implementation of 4 consecutive layers is shown below. Note that we reuse the features from the other convolutions with mask_center=True from above.

[12]:
# Initialize convolutions with equal weight to all input pixels
horiz_conv = HorizontalStackConvolution(c_in=1, c_out=1, kernel_size=3, mask_center=False)
horiz_conv.conv.weight.data.fill_(1)
horiz_conv.conv.bias.data.fill_(0)
vert_conv = VerticalStackConvolution(c_in=1, c_out=1, kernel_size=3, mask_center=False)
vert_conv.conv.weight.data.fill_(1)
vert_conv.conv.bias.data.fill_(0)

# We reuse our convolutions for the 4 layers here. Note that in a standard network,
# we don't do that, and instead learn 4 separate convolution. As this cell is only for
# visualization purposes, we reuse the convolutions for all layers.
for l_idx in range(4):
    vert_img = vert_conv(vert_img)
    horiz_img = horiz_conv(horiz_img) + vert_img
    print("Layer %i" % (l_idx + 2))
    show_center_recep_field(inp_img, horiz_img)
Layer 2
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_25_1.svg
Layer 3
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_25_3.svg
Layer 4
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_25_5.svg
Layer 5
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_25_7.svg

The receptive field above it visualized for the horizontal stack, which includes the features of the vertical convolutions. It grows over layers without any blind spot as we had before. The difference between “weighted” and “binary” receptive field is that for the latter, we check whether there are any gradients flowing back to this pixel. This indicates that the center pixel indeed can use information from this pixel. Nevertheless, due to the convolution weights, some pixels have a stronger effect on the prediction than others. This is visualized in the weighted receptive field by plotting the gradient magnitude for each pixel instead of a binary yes/no.

Another receptive field we can check is the one for the vertical stack as the one above is for the horizontal stack. Let’s visualize it below:

[13]:
show_center_recep_field(inp_img, vert_img)
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_27_0.svg

As we have discussed before, the vertical stack only looks at pixels above the one we want to predict. Hence, we can validate that our implementation works as we initially expected it to. As a final step, let’s clean up the computation graph we still had kept in memory for the visualization of the receptive field:

[14]:
del inp_img, horiz_conv, vert_conv

Gated PixelCNN

In the next step, we will use the masked convolutions to build a full autoregressive model, called Gated PixelCNN. The difference between the original PixelCNN and Gated PixelCNN is the use of separate horizontal and vertical stacks. However, in literature, you often see that people refer to the Gated PixelCNN simply as “PixelCNN”. Hence, in the following, if we say “PixelCNN”, we usually mean the gated version. What “Gated” refers to in the model name is explained next.

Gated Convolutions

For visualizing the receptive field, we assumed a very simplified stack of vertical and horizontal convolutions. Obviously, there are more sophisticated ways of doing it, and PixelCNN uses gated convolutions for this. Specifically, the Gated Convolution block in PixelCNN looks as follows (figure credit - Aaron van den Oord et al. ):

50ee004d4b68446bb365c79a2b2bd54c

The left path is the vertical stack (the N\times N convolution is masked correspondingly), and the right path is the horizontal stack. Gated convolutions are implemented by having a twice as large output channel size, and combine them by a element-wise multiplication of \tanh and a sigmoid. For a linear layer, we can express a gated activation unit as follows:

\mathbf{y} = \tanh\left(\mathbf{W}_{f}\mathbf{x}\right)\odot\sigma\left(\mathbf{W}_{g}\mathbf{x}\right)

For simplicity, biases have been neglected and the linear layer split into two part, \mathbf{W}_{f} and \mathbf{W}_{g}. This concept resembles the input and modulation gate in an LSTM, and has been used in many other architectures as well. The main motivation behind this gated activation is that it might allow to model more complex interactions and simplifies learning. But as in any other architecture, this is mostly a design choice and can be considered a hyperparameters.

Besides the gated convolutions, we also see that the horizontal stack uses a residual connection while the vertical stack does not. This is because we use the output of the horizontal stack for prediction. Each convolution in the vertical stack also receives a strong gradient signal as it is only two 1\times 1 convolutions away from the residual connection, and does not require another residual connection to all its earleri layers.

The implementation in PyTorch is fairly straight forward for this block, because the visualization above gives us a computation graph to follow:

[15]:
class GatedMaskedConv(nn.Module):
    def __init__(self, c_in, **kwargs):
        """Gated Convolution block implemented the computation graph shown above."""
        super().__init__()
        self.conv_vert = VerticalStackConvolution(c_in, c_out=2 * c_in, **kwargs)
        self.conv_horiz = HorizontalStackConvolution(c_in, c_out=2 * c_in, **kwargs)
        self.conv_vert_to_horiz = nn.Conv2d(2 * c_in, 2 * c_in, kernel_size=1, padding=0)
        self.conv_horiz_1x1 = nn.Conv2d(c_in, c_in, kernel_size=1, padding=0)

    def forward(self, v_stack, h_stack):
        # Vertical stack (left)
        v_stack_feat = self.conv_vert(v_stack)
        v_val, v_gate = v_stack_feat.chunk(2, dim=1)
        v_stack_out = torch.tanh(v_val) * torch.sigmoid(v_gate)

        # Horizontal stack (right)
        h_stack_feat = self.conv_horiz(h_stack)
        h_stack_feat = h_stack_feat + self.conv_vert_to_horiz(v_stack_feat)
        h_val, h_gate = h_stack_feat.chunk(2, dim=1)
        h_stack_feat = torch.tanh(h_val) * torch.sigmoid(h_gate)
        h_stack_out = self.conv_horiz_1x1(h_stack_feat)
        h_stack_out = h_stack_out + h_stack

        return v_stack_out, h_stack_out
Building the model

Using the gated convolutions, we can now build our PixelCNN model. The architecture consists of multiple stacked GatedMaskedConv blocks, where we add an additional dilation factor to a few convolutions. This is used to increase the receptive field of the model and allows to take a larger context into accout during generation. As a reminder, dilation on a convolution works looks as follows (figure credit - Vincent Dumoulin and Francesco Visin):

57f1e27f65a84eb69e95ffb321b7e54e

Note that the smaller output size is only because the animation assumes no padding. In our implementation, we will pad the input image correspondingly. Alternatively to dilated convolutions, we could downsample the input and use a encoder-decoder architecture as in PixelCNN++ [3]. This is especially beneficial if we want to build a very deep autoregressive model. Nonetheless, as we seek to train a reasonably small model, dilated convolutions are the more efficient option to use here.

Below, we implement the PixelCNN model as a PyTorch Lightning module. Besides the stack of gated convolutions, we also have the initial horizontal and vertical convolutions which mask the center pixel, and a final 1\times 1 convolution which maps the output features to class predictions. To determine the likelihood of a batch of images, we first create our initial features using the masked horizontal and vertical input convolution. Next, we forward the features through the stack of gated convolutions. Finally, we take the output features of the horizontal stack, and apply the 1\times 1 convolution for classification. We use the bits per dimension metric for the likelihood, similarly to Tutorial 11 and assignment 3.

[16]:
class PixelCNN(pl.LightningModule):
    def __init__(self, c_in, c_hidden):
        super().__init__()
        self.save_hyperparameters()

        # Initial convolutions skipping the center pixel
        self.conv_vstack = VerticalStackConvolution(c_in, c_hidden, mask_center=True)
        self.conv_hstack = HorizontalStackConvolution(c_in, c_hidden, mask_center=True)
        # Convolution block of PixelCNN. We use dilation instead of downscaling
        self.conv_layers = nn.ModuleList(
            [
                GatedMaskedConv(c_hidden),
                GatedMaskedConv(c_hidden, dilation=2),
                GatedMaskedConv(c_hidden),
                GatedMaskedConv(c_hidden, dilation=4),
                GatedMaskedConv(c_hidden),
                GatedMaskedConv(c_hidden, dilation=2),
                GatedMaskedConv(c_hidden),
            ]
        )
        # Output classification convolution (1x1)
        self.conv_out = nn.Conv2d(c_hidden, c_in * 256, kernel_size=1, padding=0)

        self.example_input_array = train_set[0][0][None]

    def forward(self, x):
        """Forward image through model and return logits for each pixel.

        Args:
            x: Image tensor with integer values between 0 and 255.
        """
        # Scale input from 0 to 255 back to -1 to 1
        x = (x.float() / 255.0) * 2 - 1

        # Initial convolutions
        v_stack = self.conv_vstack(x)
        h_stack = self.conv_hstack(x)
        # Gated Convolutions
        for layer in self.conv_layers:
            v_stack, h_stack = layer(v_stack, h_stack)
        # 1x1 classification convolution
        # Apply ELU before 1x1 convolution for non-linearity on residual connection
        out = self.conv_out(F.elu(h_stack))

        # Output dimensions: [Batch, Classes, Channels, Height, Width]
        out = out.reshape(out.shape[0], 256, out.shape[1] // 256, out.shape[2], out.shape[3])
        return out

    def calc_likelihood(self, x):
        # Forward pass with bpd likelihood calculation
        pred = self.forward(x)
        nll = F.cross_entropy(pred, x, reduction="none")
        bpd = nll.mean(dim=[1, 2, 3]) * np.log2(np.exp(1))
        return bpd.mean()

    @torch.no_grad()
    def sample(self, img_shape, img=None):
        """Sampling function for the autoregressive model.

        Args:
            img_shape: Shape of the image to generate (B,C,H,W)
            img (optional): If given, this tensor will be used as
                             a starting image. The pixels to fill
                             should be -1 in the input tensor.
        """
        # Create empty image
        if img is None:
            img = torch.zeros(img_shape, dtype=torch.long).to(device) - 1
        # Generation loop
        for h in tqdm(range(img_shape[2]), leave=False):
            for w in range(img_shape[3]):
                for c in range(img_shape[1]):
                    # Skip if not to be filled (-1)
                    if (img[:, c, h, w] != -1).all().item():
                        continue
                    # For efficiency, we only have to input the upper part of the image
                    # as all other parts will be skipped by the masked convolutions anyways
                    pred = self.forward(img[:, :, : h + 1, :])
                    probs = F.softmax(pred[:, :, c, h, w], dim=-1)
                    img[:, c, h, w] = torch.multinomial(probs, num_samples=1).squeeze(dim=-1)
        return img

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.99)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        loss = self.calc_likelihood(batch[0])
        self.log("train_bpd", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.calc_likelihood(batch[0])
        self.log("val_bpd", loss)

    def test_step(self, batch, batch_idx):
        loss = self.calc_likelihood(batch[0])
        self.log("test_bpd", loss)

To sample from the autoregressive model, we need to iterate over all dimensions of the input. We start with an empty image, and fill the pixels one by one, starting from the upper left corner. Note that as for predicting x_i, all pixels below it have no influence on the prediction. Hence, we can cut the image in height without changing the prediction while increasing efficiency. Nevertheless, all the loops in the sampling function already show that it will take us quite some time to sample. A lot of computation could be reused across loop iterations as those the features on the already predicted pixels will not change over iterations. Nevertheless, this takes quite some effort to implement, and is often not done in implementations because in the end, autoregressive sampling remains sequential and slow. Hence, we settle with the default implementation here.

Before training the model, we can check the full receptive field of the model on an MNIST image of size 28\times 28:

[17]:
test_model = PixelCNN(c_in=1, c_hidden=64)
inp = torch.zeros(1, 1, 28, 28)
inp.requires_grad_()
out = test_model(inp)
show_center_recep_field(inp, out.squeeze(dim=2))
del inp, out, test_model
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_35_0.svg

The visualization shows that for predicting any pixel, we can take almost half of the image into account. However, keep in mind that this is the “theoretical” receptive field and not necessarily the effective receptive field, which is usually much smaller. For a stronger model, we should therefore try to increase the receptive field even further. Especially, for the pixel on the bottom right, the very last pixel, we would be allowed to take into account the whole image. However, our current receptive field only spans across 1/4 of the image. An encoder-decoder architecture can help with this, but it also shows that we require a much deeper, more complex network in autoregressive models than in VAEs or energy-based models.

Training loop

To train the model, we again can rely on PyTorch Lightning and write a function below for loading the pretrained model if it exists. To reduce the computational cost, we have saved the validation and test score in the checkpoint already:

[18]:
def train_model(**kwargs):
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, "PixelCNN"),
        gpus=1 if str(device).startswith("cuda") else 0,
        max_epochs=150,
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode="min", monitor="val_bpd"),
            LearningRateMonitor("epoch"),
        ],
    )
    result = None
    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "PixelCNN.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = PixelCNN.load_from_checkpoint(pretrained_filename)
        ckpt = torch.load(pretrained_filename, map_location=device)
        result = ckpt.get("result", None)
    else:
        model = PixelCNN(**kwargs)
        trainer.fit(model, train_loader, val_loader)
    model = model.to(device)

    if result is None:
        # Test best model on validation and test set
        val_result = trainer.test(model, dataloaders=val_loader, verbose=False)
        test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
        result = {"test": test_result, "val": val_result}
    return model, result

Training the model is time consuming and we recommend using the provided pre-trained model for going through this notebook. However, feel free to play around with the hyperparameter like number of layers etc. if you want to get a feeling for those.

When calling the training function with a pre-trained model, we automatically load it and print its test performance:

[19]:
model, result = train_model(c_in=1, c_hidden=64)
test_res = result["test"][0]
print(
    "Test bits per dimension: %4.3fbpd" % (test_res["test_loss"] if "test_loss" in test_res else test_res["test_bpd"])
)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Found pretrained model, loading...
Test bits per dimension: 0.808bpd

With a test performance of 0.809bpd, the PixelCNN significantly outperforms the normalizing flows we have seen in Tutorial 11. Considering image modeling as an autoregressive problem simplifies the learning process as predicting one pixel given the ground truth of all others is much easier than predicting all pixels at once. In addition, PixelCNN can explicitly predict the pixel values by a discrete softmax while Normalizing Flows have to learn transformations in continuous latent space. These two aspects allow the PixelCNN to achieve a notably better performance.

To fully compare the models, let’s also measure the number of parameters of the PixelCNN:

[20]:
num_params = sum(np.prod(param.shape) for param in model.parameters())
print(f"Number of parameters: {num_params:,}")
Number of parameters: 852,160

Compared to the multi-scale normalizing flows, the PixelCNN has considerably less parameters. Of course, the number of parameters depend on our hyperparameter choices. Nevertheless, in general, it can be said that autoregressive models require considerably less parameters than normalizing flows to reach good performance, based on the reasons stated above. Still, autoregressive models are much slower in sampling than normalizing flows, which limits their possible applications.

Sampling

One way of qualitatively analysing generative models is by looking at the actual samples. Let’s therefore use our sampling function to generate a few digits:

[21]:
pl.seed_everything(1)
samples = model.sample(img_shape=(16, 1, 28, 28))
show_imgs(samples.cpu())
Global seed set to 1
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_45_2.svg

Most of the samples can be identified as digits, and overall we achieve a better quality than we had in normalizing flows. This goes along with the lower likelihood we achieved with autoregressive models. Nevertheless, we also see that there is still place for improvement as a considerable amount of samples cannot be identified (for example the first row). Deeper autoregressive models are expected to achieve better quality, as they can take more context into account for generating the pixels.

Note that on Google Colab, you might see different results, specifically with a white line at the top. After some debugging, it seemed that the difference occurs inside the dilated convolution, as it gives different results for different batch sizes. However, it is hard to debug this further as it might be a bug of the installed PyTorch version on Google Colab.

The trained model itself is not restricted to any specific image size. However, what happens if we actually sample a larger image than we had seen in our training dataset? Let’s try below to sample images of size 64\times64 instead of 28\times28:

[22]:
pl.seed_everything(1)
samples = model.sample(img_shape=(8, 1, 64, 64))
show_imgs(samples.cpu())
Global seed set to 1
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_47_2.svg

The larger images show that changing the size of the image during testing confuses the model and generates abstract figures (you can sometimes spot a digit in the upper left corner). In addition, sampling for images of 64x64 pixels take more than a minute on a GPU. Clearly, autoregressive models cannot be scaled to large images without changing the sampling procedure such as with forecasting. Our implementation is also not the most efficient as many computations can be stored and reused throughout the sampling process. Nevertheless, the sampling procedure stays sequential which is inherently slower than parallel generation like done in normalizing flows.

Autocompletion

One common application done with autoregressive models is auto-completing an image. As autoregressive models predict pixels one by one, we can set the first N pixels to predefined values and check how the model completes the image. For implementing this, we just need to skip the iterations in the sampling loop that already have a value unequals -1. See above in our PyTorch Lightning module for the specific implementation. In the cell below, we randomly take three images from the training set, mask about the lower half of the image, and let the model autocomplete it. To see the diversity of samples, we do this 12 times for each image:

[23]:
def autocomplete_image(img):
    # Remove lower half of the image
    img_init = img.clone()
    img_init[:, 10:, :] = -1
    print("Original image and input image to sampling:")
    show_imgs([img, img_init])
    # Generate 12 example completions
    img_init = img_init.unsqueeze(dim=0).expand(12, -1, -1, -1).to(device)
    pl.seed_everything(1)
    img_generated = model.sample(img_init.shape, img_init)
    print("Autocompletion samples:")
    show_imgs(img_generated)


for i in range(1, 4):
    img = train_set[i][0]
    autocomplete_image(img)
Original image and input image to sampling:
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_50_1.svg
Global seed set to 1
Autocompletion samples:
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_50_5.svg
Original image and input image to sampling:
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_50_7.svg
Global seed set to 1
Autocompletion samples:
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_50_11.svg
Original image and input image to sampling:
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_50_13.svg
Global seed set to 1
Autocompletion samples:
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_50_17.svg

For the first two digits (7 and 6), we see that the 12 samples all result in a shape which resemble the original digit. Nevertheless, there are some style difference in writing the 7, and some deformed sixes in the samples. When autocompleting the 9 below, we see that the model can fit multiple digits to it. We obtain diverse samples from 0, 3, 8 and 9. This shows that despite having no latent space, we can still obtain diverse samples from an autoregressive model.

Visualization of the predictive distribution (softmax)

Autoregressive models use a softmax over 256 values to predict the next pixel. This gives the model a large flexibility as the probabilities for each pixel value can be learned independently if necessary. However, the values are actually not independent because the values 32 and 33 are much closer than 32 and 255. In the following, we visualize the softmax distribution that the model predicts to gain insights how it has learned the relationships of close-by pixels.

To do this, we first run the model on a batch of images and store the output softmax distributions:

[24]:
det_loader = data.DataLoader(train_set, batch_size=128, shuffle=False, drop_last=False)
imgs, _ = next(iter(det_loader))
imgs = imgs.to(device)
with torch.no_grad():
    out = model(imgs)
    out = F.softmax(out, dim=1)
    mean_out = out.mean(dim=[0, 2, 3, 4]).cpu().numpy()
    out = out.cpu().numpy()

Before diving into the model, let’s visualize the distribution of the pixel values in the whole dataset:

[25]:
sns.set()
plot_args = {"color": to_rgb("C0") + (0.5,), "edgecolor": "C0", "linewidth": 0.5, "width": 1.0}
plt.hist(imgs.view(-1).cpu().numpy(), bins=256, density=True, **plot_args)
plt.yscale("log")
plt.xticks([0, 64, 128, 192, 256])
plt.show()
plt.close()
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_55_0.svg

As we would expect from the seen images, the pixel value 0 (black) is the dominant value, followed by a batch of values between 250 and 255. Note that we use a log scale on the y-axis due to the big imbalance in the dataset. Interestingly, the pixel values 64, 128 and 191 also stand out which is likely due to the quantization used during the creation of the dataset. For RGB images, we would also see two peaks around 0 and 255, but the values in between would be much more frequent than in MNIST (see Figure 1 in the PixelCNN++ for a visualization on CIFAR10).

Next, we can visualize the distribution our model predicts (in average):

[26]:
plt.bar(np.arange(mean_out.shape[0]), mean_out, **plot_args)
plt.yscale("log")
plt.xticks([0, 64, 128, 192, 256])
plt.show()
plt.close()
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_57_0.svg

This distribution is very close to the actual dataset distribution. This is in general a good sign, but we can see a slightly smoother histogram than above.

Finally, to take a closer look at learned value relations, we can visualize the distribution for individual pixel predictions to get a better intuition. For this, we pick 4 random images and pixels, and visualize their distribution below:

[27]:
fig, ax = plt.subplots(2, 2, figsize=(10, 6))
for i in range(4):
    ax_sub = ax[i // 2][i % 2]
    ax_sub.bar(np.arange(out.shape[1], dtype=np.int32), out[i + 4, :, 0, 14, 14], **plot_args)
    ax_sub.set_yscale("log")
    ax_sub.set_xticks([0, 64, 128, 192, 256])
plt.show()
plt.close()
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_59_0.svg

Overall we see a very diverse set of distributions, with a usual peak for 0 and close to 1. However, the distributions in the first row show a potentially undesirable behavior. For instance, the value 242 has a 1000x lower likelihood than 243 although they are extremely close and can often not be distinguished. This shows that the model might have not generlized well over pixel values. The better solution to this problem is to use discrete logitics mixtures instead of a softmax distribution. A discrete logistic distribution can be imagined as discretized, binned Gaussians. Using a mixture of discrete logistics instead of a softmax introduces an inductive bias to the model to assign close-by values similar likelihoods. We can visualize a discrete logistic below:

[28]:
mu = Tensor([128])
sigma = Tensor([2.0])


def discrete_logistic(x, mu, sigma):
    return torch.sigmoid((x + 0.5 - mu) / sigma) - torch.sigmoid((x - 0.5 - mu) / sigma)


x = torch.arange(256)
p = discrete_logistic(x, mu, sigma)

# Visualization
plt.figure(figsize=(6, 3))
plt.bar(x.numpy(), p.numpy(), **plot_args)
plt.xlim(96, 160)
plt.title("Discrete logistic distribution")
plt.xlabel("Pixel value")
plt.ylabel("Probability")
plt.show()
plt.close()
_images/notebooks_course_UvA-DL_10-autoregressive-image-modeling_61_0.svg

Instead of the softmax, the model would output mean and standard deviations for the K logistics we use in the mixture. This is one of the improvements in autoregressive models that PixelCNN++ [3] has introduced compared to the original PixelCNN.

Conclusion

In this tutorial, we have looked at autoregressive image modeling, and implemented the PixelCNN architecture. With the usage of masked convolutions, we are able to apply a convolutional network in which a pixel is only influenced by all its predecessors. Separating the masked convolution into a horizontal and vertical stack allowed us to remove the known blind spot on the right upper row of a pixel. In experiments, autoregressive models outperformed normalizing flows in terms of bits per dimension, but are much slower to sample from. Improvements, that we have not implemented ourselves here, are discrete logistic mixtures, a downsampling architecture, and changing the pixel order in a diagonal fashion (see PixelSNAIL). Overall, autoregressive models are another, strong family of generative models, which however are mostly used in sequence tasks because of their linear scaling in sampling time than quadratic as on images.

References

[1] van den Oord, A., et al. “Pixel Recurrent Neural Networks.” arXiv preprint arXiv:1601.06759 (2016). Link

[2] van den Oord, A., et al. “Conditional Image Generation with PixelCNN Decoders.” In Advances in Neural Information Processing Systems 29, pp. 4790–4798 (2016). Link

[3] Salimans, Tim, et al. “PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications.” arXiv preprint arXiv:1701.05517 (2017). Link

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Tutorial 11: Vision Transformers

  • Author: Phillip Lippe

  • License: CC BY-SA

  • Generated: 2022-05-03T02:43:19.157102

In this tutorial, we will take a closer look at a recent new trend: Transformers for Computer Vision. Since Alexey Dosovitskiy et al. successfully applied a Transformer on a variety of image recognition benchmarks, there have been an incredible amount of follow-up works showing that CNNs might not be optimal architecture for Computer Vision anymore. But how do Vision Transformers work exactly, and what benefits and drawbacks do they offer in contrast to CNNs? We will answer these questions by implementing a Vision Transformer ourselves, and train it on the popular, small dataset CIFAR10. We will compare these results to popular convolutional architectures such as Inception, ResNet and DenseNet. This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "setuptools==59.5.0" "pytorch-lightning>=1.4" "matplotlib" "torch>=1.8" "ipython[notebook]" "torchmetrics>=0.7" "torchvision" "seaborn"

Let’s start with importing our standard set of libraries.

[2]:
import os
import urllib.request
from urllib.error import HTTPError

import matplotlib
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from IPython.display import set_matplotlib_formats
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torchvision import transforms
from torchvision.datasets import CIFAR10

plt.set_cmap("cividis")
%matplotlib inline
set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.reset_orig()

%load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/VisionTransformers/")

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
/usr/lib/python3.8/site-packages/apex/pyprof/__init__.py:5: FutureWarning: pyprof will be removed by the end of June, 2022
  warnings.warn("pyprof will be removed by the end of June, 2022", FutureWarning)
WARNING:root:Bagua cannot detect bundled NCCL library, Bagua will try to use system NCCL instead. If you encounter any error, please run `import bagua_core; bagua_core.install_deps()` or the `bagua_install_deps.py` script to install bundled libraries.
/tmp/ipykernel_4348/214601731.py:22: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
  set_matplotlib_formats("svg", "pdf")  # For export
Global seed set to 42
Device: cuda:0

We provide a pre-trained Vision Transformer which we download in the next cell. However, Vision Transformers can be relatively quickly trained on CIFAR10 with an overall training time of less than an hour on an NVIDIA TitanRTX. Feel free to experiment with training your own Transformer once you went through the whole notebook.

[3]:
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/"
# Files to download
pretrained_files = [
    "tutorial15/ViT.ckpt",
    "tutorial15/tensorboards/ViT/events.out.tfevents.ViT",
    "tutorial5/tensorboards/ResNet/events.out.tfevents.resnet",
]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name.split("/", 1)[1])
    if "/" in file_name.split("/", 1)[1]:
        os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print("Downloading %s..." % file_url)
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n",
                e,
            )
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial15/ViT.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial15/tensorboards/ViT/events.out.tfevents.ViT...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial5/tensorboards/ResNet/events.out.tfevents.resnet...

We load the CIFAR10 dataset below. We use the same setup of the datasets and data augmentations as for the CNNs in Tutorial 5 to keep a fair comparison. The constants in the transforms.Normalize correspond to the values that scale and shift the data to a zero mean and standard deviation of one.

[4]:
test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784]),
    ]
)
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784]),
    ]
)
# Loading the training dataset. We need to split it into a training and validation part
# We need to do a little trick because the validation set should not use the augmentation.
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)
val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)
pl.seed_everything(42)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
pl.seed_everything(42)
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])

# Loading the test set
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)

# Visualize some examples
NUM_IMAGES = 4
CIFAR_images = torch.stack([val_set[idx][0] for idx in range(NUM_IMAGES)], dim=0)
img_grid = torchvision.utils.make_grid(CIFAR_images, nrow=4, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8, 8))
plt.title("Image examples of the CIFAR10 dataset")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()
Files already downloaded and verified
Files already downloaded and verified
Global seed set to 42
Global seed set to 42
Files already downloaded and verified
_images/notebooks_course_UvA-DL_11-vision-transformer_8_3.svg

Transformers for image classification

Transformers have been originally proposed to process sets since it is a permutation-equivariant architecture, i.e., producing the same output permuted if the input is permuted. To apply Transformers to sequences, we have simply added a positional encoding to the input feature vectors, and the model learned by itself what to do with it. So, why not do the same thing on images? This is exactly what Alexey Dosovitskiy et al.  proposed in their paper “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”. Specifically, the Vision Transformer is a model for image classification that views images as sequences of smaller patches. As a preprocessing step, we split an image of, for example, 48\times 48 pixels into 9 16\times 16 patches. Each of those patches is considered to be a “word”/“token”, and projected to a feature space. With adding positional encodings and a token for classification on top, we can apply a Transformer as usual to this sequence and start training it for our task. A nice GIF visualization of the architecture is shown below (figure credit - Phil Wang):

b316c69c9a264621be97b71e42006bdf

We will walk step by step through the Vision Transformer, and implement all parts by ourselves. First, let’s implement the image preprocessing: an image of size N\times N has to be split into (N/M)^2 patches of size M\times M. These represent the input words to the Transformer.

[5]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Inputs:
        x - Tensor representing the image of shape [B, C, H, W]
        patch_size - Number of pixels per dimension of the patches (integer)
        flatten_channels - If True, the patches will be returned in a flattened format
                           as a feature vector instead of a image grid.
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)  # [B, H', W', C, p_H, p_W]
    x = x.flatten(1, 2)  # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2, 4)  # [B, H'*W', C*p_H*p_W]
    return x

Let’s take a look at how that works for our CIFAR examples above. For our images of size 32\times 32, we choose a patch size of 4. Hence, we obtain sequences of 64 patches of size 4\times 4. We visualize them below:

[6]:
img_patches = img_to_patch(CIFAR_images, patch_size=4, flatten_channels=False)

fig, ax = plt.subplots(CIFAR_images.shape[0], 1, figsize=(14, 3))
fig.suptitle("Images as input sequences of patches")
for i in range(CIFAR_images.shape[0]):
    img_grid = torchvision.utils.make_grid(img_patches[i], nrow=64, normalize=True, pad_value=0.9)
    img_grid = img_grid.permute(1, 2, 0)
    ax[i].imshow(img_grid)
    ax[i].axis("off")
plt.show()
plt.close()
_images/notebooks_course_UvA-DL_11-vision-transformer_12_0.svg

Compared to the original images, it is much harder to recognize the objects from those patch lists now. Still, this is the input we provide to the Transformer for classifying the images. The model has to learn itself how it has to combine the patches to recognize the objects. The inductive bias in CNNs that an image is grid of pixels, is lost in this input format.

After we have looked at the preprocessing, we can now start building the Transformer model. Since we have discussed the fundamentals of Multi-Head Attention in Tutorial 6, we will use the PyTorch module nn.MultiheadAttention (docs) here. Further, we use the Pre-Layer Normalization version of the Transformer blocks proposed by Ruibin Xiong et al.  in 2020. The idea is to apply Layer Normalization not in between residual blocks, but instead as a first layer in the residual blocks. This reorganization of the layers supports better gradient flow and removes the necessity of a warm-up stage. A visualization of the difference between the standard Post-LN and the Pre-LN version is shown below.

ad2cddff92f047449f0e0b142bfc8d53

The implementation of the Pre-LN attention block looks as follows:

[7]:
class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of input and attention feature vectors
            hidden_dim - Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than embed_dim)
            num_heads - Number of heads to use in the Multi-Head Attention block
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x

Now we have all modules ready to build our own Vision Transformer. Besides the Transformer encoder, we need the following modules:

  • A linear projection layer that maps the input patches to a feature vector of larger size. It is implemented by a simple linear layer that takes each M\times M patch independently as input.

  • A classification token that is added to the input sequence. We will use the output feature vector of the classification token (CLS token in short) for determining the classification prediction.

  • Learnable positional encodings that are added to the tokens before being processed by the Transformer. Those are needed to learn position-dependent information, and convert the set to a sequence. Since we usually work with a fixed resolution, we can learn the positional encodings instead of having the pattern of sine and cosine functions.

  • A MLP head that takes the output feature vector of the CLS token, and maps it to a classification prediction. This is usually implemented by a small feed-forward network or even a single linear layer.

With those components in mind, let’s implement the full Vision Transformer below:

[8]:
class VisionTransformer(nn.Module):
    def __init__(
        self,
        embed_dim,
        hidden_dim,
        num_channels,
        num_heads,
        num_layers,
        num_classes,
        patch_size,
        num_patches,
        dropout=0.0,
    ):
        """
        Inputs:
            embed_dim - Dimensionality of the input feature vectors to the Transformer
            hidden_dim - Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels - Number of channels of the input (3 for RGB)
            num_heads - Number of heads to use in the Multi-Head Attention block
            num_layers - Number of layers to use in the Transformer
            num_classes - Number of classes to predict
            patch_size - Number of pixels that the patches have per dimension
            num_patches - Maximum number of patches an image can have
            dropout - Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size

        # Layers/Networks
        self.input_layer = nn.Linear(num_channels * (patch_size**2), embed_dim)
        self.transformer = nn.Sequential(
            *(AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers))
        )
        self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes))
        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))

    def forward(self, x):
        # Preprocess input
        x = img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)

        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:, : T + 1]

        # Apply Transforrmer
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        # Perform classification prediction
        cls = x[0]
        out = self.mlp_head(cls)
        return out

Finally, we can put everything into a PyTorch Lightning Module as usual. We use torch.optim.AdamW as the optimizer, which is Adam with a corrected weight decay implementation. Since we use the Pre-LN Transformer version, we do not need to use a learning rate warmup stage anymore. Instead, we use the same learning rate scheduler as the CNNs in our previous tutorial on image classification.

[9]:
class ViT(pl.LightningModule):
    def __init__(self, model_kwargs, lr):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)
        self.example_input_array = next(iter(train_loader))[0]

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode="train"):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log("%s_loss" % mode, loss)
        self.log("%s_acc" % mode, acc)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="test")

Experiments

Commonly, Vision Transformers are applied to large-scale image classification benchmarks such as ImageNet to leverage their full potential. However, here we take a step back and ask: can Vision Transformer also succeed on classical, small benchmarks such as CIFAR10? To find this out, we train a Vision Transformer from scratch on the CIFAR10 dataset. Let’s first create a training function for our PyTorch Lightning module which also loads the pre-trained model if you have downloaded it above.

[10]:
def train_model(**kwargs):
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),
        gpus=1 if str(device) == "cuda:0" else 0,
        max_epochs=180,
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
            LearningRateMonitor("epoch"),
        ],
        progress_bar_refresh_rate=1,
    )
    trainer.logger._log_graph = True  # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "ViT.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model at %s, loading..." % pretrained_filename)
        # Automatically loads the model with the saved hyperparameters
        model = ViT.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)  # To be reproducable
        model = ViT(**kwargs)
        trainer.fit(model, train_loader, val_loader)
        # Load best checkpoint after training
        model = ViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    # Test best model on validation and test set
    val_result = trainer.test(model, dataloaders=val_loader, verbose=False)
    test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
    result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}

    return model, result

Now, we can already start training our model. As seen in our implementation, we have couple of hyperparameter that we have to choose. When creating this notebook, we have performed a small grid search over hyperparameters and listed the best hyperparameters in the cell below. Nevertheless, it is worth to discuss the influence that each hyperparameter has, and what intuition we have for choosing its value.

First, let’s consider the patch size. The smaller we make the patches, the longer the input sequences to the Transformer become. While in general, this allows the Transformer to model more complex functions, it requires a longer computation time due to its quadratic memory usage in the attention layer. Furthermore, small patches can make the task more difficult since the Transformer has to learn which patches are close-by, and which are far away. We experimented with patch sizes of 2, 4 and 8 which gives us the input sequence lengths of 256, 64, and 16 respectively. We found 4 to result in the best performance, and hence pick it below.

Next, the embedding and hidden dimensionality have a similar impact to a Transformer as to an MLP. The larger the sizes, the more complex the model becomes, and the longer it takes to train. In Transformer however, we have one more aspect to consider: the query-key sizes in the Multi-Head Attention layers. Each key has the feature dimensionality of embed_dim/num_heads. Considering that we have an input sequence length of 64, a minimum reasonable size for the key vectors is 16 or 32. Lower dimensionalities can restrain the possible attention maps too much. We observed that more than 8 heads are not necessary for the Transformer, and therefore pick a embedding dimensionality of 256. The hidden dimensionality in the feed-forward networks is usually 2-4x larger than the embedding dimensionality, and thus we pick 512.

Finally, the learning rate for Transformers is usually relatively small, and in papers, a common value to use is 3e-5. However, since we work with a smaller dataset and have a potentially easier task, we found that we are able to increase the learning rate to 3e-4 without any problems. To reduce overfitting, we use a dropout value of 0.2. Remember that we also use small image augmentations as regularization during training.

Feel free to explore the hyperparameters yourself by changing the values below. In general, the Vision Transformer did not show to be too sensitive to the hyperparameter choices on the CIFAR10 dataset.

[11]:
model, results = train_model(
    model_kwargs={
        "embed_dim": 256,
        "hidden_dim": 512,
        "num_heads": 8,
        "num_layers": 6,
        "patch_size": 4,
        "num_channels": 3,
        "num_patches": 64,
        "num_classes": 10,
        "dropout": 0.2,
    },
    lr=3e-4,
)
print("ViT results", results)
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:96: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=1)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Found pretrained model at saved_models/VisionTransformers/ViT.ckpt, loading...
Missing logger folder: saved_models/VisionTransformers/ViT/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/usr/local/lib/python3.8/dist-packages/torch/_tensor.py:575: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
ViT results {'test': 0.7559000253677368, 'val': 0.7563999891281128}

The Vision Transformer achieves a validation and test performance of about 75%. In comparison, almost all CNN architectures that we have tested in Tutorial 5 obtained a classification performance of around 90%. This is a considerable gap and shows that although Vision Transformers perform strongly on ImageNet with potential pretraining, they cannot come close to simple CNNs on CIFAR10 when being trained from scratch. The differences between a CNN and Transformer can be well observed in the training curves. Let’s look at them in a tensorboard below:

[12]:
# Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH!
%tensorboard --logdir ../saved_models/tutorial15/tensorboards/

8a8c6d1fc0004de1a23382e236dc17b7

The tensorboard compares the Vision Transformer to a ResNet trained on CIFAR10. When looking at the training losses, we see that the ResNet learns much more quickly in the first iterations. While the learning rate might have an influence on the initial learning speed, we see the same trend in the validation accuracy. The ResNet achieves the best performance of the Vision Transformer after just 5 epochs (2000 iterations). Further, while the ResNet training loss and validation accuracy have a similar trend, the validation performance of the Vision Transformers only marginally changes after 10k iterations while the training loss has almost just started going down. Yet, the Vision Transformer is also able to achieve a close-to 100% accuracy on the training set.

All those observed phenomenons can be explained with a concept that we have visited before: inductive biases. Convolutional Neural Networks have been designed with the assumption that images are translation invariant. Hence, we apply convolutions with shared filters across the image. Furthermore, a CNN architecture integrates the concept of distance in an image: two pixels that are close to each other are more related than two distant pixels. Local patterns are combined into larger patterns, until we perform our classification prediction. All those aspects are inductive biases of a CNN. In contrast, a Vision Transformer does not know which two pixels are close to each other, and which are far apart. It has to learn this information solely from the sparse learning signal of the classification task. This is a huge disadvantage when we have a small dataset since such information is crucial for generalizing to an unseen test dataset. With large enough datasets and/or good pre-training, a Transformer can learn this information without the need of inductive biases, and instead is more flexible than a CNN. Especially long-distance relations between local patterns can be difficult to process in CNNs, while in Transformers, all patches have the distance of one. This is why Vision Transformers are so strong on large-scale datasets such as ImageNet, but underperform a lot when being applied to a small dataset such as CIFAR10.

Conclusion

In this tutorial, we have implemented our own Vision Transformer from scratch and applied it on the task of image classification. Vision Transformers work by splitting an image into a sequence of smaller patches, use those as input to a standard Transformer encoder. While Vision Transformers achieved outstanding results on large-scale image recognition benchmarks such as ImageNet, they considerably underperform when being trained from scratch on small-scale datasets like CIFAR10. The reason is that in contrast to CNNs, Transformers do not have the inductive biases of translation invariance and the feature hierachy (i.e. larger patterns consist of many smaller patterns). However, these aspects can be learned when enough data is provided, or the model has been pre-trained on other large-scale tasks. Considering that Vision Transformers have just been proposed end of 2020, there is likely a lot more to come on Transformers for Computer Vision.

References

Dosovitskiy, Alexey, et al. “An image is worth 16x16 words: Transformers for image recognition at scale.” International Conference on Representation Learning (2021). link

Chen, Xiangning, et al. “When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations.” arXiv preprint arXiv:2106.01548 (2021). link

Tolstikhin, Ilya, et al. “MLP-mixer: An all-MLP Architecture for Vision.” arXiv preprint arXiv:2105.01601 (2021). link

Xiong, Ruibin, et al. “On layer normalization in the transformer architecture.” International Conference on Machine Learning. PMLR, 2020. link

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Tutorial 12: Meta-Learning - Learning to Learn

  • Author: Phillip Lippe

  • License: CC BY-SA

  • Generated: 2021-10-10T18:35:50.818431

In this tutorial, we will discuss algorithms that learn models which can quickly adapt to new classes and/or tasks with few samples. This area of machine learning is called Meta-Learning aiming at “learning to learn”. Learning from very few examples is a natural task for humans. In contrast to current deep learning models, we need to see only a few examples of a police car or firetruck to recognize them in daily traffic. This is crucial ability since in real-world application, it is rarely the case that the data stays static and does not change over time. For example, an object detection system for mobile phones trained on data from 2000 will have troubles detecting today’s common mobile phones, and thus, needs to adapt to new data without excessive label effort. The optimization techniques we have discussed so far struggle with this because they only aim at obtaining a good performance on a test set that had similar data. However, what if the test set has classes that we do not have in the training set? Or what if we want to test the model on a completely different task? We will discuss and implement three common Meta-Learning algorithms for such situations. This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
# ! pip install --quiet "torch>=1.6, <1.9" "matplotlib" "torchmetrics>=0.3" "seaborn" "torchvision" "pytorch-lightning>=1.3"

Meta-Learning offers solutions to these situations, and we will discuss three popular algorithms: Prototypical Networks (Snell et al., 2017), Model-Agnostic Meta-Learning / MAML (Finn et al., 2017), and Proto-MAML (Triantafillou et al., 2020). We will focus on the task of few-shot classification where the training and test set have distinct sets of classes. For instance, we would train the model on the binary classifications of cats-birds and flowers-bikes, but during test time, the model would need to learn from 4 examples each the difference between dogs and otters, two classes we have not seen during training (Figure credit - Lilian Weng).

fcaba038be794f938c80f972ca9a92b1

A different setup, which is very common in Reinforcement Learning and recently Natural Language Processing, is to aim at few-shot learning of a completely new task. For example, an robot agent that learned to run, jump and pick up boxes, should quickly adapt to collecting and stacking boxes. In NLP, we can think of a model which was trained sentiment classification, hatespeech detection and sarcasm classification, to adapt to classifying the emotion of a text. All methods we will discuss in this notebook can be easily applied to these settings since we only use a different definition of a ‘task’. For few-shot classification, we consider a task to distinguish between M novel classes. Here, we would not only have novel classes, but also a completely different dataset.

First of all, let’s start with importing our standard libraries. We will again be using PyTorch Lightning.

[2]:
import json
import os
import random
import urllib.request
from collections import defaultdict
from copy import deepcopy
from statistics import mean, stdev
from urllib.error import HTTPError

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import seaborn as sns
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from IPython.display import set_matplotlib_formats
from PIL import Image
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torchvision import transforms
from torchvision.datasets import CIFAR100, SVHN
from tqdm.auto import tqdm

plt.set_cmap("cividis")
# %matplotlib inline
set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.reset_orig()

# Import tensorboard
# %load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/MetaLearning/")

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
/tmp/ipykernel_739/3072189054.py:29: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
  set_matplotlib_formats("svg", "pdf")  # For export
Global seed set to 42
Device: cuda:0
<Figure size 432x288 with 0 Axes>

Training the models in this notebook can take between 2 and 8 hours, and the evaluation time of some algorithms is in the span of couples of minutes. Hence, we download pre-trained models and results below.

[3]:
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial16/"
# Files to download
pretrained_files = [
    "ProtoNet.ckpt",
    "ProtoMAML.ckpt",
    "tensorboards/ProtoNet/events.out.tfevents.ProtoNet",
    "tensorboards/ProtoMAML/events.out.tfevents.ProtoMAML",
    "protomaml_fewshot.json",
    "protomaml_svhn_fewshot.json",
]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print("Downloading %s..." % file_url)
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n",
                e,
            )
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial16/ProtoNet.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial16/ProtoMAML.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial16/tensorboards/ProtoNet/events.out.tfevents.ProtoNet...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial16/tensorboards/ProtoMAML/events.out.tfevents.ProtoMAML...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial16/protomaml_fewshot.json...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial16/protomaml_svhn_fewshot.json...

Few-shot classification

We start our implementation by discussing the dataset setup. In this notebook, we will use CIFAR100 which we have already seen in Tutorial 6. CIFAR100 has 100 classes each with 600 images of size 32\times 32 pixels. Instead of splitting the training, validation and test set over examples, we will split them over classes: we will use 80 classes for training, and 10 for validation and 10 for testing. Our overall goal is to obtain a model that can distinguish between the 10 test classes with seeing very little examples. First, let’s load the dataset and visualize some examples.

[4]:
# Loading CIFAR100 dataset
cifar_train_set = CIFAR100(root=DATASET_PATH, train=True, download=True, transform=transforms.ToTensor())
cifar_test_set = CIFAR100(root=DATASET_PATH, train=False, download=True, transform=transforms.ToTensor())
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /__w/1/s/.datasets/cifar-100-python.tar.gz
Extracting /__w/1/s/.datasets/cifar-100-python.tar.gz to /__w/1/s/.datasets
Files already downloaded and verified
[5]:
# Visualize some examples
NUM_IMAGES = 12
cifar_images = [cifar_train_set[np.random.randint(len(cifar_train_set))][0] for idx in range(NUM_IMAGES)]
cifar_images = torch.stack(cifar_images, dim=0)
img_grid = torchvision.utils.make_grid(cifar_images, nrow=6, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8, 8))
plt.title("Image examples of the CIFAR100 dataset")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()
_images/notebooks_course_UvA-DL_12-meta-learning_9_0.svg
Data preprocessing

Next, we need to prepare the dataset in the training, validation and test split as mentioned before. The torchvision package gives us the training and test set as two separate dataset objects. The next code cells will merge the original training and test set, and then create the new train-val-test split.

[6]:
# Merging original training and test set
cifar_all_images = np.concatenate([cifar_train_set.data, cifar_test_set.data], axis=0)
cifar_all_targets = torch.LongTensor(cifar_train_set.targets + cifar_test_set.targets)

To have an easier time handling the dataset, we define our own, simple dataset class below. It takes a set of images, labels/targets, and image transformations, and returns the corresponding images and labels element-wise.

[7]:
class ImageDataset(data.Dataset):
    def __init__(self, imgs, targets, img_transform=None):
        """
        Inputs:
            imgs - Numpy array of shape [N,32,32,3] containing all images.
            targets - PyTorch array of shape [N] containing all labels.
            img_transform - A torchvision transformation that should be applied
                            to the images before returning. If none, no transformation
                            is applied.
        """
        super().__init__()
        self.img_transform = img_transform
        self.imgs = imgs
        self.targets = targets

    def __getitem__(self, idx):
        img, target = self.imgs[idx], self.targets[idx]
        img = Image.fromarray(img)

        if self.img_transform is not None:
            img = self.img_transform(img)

        return img, target

    def __len__(self):
        return self.imgs.shape[0]

Now, we can create the class splits. We will assign the classes randomly to training, validation and test, and use a 80%-10%-10% split.

[8]:
pl.seed_everything(0)  # Set seed for reproducibility
classes = torch.randperm(100)  # Returns random permutation of numbers 0 to 99
train_classes, val_classes, test_classes = classes[:80], classes[80:90], classes[90:]
Global seed set to 0

To get an intuition of the validation and test classes, we print the class names below:

[9]:
# Printing validation and test classes
idx_to_class = {val: key for key, val in cifar_train_set.class_to_idx.items()}
print("Validation classes:", [idx_to_class[c.item()] for c in val_classes])
print("Test classes:", [idx_to_class[c.item()] for c in test_classes])
Validation classes: ['caterpillar', 'castle', 'skunk', 'ray', 'bus', 'motorcycle', 'keyboard', 'chimpanzee', 'possum', 'tiger']
Test classes: ['kangaroo', 'crocodile', 'butterfly', 'shark', 'forest', 'pickup_truck', 'telephone', 'lion', 'worm', 'mushroom']

As we can see, the classes have quite some variety and some classes might be easier to distinguish than others. For instance, in the test classes, ‘pickup_truck’ is the only vehicle while the classes ‘mushroom’, ‘worm’ and ‘forest’ might be harder to keep apart. Remember that we want to learn the classification of those ten classes from 80 other classes in our training set, and few examples from the actual test classes. We will experiment with the number of examples per class.

Finally, we can create the training, validation and test dataset according to our split above. For this, we create dataset objects of our previously defined class ImageDataset.

[10]:
def dataset_from_labels(imgs, targets, class_set, **kwargs):
    class_mask = (targets[:, None] == class_set[None, :]).any(dim=-1)
    return ImageDataset(imgs=imgs[class_mask], targets=targets[class_mask], **kwargs)

As in our experiments before on CIFAR in Tutorial 5, 6 and 9, we normalize the dataset. Additionally, we use small augmentations during training to prevent overfitting.

[11]:
DATA_MEANS = (cifar_train_set.data / 255.0).mean(axis=(0, 1, 2))
DATA_STD = (cifar_train_set.data / 255.0).std(axis=(0, 1, 2))

test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(DATA_MEANS, DATA_STD)])
# For training, we add some augmentation.
train_transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(DATA_MEANS, DATA_STD),
    ]
)

train_set = dataset_from_labels(cifar_all_images, cifar_all_targets, train_classes, img_transform=train_transform)
val_set = dataset_from_labels(cifar_all_images, cifar_all_targets, val_classes, img_transform=test_transform)
test_set = dataset_from_labels(cifar_all_images, cifar_all_targets, test_classes, img_transform=test_transform)
Data sampling

The strategy of how to use the available training data for learning few-shot adaptation is crucial in meta-learning. All three algorithms that we discuss here have a similar idea: simulate few-shot learning during training. Specifically, at each training step, we randomly select a small number of classes, and sample a small number of examples for each class. This represents our few-shot training batch, which we also refer to as support set. Additionally, we sample a second set of examples from the same classes, and refer to this batch as query set. Our training objective is to classify the query set correctly from seeing the support set and its corresponding labels. The main difference between our three methods (ProtoNet, MAML, and Proto-MAML) is in how they use the support set to adapt to the training classes.

This subsection summarizes the code that is needed to create such training batches. In PyTorch, we can specify the data sampling procedure by so-called Sampler (documentation). Samplers are iteratable objects that return indices in the order in which the data elements should be sampled. In our previous notebooks, we usually used the option shuffle=True in the data.DataLoader objects which creates a sampler returning the data indices in a random order. Here, we focus on samplers that return batches of indices that correspond to support and query set batches. Below, we implement such a sampler.

[12]:
class FewShotBatchSampler:
    def __init__(self, dataset_targets, N_way, K_shot, include_query=False, shuffle=True, shuffle_once=False):
        """
        Inputs:
            dataset_targets - PyTorch tensor of the labels of the data elements.
            N_way - Number of classes to sample per batch.
            K_shot - Number of examples to sample per class in the batch.
            include_query - If True, returns batch of size N_way*K_shot*2, which
                            can be split into support and query set. Simplifies
                            the implementation of sampling the same classes but
                            distinct examples for support and query set.
            shuffle - If True, examples and classes are newly shuffled in each
                      iteration (for training)
            shuffle_once - If True, examples and classes are shuffled once in
                           the beginning, but kept constant across iterations
                           (for validation)
        """
        super().__init__()
        self.dataset_targets = dataset_targets
        self.N_way = N_way
        self.K_shot = K_shot
        self.shuffle = shuffle
        self.include_query = include_query
        if self.include_query:
            self.K_shot *= 2
        self.batch_size = self.N_way * self.K_shot  # Number of overall images per batch

        # Organize examples by class
        self.classes = torch.unique(self.dataset_targets).tolist()
        self.num_classes = len(self.classes)
        self.indices_per_class = {}
        self.batches_per_class = {}  # Number of K-shot batches that each class can provide
        for c in self.classes:
            self.indices_per_class[c] = torch.where(self.dataset_targets == c)[0]
            self.batches_per_class[c] = self.indices_per_class[c].shape[0] // self.K_shot

        # Create a list of classes from which we select the N classes per batch
        self.iterations = sum(self.batches_per_class.values()) // self.N_way
        self.class_list = [c for c in self.classes for _ in range(self.batches_per_class[c])]
        if shuffle_once or self.shuffle:
            self.shuffle_data()
        else:
            # For testing, we iterate over classes instead of shuffling them
            sort_idxs = [
                i + p * self.num_classes for i, c in enumerate(self.classes) for p in range(self.batches_per_class[c])
            ]
            self.class_list = np.array(self.class_list)[np.argsort(sort_idxs)].tolist()

    def shuffle_data(self):
        # Shuffle the examples per class
        for c in self.classes:
            perm = torch.randperm(self.indices_per_class[c].shape[0])
            self.indices_per_class[c] = self.indices_per_class[c][perm]
        # Shuffle the class list from which we sample. Note that this way of shuffling
        # does not prevent to choose the same class twice in a batch. However, for
        # training and validation, this is not a problem.
        random.shuffle(self.class_list)

    def __iter__(self):
        # Shuffle data
        if self.shuffle:
            self.shuffle_data()

        # Sample few-shot batches
        start_index = defaultdict(int)
        for it in range(self.iterations):
            class_batch = self.class_list[it * self.N_way : (it + 1) * self.N_way]  # Select N classes for the batch
            index_batch = []
            for c in class_batch:  # For each class, select the next K examples and add them to the batch
                index_batch.extend(self.indices_per_class[c][start_index[c] : start_index[c] + self.K_shot])
                start_index[c] += self.K_shot
            if self.include_query:  # If we return support+query set, sort them so that they are easy to split
                index_batch = index_batch[::2] + index_batch[1::2]
            yield index_batch

    def __len__(self):
        return self.iterations

Now, we can create our intended data loaders by passing an object of FewShotBatchSampler as batch_sampler=... input to the PyTorch data loader object. For our experiments, we will use a 5-class 4-shot training setting. This means that each support set contains 5 classes with 4 examples each, i.e., 20 images overall. Usually, it is good to keep the number of shots equal to the number that you aim to test on. However, we will experiment later with different number of shots, and hence, we pick 4 as a compromise for now. To get the best performing model, it is recommended to consider the number of training shots as hyperparameter in a grid search.

[13]:
N_WAY = 5
K_SHOT = 4
train_data_loader = data.DataLoader(
    train_set,
    batch_sampler=FewShotBatchSampler(train_set.targets, include_query=True, N_way=N_WAY, K_shot=K_SHOT, shuffle=True),
    num_workers=4,
)
val_data_loader = data.DataLoader(
    val_set,
    batch_sampler=FewShotBatchSampler(
        val_set.targets, include_query=True, N_way=N_WAY, K_shot=K_SHOT, shuffle=False, shuffle_once=True
    ),
    num_workers=4,
)

For simplicity, we implemented the sampling of a support and query set as sampling a support set with twice the number of examples. After sampling a batch from the data loader, we need to split it into a support and query set. We can summarize this step in the following function:

[14]:
def split_batch(imgs, targets):
    support_imgs, query_imgs = imgs.chunk(2, dim=0)
    support_targets, query_targets = targets.chunk(2, dim=0)
    return support_imgs, query_imgs, support_targets, query_targets

Finally, to ensure that our implementation of the data sampling process is correct, we can sample a batch and visualize its support and query set. What we would like to see is that the support and query set have the same classes, but distinct examples.

[15]:
imgs, targets = next(iter(val_data_loader))  # We use the validation set since it does not apply augmentations
support_imgs, query_imgs, _, _ = split_batch(imgs, targets)
support_grid = torchvision.utils.make_grid(support_imgs, nrow=K_SHOT, normalize=True, pad_value=0.9)
support_grid = support_grid.permute(1, 2, 0)
query_grid = torchvision.utils.make_grid(query_imgs, nrow=K_SHOT, normalize=True, pad_value=0.9)
query_grid = query_grid.permute(1, 2, 0)

fig, ax = plt.subplots(1, 2, figsize=(8, 5))
ax[0].imshow(support_grid)
ax[0].set_title("Support set")
ax[0].axis("off")
ax[1].imshow(query_grid)
ax[1].set_title("Query set")
ax[1].axis("off")
fig.suptitle("Few Shot Batch", weight="bold")
fig.show()
plt.close(fig)

As we can see, the support and query set have the same five classes, but different examples. The models will be tasked to classify the examples in the query set by learning from the support set and its labels. With the data sampling in place, we can now start to implement our first meta-learning model: Prototypical Networks.

Prototypical Networks

The Prototypical Network, or ProtoNet for short, is a metric-based meta-learning algorithm which operates similar to a nearest neighbor classification. Metric-based meta-learning methods classify a new example \mathbf{x} based on some distance function d_{\varphi} between x and all elements in the support set. ProtoNets implements this idea with the concept of prototypes in a learned feature space. First, ProtoNet uses an embedding function f_{\theta} to encode each input in the support set into a L-dimensional feature vector. Next, for each class c, we collect the feature vectors of all examples with label c, and average their feature vectors. Formally, we can define this as:

\mathbf{v}_c=\frac{1}{|S_c|}\sum_{(\mathbf{x}_i,y_i)\in S_c}f_{\theta}(\mathbf{x}_i)

where S_c is the part of the support set S for which y_i=c, and \mathbf{v}_c represents the prototype of class c. The prototype calculation is visualized below for a 2-dimensional feature space and 3 classes (Figure credit - Snell et al.). The colored dots represent encoded support elements with color-corresponding class label, and the black dots next to the class label are the averaged prototypes.

fb49876d485c40b1aac894ed3b86d9e3

Based on these prototypes, we want to classify a new example. Remember that since we want to learn the encoding function f_{\theta}, this classification must be differentiable and hence, we need to define a probability distribution across classes. For this, we will make use of the distance function d_{\varphi}: the closer a new example \mathbf{x} is to a prototype \mathbf{v}_c, the higher the probability for \mathbf{x} belonging to class c. Formally, we can simply use a softmax over the distances of \mathbf{x} to all class prototypes:

p(y=c\vert\mathbf{x})=\text{softmax}(-d_{\varphi}(f_{\theta}(\mathbf{x}), \mathbf{v}_c))=\frac{\exp\left(-d_{\varphi}(f_{\theta}(\mathbf{x}), \mathbf{v}_c)\right)}{\sum_{c'\in \mathcal{C}}\exp\left(-d_{\varphi}(f_{\theta}(\mathbf{x}), \mathbf{v}_{c'})\right)}

Note that the negative sign is necessary since we want to increase the probability for close-by vectors and have a low probability for distant vectors. We train the network f_{\theta} based on the cross entropy error of the training query set examples. Thereby, the gradient flows through both the prototypes \mathbf{v}_c and the query set encodings f_{\theta}(\mathbf{x}). For the distance function d_{\varphi}, we can choose any function as long as it is differentiable with respect to both of its inputs. The most common function, which we also use here, is the squared euclidean distance, but there has been several works on different distance functions as well.

ProtoNet implementation

Now that we know how a ProtoNet works in principle, let’s look at how we can apply to our specific problem of few-shot image classification, and implement it below. First, we need to define the encoder function f_{\theta}. Since we work with CIFAR images, we can take a look back at Tutorial 5 where we compared common Computer Vision architectures, and choose one of the best performing ones. Here, we go with a DenseNet since it is in general more parameter efficient than ResNet. Luckily, we do not need to implement DenseNet ourselves again and can rely on torchvision’s model package instead. We use common hyperparameters of 64 initial feature channels, add 32 per block, and use a bottleneck size of 64 (i.e. 2 times the growth rate). We use 4 stages of 6 layers each, which results in overall about 1 million parameters. Note that the torchvision package assumes that the last layer is used for classification and hence calls its output size num_classes. However, we can instead just use it as the feature space of ProtoNet, and choose an arbitrary dimensionality. We will use the same network for other algorithms in this notebook to ensure a fair comparison.

[16]:
def get_convnet(output_size):
    convnet = torchvision.models.DenseNet(
        growth_rate=32,
        block_config=(6, 6, 6, 6),
        bn_size=2,
        num_init_features=64,
        num_classes=output_size,  # Output dimensionality
    )
    return convnet

Next, we can look at implementing ProtoNet. We will define it as PyTorch Lightning module to use all functionalities of PyTorch Lightning. The first step during training is to encode all images in a batch with our network. Next, we calculate the class prototypes from the support set (function calculate_prototypes), and classify the query set examples according to the prototypes (function classify_feats). Keep in mind that we use the data sampling described before, such that the support and query set are stacked together in the batch. Thus, we use our previously defined function split_batch to split them apart. The full code can be found below.

[17]:
class ProtoNet(pl.LightningModule):
    def __init__(self, proto_dim, lr):
        """Inputs.

        proto_dim - Dimensionality of prototype feature space
        lr - Learning rate of Adam optimizer
        """
        super().__init__()
        self.save_hyperparameters()
        self.model = get_convnet(output_size=self.hparams.proto_dim)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[140, 180], gamma=0.1)
        return [optimizer], [scheduler]

    @staticmethod
    def calculate_prototypes(features, targets):
        # Given a stack of features vectors and labels, return class prototypes
        # features - shape [N, proto_dim], targets - shape [N]
        classes, _ = torch.unique(targets).sort()  # Determine which classes we have
        prototypes = []
        for c in classes:
            p = features[torch.where(targets == c)[0]].mean(dim=0)  # Average class feature vectors
            prototypes.append(p)
        prototypes = torch.stack(prototypes, dim=0)
        # Return the 'classes' tensor to know which prototype belongs to which class
        return prototypes, classes

    def classify_feats(self, prototypes, classes, feats, targets):
        # Classify new examples with prototypes and return classification error
        dist = torch.pow(prototypes[None, :] - feats[:, None], 2).sum(dim=2)  # Squared euclidean distance
        preds = F.log_softmax(-dist, dim=1)
        labels = (classes[None, :] == targets[:, None]).long().argmax(dim=-1)
        acc = (preds.argmax(dim=1) == labels).float().mean()
        return preds, labels, acc

    def calculate_loss(self, batch, mode):
        # Determine training loss for a given support and query set
        imgs, targets = batch
        features = self.model(imgs)  # Encode all images of support and query set
        support_feats, query_feats, support_targets, query_targets = split_batch(features, targets)
        prototypes, classes = ProtoNet.calculate_prototypes(support_feats, support_targets)
        preds, labels, acc = self.classify_feats(prototypes, classes, query_feats, query_targets)
        loss = F.cross_entropy(preds, labels)

        self.log("%s_loss" % mode, loss)
        self.log("%s_acc" % mode, acc)
        return loss

    def training_step(self, batch, batch_idx):
        return self.calculate_loss(batch, mode="train")

    def validation_step(self, batch, batch_idx):
        self.calculate_loss(batch, mode="val")

For validation, we use the same principle as training and sample support and query sets from the hold-out 10 classes. However, this gives us noisy scores depending on which query sets are chosen to which support sets. This is why we will use a different strategy during testing. For validation, our training strategy is sufficient since it is much faster than testing, and gives a good estimate of the training generalization as long as we keep the support-query sets constant across validation iterations.

Training

After implementing the model, we can already start training it. We use our common PyTorch Lightning training function, and train the model for 200 epochs. The training function takes model_class as input argument, i.e. the PyTorch Lightning module class that should be trained, since we will reuse this function for other algorithms as well.

[18]:
def train_model(model_class, train_loader, val_loader, **kwargs):
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, model_class.__name__),
        gpus=1 if str(device) == "cuda:0" else 0,
        max_epochs=200,
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
            LearningRateMonitor("epoch"),
        ],
        progress_bar_refresh_rate=0,
    )
    trainer.logger._default_hp_metric = None

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, model_class.__name__ + ".ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model at %s, loading..." % pretrained_filename)
        # Automatically loads the model with the saved hyperparameters
        model = model_class.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)  # To be reproducable
        model = model_class(**kwargs)
        trainer.fit(model, train_loader, val_loader)
        model = model_class.load_from_checkpoint(
            trainer.checkpoint_callback.best_model_path
        )  # Load best checkpoint after training

    return model

Below is the training call for our ProtoNet. We use a 64-dimensional feature space. Larger feature spaces showed to give noisier results since the squared euclidean distance becomes proportionally larger in expectation, and smaller feature spaces might not allow for enough flexibility. We recommend to load the pre-trained model here at first, but feel free to play around with the hyperparameters yourself.

[19]:
protonet_model = train_model(
    ProtoNet, proto_dim=64, lr=2e-4, train_loader=train_data_loader, val_loader=val_data_loader
)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Found pretrained model at saved_models/MetaLearning/ProtoNet.ckpt, loading...

We can also take a closer look at the TensorBoard below.

[20]:
# Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH if needed
# # %tensorboard --logdir ../saved_models/tutorial16/tensorboards/ProtoNet/

c5fee26c62184481bbfa7c3944af8d53

In contrast to standard supervised learning, we see that ProtoNet does not overfit as much as we would expect. The validation accuracy is of course lower than the average training, but the training loss does not stick close to zero. This is because no training batch is as the other, and we also mix new examples in the support set and query set. This gives us slightly different prototypes in every iteration, and makes it harder for the network to fully overfit.

Testing

Our goal of meta-learning is to obtain a model that can quickly adapt to a new task, or in this case, new classes to distinguish between. To test this, we will use our trained ProtoNet and adapt it to the 10 test classes. Thereby, we pick k examples per class from which we determine the prototypes, and test the classification accuracy on all other examples. This can be seen as using the k examples per class as support set, and the rest of the dataset as a query set. We iterate through the dataset such that each example has been once included in a support set. The average performance over all support sets tells us how well we can expect ProtoNet to perform when seeing only k examples per class. During training, we used k=4. In testing, we will experiment with k=\{2,4,8,16,32\} to get a better sense of how k influences the results. We would expect that we achieve higher accuracies the more examples we have in the support set, but we don’t know how it scales. Hence, let’s first implement a function that executes the testing procedure for a given k:

[21]:
@torch.no_grad()
def test_proto_net(model, dataset, data_feats=None, k_shot=4):
    """Inputs.

    model - Pretrained ProtoNet model
    dataset - The dataset on which the test should be performed.
              Should be instance of ImageDataset
    data_feats - The encoded features of all images in the dataset.
                 If None, they will be newly calculated, and returned
                 for later usage.
    k_shot - Number of examples per class in the support set.
    """
    model = model.to(device)
    model.eval()
    num_classes = dataset.targets.unique().shape[0]
    exmps_per_class = dataset.targets.shape[0] // num_classes  # We assume uniform example distribution here

    # The encoder network remains unchanged across k-shot settings. Hence, we only need
    # to extract the features for all images once.
    if data_feats is None:
        # Dataset preparation
        dataloader = data.DataLoader(dataset, batch_size=128, num_workers=4, shuffle=False, drop_last=False)

        img_features = []
        img_targets = []
        for imgs, targets in tqdm(dataloader, "Extracting image features", leave=False):
            imgs = imgs.to(device)
            feats = model.model(imgs)
            img_features.append(feats.detach().cpu())
            img_targets.append(targets)
        img_features = torch.cat(img_features, dim=0)
        img_targets = torch.cat(img_targets, dim=0)
        # Sort by classes, so that we obtain tensors of shape [num_classes, exmps_per_class, ...]
        # Makes it easier to process later
        img_targets, sort_idx = img_targets.sort()
        img_targets = img_targets.reshape(num_classes, exmps_per_class).transpose(0, 1)
        img_features = img_features[sort_idx].reshape(num_classes, exmps_per_class, -1).transpose(0, 1)
    else:
        img_features, img_targets = data_feats

    # We iterate through the full dataset in two manners. First, to select the k-shot batch.
    # Second, the evaluate the model on all other examples
    accuracies = []
    for k_idx in tqdm(range(0, img_features.shape[0], k_shot), "Evaluating prototype classification", leave=False):
        # Select support set and calculate prototypes
        k_img_feats = img_features[k_idx : k_idx + k_shot].flatten(0, 1)
        k_targets = img_targets[k_idx : k_idx + k_shot].flatten(0, 1)
        prototypes, proto_classes = model.calculate_prototypes(k_img_feats, k_targets)
        # Evaluate accuracy on the rest of the dataset
        batch_acc = 0
        for e_idx in range(0, img_features.shape[0], k_shot):
            if k_idx == e_idx:  # Do not evaluate on the support set examples
                continue
            e_img_feats = img_features[e_idx : e_idx + k_shot].flatten(0, 1)
            e_targets = img_targets[e_idx : e_idx + k_shot].flatten(0, 1)
            _, _, acc = model.classify_feats(prototypes, proto_classes, e_img_feats, e_targets)
            batch_acc += acc.item()
        batch_acc /= img_features.shape[0] // k_shot - 1
        accuracies.append(batch_acc)

    return (mean(accuracies), stdev(accuracies)), (img_features, img_targets)

Testing ProtoNet is relatively quick if we have processed all images once. Hence, we can do in this notebook:

[22]:
protonet_accuracies = dict()
data_feats = None
for k in [2, 4, 8, 16, 32]:
    protonet_accuracies[k], data_feats = test_proto_net(protonet_model, test_set, data_feats=data_feats, k_shot=k)
    print(
        "Accuracy for k=%i: %4.2f%% (+-%4.2f%%)"
        % (k, 100.0 * protonet_accuracies[k][0], 100 * protonet_accuracies[k][1])
    )
Accuracy for k=2: 44.30% (+-3.63%)
Accuracy for k=4: 52.07% (+-2.27%)
Accuracy for k=8: 57.59% (+-1.30%)
Accuracy for k=16: 62.56% (+-1.02%)
Accuracy for k=32: 66.49% (+-0.87%)

Before discussing the results above, let’s first plot the accuracies over number of examples in the support set:

[23]:
def plot_few_shot(acc_dict, name, color=None, ax=None):
    sns.set()
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(5, 3))
    ks = sorted(list(acc_dict.keys()))
    mean_accs = [acc_dict[k][0] for k in ks]
    std_accs = [acc_dict[k][1] for k in ks]
    ax.plot(ks, mean_accs, marker="o", markeredgecolor="k", markersize=6, label=name, color=color)
    ax.fill_between(
        ks,
        [m - s for m, s in zip(mean_accs, std_accs)],
        [m + s for m, s in zip(mean_accs, std_accs)],
        alpha=0.2,
        color=color,
    )
    ax.set_xticks(ks)
    ax.set_xlim([ks[0] - 1, ks[-1] + 1])
    ax.set_xlabel("Number of shots per class", weight="bold")
    ax.set_ylabel("Accuracy", weight="bold")
    if len(ax.get_title()) == 0:
        ax.set_title("Few-Shot Performance " + name, weight="bold")
    else:
        ax.set_title(ax.get_title() + " and " + name, weight="bold")
    ax.legend()
    return ax
[24]:
ax = plot_few_shot(protonet_accuracies, name="ProtoNet", color="C1")
plt.show()
plt.close()
_images/notebooks_course_UvA-DL_12-meta-learning_52_0.svg

As we initially expected, the performance of ProtoNet indeed increases the more samples we have. However, even with just two samples per class, we classify almost half of the images correctly, which is well above random accuracy (10%). The curve shows an exponentially dampend trend, meaning that adding 2 extra examples to k=2 has a much higher impact than adding 2 extra samples if we already have k=16. Nonetheless, we can say that ProtoNet adapts fairly well to new classes.

MAML and ProtoMAML

The second meta-learning algorithm we will look at is MAML, short for Model-Agnostic Meta-Learning. MAML is an optimization-based meta-learning algorithm, which means that it tries to adjust the standard optimization procedure to a few-shot setting. The idea of MAML is relatively simple: given a model, support and query set during training, we optimize the model for m steps on the support set, and evaluate the gradients of the query loss with respect to the original model’s parameters. For the same model, we do it for a few different support-query sets and accumulate the gradients. This results in learning a model that provides a good initialization for being quickly adapted to the training tasks. If we denote the model parameters with \theta, we can visualize the procedure as follows (Figure credit - Finn et al. ).

1001ee3b51114959868ac1a30d9797ea

The full algorithm of MAML is therefore as follows. At each training step, we sample a batch of tasks, i.e., a batch of support-query set pairs. For each task \mathcal{T}_i, we optimize a model f_{\theta} on the support set via SGD, and denote this model as f_{\theta_i'}. We refer to this optimization as inner loop. Using this new model, we calculate the gradients of the original parameters, \theta, with respect to the query loss on f_{\theta_i'}. These gradients are accumulated over all tasks, and used to update \theta. This is called outer loop since we iterate over tasks. The full MAML algorithm is summarized below (Figure credit - Finn et al. ).

0bd4d879356342d8a7bb2933011b39e9

To obtain gradients for the initial parameters \theta from the optimized model f_{\theta_i'}, we actually need second-order gradients, i.e. gradients of gradients, as the support set gradients depend on \theta as well. This makes MAML computationally expensive, especially when using mulitple inner loop steps. A simpler, yet almost equally well performing alternative is First-Order MAML (FOMAML) which only uses first-order gradients. This means that the second-order gradients are ignored, and we can calculate the outer loop gradients (line 10 in algorithm 2) simply by calculating the gradients with respect to \theta_i', and use those as update to \theta. Hence, the new update rule becomes:

\theta\leftarrow\theta-\beta\sum_{\mathcal{T}_i\sim p(\mathcal{T})}\nabla_{\theta_i'}\mathcal{L}_{\mathcal{T}_i}(f_{\theta_i'})

Note the change of \theta to \theta_i' for \nabla.

ProtoMAML

A problem of MAML is how to design the output classification layer. In case all tasks have different number of classes, we need to initialize the output layer with zeros or randomly in every iteration. Even if we always have the same number of classes, we just start from random predictions. This requires several inner loop steps to reach a reasonable classification result. To overcome this problem, Triantafillou et al. (2020) propose to combine the merits of Prototypical Networks and MAML. Specifically, we can use prototypes to initialize our output layer to have a strong initialization. Thereby, it can be shown that the softmax over euclidean distances can be reformulated as a linear layer with softmax. To see this, let’s first write out the negative euclidean distance between a feature vector f_{\theta}(\mathbf{x}^{*}) of a new data point \mathbf{x}^{*} to a prototype \mathbf{v}_c of class c:

-||f_{\theta}(\mathbf{x}^{*})-\mathbf{v}_c||^2=-f_{\theta}(\mathbf{x}^{*})^Tf_{\theta}(\mathbf{x}^{*})+2\mathbf{v}_c^{T}f_{\theta}(\mathbf{x}^{*})-\mathbf{v}_c^T\mathbf{v}_c

We perform the classification across all classes c\in\mathcal{C} and take a softmax on the distance. Hence, any term that is same for all classes can be removed without changing the output probabilities. In the equation above, this is true for -f_{\theta}(\mathbf{x}^{*})^Tf_{\theta}(\mathbf{x}^{*}) since it is independent of any class prototype. Thus, we can write:

-||f_{\theta}(\mathbf{x}^{*})-\mathbf{v}_c||^2=2\mathbf{v}_c^{T}f_{\theta}(\mathbf{x}^{*})-||\mathbf{v}_c||^2+\text{constant}

Taking a second look at the equation above, it looks a lot like a linear layer. For this, we use \mathbf{W}_{c,\cdot}=2\mathbf{v}_c and b_c=-||\mathbf{v}_c||^2 which gives us the linear layer \mathbf{W}f_{\theta}(\mathbf{x}^{*})+\mathbf{b}. Hence, if we initialize the output weight with twice the prototypes, and the biases by the negative squared L2 norm of the prototypes, we start with a Prototypical Network. MAML allows us to adapt this layer and the rest of the network further.

In the following, we will implement First-Order ProtoMAML for few-shot classification. The implementation of MAML would be the same except the output layer initialization.

ProtoMAML implementation

For implementing ProtoMAML, we can follow Algorithm 2 with minor modifications. At each training step, we first sample a batch of tasks, and a support and query set for each task. In our case of few-shot classification, this means that we simply sample multiple support-query set pairs from our sampler. For each task, we finetune our current model on the support set. However, since we need to remember the original parameters for the other tasks, the outer loop gradient update and future training steps, we need to create a copy of our model, and finetune only the copy. We can copy a model by using standard Python functions like deepcopy. The inner loop is implemented in the function adapt_few_shot in the PyTorch Lightning module below.

After finetuning the model, we apply it on the query set and calculate the first-order gradients with respect to the original parameters \theta. In contrast to simple MAML, we also have to consider the gradients with respect to the output layer initialization, i.e. the prototypes, since they directly rely on \theta. To realize this efficiently, we take two steps. First, we calculate the prototypes by applying the original model, i.e. not the copied model, on the support elements. When initializing the output layer, we detach the prototypes to stop the gradients. This is because in the inner loop itself, we do not want to consider gradients through the prototypes back to the original model. However, after the inner loop is finished, we re-attach the computation graph of the prototypes by writing output_weight = (output_weight - init_weight).detach() + init_weight. While this line does not change the value of the variable output_weight, it adds its dependency on the prototype initialization init_weight. Thus, if we call .backward on output_weight, we will automatically calculate the first-order gradients with respect to the prototype initialization in the original model.

After calculating all gradients and summing them together in the original model, we can take a standard optimizer step. PyTorch Lightning’s method is however designed to return a loss-tensor on which we call .backward first. Since this is not possible here, we need to perform the optimization step ourselves. All details can be found in the code below.

For implementing (Proto-)MAML with second-order gradients, it is recommended to use libraries such as :math:nabla`higher <https://github.com/facebookresearch/higher>`__ from Facebook AI Research. For simplicity, we stick with first-order methods here.

[25]:
class ProtoMAML(pl.LightningModule):
    def __init__(self, proto_dim, lr, lr_inner, lr_output, num_inner_steps):
        """Inputs.

        proto_dim - Dimensionality of prototype feature space
        lr - Learning rate of the outer loop Adam optimizer
        lr_inner - Learning rate of the inner loop SGD optimizer
        lr_output - Learning rate for the output layer in the inner loop
        num_inner_steps - Number of inner loop updates to perform
        """
        super().__init__()
        self.save_hyperparameters()
        self.model = get_convnet(output_size=self.hparams.proto_dim)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[140, 180], gamma=0.1)
        return [optimizer], [scheduler]

    def run_model(self, local_model, output_weight, output_bias, imgs, labels):
        # Execute a model with given output layer weights and inputs
        feats = local_model(imgs)
        preds = F.linear(feats, output_weight, output_bias)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=1) == labels).float()
        return loss, preds, acc

    def adapt_few_shot(self, support_imgs, support_targets):
        # Determine prototype initialization
        support_feats = self.model(support_imgs)
        prototypes, classes = ProtoNet.calculate_prototypes(support_feats, support_targets)
        support_labels = (classes[None, :] == support_targets[:, None]).long().argmax(dim=-1)
        # Create inner-loop model and optimizer
        local_model = deepcopy(self.model)
        local_model.train()
        local_optim = optim.SGD(local_model.parameters(), lr=self.hparams.lr_inner)
        local_optim.zero_grad()
        # Create output layer weights with prototype-based initialization
        init_weight = 2 * prototypes
        init_bias = -torch.norm(prototypes, dim=1) ** 2
        output_weight = init_weight.detach().requires_grad_()
        output_bias = init_bias.detach().requires_grad_()

        # Optimize inner loop model on support set
        for _ in range(self.hparams.num_inner_steps):
            # Determine loss on the support set
            loss, _, _ = self.run_model(local_model, output_weight, output_bias, support_imgs, support_labels)
            # Calculate gradients and perform inner loop update
            loss.backward()
            local_optim.step()
            # Update output layer via SGD
            output_weight.data -= self.hparams.lr_output * output_weight.grad
            output_bias.data -= self.hparams.lr_output * output_bias.grad
            # Reset gradients
            local_optim.zero_grad()
            output_weight.grad.fill_(0)
            output_bias.grad.fill_(0)

        # Re-attach computation graph of prototypes
        output_weight = (output_weight - init_weight).detach() + init_weight
        output_bias = (output_bias - init_bias).detach() + init_bias

        return local_model, output_weight, output_bias, classes

    def outer_loop(self, batch, mode="train"):
        accuracies = []
        losses = []
        self.model.zero_grad()

        # Determine gradients for batch of tasks
        for task_batch in batch:
            imgs, targets = task_batch
            support_imgs, query_imgs, support_targets, query_targets = split_batch(imgs, targets)
            # Perform inner loop adaptation
            local_model, output_weight, output_bias, classes = self.adapt_few_shot(support_imgs, support_targets)
            # Determine loss of query set
            query_labels = (classes[None, :] == query_targets[:, None]).long().argmax(dim=-1)
            loss, preds, acc = self.run_model(local_model, output_weight, output_bias, query_imgs, query_labels)
            # Calculate gradients for query set loss
            if mode == "train":
                loss.backward()

                for p_global, p_local in zip(self.model.parameters(), local_model.parameters()):
                    p_global.grad += p_local.grad  # First-order approx. -> add gradients of finetuned and base model

            accuracies.append(acc.mean().detach())
            losses.append(loss.detach())

        # Perform update of base model
        if mode == "train":
            opt = self.optimizers()
            opt.step()
            opt.zero_grad()

        self.log("%s_loss" % mode, sum(losses) / len(losses))
        self.log("%s_acc" % mode, sum(accuracies) / len(accuracies))

    def training_step(self, batch, batch_idx):
        self.outer_loop(batch, mode="train")
        return None  # Returning None means we skip the default training optimizer steps by PyTorch Lightning

    def validation_step(self, batch, batch_idx):
        # Validation requires to finetune a model, hence we need to enable gradients
        torch.set_grad_enabled(True)
        self.outer_loop(batch, mode="val")
        torch.set_grad_enabled(False)
Training

To train ProtoMAML, we need to change our sampling slightly. Instead of a single support-query set batch, we need to sample multiple. To implement this, we yet use another Sampler which combines multiple batches from a FewShotBatchSampler, and returns it afterwards. Additionally, we define a collate_fn for our data loader which takes the stack of support-query set images, and returns the tasks as a list. This makes it easier to process in our PyTorch Lightning module before. The implementation of the sampler can be found below.

[26]:
class TaskBatchSampler:
    def __init__(self, dataset_targets, batch_size, N_way, K_shot, include_query=False, shuffle=True):
        """
        Inputs:
            dataset_targets - PyTorch tensor of the labels of the data elements.
            batch_size - Number of tasks to aggregate in a batch
            N_way - Number of classes to sample per batch.
            K_shot - Number of examples to sample per class in the batch.
            include_query - If True, returns batch of size N_way*K_shot*2, which
                            can be split into support and query set. Simplifies
                            the implementation of sampling the same classes but
                            distinct examples for support and query set.
            shuffle - If True, examples and classes are newly shuffled in each
                      iteration (for training)
        """
        super().__init__()
        self.batch_sampler = FewShotBatchSampler(dataset_targets, N_way, K_shot, include_query, shuffle)
        self.task_batch_size = batch_size
        self.local_batch_size = self.batch_sampler.batch_size

    def __iter__(self):
        # Aggregate multiple batches before returning the indices
        batch_list = []
        for batch_idx, batch in enumerate(self.batch_sampler):
            batch_list.extend(batch)
            if (batch_idx + 1) % self.task_batch_size == 0:
                yield batch_list
                batch_list = []

    def __len__(self):
        return len(self.batch_sampler) // self.task_batch_size

    def get_collate_fn(self):
        # Returns a collate function that converts one big tensor into a list of task-specific tensors
        def collate_fn(item_list):
            imgs = torch.stack([img for img, target in item_list], dim=0)
            targets = torch.stack([target for img, target in item_list], dim=0)
            imgs = imgs.chunk(self.task_batch_size, dim=0)
            targets = targets.chunk(self.task_batch_size, dim=0)
            return list(zip(imgs, targets))

        return collate_fn

The creation of the data loaders is with this sampler straight-forward. Note that since many images need to loaded for a training batch, it is recommended to use less workers than usual.

[27]:
# Training constant (same as for ProtoNet)
N_WAY = 5
K_SHOT = 4

# Training set
train_protomaml_sampler = TaskBatchSampler(
    train_set.targets, include_query=True, N_way=N_WAY, K_shot=K_SHOT, batch_size=16
)
train_protomaml_loader = data.DataLoader(
    train_set, batch_sampler=train_protomaml_sampler, collate_fn=train_protomaml_sampler.get_collate_fn(), num_workers=2
)

# Validation set
val_protomaml_sampler = TaskBatchSampler(
    val_set.targets,
    include_query=True,
    N_way=N_WAY,
    K_shot=K_SHOT,
    batch_size=1,  # We do not update the parameters, hence the batch size is irrelevant here
    shuffle=False,
)
val_protomaml_loader = data.DataLoader(
    val_set, batch_sampler=val_protomaml_sampler, collate_fn=val_protomaml_sampler.get_collate_fn(), num_workers=2
)

Now, we are ready to train our ProtoMAML. We use the same feature space size as for ProtoNet, but can use a higher learning rate since the outer loop gradients are accumulated over 16 batches. The inner loop learning rate is set to 0.1, which is much higher than the outer loop lr because we use SGD in the inner loop instead of Adam. Commonly, the learning rate for the output layer is higher than the base model is the base model is very deep or pre-trained. However, for our setup, we observed no noticable impact of using a different learning rate than the base model. The number of inner loop updates is another crucial hyperparmaeter, and depends on the similarity of our training tasks. Since all tasks are on images from the same dataset, we notice that a single inner loop update achieves similar performance as 3 or 5 while training considerably faster. However, especially in RL and NLP, larger number of inner loop steps are often needed.

[28]:
protomaml_model = train_model(
    ProtoMAML,
    proto_dim=64,
    lr=1e-3,
    lr_inner=0.1,
    lr_output=0.1,
    num_inner_steps=1,  # Often values between 1 and 10
    train_loader=train_protomaml_loader,
    val_loader=val_protomaml_loader,
)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Found pretrained model at saved_models/MetaLearning/ProtoMAML.ckpt, loading...

Let’s have a look at the training TensorBoard.

[29]:
# Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH if needed
# # %tensorboard --logdir ../saved_models/tutorial16/tensorboards/ProtoMAML/

bed1e70a111b4d8e81bbb5419ed961b0

One obvious difference to ProtoNet is that the loss curves look much less noisy. This is because we average the outer loop gradients over multiple tasks, and thus have a smoother training curve. Additionally, we only have 15k training iterations after 200 epochs. This is again because of the task batches, which cause 16 times less iterations. However, each iteration has seen 16 times more data in this experiment. Thus, we still have a fair comparison between ProtoMAML and ProtoNet. At first sight on the validation accuracy, one would assume that ProtoNet performs superior to ProtoMAML, but we have to verify that with proper testing below.

Testing

We test ProtoMAML in the same manner as ProtoNet, namely by picking random examples in the test set as support sets and use the rest of the dataset as query set. Instead of just calculating the prototypes for all examples, we need to finetune a separate model for each support set. This is why this process is more expensive than ProtoNet, and in our case, testing k=\{2,4,8,16,32\} can take almost an hour. Hence, we provide evaluation files besides the pretrained models.

[30]:
def test_protomaml(model, dataset, k_shot=4):
    pl.seed_everything(42)
    model = model.to(device)
    num_classes = dataset.targets.unique().shape[0]

    # Data loader for full test set as query set
    full_dataloader = data.DataLoader(dataset, batch_size=128, num_workers=4, shuffle=False, drop_last=False)
    # Data loader for sampling support sets
    sampler = FewShotBatchSampler(
        dataset.targets, include_query=False, N_way=num_classes, K_shot=k_shot, shuffle=False, shuffle_once=False
    )
    sample_dataloader = data.DataLoader(dataset, batch_sampler=sampler, num_workers=2)

    # We iterate through the full dataset in two manners. First, to select the k-shot batch.
    # Second, the evaluate the model on all other examples
    accuracies = []
    for (support_imgs, support_targets), support_indices in tqdm(
        zip(sample_dataloader, sampler), "Performing few-shot finetuning"
    ):
        support_imgs = support_imgs.to(device)
        support_targets = support_targets.to(device)
        # Finetune new model on support set
        local_model, output_weight, output_bias, classes = model.adapt_few_shot(support_imgs, support_targets)
        with torch.no_grad():  # No gradients for query set needed
            local_model.eval()
            batch_acc = torch.zeros((0,), dtype=torch.float32, device=device)
            # Evaluate all examples in test dataset
            for query_imgs, query_targets in full_dataloader:
                query_imgs = query_imgs.to(device)
                query_targets = query_targets.to(device)
                query_labels = (classes[None, :] == query_targets[:, None]).long().argmax(dim=-1)
                _, _, acc = model.run_model(local_model, output_weight, output_bias, query_imgs, query_labels)
                batch_acc = torch.cat([batch_acc, acc.detach()], dim=0)
            # Exclude support set elements
            for s_idx in support_indices:
                batch_acc[s_idx] = 0
            batch_acc = batch_acc.sum().item() / (batch_acc.shape[0] - len(support_indices))
            accuracies.append(batch_acc)
    return mean(accuracies), stdev(accuracies)

In contrast to training, it is recommended to use many more inner loop updates during testing. During training, we are not interested in getting the best model from the inner loop, but the model which can provide the best gradients. Hence, one update might be already sufficient in training, but for testing, it was often observed that larger number of updates can give a considerable performance boost. Thus, we change the inner loop updates to 200 before testing.

[31]:
protomaml_model.hparams.num_inner_steps = 200

Now, we can test our model. For the pre-trained models, we provide a json file with the results to reduce evaluation time.

[32]:
protomaml_result_file = os.path.join(CHECKPOINT_PATH, "protomaml_fewshot.json")

if os.path.isfile(protomaml_result_file):
    # Load pre-computed results
    with open(protomaml_result_file) as f:
        protomaml_accuracies = json.load(f)
    protomaml_accuracies = {int(k): v for k, v in protomaml_accuracies.items()}
else:
    # Perform same experiments as for ProtoNet
    protomaml_accuracies = dict()
    for k in [2, 4, 8, 16, 32]:
        protomaml_accuracies[k] = test_protomaml(protomaml_model, test_set, k_shot=k)
    # Export results
    with open(protomaml_result_file, "w") as f:
        json.dump(protomaml_accuracies, f, indent=4)

for k in protomaml_accuracies:
    print(
        "Accuracy for k=%i: %4.2f%% (+-%4.2f%%)"
        % (k, 100.0 * protomaml_accuracies[k][0], 100.0 * protomaml_accuracies[k][1])
    )
Accuracy for k=2: 42.89% (+-3.82%)
Accuracy for k=4: 52.27% (+-2.72%)
Accuracy for k=8: 59.23% (+-1.50%)
Accuracy for k=16: 63.94% (+-1.24%)
Accuracy for k=32: 67.57% (+-0.90%)

Again, let’s plot the results in our plot from before.

[33]:
ax = plot_few_shot(protonet_accuracies, name="ProtoNet", color="C1")
plot_few_shot(protomaml_accuracies, name="ProtoMAML", color="C2", ax=ax)
plt.show()
plt.close()
_images/notebooks_course_UvA-DL_12-meta-learning_77_0.svg

We can observe that ProtoMAML is indeed able to outperform ProtoNet for k>4. This is because with more samples, it becomes more relevant to also adapt the base model’s parameters. Meanwhile, for k=2, ProtoMAML achieves lower performance than ProtoNet. This is likely also related to choosing 200 inner loop updates since with more updates, there exists the risk of overfitting. Nonetheless, the high standard deviation for k=2 makes it hard to take any statistically valid conclusion.

Overall, we can conclude that ProtoMAML slightly outperforms ProtoNet for larger shot counts. However, one disadvantage of ProtoMAML is its much longer training and testing time. ProtoNet provides a simple, efficient, yet strong baseline for ProtoMAML, and might be the better solution in situations where limited resources are available.

Domain adaptation

So far, we have evaluated our meta-learning algorithms on the same dataset on which we have trained them. However, meta-learning algorithms are especially interesting when we want to move from one to another dataset. So, what happens if we apply them on a quite different dataset than CIFAR? This is what we try out below, and evaluate ProtoNet and ProtoMAML on the SVHN dataset.

SVHN dataset

The Street View House Numbers (SVHN) dataset is a real-world image dataset for house number detection. It is similar to MNIST by having the classes 0 to 9, but is more difficult due to its real-world setting and possible distracting numbers left and right. Let’s first load the dataset, and visualize some images to get an impression of the dataset.

[34]:
SVHN_test_dataset = SVHN(root=DATASET_PATH, split="test", download=True, transform=transforms.ToTensor())
Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to /__w/1/s/.datasets/test_32x32.mat
[35]:
# Visualize some examples
NUM_IMAGES = 12
SVHN_images = [SVHN_test_dataset[np.random.randint(len(SVHN_test_dataset))][0] for idx in range(NUM_IMAGES)]
SVHN_images = torch.stack(SVHN_images, dim=0)
img_grid = torchvision.utils.make_grid(SVHN_images, nrow=6, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8, 8))
plt.title("Image examples of the SVHN dataset")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()
_images/notebooks_course_UvA-DL_12-meta-learning_82_0.svg

Each image is labeled with one class between 0 and 9 representing the main digit in the image. Can our ProtoNet and ProtoMAML learn to classify the digits from only a few examples? This is what we will test out below. The images have the same size as CIFAR, so that we can use the images without changes. We first prepare the dataset, for which we take the first 500 images per class. For this dataset, we use our test functions as before to get an estimated performance for different number of shots.

[36]:
imgs = np.transpose(SVHN_test_dataset.data, (0, 2, 3, 1))
targets = SVHN_test_dataset.labels
# Limit number of examples to 500 to reduce test time
min_label_count = min(500, np.bincount(SVHN_test_dataset.labels).min())

idxs = np.concatenate([np.where(targets == c)[0][:min_label_count] for c in range(1 + targets.max())], axis=0)
imgs = imgs[idxs]
targets = torch.from_numpy(targets[idxs]).long()

svhn_fewshot_dataset = ImageDataset(imgs, targets, img_transform=test_transform)
svhn_fewshot_dataset.imgs.shape
[36]:
(5000, 32, 32, 3)
Experiments

First, we can apply ProtoNet to the SVHN dataset:

[37]:
protonet_svhn_accuracies = dict()
data_feats = None
for k in [2, 4, 8, 16, 32]:
    protonet_svhn_accuracies[k], data_feats = test_proto_net(
        protonet_model, svhn_fewshot_dataset, data_feats=data_feats, k_shot=k
    )
    print(
        "Accuracy for k=%i: %4.2f%% (+-%4.2f%%)"
        % (k, 100.0 * protonet_svhn_accuracies[k][0], 100 * protonet_svhn_accuracies[k][1])
    )
Accuracy for k=2: 18.82% (+-2.28%)
Accuracy for k=4: 21.94% (+-2.09%)
Accuracy for k=8: 25.59% (+-1.76%)
Accuracy for k=16: 29.06% (+-1.84%)
Accuracy for k=32: 32.93% (+-1.33%)

It becomes clear that the results are much lower than the ones on CIFAR, and just slightly above random for k=2. How about ProtoMAML? We provide again evaluation files since the evaluation can take several minutes to complete.

[38]:
protomaml_result_file = os.path.join(CHECKPOINT_PATH, "protomaml_svhn_fewshot.json")

if os.path.isfile(protomaml_result_file):
    # Load pre-computed results
    with open(protomaml_result_file) as f:
        protomaml_svhn_accuracies = json.load(f)
    protomaml_svhn_accuracies = {int(k): v for k, v in protomaml_svhn_accuracies.items()}
else:
    # Perform same experiments as for ProtoNet
    protomaml_svhn_accuracies = dict()
    for k in [2, 4, 8, 16, 32]:
        protomaml_svhn_accuracies[k] = test_protomaml(protomaml_model, svhn_fewshot_dataset, k_shot=k)
    # Export results
    with open(protomaml_result_file, "w") as f:
        json.dump(protomaml_svhn_accuracies, f, indent=4)

for k in protomaml_svhn_accuracies:
    print(
        "Accuracy for k=%i: %4.2f%% (+-%4.2f%%)"
        % (k, 100.0 * protomaml_svhn_accuracies[k][0], 100.0 * protomaml_svhn_accuracies[k][1])
    )
Accuracy for k=2: 17.11% (+-1.95%)
Accuracy for k=4: 21.29% (+-1.92%)
Accuracy for k=8: 27.62% (+-1.84%)
Accuracy for k=16: 36.17% (+-1.80%)
Accuracy for k=32: 46.03% (+-1.65%)

While ProtoMAML shows similar performance than ProtoNet for k\leq 4, it considerably outperforms ProtoNet for more than 8 shots. This is because we can adapt the base model, which is crucial when the data does not fit the original training data. For k=32, ProtoMAML achieves 13\% higher classification accuracy than ProtoNet which already starts to flatten out. We can see the trend more clearly in our plot below.

[39]:
ax = plot_few_shot(protonet_svhn_accuracies, name="ProtoNet", color="C1")
plot_few_shot(protomaml_svhn_accuracies, name="ProtoMAML", color="C2", ax=ax)
plt.show()
plt.close()
_images/notebooks_course_UvA-DL_12-meta-learning_90_0.svg

Conclusion

In this notebook, we have discussed meta-learning algorithms that learn to adapt to new classes and/or tasks with just a few samples. We have discussed three popular algorithms, namely ProtoNet, MAML and ProtoMAML. On the few-shot image classification task of CIFAR100, ProtoNet and ProtoMAML showed to perform similarly well, with slight benefits of ProtoMAML for larger shot sizes. However, for out-of-distribution data (SVHN), the ability to optimize the base model showed to be crucial and gave ProtoMAML considerable performance gains over ProtoNet. Nonetheless, ProtoNet offers other advantages compared to ProtoMAML, namely a very cheap training and test cost as well as a simpler implementation. Hence, it is recommended to consider whether the additionally complexity of ProtoMAML is worth the extra training computation cost, or whether ProtoNet is already sufficient for the task at hand.

References

[1] Snell, Jake, Kevin Swersky, and Richard S. Zemel. “Prototypical networks for few-shot learning.” NeurIPS 2017. (link)

[2] Chelsea Finn, Pieter Abbeel, Sergey Levine. “Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.” ICML 2017. (link)

[3] Triantafillou, Eleni, Tyler Zhu, Vincent Dumoulin, Pascal Lamblin, Utku Evci, Kelvin Xu, Ross Goroshin et al. “Meta-dataset: A dataset of datasets for learning to learn from few examples.” ICLR 2020. (link)

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Tutorial 13: Self-Supervised Contrastive Learning with SimCLR

  • Author: Phillip Lippe

  • License: CC BY-SA

  • Generated: 2022-05-03T02:43:21.313398

In this tutorial, we will take a closer look at self-supervised contrastive learning. Self-supervised learning, or also sometimes called unsupervised learning, describes the scenario where we have given input data, but no accompanying labels to train in a classical supervised way. However, this data still contains a lot of information from which we can learn: how are the images different from each other? What patterns are descriptive for certain images? Can we cluster the images? To get an insight into these questions, we will implement a popular, simple contrastive learning method, SimCLR, and apply it to the STL10 dataset. This notebook is part of a lecture series on Deep Learning at the University of Amsterdam. The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "torch>=1.8" "torchmetrics>=0.7" "ipython[notebook]" "pytorch-lightning>=1.4" "torchvision" "setuptools==59.5.0" "seaborn" "matplotlib"

Methods for self-supervised learning try to learn as much as possible from the data alone, so it can quickly be finetuned for a specific classification task. The benefit of self-supervised learning is that a large dataset can often easily be obtained. For instance, if we want to train a vision model on semantic segmentation for autonomous driving, we can collect large amounts of data by simply installing a camera in a car, and driving through a city for an hour. In contrast, if we would want to do supervised learning, we would have to manually label all those images before training a model. This is extremely expensive, and would likely take a couple of months to manually label the same amount of data. Further, self-supervised learning can provide an alternative to transfer learning from models pretrained on ImageNet since we could pretrain a model on a specific dataset/situation, e.g. traffic scenarios for autonomous driving.

Within the last two years, a lot of new approaches have been proposed for self-supervised learning, in particular for images, that have resulted in great improvements over supervised models when few labels are available. The subfield that we will focus on in this tutorial is contrastive learning. Contrastive learning is motivated by the question mentioned above: how are images different from each other? Specifically, contrastive learning methods train a model to cluster an image and its slightly augmented version in latent space, while the distance to other images should be maximized. A very recent and simple method for this is SimCLR, which is visualized below (figure credit - Ting Chen et al. ).

simclr contrastive learning

The general setup is that we are given a dataset of images without any labels, and want to train a model on this data such that it can quickly adapt to any image recognition task afterward. During each training iteration, we sample a batch of images as usual. For each image, we create two versions by applying data augmentation techniques like cropping, Gaussian noise, blurring, etc. An example of such is shown on the left with the image of the dog. We will go into the details and effects of the chosen augmentation techniques later. On those images, we apply a CNN like ResNet and obtain as output a 1D feature vector on which we apply a small MLP. The output features of the two augmented images are then trained to be close to each other, while all other images in that batch should be as different as possible. This way, the model has to learn to recognize the content of the image that remains unchanged under the data augmentations, such as objects which we usually care about in supervised tasks.

We will now implement this framework ourselves and discuss further details along the way. Let’s first start with importing our standard libraries below:

[2]:
import os
import urllib.request
from copy import deepcopy
from urllib.error import HTTPError

import matplotlib
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from IPython.display import set_matplotlib_formats
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torchvision import transforms
from torchvision.datasets import STL10
from tqdm.notebook import tqdm

plt.set_cmap("cividis")
%matplotlib inline
set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.set()

# Import tensorboard
%load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/ContrastiveLearning/")
# In this notebook, we use data loaders with heavier computational processing. It is recommended to use as many
# workers as possible in a data loader, which corresponds to the number of CPU cores
NUM_WORKERS = os.cpu_count()

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
print("Number of workers:", NUM_WORKERS)
/usr/lib/python3.8/site-packages/apex/pyprof/__init__.py:5: FutureWarning: pyprof will be removed by the end of June, 2022
  warnings.warn("pyprof will be removed by the end of June, 2022", FutureWarning)
WARNING:root:Bagua cannot detect bundled NCCL library, Bagua will try to use system NCCL instead. If you encounter any error, please run `import bagua_core; bagua_core.install_deps()` or the `bagua_install_deps.py` script to install bundled libraries.
/tmp/ipykernel_4616/1782126012.py:24: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
  set_matplotlib_formats("svg", "pdf")  # For export
Global seed set to 42
Device: cuda:0
Number of workers: 12

As in many tutorials before, we provide pre-trained models. Note that those models are slightly larger as normal (~100MB overall) since we use the default ResNet-18 architecture. If you are running this notebook locally, make sure to have sufficient disk space available.

[3]:
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/"
# Files to download
pretrained_files = [
    "SimCLR.ckpt",
    "ResNet.ckpt",
    "tensorboards/SimCLR/events.out.tfevents.SimCLR",
    "tensorboards/classification/ResNet/events.out.tfevents.ResNet",
]
pretrained_files += [f"LogisticRegression_{size}.ckpt" for size in [10, 20, 50, 100, 200, 500]]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n",
                e,
            )
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/SimCLR.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/ResNet.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/tensorboards/SimCLR/events.out.tfevents.SimCLR...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/tensorboards/classification/ResNet/events.out.tfevents.ResNet...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/LogisticRegression_10.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/LogisticRegression_20.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/LogisticRegression_50.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/LogisticRegression_100.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/LogisticRegression_200.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/LogisticRegression_500.ckpt...

SimCLR

We will start our exploration of contrastive learning by discussing the effect of different data augmentation techniques, and how we can implement an efficient data loader for such. Next, we implement SimCLR with PyTorch Lightning, and finally train it on a large, unlabeled dataset.

Data Augmentation for Contrastive Learning

To allow efficient training, we need to prepare the data loading such that we sample two different, random augmentations for each image in the batch. The easiest way to do this is by creating a transformation that, when being called, applies a set of data augmentations to an image twice. This is implemented in the class ContrastiveTransformations below:

[4]:
class ContrastiveTransformations:
    def __init__(self, base_transforms, n_views=2):
        self.base_transforms = base_transforms
        self.n_views = n_views

    def __call__(self, x):
        return [self.base_transforms(x) for i in range(self.n_views)]

The contrastive learning framework can easily be extended to have more positive examples by sampling more than two augmentations of the same image. However, the most efficient training is usually obtained by using only two.

Next, we can look at the specific augmentations we want to apply. The choice of the data augmentation to use is the most crucial hyperparameter in SimCLR since it directly affects how the latent space is structured, and what patterns might be learned from the data. Let’s first take a look at some of the most popular data augmentations (figure credit - Ting Chen and Geoffrey Hinton):

b62d055e317c4ea1a6afe507d5bba3cf

All of them can be used, but it turns out that two augmentations stand out in their importance: crop-and-resize, and color distortion. Interestingly, however, they only lead to strong performance if they have been used together as discussed by Ting Chen et al.  in their SimCLR paper. When performing randomly cropping and resizing, we can distinguish between two situations: (a) cropped image A provides a local view of cropped image B, or (b) cropped images C and D show neighboring views of the same image (figure credit - Ting Chen and Geoffrey Hinton).

a56441648eac41a98de45a603762a66d

While situation (a) requires the model to learn some sort of scale invariance to make crops A and B similar in latent space, situation (b) is more challenging since the model needs to recognize an object beyond its limited view. However, without color distortion, there is a loophole that the model can exploit, namely that different crops of the same image usually look very similar in color space. Consider the picture of the dog above. Simply from the color of the fur and the green color tone of the background, you can reason that two patches belong to the same image without actually recognizing the dog in the picture. In this case, the model might end up focusing only on the color histograms of the images, and ignore other more generalizable features. If, however, we distort the colors in the two patches randomly and independently of each other, the model cannot rely on this simple feature anymore. Hence, by combining random cropping and color distortions, the model can only match two patches by learning generalizable representations.

Overall, for our experiments, we apply a set of 5 transformations following the original SimCLR setup: random horizontal flip, crop-and-resize, color distortion, random grayscale, and gaussian blur. In comparison to the original implementation, we reduce the effect of the color jitter slightly (0.5 instead of 0.8 for brightness, contrast, and saturation, and 0.1 instead of 0.2 for hue). In our experiments, this setting obtained better performance and was faster and more stable to train. If, for instance, the brightness scale highly varies in a dataset, the original settings can be more beneficial since the model can’t rely on this information anymore to distinguish between images.

[5]:
contrast_transforms = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(size=96),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=9),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

After discussing the data augmentation techniques, we can now focus on the dataset. In this tutorial, we will use the STL10 dataset, which, similarly to CIFAR10, contains images of 10 classes: airplane, bird, car, cat, deer, dog, horse, monkey, ship, truck. However, the images have a higher resolution, namely 96\times 96 pixels, and we are only provided with 500 labeled images per class. Additionally, we have a much larger set of 100,000 unlabeled images which are similar to the training images but are sampled from a wider range of animals and vehicles. This makes the dataset ideal to showcase the benefits that self-supervised learning offers.

Luckily, the STL10 dataset is provided through torchvision. Keep in mind, however, that since this dataset is relatively large and has a considerably higher resolution than CIFAR10, it requires more disk space (~3GB) and takes a bit of time to download. For our initial discussion of self-supervised learning and SimCLR, we will create two data loaders with our contrastive transformations above: the unlabeled_data will be used to train our model via contrastive learning, and train_data_contrast will be used as a validation set in contrastive learning.

[6]:
unlabeled_data = STL10(
    root=DATASET_PATH,
    split="unlabeled",
    download=True,
    transform=ContrastiveTransformations(contrast_transforms, n_views=2),
)
train_data_contrast = STL10(
    root=DATASET_PATH,
    split="train",
    download=True,
    transform=ContrastiveTransformations(contrast_transforms, n_views=2),
)
Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to /__w/1/s/.datasets/stl10_binary.tar.gz
Extracting /__w/1/s/.datasets/stl10_binary.tar.gz to /__w/1/s/.datasets
Files already downloaded and verified

Finally, before starting with our implementation of SimCLR, let’s look at some example image pairs sampled with our augmentations:

[7]:
# Visualize some examples
pl.seed_everything(42)
NUM_IMAGES = 6
imgs = torch.stack([img for idx in range(NUM_IMAGES) for img in unlabeled_data[idx][0]], dim=0)
img_grid = torchvision.utils.make_grid(imgs, nrow=6, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(10, 5))
plt.title("Augmented image examples of the STL10 dataset")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()
Global seed set to 42
_images/notebooks_course_UvA-DL_13-contrastive-learning_15_1.svg

We see the wide variety of our data augmentation, including randomly cropping, grayscaling, gaussian blur, and color distortion. Thus, it remains a challenging task for the model to match two, independently augmented patches of the same image.

SimCLR implementation

Using the data loader pipeline above, we can now implement SimCLR. At each iteration, we get for every image x two differently augmented versions, which we refer to as \tilde{x}_i and \tilde{x}_j. Both of these images are encoded into a one-dimensional feature vector, between which we want to maximize similarity which minimizes it to all other images in the batch. The encoder network is split into two parts: a base encoder network f(\cdot), and a projection head g(\cdot). The base network is usually a deep CNN as we have seen in e.g. Tutorial 5 before, and is responsible for extracting a representation vector from the augmented data examples. In our experiments, we will use the common ResNet-18 architecture as f(\cdot), and refer to the output as f(\tilde{x}_i)=h_i. The projection head g(\cdot) maps the representation h into a space where we apply the contrastive loss, i.e., compare similarities between vectors. It is often chosen to be a small MLP with non-linearities, and for simplicity, we follow the original SimCLR paper setup by defining it as a two-layer MLP with ReLU activation in the hidden layer. Note that in the follow-up paper, SimCLRv2, the authors mention that larger/wider MLPs can boost the performance considerably. This is why we apply an MLP with four times larger hidden dimensions, but deeper MLPs showed to overfit on the given dataset. The general setup is visualized below (figure credit - Ting Chen et al. ):

c609433db19b4b75b8b16e4441d431de

After finishing the training with contrastive learning, we will remove the projection head g(\cdot), and use f(\cdot) as a pretrained feature extractor. The representations z that come out of the projection head g(\cdot) have been shown to perform worse than those of the base network f(\cdot) when finetuning the network for a new task. This is likely because the representations z are trained to become invariant to many features like the color that can be important for downstream tasks. Thus, g(\cdot) is only needed for the contrastive learning stage.

Now that the architecture is described, let’s take a closer look at how we train the model. As mentioned before, we want to maximize the similarity between the representations of the two augmented versions of the same image, i.e., z_i and z_j in the figure above, while minimizing it to all other examples in the batch. SimCLR thereby applies the InfoNCE loss, originally proposed by Aaron van den Oord et al.  for contrastive learning. In short, the InfoNCE loss compares the similarity of z_i and z_j to the similarity of z_i to any other representation in the batch by performing a softmax over the similarity values. The loss can be formally written as:

\ell_{i,j}=-\log \frac{\exp(\text{sim}(z_i,z_j)/\tau)}{\sum_{k=1}^{2N}\mathbb{1}_{[k\neq i]}\exp(\text{sim}(z_i,z_k)/\tau)}=-\text{sim}(z_i,z_j)/\tau+\log\left[\sum_{k=1}^{2N}\mathbb{1}_{[k\neq i]}\exp(\text{sim}(z_i,z_k)/\tau)\right]

The function \text{sim} is a similarity metric, and the hyperparameter \tau is called temperature determining how peaked the distribution is. Since many similarity metrics are bounded, the temperature parameter allows us to balance the influence of many dissimilar image patches versus one similar patch. The similarity metric that is used in SimCLR is cosine similarity, as defined below:

\text{sim}(z_i,z_j) = \frac{z_i^\top \cdot z_j}{||z_i||\cdot||z_j||}

The maximum cosine similarity possible is 1, while the minimum is -1. In general, we will see that the features of two different images will converge to a cosine similarity around zero since the minimum, -1, would require z_i and z_j to be in the exact opposite direction in all feature dimensions, which does not allow for great flexibility.

Finally, now that we have discussed all details, let’s implement SimCLR below as a PyTorch Lightning module:

[8]:
class SimCLR(pl.LightningModule):
    def __init__(self, hidden_dim, lr, temperature, weight_decay, max_epochs=500):
        super().__init__()
        self.save_hyperparameters()
        assert self.hparams.temperature > 0.0, "The temperature must be a positive float!"
        # Base model f(.)
        self.convnet = torchvision.models.resnet18(
            pretrained=False, num_classes=4 * hidden_dim
        )  # num_classes is the output size of the last linear layer
        # The MLP for g(.) consists of Linear->ReLU->Linear
        self.convnet.fc = nn.Sequential(
            self.convnet.fc,  # Linear(ResNet output, 4*hidden_dim)
            nn.ReLU(inplace=True),
            nn.Linear(4 * hidden_dim, hidden_dim),
        )

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.hparams.max_epochs, eta_min=self.hparams.lr / 50
        )
        return [optimizer], [lr_scheduler]

    def info_nce_loss(self, batch, mode="train"):
        imgs, _ = batch
        imgs = torch.cat(imgs, dim=0)

        # Encode all images
        feats = self.convnet(imgs)
        # Calculate cosine similarity
        cos_sim = F.cosine_similarity(feats[:, None, :], feats[None, :, :], dim=-1)
        # Mask out cosine similarity to itself
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
        cos_sim.masked_fill_(self_mask, -9e15)
        # Find positive example -> batch_size//2 away from the original example
        pos_mask = self_mask.roll(shifts=cos_sim.shape[0] // 2, dims=0)
        # InfoNCE loss
        cos_sim = cos_sim / self.hparams.temperature
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
        nll = nll.mean()

        # Logging loss
        self.log(mode + "_loss", nll)
        # Get ranking position of positive example
        comb_sim = torch.cat(
            [cos_sim[pos_mask][:, None], cos_sim.masked_fill(pos_mask, -9e15)],  # First position positive example
            dim=-1,
        )
        sim_argsort = comb_sim.argsort(dim=-1, descending=True).argmin(dim=-1)
        # Logging ranking metrics
        self.log(mode + "_acc_top1", (sim_argsort == 0).float().mean())
        self.log(mode + "_acc_top5", (sim_argsort < 5).float().mean())
        self.log(mode + "_acc_mean_pos", 1 + sim_argsort.float().mean())

        return nll

    def training_step(self, batch, batch_idx):
        return self.info_nce_loss(batch, mode="train")

    def validation_step(self, batch, batch_idx):
        self.info_nce_loss(batch, mode="val")

Alternatively to performing the validation on the contrastive learning loss as well, we could also take a simple, small downstream task, and track the performance of the base network f(\cdot) on that. However, in this tutorial, we will restrict ourselves to the STL10 dataset where we use the task of image classification on STL10 as our test task.

Training

Now that we have implemented SimCLR and the data loading pipeline, we are ready to train the model. We will use the same training function setup as usual. For saving the best model checkpoint, we track the metric val_acc_top5, which describes how often the correct image patch is within the top-5 most similar examples in the batch. This is usually less noisy than the top-1 metric, making it a better metric to choose the best model from.

[9]:
def train_simclr(batch_size, max_epochs=500, **kwargs):
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, "SimCLR"),
        gpus=1 if str(device) == "cuda:0" else 0,
        max_epochs=max_epochs,
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc_top5"),
            LearningRateMonitor("epoch"),
        ],
        progress_bar_refresh_rate=1,
    )
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "SimCLR.ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        # Automatically loads the model with the saved hyperparameters
        model = SimCLR.load_from_checkpoint(pretrained_filename)
    else:
        train_loader = data.DataLoader(
            unlabeled_data,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
            pin_memory=True,
            num_workers=NUM_WORKERS,
        )
        val_loader = data.DataLoader(
            train_data_contrast,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            pin_memory=True,
            num_workers=NUM_WORKERS,
        )
        pl.seed_everything(42)  # To be reproducable
        model = SimCLR(max_epochs=max_epochs, **kwargs)
        trainer.fit(model, train_loader, val_loader)
        # Load best checkpoint after training
        model = SimCLR.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    return model

A common observation in contrastive learning is that the larger the batch size, the better the models perform. A larger batch size allows us to compare each image to more negative examples, leading to overall smoother loss gradients. However, in our case, we experienced that a batch size of 256 was sufficient to get good results.

[10]:
simclr_model = train_simclr(
    batch_size=256, hidden_dim=128, lr=5e-4, temperature=0.07, weight_decay=1e-4, max_epochs=500
)
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:96: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=1)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Found pretrained model at saved_models/ContrastiveLearning/SimCLR.ckpt, loading...

To get an intuition of how training with contrastive learning behaves, we can take a look at the TensorBoard below:

[11]:
%tensorboard --logdir ../saved_models/tutorial17/tensorboards/SimCLR/

tensorboard simclr

One thing to note is that contrastive learning benefits a lot from long training. The shown plot above is from a training that took approx. 1 day on a NVIDIA TitanRTX. Training the model for even longer might reduce its loss further, but we did not experience any gains from it for the downstream task on image classification. In general, contrastive learning can also benefit from using larger models, if sufficient unlabeled data is available.

Logistic Regression

After we have trained our model via contrastive learning, we can deploy it on downstream tasks and see how well it performs with little data. A common setup, which also verifies whether the model has learned generalized representations, is to perform Logistic Regression on the features. In other words, we learn a single, linear layer that maps the representations to a class prediction. Since the base network f(\cdot) is not changed during the training process, the model can only perform well if the representations of h describe all features that might be necessary for the task. Further, we do not have to worry too much about overfitting since we have very few parameters that are trained. Hence, we might expect that the model can perform well even with very little data.

First, let’s implement a simple Logistic Regression setup for which we assume that the images already have been encoded in their feature vectors. If very little data is available, it might be beneficial to dynamically encode the images during training so that we can also apply data augmentations. However, the way we implement it here is much more efficient and can be trained within a few seconds. Further, using data augmentations did not show any significant gain in this simple setup.

[12]:
class LogisticRegression(pl.LightningModule):
    def __init__(self, feature_dim, num_classes, lr, weight_decay, max_epochs=100):
        super().__init__()
        self.save_hyperparameters()
        # Mapping from representation h to classes
        self.model = nn.Linear(feature_dim, num_classes)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[int(self.hparams.max_epochs * 0.6), int(self.hparams.max_epochs * 0.8)], gamma=0.1
        )
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode="train"):
        feats, labels = batch
        preds = self.model(feats)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log(mode + "_loss", loss)
        self.log(mode + "_acc", acc)
        return loss

    def training_step(self, batch, batch_idx):
        return self._calculate_loss(batch, mode="train")

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="test")

The data we use is the training and test set of STL10. The training contains 500 images per class, while the test set has 800 images per class.

[13]:
img_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_img_data = STL10(root=DATASET_PATH, split="train", download=True, transform=img_transforms)
test_img_data = STL10(root=DATASET_PATH, split="test", download=True, transform=img_transforms)

print("Number of training examples:", len(train_img_data))
print("Number of test examples:", len(test_img_data))
Files already downloaded and verified
Files already downloaded and verified
Number of training examples: 5000
Number of test examples: 8000

Next, we implement a small function to encode all images in our datasets. The output representations are then used as inputs to the Logistic Regression model.

[14]:
@torch.no_grad()
def prepare_data_features(model, dataset):
    # Prepare model
    network = deepcopy(model.convnet)
    network.fc = nn.Identity()  # Removing projection head g(.)
    network.eval()
    network.to(device)

    # Encode all images
    data_loader = data.DataLoader(dataset, batch_size=64, num_workers=NUM_WORKERS, shuffle=False, drop_last=False)
    feats, labels = [], []
    for batch_imgs, batch_labels in tqdm(data_loader):
        batch_imgs = batch_imgs.to(device)
        batch_feats = network(batch_imgs)
        feats.append(batch_feats.detach().cpu())
        labels.append(batch_labels)

    feats = torch.cat(feats, dim=0)
    labels = torch.cat(labels, dim=0)

    # Sort images by labels
    labels, idxs = labels.sort()
    feats = feats[idxs]

    return data.TensorDataset(feats, labels)

Let’s apply the function to both training and test set below.

[15]:
train_feats_simclr = prepare_data_features(simclr_model, train_img_data)
test_feats_simclr = prepare_data_features(simclr_model, test_img_data)

Finally, we can write a training function as usual. We evaluate the model on the test set every 10 epochs to allow early stopping, but the low frequency of the validation ensures that we do not overfit too much on the test set.

[16]:
def train_logreg(batch_size, train_feats_data, test_feats_data, model_suffix, max_epochs=100, **kwargs):
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, "LogisticRegression"),
        gpus=1 if str(device) == "cuda:0" else 0,
        max_epochs=max_epochs,
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
            LearningRateMonitor("epoch"),
        ],
        progress_bar_refresh_rate=0,
        check_val_every_n_epoch=10,
    )
    trainer.logger._default_hp_metric = None

    # Data loaders
    train_loader = data.DataLoader(
        train_feats_data, batch_size=batch_size, shuffle=True, drop_last=False, pin_memory=True, num_workers=0
    )
    test_loader = data.DataLoader(
        test_feats_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=0
    )

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, f"LogisticRegression_{model_suffix}.ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        model = LogisticRegression.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)  # To be reproducable
        model = LogisticRegression(**kwargs)
        trainer.fit(model, train_loader, test_loader)
        model = LogisticRegression.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    # Test best model on train and validation set
    train_result = trainer.test(model, dataloaders=train_loader, verbose=False)
    test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
    result = {"train": train_result[0]["test_acc"], "test": test_result[0]["test_acc"]}

    return model, result

Despite the training dataset of STL10 already only having 500 labeled images per class, we will perform experiments with even smaller datasets. Specifically, we train a Logistic Regression model for datasets with only 10, 20, 50, 100, 200, and all 500 examples per class. This gives us an intuition on how well the representations learned by contrastive learning can be transfered to a image recognition task like this classification. First, let’s define a function to create the intended sub-datasets from the full training set:

[17]:
def get_smaller_dataset(original_dataset, num_imgs_per_label):
    new_dataset = data.TensorDataset(
        *(t.unflatten(0, (10, 500))[:, :num_imgs_per_label].flatten(0, 1) for t in original_dataset.tensors)
    )
    return new_dataset

Next, let’s run all models. Despite us training 6 models, this cell could be run within a minute or two without the pretrained models.

[18]:
results = {}
for num_imgs_per_label in [10, 20, 50, 100, 200, 500]:
    sub_train_set = get_smaller_dataset(train_feats_simclr, num_imgs_per_label)
    _, small_set_results = train_logreg(
        batch_size=64,
        train_feats_data=sub_train_set,
        test_feats_data=test_feats_simclr,
        model_suffix=num_imgs_per_label,
        feature_dim=train_feats_simclr.tensors[0].shape[1],
        num_classes=10,
        lr=1e-3,
        weight_decay=1e-3,
    )
    results[num_imgs_per_label] = small_set_results
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:96: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=0)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/ContrastiveLearning/LogisticRegression/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:486: PossibleUserWarning: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test/predict dataloaders.
  rank_zero_warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model at saved_models/ContrastiveLearning/LogisticRegression_10.ckpt, loading...
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model at saved_models/ContrastiveLearning/LogisticRegression_20.ckpt, loading...
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model at saved_models/ContrastiveLearning/LogisticRegression_50.ckpt, loading...
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model at saved_models/ContrastiveLearning/LogisticRegression_100.ckpt, loading...
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model at saved_models/ContrastiveLearning/LogisticRegression_200.ckpt, loading...
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Found pretrained model at saved_models/ContrastiveLearning/LogisticRegression_500.ckpt, loading...

Finally, let’s plot the results.

[19]:
dataset_sizes = sorted(k for k in results)
test_scores = [results[k]["test"] for k in dataset_sizes]

fig = plt.figure(figsize=(6, 4))
plt.plot(
    dataset_sizes,
    test_scores,
    "--",
    color="#000",
    marker="*",
    markeredgecolor="#000",
    markerfacecolor="y",
    markersize=16,
)
plt.xscale("log")
plt.xticks(dataset_sizes, labels=dataset_sizes)
plt.title("STL10 classification over dataset size", fontsize=14)
plt.xlabel("Number of images per class")
plt.ylabel("Test accuracy")
plt.minorticks_off()
plt.show()

for k, score in zip(dataset_sizes, test_scores):
    print(f"Test accuracy for {k:3d} images per label: {100*score:4.2f}%")
_images/notebooks_course_UvA-DL_13-contrastive-learning_42_0.svg
Test accuracy for  10 images per label: 62.79%
Test accuracy for  20 images per label: 68.60%
Test accuracy for  50 images per label: 74.44%
Test accuracy for 100 images per label: 77.20%
Test accuracy for 200 images per label: 79.06%
Test accuracy for 500 images per label: 81.33%

As one would expect, the classification performance improves the more data we have. However, with only 10 images per class, we can already classify more than 60% of the images correctly. This is quite impressive, considering that the images are also higher dimensional than e.g. CIFAR10. With the full dataset, we achieve an accuracy of 81%. The increase between 50 to 500 images per class might suggest a linear increase in performance with an exponentially larger dataset. However, with even more data, we could also finetune f(\cdot) in the training process, allowing for the representations to adapt more to the specific classification task given.

To set the results above into perspective, we will train the base network, a ResNet-18, on the classification task from scratch.

Baseline

As a baseline to our results above, we will train a standard ResNet-18 with random initialization on the labeled training set of STL10. The results will give us an indication of the advantages that contrastive learning on unlabeled data has compared to using only supervised training. The implementation of the model is straightforward since the ResNet architecture is provided in the torchvision library.

[20]:
class ResNet(pl.LightningModule):
    def __init__(self, num_classes, lr, weight_decay, max_epochs=100):
        super().__init__()
        self.save_hyperparameters()
        self.model = torchvision.models.resnet18(pretrained=False, num_classes=num_classes)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[int(self.hparams.max_epochs * 0.7), int(self.hparams.max_epochs * 0.9)], gamma=0.1
        )
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode="train"):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log(mode + "_loss", loss)
        self.log(mode + "_acc", acc)
        return loss

    def training_step(self, batch, batch_idx):
        return self._calculate_loss(batch, mode="train")

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="test")

It is clear that the ResNet easily overfits on the training data since its parameter count is more than 1000 times larger than the dataset size. To make the comparison to the contrastive learning models fair, we apply data augmentations similar to the ones we used before: horizontal flip, crop-and-resize, grayscale, and gaussian blur. Color distortions as before are not used because the color distribution of an image showed to be an important feature for the classification. Hence, we observed no noticeable performance gains when adding color distortions to the set of augmentations. Similarly, we restrict the resizing operation before cropping to the max. 125% of its original resolution, instead of 1250% as done in SimCLR. This is because, for classification, the model needs to recognize the full object, while in contrastive learning, we only want to check whether two patches belong to the same image/object. Hence, the chosen augmentations below are overall weaker than in the contrastive learning case.

[21]:
train_transforms = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(size=96, scale=(0.8, 1.0)),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 0.5)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

train_img_aug_data = STL10(root=DATASET_PATH, split="train", download=True, transform=train_transforms)
Files already downloaded and verified

The training function for the ResNet is almost identical to the Logistic Regression setup. Note that we allow the ResNet to perform validation every 2 epochs to also check whether the model overfits strongly in the first iterations or not.

[22]:
def train_resnet(batch_size, max_epochs=100, **kwargs):
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, "ResNet"),
        gpus=1 if str(device) == "cuda:0" else 0,
        max_epochs=max_epochs,
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
            LearningRateMonitor("epoch"),
        ],
        progress_bar_refresh_rate=1,
        check_val_every_n_epoch=2,
    )
    trainer.logger._default_hp_metric = None

    # Data loaders
    train_loader = data.DataLoader(
        train_img_aug_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        num_workers=NUM_WORKERS,
    )
    test_loader = data.DataLoader(
        test_img_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=NUM_WORKERS
    )

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "ResNet.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model at %s, loading..." % pretrained_filename)
        model = ResNet.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)  # To be reproducable
        model = ResNet(**kwargs)
        trainer.fit(model, train_loader, test_loader)
        model = ResNet.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    # Test best model on validation set
    train_result = trainer.test(model, dataloaders=train_loader, verbose=False)
    val_result = trainer.test(model, dataloaders=test_loader, verbose=False)
    result = {"train": train_result[0]["test_acc"], "test": val_result[0]["test_acc"]}

    return model, result

Finally, let’s train the model and check its results:

[23]:
resnet_model, resnet_result = train_resnet(batch_size=64, num_classes=10, lr=1e-3, weight_decay=2e-4, max_epochs=100)
print(f"Accuracy on training set: {100*resnet_result['train']:4.2f}%")
print(f"Accuracy on test set: {100*resnet_result['test']:4.2f}%")
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:96: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=1)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/ContrastiveLearning/ResNet/lightning_logs
Found pretrained model at saved_models/ContrastiveLearning/ResNet.ckpt, loading...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:486: PossibleUserWarning: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test/predict dataloaders.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Accuracy on training set: 99.70%
Accuracy on test set: 73.31%

The ResNet trained from scratch achieves 73.31% on the test set. This is almost 8% less than the contrastive learning model, and even slightly less than SimCLR achieves with 1/10 of the data. This shows that self-supervised, contrastive learning provides considerable performance gains by leveraging large amounts of unlabeled data when little labeled data is available.

Conclusion

In this tutorial, we have discussed self-supervised contrastive learning and implemented SimCLR as an example method. We have applied it to the STL10 dataset and showed that it can learn generalizable representations that we can use to train simple classification models. With 500 images per label, it achieved an 8% higher accuracy than a similar model solely trained from supervision and performs on par with it when only using a tenth of the labeled data. Our experimental results are limited to a single dataset, but recent works such as Ting Chen et al.  showed similar trends for larger datasets like ImageNet. Besides the discussed hyperparameters, the size of the model seems to be important in contrastive learning as well. If a lot of unlabeled data is available, larger models can achieve much stronger results and come close to their supervised baselines. Further, there are also approaches for combining contrastive and supervised learning, leading to performance gains beyond supervision (see Khosla et al.). Moreover, contrastive learning is not the only approach to self-supervised learning that has come up in the last two years and showed great results. Other methods include distillation-based methods like BYOL and redundancy reduction techniques like Barlow Twins. There is a lot more to explore in the self-supervised domain, and more, impressive steps ahead are to be expected.

References

[1] Chen, T., Kornblith, S., Norouzi, M., and Hinton, G. (2020). A simple framework for contrastive learning of visual representations. In International conference on machine learning (pp. 1597-1607). PMLR. (link)

[2] Chen, T., Kornblith, S., Swersky, K., Norouzi, M., and Hinton, G. (2020). Big self-supervised models are strong semi-supervised learners. NeurIPS 2021 (link).

[3] Oord, A. V. D., Li, Y., and Vinyals, O. (2018). Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748. (link)

[4] Grill, J.B., Strub, F., Altché, F., Tallec, C., Richemond, P.H., Buchatskaya, E., Doersch, C., Pires, B.A., Guo, Z.D., Azar, M.G. and Piot, B. (2020). Bootstrap your own latent: A new approach to self-supervised learning. arXiv preprint arXiv:2006.07733. (link)

[5] Khosla, P., Teterwak, P., Wang, C., Sarna, A., Tian, Y., Isola, P., Maschinot, A., Liu, C. and Krishnan, D. (2020). Supervised contrastive learning. arXiv preprint arXiv:2004.11362. (link)

[6] Zbontar, J., Jing, L., Misra, I., LeCun, Y. and Deny, S. (2021). Barlow twins: Self-supervised learning via redundancy reduction. arXiv preprint arXiv:2103.03230. (link)

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

GPU and batched data augmentation with Kornia and PyTorch-Lightning

  • Author: PL/Kornia team

  • License: CC BY-SA

  • Generated: 2023-01-03T14:46:27.309679

In this tutorial we will show how to combine both Kornia and PyTorch Lightning to perform efficient data augmentation to train a simple model using the GPU in batch mode without additional effort.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "pytorch-lightning>=1.4" "pytorch-lightning !=1.8.0, !=1.8.0.post1" "torchmetrics>=0.7" "matplotlib" "torchmetrics" "seaborn" "ipython[notebook]" "kornia" "torch>=1.8" "pandas" "setuptools==59.5.0" "torchvision"
[2]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
import torch
import torch.nn as nn
import torchmetrics
import torchvision
from IPython.core.display import display
from kornia import image_to_tensor, tensor_to_image
from kornia.augmentation import ColorJitter, RandomChannelShuffle, RandomHorizontalFlip, RandomThinPlateSpline
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torch import Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

sn.set()
/tmp/ipykernel_2888/1380430141.py:11: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display
  from IPython.core.display import display

Define Data Augmentations module

Kornia is low level Computer Vision library that provides a dedicated module `kornia.augmentation <https://kornia.readthedocs.io/en/latest/augmentation.html>`__ module implementing en extensive set of data augmentation techniques for image and video.

Similar to Lightning, in Kornia it’s promoted to encapsulate functionalities inside classes for readability and efficiency purposes. In this case, we define a data augmentaton pipeline subclassing a nn.Module where the augmentation_kornia (also subclassing nn.Module) are combined with other PyTorch components such as nn.Sequential.

Checkout the different augmentation operators in Kornia docs and experiment yourself!

[3]:
class DataAugmentation(nn.Module):
    """Module to perform data augmentation using Kornia on torch tensors."""

    def __init__(self, apply_color_jitter: bool = False) -> None:
        super().__init__()
        self._apply_color_jitter = apply_color_jitter

        self.transforms = nn.Sequential(
            RandomHorizontalFlip(p=0.75),
            RandomChannelShuffle(p=0.75),
            RandomThinPlateSpline(p=0.75),
        )

        self.jitter = ColorJitter(0.5, 0.5, 0.5, 0.5)

    @torch.no_grad()  # disable gradients for effiency
    def forward(self, x: Tensor) -> Tensor:
        x_out = self.transforms(x)  # BxCxHxW
        if self._apply_color_jitter:
            x_out = self.jitter(x_out)
        return x_out

Define a Pre-processing module

In addition to the DataAugmentation modudle that will sample random parameters during the training stage, we define a Preprocess class to handle the conversion of the image type to properly work with Tensor.

For this example we use torchvision CIFAR10 which return samples of PIL.Image, however, to take all the advantages of PyTorch and Kornia we need to cast the images into tensors.

To do that we will use kornia.image_to_tensor which casts and permutes the images in the right format.

[4]:
class Preprocess(nn.Module):
    """Module to perform pre-process using Kornia on torch tensors."""

    @torch.no_grad()  # disable gradients for effiency
    def forward(self, x) -> Tensor:
        x_tmp: np.ndarray = np.array(x)  # HxWxC
        x_out: Tensor = image_to_tensor(x_tmp, keepdim=True)  # CxHxW
        return x_out.float() / 255.0

Define PyTorch Lightning model

The next step is to define our LightningModule to have a proper organisation of our training pipeline. This is a simple example just to show how to structure your baseline to be used as a reference, do not expect a high performance.

Notice that the Preprocess class is injected into the dataset and will be applied per sample.

The interesting part in the proposed approach happens inside the training_step where with just a single line of code we apply the data augmentation in batch and no need to worry about the device. This means that our DataAugmentation pipeline will automatically executed in the GPU.

[5]:
class CoolSystem(LightningModule):
    def __init__(self):
        super().__init__()
        # not the best model: expereiment yourself
        self.model = torchvision.models.resnet18(pretrained=True)
        self.preprocess = Preprocess()  # per sample transforms
        self.transform = DataAugmentation()  # per batch augmentation_kornia
        self.train_accuracy = torchmetrics.Accuracy()
        self.val_accuracy = torchmetrics.Accuracy()

    def forward(self, x):
        return self.model(x)

    def compute_loss(self, y_hat, y):
        return F.cross_entropy(y_hat, y)

    def show_batch(self, win_size=(10, 10)):
        def _to_vis(data):
            return tensor_to_image(torchvision.utils.make_grid(data, nrow=8))

        # get a batch from the training set: try with `val_datlaoader` :)
        imgs, labels = next(iter(self.train_dataloader()))
        imgs_aug = self.transform(imgs)  # apply transforms
        # use matplotlib to visualize
        plt.figure(figsize=win_size)
        plt.imshow(_to_vis(imgs))
        plt.figure(figsize=win_size)
        plt.imshow(_to_vis(imgs_aug))

    def on_after_batch_transfer(self, batch, dataloader_idx):
        x, y = batch
        if self.trainer.training:
            x = self.transform(x)  # => we perform GPU/Batched data augmentation
        return x, y

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.compute_loss(y_hat, y)
        self.train_accuracy.update(y_hat, y)
        self.log("train_loss", loss, prog_bar=False)
        self.log("train_acc", self.train_accuracy, prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.compute_loss(y_hat, y)
        self.val_accuracy.update(y_hat, y)
        self.log("valid_loss", loss, prog_bar=False)
        self.log("valid_acc", self.val_accuracy, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.trainer.max_epochs, 0)
        return [optimizer], [scheduler]

    def prepare_data(self):
        CIFAR10(os.getcwd(), train=True, download=True, transform=self.preprocess)
        CIFAR10(os.getcwd(), train=False, download=True, transform=self.preprocess)

    def train_dataloader(self):
        dataset = CIFAR10(os.getcwd(), train=True, download=True, transform=self.preprocess)
        loader = DataLoader(dataset, batch_size=32)
        return loader

    def val_dataloader(self):
        dataset = CIFAR10(os.getcwd(), train=False, download=True, transform=self.preprocess)
        loader = DataLoader(dataset, batch_size=32)
        return loader

Visualize images

[6]:
# init model
model = CoolSystem()
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/AzDevOps_azpcontainer/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
[7]:
model.show_batch(win_size=(14, 14))
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /__w/1/s/cifar-10-python.tar.gz
Extracting /__w/1/s/cifar-10-python.tar.gz to /__w/1/s
_images/notebooks_lightning_examples_augmentation_kornia_12_3.png
_images/notebooks_lightning_examples_augmentation_kornia_12_4.png

Run training

[8]:
# Initialize a trainer
trainer = Trainer(
    callbacks=[TQDMProgressBar(refresh_rate=20)],
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    max_epochs=10,
    logger=CSVLogger(save_dir="logs/"),
)

# Train the model ⚡
trainer.fit(model)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Files already downloaded and verified
Files already downloaded and verified
Missing logger folder: logs/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name           | Type             | Params
----------------------------------------------------
0 | model          | ResNet           | 11.7 M
1 | preprocess     | Preprocess       | 0
2 | transform      | DataAugmentation | 0
3 | train_accuracy | Accuracy         | 0
4 | val_accuracy   | Accuracy         | 0
----------------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.758    Total estimated model params size (MB)
Files already downloaded and verified
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Files already downloaded and verified
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=10` reached.
Visualize the training results
[9]:
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")
del metrics["step"]
metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())
sn.relplot(data=metrics, kind="line")
train_loss valid_loss valid_acc
epoch
0 4.545157 NaN NaN
0 2.842391 NaN NaN
0 2.503857 NaN NaN
0 2.339133 NaN NaN
0 2.238966 NaN NaN
[9]:
<seaborn.axisgrid.FacetGrid at 0x7fa60c10f6a0>
_images/notebooks_lightning_examples_augmentation_kornia_16_2.png

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Barlow Twins Tutorial

This notebook describes the self-supervised learning method Barlow Twins. Barlow Twins differs from other recently proposed algorithms as it doesn’t fall under the category of either contrastive learning, or methods like knowledge distillation or clustering. The simplicity of the loss function and its effectiveness in comparison to the current state of the art makes Barlow Twins an interesting case study.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "setuptools==59.5.0" "matplotlib" "ipython[notebook]" "torch>=1.8" "torchvision" "torchmetrics>=0.7" "pytorch-lightning>=1.4"

Barlow Twins

Barlow Twins finds itself in unique place amongst the current state-of-the-art self-supervised learning methods. It does not fall under the existing categories of contrastive learning, knowledge distillation or clustering based methods. Instead, it creates its own category of redundancy reductionand achieves competitive performance with a simple yet effective loss function. In this tutorial, we look at coding up a small version of Barlow Twins algorithm using PyTorch Lightning.

[2]:
from functools import partial
from typing import Sequence, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as VisionF
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torch import Tensor
from torch.utils.data import DataLoader
from torchmetrics.functional import accuracy
from torchvision.datasets import CIFAR10
from torchvision.models.resnet import resnet18
from torchvision.utils import make_grid

batch_size = 32
num_workers = 0  # to run notebook on CPU
max_epochs = 200
z_dim = 128
Transforms

We first define the data augmentation pipeline used in Barlow Twins. Here, we use pipeline proposed in SimCLR, which generates two copies/views of an input image by applying the following transformations in a sequence.

First it takes a random crop of the image and resizes it to a fixed pre-specified size. Then, it applies a left-to-right random flip with a probability of 0.5. This step is followed by a composition of color jitter, conversion to grayscale with a probability of 0.2 and the application of a Gaussian blur filter. Finally, we normalize the image and convert it to a tensor.

Within this transform, we add a third view for our online finetuner, which we explain later on. But, to explain things quickly here, we add a another transform to perform perform test our encoder on a downstream classification task.

[3]:
class BarlowTwinsTransform:
    def __init__(self, train=True, input_height=224, gaussian_blur=True, jitter_strength=1.0, normalize=None):
        self.input_height = input_height
        self.gaussian_blur = gaussian_blur
        self.jitter_strength = jitter_strength
        self.normalize = normalize
        self.train = train

        color_jitter = transforms.ColorJitter(
            0.8 * self.jitter_strength,
            0.8 * self.jitter_strength,
            0.8 * self.jitter_strength,
            0.2 * self.jitter_strength,
        )

        color_transform = [transforms.RandomApply([color_jitter], p=0.8), transforms.RandomGrayscale(p=0.2)]

        if self.gaussian_blur:
            kernel_size = int(0.1 * self.input_height)
            if kernel_size % 2 == 0:
                kernel_size += 1

            color_transform.append(transforms.RandomApply([transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5))

        self.color_transform = transforms.Compose(color_transform)

        if normalize is None:
            self.final_transform = transforms.ToTensor()
        else:
            self.final_transform = transforms.Compose([transforms.ToTensor(), normalize])

        self.transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(self.input_height),
                transforms.RandomHorizontalFlip(p=0.5),
                self.color_transform,
                self.final_transform,
            ]
        )

        self.finetune_transform = None
        if self.train:
            self.finetune_transform = transforms.Compose(
                [
                    transforms.RandomCrop(32, padding=4, padding_mode="reflect"),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ]
            )
        else:
            self.finetune_transform = transforms.ToTensor()

    def __call__(self, sample):
        return self.transform(sample), self.transform(sample), self.finetune_transform(sample)
Dataset

We select CIFAR10 as the dataset to demonstrate the pre-training process for Barlow Twins. CIFAR10 images are 32x32 in size and we do not apply a Gaussian blur transformation on them. In this step, we create the training and validation dataloaders for CIFAR10.

[4]:
def cifar10_normalization():
    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
    )
    return normalize


train_transform = BarlowTwinsTransform(
    train=True, input_height=32, gaussian_blur=False, jitter_strength=0.5, normalize=cifar10_normalization()
)
train_dataset = CIFAR10(root=".", train=True, download=True, transform=train_transform)

val_transform = BarlowTwinsTransform(
    train=False, input_height=32, gaussian_blur=False, jitter_strength=0.5, normalize=cifar10_normalization()
)
val_dataset = CIFAR10(root=".", train=False, download=True, transform=train_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz
Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified
Plot images

To see how the CIFAR10 images look after the data augmentation pipeline, we load a few images from the dataloader and plot them here.

[5]:
for batch in val_loader:
    (img1, img2, _), label = batch
    break

img_grid = make_grid(img1, normalize=True)


def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = VisionF.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])


show(img_grid)
_images/notebooks_lightning_examples_barlow-twins_10_0.png
Barlow Twins Loss

Here we define the loss function for Barlow Twins. It first normalizes the D dimensinonal vectors from the projection head and then computes the DxD cross-correlation matrix between the normalized vectors of the 2 views of each image.

Then it splits this cross-correlation matrix into two parts. The first part, the diagonal of this matrix is brought closer to 1, which pushes up the cosine similarity between the latent vectors of two views of each image, thus making the backbone invariant to the transformations applied to the views. The second part of the loss pushes the non-diagonal elements of the cross-corrlelation matrix closes to 0. This reduces the redundancy between the different dimensions of the latent vector.

[6]:
class BarlowTwinsLoss(nn.Module):
    def __init__(self, batch_size, lambda_coeff=5e-3, z_dim=128):
        super().__init__()

        self.z_dim = z_dim
        self.batch_size = batch_size
        self.lambda_coeff = lambda_coeff

    def off_diagonal_ele(self, x):
        # taken from: https://github.com/facebookresearch/barlowtwins/blob/main/main.py
        # return a flattened view of the off-diagonal elements of a square matrix
        n, m = x.shape
        assert n == m
        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

    def forward(self, z1, z2):
        # N x D, where N is the batch size and D is output dim of projection head
        z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0)
        z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0)

        cross_corr = torch.matmul(z1_norm.T, z2_norm) / self.batch_size

        on_diag = torch.diagonal(cross_corr).add_(-1).pow_(2).sum()
        off_diag = self.off_diagonal_ele(cross_corr).pow_(2).sum()

        return on_diag + self.lambda_coeff * off_diag
Backbone

This is a standard Resnet backbone that we pre-train using the Barlow Twins method. To accommodate the 32x32 CIFAR10 images, we replace the first 7x7 convolution of the Resnet backbone by a 3x3 filter. We also remove the first Maxpool layer from the network for CIFAR10 images.

[7]:
encoder = resnet18()

# for CIFAR10, replace the first 7x7 conv with smaller 3x3 conv and remove the first maxpool
encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
encoder.maxpool = nn.MaxPool2d(kernel_size=1, stride=1)

# replace classification fc layer of Resnet to obtain representations from the backbone
encoder.fc = nn.Identity()
Projection head

Unlike SimCLR and BYOL, the downstream performance of Barlow Twins greatly benefits from having a larger projection head after the backbone network. The paper utilizes a 3 layer MLP with 8192 hidden dimensions and 8192 as the output dimenion of the projection head. For the purposes of the tutorial, we use a smaller projection head. But, it is imperative to mention here that in practice, Barlow Twins needs to be trained using a bigger projection head as it is highly sensitive to its architecture and output dimensionality.

[8]:
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128):
        super().__init__()

        self.projection_head = nn.Sequential(
            nn.Linear(input_dim, hidden_dim, bias=True),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim, bias=False),
        )

    def forward(self, x):
        return self.projection_head(x)
Learning rate warmup

For the purposes of this tutorial, we keep things simple and use a linear warmup schedule with Adam optimizer. In our previous experiments we have found that linear warmup part is much more important for the final performance of a model than the cosine decay component of the schedule.

[9]:
def fn(warmup_steps, step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    else:
        return 1.0


def linear_warmup_decay(warmup_steps):
    return partial(fn, warmup_steps)
Barlow Twins Lightning Module

We keep the LightningModule for Barlow Twins neat and simple. It takes in an backbone encoder and initializes the projection head and the loss function. We configure the optimizer and the learning rate scheduler in the configure_optimizers method.

[10]:
class BarlowTwins(LightningModule):
    def __init__(
        self,
        encoder,
        encoder_out_dim,
        num_training_samples,
        batch_size,
        lambda_coeff=5e-3,
        z_dim=128,
        learning_rate=1e-4,
        warmup_epochs=10,
        max_epochs=200,
    ):
        super().__init__()

        self.encoder = encoder
        self.projection_head = ProjectionHead(input_dim=encoder_out_dim, hidden_dim=encoder_out_dim, output_dim=z_dim)
        self.loss_fn = BarlowTwinsLoss(batch_size=batch_size, lambda_coeff=lambda_coeff, z_dim=z_dim)

        self.learning_rate = learning_rate
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs

        self.train_iters_per_epoch = num_training_samples // batch_size

    def forward(self, x):
        return self.encoder(x)

    def shared_step(self, batch):
        (x1, x2, _), _ = batch

        z1 = self.projection_head(self.encoder(x1))
        z2 = self.projection_head(self.encoder(x2))

        return self.loss_fn(z1, z2)

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log("train_loss", loss, on_step=True, on_epoch=False)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log("val_loss", loss, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)

        warmup_steps = self.train_iters_per_epoch * self.warmup_epochs

        scheduler = {
            "scheduler": torch.optim.lr_scheduler.LambdaLR(
                optimizer,
                linear_warmup_decay(warmup_steps),
            ),
            "interval": "step",
            "frequency": 1,
        }

        return [optimizer], [scheduler]
Evaluation

We define a callback which appends a linear layer on top of the encoder and trains the classification evaluation head in an online manner. We make sure not to backpropagate the gradients back to the encoder while tuning the linear layer. This technique was used in SimCLR as well and they showed that the final downstream classification peformance is pretty much similar to the results on online finetuning as the training progresses.

[11]:
class OnlineFineTuner(Callback):
    def __init__(
        self,
        encoder_output_dim: int,
        num_classes: int,
    ) -> None:
        super().__init__()

        self.optimizer: torch.optim.Optimizer

        self.encoder_output_dim = encoder_output_dim
        self.num_classes = num_classes

    def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        # add linear_eval layer and optimizer
        pl_module.online_finetuner = nn.Linear(self.encoder_output_dim, self.num_classes).to(pl_module.device)
        self.optimizer = torch.optim.Adam(pl_module.online_finetuner.parameters(), lr=1e-4)

    def extract_online_finetuning_view(
        self, batch: Sequence, device: Union[str, torch.device]
    ) -> Tuple[Tensor, Tensor]:
        (_, _, finetune_view), y = batch
        finetune_view = finetune_view.to(device)
        y = y.to(device)

        return finetune_view, y

    def on_train_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        x, y = self.extract_online_finetuning_view(batch, pl_module.device)

        with torch.no_grad():
            feats = pl_module(x)

        feats = feats.detach()
        preds = pl_module.online_finetuner(feats)
        loss = F.cross_entropy(preds, y)

        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        acc = accuracy(F.softmax(preds, dim=1), y)
        pl_module.log("online_train_acc", acc, on_step=True, on_epoch=False)
        pl_module.log("online_train_loss", loss, on_step=True, on_epoch=False)

    def on_validation_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        x, y = self.extract_online_finetuning_view(batch, pl_module.device)

        with torch.no_grad():
            feats = pl_module(x)

        feats = feats.detach()
        preds = pl_module.online_finetuner(feats)
        loss = F.cross_entropy(preds, y)

        acc = accuracy(F.softmax(preds, dim=1), y)
        pl_module.log("online_val_acc", acc, on_step=False, on_epoch=True, sync_dist=True)
        pl_module.log("online_val_loss", loss, on_step=False, on_epoch=True, sync_dist=True)

Finally, we define the trainer for training the model. We pass in the train_loader and val_loader we had initialized earlier to the fit function.

[12]:
encoder_out_dim = 512

model = BarlowTwins(
    encoder=encoder,
    encoder_out_dim=encoder_out_dim,
    num_training_samples=len(train_dataset),
    batch_size=batch_size,
    z_dim=z_dim,
)

online_finetuner = OnlineFineTuner(encoder_output_dim=encoder_out_dim, num_classes=10)
checkpoint_callback = ModelCheckpoint(every_n_epochs=100, save_top_k=-1, save_last=True)

trainer = Trainer(
    max_epochs=max_epochs,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    callbacks=[online_finetuner, checkpoint_callback],
)

# uncomment this to train the model
# this is done for the tutorial so that the notebook compiles
# trainer.fit(model, train_loader, val_loader)
ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None) will duplicate the last checkpoint saved.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Using the trained encoder for downstream tasks

Once the encoder is pretrained on CIFAR10, we can use it to get image embeddings and use them further downstream on tasks like classification, detection, segmentation etc.

In this tutorial, we did not completely train our encoder for 100s of epochs using the Barlow Twins pretraining method. So, we will load the pretrained encoder weights from a checkpoint and show the image embeddings obtained from that.

To create this checkpoint, the encoder was pretrained for 200 epochs, and obtained a online finetune accuracy of x% on CIFAR-10.

[13]:
# ckpt_model = torch.load('')  # upload checkpoint to aws
# encoder = ckpt_model.encoder
encoder = model.encoder

downstream_dataset = CIFAR10(root=".", train=False, transform=transforms.ToTensor())
dataloader = DataLoader(downstream_dataset, batch_size=4, shuffle=False)

for batch in dataloader:
    img, label = batch
    print(encoder(img).shape)
    break
torch.Size([4, 512])

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

PyTorch Lightning Basic GAN Tutorial

  • Author: PL team

  • License: CC BY-SA

  • Generated: 2022-08-15T09:28:43.606365

How to train a GAN!

Main takeaways: 1. Generator and discriminator are arbitrary PyTorch modules. 2. training_step does both the generator and discriminator training.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "pytorch-lightning>=1.4" "torch>=1.8" "torchvision" "ipython[notebook]" "torchmetrics>=0.7" "setuptools==59.5.0"
[2]:
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)
WARNING:root:Bagua cannot detect bundled NCCL library, Bagua will try to use system NCCL instead. If you encounter any error, please run `import bagua_core; bagua_core.install_deps()` or the `bagua_install_deps.py` script to install bundled libraries.
MNIST DataModule

Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial on them or see the latest release docs.

[3]:
class MNISTDataModule(LightningDataModule):
    def __init__(
        self,
        data_dir: str = PATH_DATASETS,
        batch_size: int = BATCH_SIZE,
        num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)
A. Generator
[4]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh(),
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img
B. Discriminator
[5]:
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity
C. GAN
A couple of cool features to check out in this example…
  • We use some_tensor.type_as(another_tensor) to make sure we initialize new tensors on the right device (i.e. GPU, CPU).

    • Lightning will put your dataloader data on the right device automatically

    • In this example, we pull from latent dim on the fly, so we need to dynamically add tensors to the right device.

    • type_as is the way we recommend to do this.

  • This example shows how to use multiple dataloaders in your LightningModule.

[6]:
class GAN(LightningModule):
    def __init__(
        self,
        channels,
        width,
        height,
        latent_dim: int = 100,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = BATCH_SIZE,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()

        # networks
        data_shape = (channels, width, height)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
        self.discriminator = Discriminator(img_shape=data_shape)

        self.validation_z = torch.randn(8, self.hparams.latent_dim)

        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, _ = batch

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)

        # train generator
        if optimizer_idx == 0:

            # generate images
            self.generated_imgs = self(z)

            # log sampled images
            sample_imgs = self.generated_imgs[:6]
            grid = torchvision.utils.make_grid(sample_imgs)
            self.logger.experiment.add_image("generated_images", grid, 0)

            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            # adversarial loss is binary cross-entropy
            g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
            self.log("g_loss", g_loss, prog_bar=True)
            return g_loss

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # how well can it label as real?
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

            # how well can it label as fake?
            fake = torch.zeros(imgs.size(0), 1)
            fake = fake.type_as(imgs)

            fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)

            # discriminator loss is the average of these
            d_loss = (real_loss + fake_loss) / 2
            self.log("d_loss", d_loss, prog_bar=True)
            return d_loss

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

    def on_validation_epoch_end(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)

        # log sampled images
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image("generated_images", grid, self.current_epoch)
[7]:
dm = MNISTDataModule()
model = GAN(*dm.dims)
trainer = Trainer(
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    max_epochs=5,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
)
trainer.fit(model, dm)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:115: UserWarning: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
Extracting /__w/1/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /__w/1/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /__w/1/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz
Missing logger folder: /__w/1/s/lightning_logs
Extracting /__w/1/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type          | Params | In sizes | Out sizes
----------------------------------------------------------------------------
0 | generator     | Generator     | 1.5 M  | [2, 100] | [2, 1, 28, 28]
1 | discriminator | Discriminator | 533 K  | ?        | ?
----------------------------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.174     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=5` reached.
[8]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

PyTorch Lightning CIFAR10 ~94% Baseline Tutorial

  • Author: PL team

  • License: CC BY-SA

  • Generated: 2022-04-28T08:05:29.967173

Train a Resnet to 94% accuracy on Cifar10!


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "pandas" "torch>=1.6, <1.9" "torchvision" "ipython[notebook]" "seaborn" "pytorch-lightning>=1.4" "torchmetrics>=0.6" "lightning-bolts"
WARNING: You are using pip version 21.3.1; however, version 22.0.4 is available.
You should consider upgrading via the '/usr/bin/python3.8 -m pip install --upgrade pip' command.
[2]:
# Run this if you intend to use TPUs
# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
[3]:
import os

import pandas as pd
import seaborn as sn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from IPython.core.display import display
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.swa_utils import AveragedModel, update_bn
from torchmetrics.functional import accuracy

seed_everything(7)

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)
Global seed set to 7
CIFAR10 Data Module

Import the existing data module from bolts and modify the train and test transforms.

[4]:

train_transforms = torchvision.transforms.Compose( [ torchvision.transforms.RandomCrop(32, padding=4), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor(), cifar10_normalization(), ] ) test_transforms = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), cifar10_normalization(), ] ) cifar10_dm = CIFAR10DataModule( data_dir=PATH_DATASETS, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, train_transforms=train_transforms, test_transforms=test_transforms, val_transforms=test_transforms, )
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:60: LightningDeprecationWarning: DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7.
  rank_zero_deprecation(
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:64: LightningDeprecationWarning: DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7.
  rank_zero_deprecation(
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:68: LightningDeprecationWarning: DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7.
  rank_zero_deprecation(
Resnet

Modify the pre-existing Resnet architecture from TorchVision. The pre-existing architecture is based on ImageNet images (224x224) as input. So we need to modify it for CIFAR10 images (32x32).

[5]:
def create_model():
    model = torchvision.models.resnet18(pretrained=False, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    return model
Lightning Module

Check out the `configure_optimizers <https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#configure-optimizers>`__ method to use custom Learning Rate schedulers. The OneCycleLR with SGD will get you to around 92-93% accuracy in 20-30 epochs and 93-94% accuracy in 40-50 epochs. Feel free to experiment with different LR schedules from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

[6]:
class LitResnet(LightningModule):
    def __init__(self, lr=0.05):
        super().__init__()

        self.save_hyperparameters()
        self.model = create_model()

    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.hparams.lr,
            momentum=0.9,
            weight_decay=5e-4,
        )
        steps_per_epoch = 45000 // BATCH_SIZE
        scheduler_dict = {
            "scheduler": OneCycleLR(
                optimizer,
                0.1,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=steps_per_epoch,
            ),
            "interval": "step",
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
[7]:
model = LitResnet(lr=0.05)

trainer = Trainer(
    max_epochs=30,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    logger=CSVLogger(save_dir="logs/"),
    callbacks=[LearningRateMonitor(logging_interval="step"), TQDMProgressBar(refresh_rate=10)],
)

trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /__w/1/s/.datasets/cifar-10-python.tar.gz
Extracting /__w/1/s/.datasets/cifar-10-python.tar.gz to /__w/1/s/.datasets
Files already downloaded and verified
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:88: LightningDeprecationWarning: DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7.
  rank_zero_deprecation(
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:107: LightningDeprecationWarning: DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7.
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.696    Total estimated model params size (MB)
Files already downloaded and verified
Files already downloaded and verified
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:126: LightningDeprecationWarning: DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7.
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_acc              0.9193999767303467     │
│         test_loss             0.28191840648651123    │
└───────────────────────────┴───────────────────────────┘
[7]:
[{'test_loss': 0.28191840648651123, 'test_acc': 0.9193999767303467}]
[8]:

metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv") del metrics["step"] metrics.set_index("epoch", inplace=True) display(metrics.dropna(axis=1, how="all").head()) sn.relplot(data=metrics, kind="line")
lr-SGD train_loss val_loss val_acc test_loss test_acc
epoch
NaN 0.004229 NaN NaN NaN NaN NaN
0.0 NaN 1.847524 NaN NaN NaN NaN
NaN 0.004934 NaN NaN NaN NaN NaN
0.0 NaN 1.724640 NaN NaN NaN NaN
NaN 0.006107 NaN NaN NaN NaN NaN
[8]:
<seaborn.axisgrid.FacetGrid at 0x7f1ac89db5e0>
_images/notebooks_lightning_examples_cifar10-baseline_12_2.png
Bonus: Use Stochastic Weight Averaging to get a boost on performance

Use SWA from torch.optim to get a quick performance boost. Also shows a couple of cool features from Lightning: - Use training_epoch_end to run code after the end of every epoch - Use a pretrained model directly with this wrapper for SWA

[9]:
class SWAResnet(LitResnet):
    def __init__(self, trained_model, lr=0.01):
        super().__init__()

        self.save_hyperparameters("lr")
        self.model = trained_model
        self.swa_model = AveragedModel(self.model)

    def forward(self, x):
        out = self.swa_model(x)
        return F.log_softmax(out, dim=1)

    def training_epoch_end(self, training_step_outputs):
        self.swa_model.update_parameters(self.model)

    def validation_step(self, batch, batch_idx, stage=None):
        x, y = batch
        logits = F.log_softmax(self.model(x), dim=1)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
        return optimizer

    def on_train_end(self):
        update_bn(self.trainer.datamodule.train_dataloader(), self.swa_model, device=self.device)
[10]:
swa_model = SWAResnet(model.model, lr=0.01)
swa_model.datamodule = cifar10_dm

swa_trainer = Trainer(
    max_epochs=20,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    callbacks=[TQDMProgressBar(refresh_rate=20)],
    logger=CSVLogger(save_dir="logs/"),
)

swa_trainer.fit(swa_model, cifar10_dm)
swa_trainer.test(swa_model, datamodule=cifar10_dm)
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py:261: UserWarning: Attribute 'trained_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['trained_model'])`.
  rank_zero_warn(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Files already downloaded and verified
Files already downloaded and verified
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:88: LightningDeprecationWarning: DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7.
  rank_zero_deprecation(
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:107: LightningDeprecationWarning: DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7.
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type          | Params
--------------------------------------------
0 | model     | ResNet        | 11.2 M
1 | swa_model | AveragedModel | 11.2 M
--------------------------------------------
22.3 M    Trainable params
0         Non-trainable params
22.3 M    Total params
89.392    Total estimated model params size (MB)
Files already downloaded and verified
Files already downloaded and verified
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:126: LightningDeprecationWarning: DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7.
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_acc              0.9204999804496765     │
│         test_loss             0.25821828842163086    │
└───────────────────────────┴───────────────────────────┘
[10]:
[{'test_loss': 0.25821828842163086, 'test_acc': 0.9204999804496765}]
[11]:

metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv") del metrics["step"] metrics.set_index("epoch", inplace=True) display(metrics.dropna(axis=1, how="all").head()) sn.relplot(data=metrics, kind="line")
lr-SGD train_loss val_loss val_acc test_loss test_acc
epoch
NaN 0.004229 NaN NaN NaN NaN NaN
0.0 NaN 1.847524 NaN NaN NaN NaN
NaN 0.004934 NaN NaN NaN NaN NaN
0.0 NaN 1.724640 NaN NaN NaN NaN
NaN 0.006107 NaN NaN NaN NaN NaN
[11]:
<seaborn.axisgrid.FacetGrid at 0x7f1ac8283e20>
_images/notebooks_lightning_examples_cifar10-baseline_16_2.png

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

PyTorch Lightning DataModules

  • Author: PL team

  • License: CC BY-SA

  • Generated: 2023-01-03T15:18:14.618737

This notebook will walk you through how to start using Datamodules. With the release of pytorch-lightning version 0.9.0, we have included a new class called LightningDataModule to help you decouple data related hooks from your LightningModule. The most up-to-date documentation on datamodules can be found here.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "ipython[notebook]" "torchmetrics>=0.7" "torchvision" "pytorch-lightning>=1.4" "setuptools==59.5.0" "torch>=1.8"

Introduction

First, we’ll go over a regular LightningModule implementation without the use of a LightningDataModule

[2]:
import os

import torch
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
from torchvision import transforms

# Note - you must have torchvision installed for this example
from torchvision.datasets import CIFAR10, MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
Defining the LitMNISTModel

Below, we reuse a LightningModule from our hello world tutorial that classifies MNIST Handwritten Digits.

Unfortunately, we have hardcoded dataset-specific items within the model, forever limiting it to working with MNIST Data. 😢

This is fine if you don’t plan on training/evaluating your model on different datasets. However, in many cases, this can become bothersome when you want to try out your architecture with different datasets.

[3]:
class LitMNIST(LightningModule):
    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        # We hardcode dataset specific stuff here.
        self.data_dir = data_dir
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Build model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=128)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=128)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=128)
Training the ListMNIST Model
[4]:
model = LitMNIST()
trainer = Trainer(
    max_epochs=2,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    callbacks=[TQDMProgressBar(refresh_rate=20)],
)
trainer.fit(model)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /__w/1/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /__w/1/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /__w/1/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz
Missing logger folder: /__w/1/s/lightning_logs
Extracting /__w/1/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55.1 K
-------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=2` reached.

Using DataModules

DataModules are a way of decoupling data-related hooks from the LightningModule so you can develop dataset agnostic models.

Defining The MNISTDataModule

Let’s go over each function in the class below and talk about what they’re doing:

  1. __init__

    • Takes in a data_dir arg that points to where you have downloaded/wish to download the MNIST dataset.

    • Defines a transform that will be applied across train, val, and test dataset splits.

    • Defines default self.dims.

  2. prepare_data

    • This is where we can download the dataset. We point to our desired dataset and ask torchvision’s MNIST dataset class to download if the dataset isn’t found there.

    • Note we do not make any state assignments in this function (i.e. self.something = ...)

  3. setup

    • Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).

    • Setup expects a ‘stage’ arg which is used to separate logic for ‘fit’ and ‘test’.

    • If you don’t mind loading all your datasets at once, you can set up a condition to allow for both ‘fit’ related setup and ‘test’ related setup to run whenever None is passed to stage.

    • Note this runs across all GPUs and it is safe to make state assignments here

  4. x_dataloader

    • train_dataloader(), val_dataloader(), and test_dataloader() all return PyTorch DataLoader instances that are created by wrapping their respective datasets that we prepared in setup()

[5]:
class MNISTDataModule(LightningDataModule):
    def __init__(self, data_dir: str = PATH_DATASETS):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
Defining the dataset agnostic LitModel

Below, we define the same model as the LitMNIST model we made earlier.

However, this time our model has the freedom to use any input data that we’d like 🔥.

[6]:
class LitModel(LightningModule):
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        # We take in input dimensions as parameters and use those to dynamically build model.
        self.channels = channels
        self.width = width
        self.height = height
        self.num_classes = num_classes
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
Training the LitModel using the MNISTDataModule

Now, we initialize and train the LitModel using the MNISTDataModule’s configuration settings and dataloaders.

[7]:
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.dims, dm.num_classes)
# Init trainer
trainer = Trainer(
    max_epochs=3,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
)
# Pass the datamodule as arg to trainer.fit to override model hooks :)
trainer.fit(model, dm)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55.1 K
-------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=3` reached.
Defining the CIFAR10 DataModule

Lets prove the LitModel we made earlier is dataset agnostic by defining a new datamodule for the CIFAR10 dataset.

[8]:
class CIFAR10DataModule(LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        self.dims = (3, 32, 32)
        self.num_classes = 10

    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=BATCH_SIZE)
Training the LitModel using the CIFAR10DataModule

Our model isn’t very good, so it will perform pretty badly on the CIFAR10 dataset.

The point here is that we can see that our LitModel has no problem using a different datamodule as its input data.

[9]:
dm = CIFAR10DataModule()
model = LitModel(*dm.dims, dm.num_classes, hidden_size=256)
tqdm_progress_bar = TQDMProgressBar(refresh_rate=20)
trainer = Trainer(
    max_epochs=5,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    callbacks=[tqdm_progress_bar],
)
trainer.fit(model, dm)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz
Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 855 K
-------------------------------------
855 K     Trainable params
0         Non-trainable params
855 K     Total params
3.420     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=5` reached.

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Fine-Tuning Scheduler

  • Author: Dan Dale

  • License: CC BY-SA

  • Generated: 2023-01-05T12:03:14.890703

This notebook introduces the Fine-Tuning Scheduler extension and demonstrates the use of it to fine-tune a small foundational model on the RTE task of SuperGLUE with iterative early-stopping defined according to a user-specified schedule. It uses Hugging Face’s datasets and transformers libraries to retrieve the relevant benchmark data and foundational model weights. The required dependencies are installed via the finetuning-scheduler [examples] extra.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "pytorch-lightning>=1.4, <1.9" "finetuning-scheduler[examples]>=0.3.0" "ipython[notebook]>=8.0.0, <8.9.0" "datasets<2.8.0" "torch>=1.8.1, <1.14.0" "setuptools==65.6.3" "torchmetrics>=0.7, <0.12"
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

Scheduled Fine-Tuning with the Fine-Tuning Scheduler Extension

Fine-Tuning Scheduler logo

The Fine-Tuning Scheduler extension accelerates and enhances model experimentation with flexible fine-tuning schedules.

Training with the extension is simple and confers a host of benefits:

  • it dramatically increases fine-tuning flexibility

  • expedites and facilitates exploration of model tuning dynamics

  • enables marginal performance improvements of fine-tuned models

Setup is straightforward, just install from PyPI! Since this notebook-based example requires a few additional packages (e.g. transformers, sentencepiece), we installed the finetuning-scheduler package with the [examples] extra above. Once the finetuning-scheduler package is installed, the FinetuningScheduler callback is available for use with PyTorch Lightning. For additional installation options, please see the Fine-Tuning Scheduler README.

Fundamentally, Fine-Tuning Scheduler enables scheduled, multi-phase, fine-tuning of foundational models. Gradual unfreezing (i.e. thawing) can help maximize foundational model knowledge retention while allowing (typically upper layers of) the model to optimally adapt to new tasks during transfer learning 1, 2, 3

The FinetuningScheduler callback orchestrates the gradual unfreezing of models via a fine-tuning schedule that is either implicitly generated (the default) or explicitly provided by the user (more computationally efficient). Fine-tuning phase transitions are driven by FTSEarlyStopping criteria (a multi-phase extension of EarlyStopping packaged with FinetuningScheduler), user-specified epoch transitions or a composition of the two (the default mode). A FinetuningScheduler training session completes when the final phase of the schedule has its stopping criteria met. See the early stopping documentation for more details on that callback’s configuration.

FinetuningScheduler explicit loss animation

Basic Usage

If no fine-tuning schedule is provided by the user, FinetuningScheduler will generate a default schedule and proceed to fine-tune according to the generated schedule, using default FTSEarlyStopping and FTSCheckpoint callbacks with monitor=val_loss.

from pytorch_lightning import Trainer
from finetuning_scheduler import FinetuningScheduler
trainer = Trainer(callbacks=[FinetuningScheduler()])

The Default Fine-Tuning Schedule

Schedule definition is facilitated via the gen_ft_schedule method which dumps a default fine-tuning schedule (by default using a naive, 2-parameters per level heuristic) which can be adjusted as desired by the user and/or subsequently passed to the callback. Using the default/implicitly generated schedule will likely be less computationally efficient than a user-defined fine-tuning schedule but is useful for exploring a model’s fine-tuning behavior and can serve as a good baseline for subsequent explicit schedule refinement. While the current version of FinetuningScheduler only supports single optimizer and (optional) lr_scheduler configurations, per-phase maximum learning rates can be set as demonstrated in the next section.

Specifying a Fine-Tuning Schedule

To specify a fine-tuning schedule, it’s convenient to first generate the default schedule and then alter the thawed/unfrozen parameter groups associated with each fine-tuning phase as desired. Fine-tuning phases are zero-indexed and executed in ascending order.

  1. First, generate the default schedule to Trainer.log_dir. It will be named after your LightningModule subclass with the suffix _ft_schedule.yaml.

from pytorch_lightning import Trainer
from finetuning_scheduler import FinetuningScheduler
trainer = Trainer(callbacks=[FinetuningScheduler(gen_ft_sched_only=True)])
  1. Alter the schedule as desired.

side_by_side_yaml

  1. Once the fine-tuning schedule has been altered as desired, pass it to FinetuningScheduler to commence scheduled training:

from pytorch_lightning import Trainer
from finetuning_scheduler import FinetuningScheduler

trainer = Trainer(callbacks=[FinetuningScheduler(ft_schedule="/path/to/my/schedule/my_schedule.yaml")])

Early-Stopping and Epoch-Driven Phase Transition Criteria

By default, FTSEarlyStopping and epoch-driven transition criteria are composed. If a max_transition_epoch is specified for a given phase, the next fine-tuning phase will begin at that epoch unless FTSEarlyStopping criteria are met first. If FinetuningScheduler.epoch_transitions_only is True, FTSEarlyStopping will not be used and transitions will be exclusively epoch-driven.

Tip: Use of regex expressions can be convenient for specifying more complex schedules. Also, a per-phase base maximum lr can be specified:

emphasized_yaml

The end-to-end example in this notebook (Scheduled Fine-Tuning For SuperGLUE) uses FinetuningScheduler in explicit mode to fine-tune a small foundational model on the RTE task of SuperGLUE. Please see the official Fine-Tuning Scheduler documentation if you are interested in a similar CLI-based example using the LightningCLI.

Resuming Scheduled Fine-Tuning Training Sessions

Resumption of scheduled fine-tuning training is identical to the continuation of other training sessions with the caveat that the provided checkpoint must have been saved by a FinetuningScheduler session. FinetuningScheduler uses FTSCheckpoint (an extension of ModelCheckpoint) to maintain schedule state with special metadata.

from pytorch_lightning import Trainer
from finetuning_scheduler import FinetuningScheduler
trainer = Trainer(callbacks=[FinetuningScheduler()])
trainer.fit(..., ckpt_path="some/path/to/my_checkpoint.ckpt")

Training will resume at the depth/level of the provided checkpoint according to the specified schedule. Schedules can be altered between training sessions but schedule compatibility is left to the user for maximal flexibility. If executing a user-defined schedule, typically the same schedule should be provided for the original and resumed training sessions.

By default (FinetuningScheduler.restore_best is True), FinetuningScheduler will attempt to restore the best available checkpoint before fine-tuning depth transitions.

trainer = Trainer(callbacks=[FinetuningScheduler()])
trainer.fit(..., ckpt_path="some/path/to/my_kth_best_checkpoint.ckpt")

Note that similar to the behavior of ModelCheckpoint, (specifically this PR), when resuming training with a different FTSCheckpoint dirpath from the provided checkpoint, the new training session’s checkpoint state will be re-initialized at the resumption depth with the provided checkpoint being set as the best checkpoint.

Note: Currently, FinetuningScheduler supports the following strategy types:

  • DP

  • DDP

  • DDP_FORK (and its aliases e.g. ddp_notebook)

  • DDP_SPAWN

  • DDP_SHARDED

  • DDP_SHARDED_SPAWN

Custom or officially unsupported strategies can be used by setting FinetuningScheduler.allow_untested to True. Note that most currently unsupported strategies are so because they require varying degrees of modification to be compatible (e.g. deepspeed requires an add_param_group method, tpu_spawn an override of the current broadcast method to include python objects)

Scheduled Fine-Tuning For SuperGLUE

The following example demonstrates the use of FinetuningScheduler to fine-tune a small foundational model on the RTE task of SuperGLUE. Iterative early-stopping will be applied according to a user-specified schedule.

[2]:
import os
import warnings
from datetime import datetime
from typing import Any, Dict, List, Optional

from packaging.version import Version

import sentencepiece as sp  # noqa: F401 # isort: split
import datasets
import evaluate
import pytorch_lightning as pl
import torch
from datasets import logging as datasets_logging
from lightning_lite.accelerators.cuda import is_cuda_available
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.utilities import rank_zero_warn
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
from transformers import logging as transformers_logging
from transformers.tokenization_utils_base import BatchEncoding

if Version(torch.__version__) == Version("1.12.0") or torch.__version__.startswith("1.12.0"):
    # we need to use a patched version of AdamW to fix https://github.com/pytorch/pytorch/issues/80809
    # and allow examples to succeed with torch 1.12.0 (this torch bug is fixed in 1.12.1)
    from fts_examples.patched_adamw import AdamW
else:
    from torch.optim.adamw import AdamW
[3]:
# Import the `FinetuningScheduler` PyTorch Lightning extension module we want to use. This will import all necessary callbacks.
import finetuning_scheduler as fts  # isort: split

# set notebook-level variables
TASK_NUM_LABELS = {"boolq": 2, "rte": 2}
DEFAULT_TASK = "rte"

# reduce hf logging verbosity to focus on tutorial-relevant code/messages
for hflogger in [transformers_logging, datasets_logging]:
    hflogger.set_verbosity_error()
# ignore warnings related tokenizers_parallelism/DataLoader parallelism trade-off and
# expected logging behavior
for warnf in [
    r".*does not have many workers.*",
    r".*The number of training samples.*",
    r".*converting to a fast.*",
    r".*number of training batches.*",
]:
    warnings.filterwarnings("ignore", warnf)
[4]:
class RteBoolqDataModule(pl.LightningDataModule):
    """A ``LightningDataModule`` designed for both the RTE or BoolQ SuperGLUE Hugging Face datasets."""

    TASK_TEXT_FIELD_MAP = {"rte": ("premise", "hypothesis"), "boolq": ("question", "passage")}
    LOADER_COLUMNS = (
        "datasets_idx",
        "input_ids",
        "token_type_ids",
        "attention_mask",
        "start_positions",
        "end_positions",
        "labels",
    )

    def __init__(
        self,
        model_name_or_path: str,
        task_name: str = DEFAULT_TASK,
        max_seq_length: int = 128,
        train_batch_size: int = 16,
        eval_batch_size: int = 16,
        tokenizers_parallelism: bool = True,
        **dataloader_kwargs: Any,
    ):
        r"""Initialize the ``LightningDataModule`` designed for both the RTE or BoolQ SuperGLUE Hugging Face
        datasets.

        Args:
            model_name_or_path (str):
                Can be either:
                    - A string, the ``model id`` of a pretrained model hosted inside a model repo on huggingface.co.
                        Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
                        a user or organization name, like ``dbmdz/bert-base-german-cased``.
                    - A path to a ``directory`` containing model weights saved using
                        :meth:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
            task_name (str, optional): Name of the SuperGLUE task to execute. This module supports 'rte' or 'boolq'.
                Defaults to DEFAULT_TASK which is 'rte'.
            max_seq_length (int, optional): Length to which we will pad sequences or truncate input. Defaults to 128.
            train_batch_size (int, optional): Training batch size. Defaults to 16.
            eval_batch_size (int, optional): Batch size to use for validation and testing splits. Defaults to 16.
            tokenizers_parallelism (bool, optional): Whether to use parallelism in the tokenizer. Defaults to True.
            \**dataloader_kwargs: Arguments passed when initializing the dataloader
        """
        super().__init__()
        task_name = task_name if task_name in TASK_NUM_LABELS.keys() else DEFAULT_TASK
        self.text_fields = self.TASK_TEXT_FIELD_MAP[task_name]
        self.dataloader_kwargs = {
            "num_workers": dataloader_kwargs.get("num_workers", 0),
            "pin_memory": dataloader_kwargs.get("pin_memory", False),
        }
        self.save_hyperparameters()
        os.environ["TOKENIZERS_PARALLELISM"] = "true" if self.hparams.tokenizers_parallelism else "false"
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.hparams.model_name_or_path, use_fast=True, local_files_only=False
        )

    def prepare_data(self):
        """Load the SuperGLUE dataset."""
        # N.B. PL calls prepare_data from a single process (rank 0) so do not use it to assign
        # state (e.g. self.x=y)
        datasets.load_dataset("super_glue", self.hparams.task_name)

    def setup(self, stage):
        """Setup our dataset splits for training/validation."""
        self.dataset = datasets.load_dataset("super_glue", self.hparams.task_name)
        for split in self.dataset.keys():
            self.dataset[split] = self.dataset[split].map(
                self._convert_to_features, batched=True, remove_columns=["label"]
            )
            self.columns = [c for c in self.dataset[split].column_names if c in self.LOADER_COLUMNS]
            self.dataset[split].set_format(type="torch", columns=self.columns)

        self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]

    def train_dataloader(self):
        return DataLoader(self.dataset["train"], batch_size=self.hparams.train_batch_size, **self.dataloader_kwargs)

    def val_dataloader(self):
        return DataLoader(self.dataset["validation"], batch_size=self.hparams.eval_batch_size, **self.dataloader_kwargs)

    def _convert_to_features(self, example_batch: datasets.arrow_dataset.Batch) -> BatchEncoding:
        """Convert raw text examples to a :class:`~transformers.tokenization_utils_base.BatchEncoding` container
        (derived from python dict) of features that includes helpful methods for translating between word/character
        space and token space.

        Args:
            example_batch ([type]): The set of examples to convert to token space.

        Returns:
            ``BatchEncoding``: A batch of encoded examples (note default tokenizer batch_size=1000)
        """
        text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
        # Tokenize the text/text pairs
        features = self.tokenizer.batch_encode_plus(
            text_pairs, max_length=self.hparams.max_seq_length, padding="longest", truncation=True
        )
        # Rename label to labels to make it easier to pass to model forward
        features["labels"] = example_batch["label"]
        return features
[5]:
class RteBoolqModule(pl.LightningModule):
    """A ``LightningModule`` that can be used to fine-tune a foundational model on either the RTE or BoolQ
    SuperGLUE tasks using Hugging Face implementations of a given model and the `SuperGLUE Hugging Face dataset."""

    def __init__(
        self,
        model_name_or_path: str,
        optimizer_init: Dict[str, Any],
        lr_scheduler_init: Dict[str, Any],
        model_cfg: Optional[Dict[str, Any]] = None,
        task_name: str = DEFAULT_TASK,
        experiment_tag: str = "default",
    ):
        """
        Args:
            model_name_or_path (str): Path to pretrained model or identifier from https://huggingface.co/models
            optimizer_init (Dict[str, Any]): The desired optimizer configuration.
            lr_scheduler_init (Dict[str, Any]): The desired learning rate scheduler config
            model_cfg (Optional[Dict[str, Any]], optional): Defines overrides of the default model config. Defaults to
                ``None``.
            task_name (str, optional): The SuperGLUE task to execute, one of ``'rte'``, ``'boolq'``. Defaults to "rte".
            experiment_tag (str, optional): The tag to use for the experiment and tensorboard logs. Defaults to
                "default".
        """
        super().__init__()
        if task_name not in TASK_NUM_LABELS.keys():
            rank_zero_warn(f"Invalid task_name {task_name!r}. Proceeding with the default task: {DEFAULT_TASK!r}")
            task_name = DEFAULT_TASK
        self.num_labels = TASK_NUM_LABELS[task_name]
        self.model_cfg = model_cfg or {}
        conf = AutoConfig.from_pretrained(model_name_or_path, num_labels=self.num_labels, local_files_only=False)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=conf)
        self.model.config.update(self.model_cfg)  # apply model config overrides
        self.init_hparams = {
            "optimizer_init": optimizer_init,
            "lr_scheduler_init": lr_scheduler_init,
            "model_config": self.model.config,
            "model_name_or_path": model_name_or_path,
            "task_name": task_name,
            "experiment_id": f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{experiment_tag}",
        }
        self.save_hyperparameters(self.init_hparams)
        self.metric = evaluate.load("super_glue", self.hparams.task_name, experiment_id=self.hparams.experiment_id)
        self.no_decay = ["bias", "LayerNorm.weight"]

    @property
    def finetuningscheduler_callback(self) -> fts.FinetuningScheduler:
        fts_callback = [c for c in self.trainer.callbacks if isinstance(c, fts.FinetuningScheduler)]
        return fts_callback[0] if fts_callback else None

    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs[0]
        self.log("train_loss", loss)
        return loss

    def training_epoch_end(self, outputs: List[Any]) -> None:
        if self.finetuningscheduler_callback:
            self.log("finetuning_schedule_depth", float(self.finetuningscheduler_callback.curr_depth))

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self(**batch)
        val_loss, logits = outputs[:2]
        if self.num_labels >= 1:
            preds = torch.argmax(logits, axis=1)
        elif self.num_labels == 1:
            preds = logits.squeeze()
        labels = batch["labels"]
        self.log("val_loss", val_loss, prog_bar=True)
        metric_dict = self.metric.compute(predictions=preds, references=labels)
        self.log_dict(metric_dict, prog_bar=True)

    def _init_param_groups(self) -> List[Dict]:
        """Initialize the parameter groups. Used to ensure weight_decay is not applied to our specified bias
        parameters when we initialize the optimizer.

        Returns:
            List[Dict]: A list of parameter group dictionaries.
        """
        return [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in self.no_decay) and p.requires_grad
                ],
                "weight_decay": self.hparams.optimizer_init["weight_decay"],
            },
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if any(nd in n for nd in self.no_decay) and p.requires_grad
                ],
                "weight_decay": 0.0,
            },
        ]

    def configure_optimizers(self):
        # the phase 0 parameters will have been set to require gradients during setup
        # you can initialize the optimizer with a simple requires.grad filter as is often done,
        # but in this case we pass a list of parameter groups to ensure weight_decay is
        # not applied to the bias parameter (for completeness, in this case it won't make much
        # performance difference)
        optimizer = AdamW(params=self._init_param_groups(), **self.hparams.optimizer_init)
        scheduler = {
            "scheduler": CosineAnnealingWarmRestarts(optimizer, **self.hparams.lr_scheduler_init),
            "interval": "epoch",
        }
        return [optimizer], [scheduler]
Our Training Sessions

We’ll be comparing three different fine-tuning training configurations. Every configuration in this example depends upon a shared set of defaults, only differing in their respective fine-tuning schedules.

Experiment Tag

Training Scenario Description

fts_explicit

Training with a fine-tuning schedule explicitly provided by the user

nofts_baseline

A baseline fine-tuning training session (without scheduled fine-tuning)

fts_implicit

Training with an implicitly generated fine-tuning schedule (the default)

Let’s begin by configuring the fts_explicit scenario. We’ll subsequently run the other two scenarios for comparison.

[6]:
# Let's create a fine-tuning schedule for our model and run an explicitly scheduled fine-tuning training scenario with it
# Please see the [FinetuningScheduler documentation](https://finetuning-scheduler.readthedocs.io/en/stable/index.html) for a full description of the schedule format


ft_schedule_yaml = """
0:
  params:
  - model.classifier.bias
  - model.classifier.weight
  - model.pooler.dense.bias
  - model.pooler.dense.weight
  - model.deberta.encoder.LayerNorm.bias
  - model.deberta.encoder.LayerNorm.weight
  - model.deberta.encoder.rel_embeddings.weight
  - model.deberta.encoder.layer.{0,11}.(output|attention|intermediate).*
1:
  params:
  - model.deberta.embeddings.LayerNorm.bias
  - model.deberta.embeddings.LayerNorm.weight
2:
  params:
  - model.deberta.embeddings.word_embeddings.weight
"""
ft_schedule_name = "RteBoolqModule_ft_schedule_deberta_base.yaml"
# Let's write the schedule to a file so we can simulate loading an explicitly defined fine-tuning
# schedule.
with open(ft_schedule_name, "w") as f:
    f.write(ft_schedule_yaml)
[7]:
datasets.logging.disable_progress_bar()
pl.seed_everything(42)
dm = RteBoolqDataModule(model_name_or_path="microsoft/deberta-v3-base", tokenizers_parallelism=True)
Global seed set to 42
Optimizer Configuration

Though other optimizers can arguably yield some marginal advantage contingent on the context, the Adam optimizer (and the AdamW version which implements decoupled weight decay) remains robust to hyperparameter choices and is commonly used for fine-tuning foundational language models. See (Sivaprasad et al., 2020) and (Mosbach, Andriushchenko & Klakow, 2020) for theoretical and systematic empirical justifications of Adam and its use in fine-tuning large transformer-based language models. The values used here have some justification in the referenced literature but have been largely empirically determined and while a good starting point could be could be further tuned.

[8]:
optimizer_init = {"weight_decay": 1e-05, "eps": 1e-07, "lr": 1e-05}
LR Scheduler Configuration

The CosineAnnealingWarmRestarts scheduler nicely fits with our iterative fine-tuning since it does not depend upon a global max_epoch value. The importance of initial warmup is reduced due to the innate warmup effect of Adam bias correction [5] and the gradual thawing we are performing. Note that commonly used LR schedulers that depend on providing max_iterations/epochs (e.g. the CosineWarmupScheduler used in other pytorch-lightning tutorials) also work with FinetuningScheduler. Though the LR scheduler is theoretically justified (Loshchilov & Hutter, 2016), the particular values provided here are primarily empircally driven.

FinetuningScheduler also supports LR scheduler reinitialization in both explicit and implicit finetuning schedule modes. See the advanced usage documentation for explanations and demonstration of the extension’s support for more complex requirements.

[9]:
lr_scheduler_init = {"T_0": 1, "T_mult": 2, "eta_min": 1e-07}
[10]:
# Load our lightning module...
lightning_module_kwargs = {
    "model_name_or_path": "microsoft/deberta-v3-base",
    "optimizer_init": optimizer_init,
    "lr_scheduler_init": lr_scheduler_init,
}
model = RteBoolqModule(**lightning_module_kwargs, experiment_tag="fts_explicit")
Callback Configuration

The only callback required to invoke the FinetuningScheduler is the FinetuningScheduler callback itself. Default versions of FTSCheckpoint and FTSEarlyStopping (if not specifying epoch_only_transitions) will be included (as discussed above) if not provided in the callbacks list. For demonstration purposes I’m including example configurations of all three callbacks below.

[11]:
# let's save our callback configurations for the explicit scenario since we'll be reusing the same
# configurations for the implicit and nofts_baseline scenarios (except the  config for the
# FinetuningScheduler callback itself of course in the case of nofts_baseline)
earlystopping_kwargs = {"monitor": "val_loss", "min_delta": 0.001, "patience": 2}
checkpoint_kwargs = {"monitor": "val_loss", "save_top_k": 1}
fts_kwargs = {"max_depth": 1}
callbacks = [
    fts.FinetuningScheduler(ft_schedule=ft_schedule_name, **fts_kwargs),
    fts.FTSEarlyStopping(**earlystopping_kwargs),
    fts.FTSCheckpoint(**checkpoint_kwargs),
]
[12]:
logger = TensorBoardLogger("lightning_logs", name="fts_explicit")
# optionally start tensorboard and monitor progress graphically while viewing multi-phase fine-tuning specific training
# logs in the cell output below by uncommenting the next 2 lines
# %load_ext tensorboard
# %tensorboard --logdir lightning_logs
# disable progress bar by default to focus on multi-phase training logs. Set to True to re-enable if desired
enable_progress_bar = False
[13]:


def train() -> None: trainer = pl.Trainer( enable_progress_bar=enable_progress_bar, max_epochs=100, precision=16, accelerator="auto", devices=1 if is_cuda_available() else None, callbacks=callbacks, logger=logger, ) trainer.fit(model, datamodule=dm) print( "Note given the computation associated w/ the multiple phases of fine-tuning demonstrated, this notebook is best used with an accelerator" ) train()
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Note given the computation associated w/ the multiple phases of fine-tuning demonstrated, this notebook is best used with an accelerator
Downloading and preparing dataset super_glue/rte to /root/.cache/huggingface/datasets/super_glue/rte/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed...
Missing logger folder: lightning_logs/fts_explicit
Dataset super_glue downloaded and prepared to /root/.cache/huggingface/datasets/super_glue/rte/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed. Subsequent calls will reuse this data.
fine-tuning schedule dumped to lightning_logs/fts_explicit/version_0/RteBoolqModule_ft_schedule.yaml.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type                               | Params
-------------------------------------------------------------
0 | model | DebertaV2ForSequenceClassification | 184 M
-------------------------------------------------------------
86.0 M    Trainable params
98.4 M    Non-trainable params
184 M     Total params
368.847   Total estimated model params size (MB)
Restoring states from the checkpoint path at lightning_logs/fts_explicit/version_0/checkpoints/epoch=1-step=312.ckpt
Restored all states from the checkpoint file at lightning_logs/fts_explicit/version_0/checkpoints/epoch=1-step=312.ckpt
Multi-phase fine-tuned training continuing at level 1.
Running the Baseline and Implicit Fine-Tuning Scenarios

Let’s now compare our nofts_baseline and fts_implicit scenarios with the fts_explicit one we just ran.

We’ll need to update our callbacks list, using the core PL EarlyStopping and ModelCheckpoint callbacks for the nofts_baseline (which operate identically to their FTS analogs apart from the recursive training support). For both core PyTorch Lightning and user-registered callbacks, we can define our callbacks using a dictionary as we do with the LightningCLI. This allows us to avoid managing imports and support more complex configuration separated from code.

Note that we’ll be using identical callback configurations to the fts_explicit scenario. Keeping max_depth for the implicit schedule will limit fine-tuning to just the last 4 parameters of the model, which is only a small fraction of the parameters you’d want to tune for maximum performance. Since the implicit schedule is quite computationally intensive and most useful for exploring model behavior, leaving max_depth 1 allows us to demo implicit mode behavior while keeping the computational cost and runtime of this notebook reasonable. To review how a full implicit mode run compares to the nofts_baseline and fts_explicit scenarios, please see the the following tensorboard experiment summary.

[14]:
nofts_callbacks = [EarlyStopping(**earlystopping_kwargs), ModelCheckpoint(**checkpoint_kwargs)]
fts_implicit_callbacks = [
    fts.FinetuningScheduler(**fts_kwargs),
    fts.FTSEarlyStopping(**earlystopping_kwargs),
    fts.FTSCheckpoint(**checkpoint_kwargs),
]
scenario_callbacks = {"nofts_baseline": nofts_callbacks, "fts_implicit": fts_implicit_callbacks}
[15]:
for scenario_name, scenario_callbacks in scenario_callbacks.items():
    model = RteBoolqModule(**lightning_module_kwargs, experiment_tag=scenario_name)
    logger = TensorBoardLogger("lightning_logs", name=scenario_name)
    callbacks = scenario_callbacks
    print(f"Beginning training the '{scenario_name}' scenario")
    train()
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Beginning training the 'nofts_baseline' scenario
Missing logger folder: lightning_logs/nofts_baseline
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type                               | Params
-------------------------------------------------------------
0 | model | DebertaV2ForSequenceClassification | 184 M
-------------------------------------------------------------
184 M     Trainable params
0         Non-trainable params
184 M     Total params
368.847   Total estimated model params size (MB)
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Beginning training the 'fts_implicit' scenario
Missing logger folder: lightning_logs/fts_implicit
fine-tuning schedule dumped to lightning_logs/fts_implicit/version_0/RteBoolqModule_ft_schedule.yaml.
Generated default fine-tuning schedule 'lightning_logs/fts_implicit/version_0/RteBoolqModule_ft_schedule.yaml' for iterative fine-tuning
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type                               | Params
-------------------------------------------------------------
0 | model | DebertaV2ForSequenceClassification | 184 M
-------------------------------------------------------------
1.5 K     Trainable params
184 M     Non-trainable params
184 M     Total params
368.847   Total estimated model params size (MB)
Restoring states from the checkpoint path at lightning_logs/fts_implicit/version_0/checkpoints/epoch=0-step=156.ckpt
Restored all states from the checkpoint file at lightning_logs/fts_implicit/version_0/checkpoints/epoch=0-step=156.ckpt
Multi-phase fine-tuned training continuing at level 1.
Reviewing the Training Results

See the tensorboard experiment summaries to get a sense of the relative computational and performance tradeoffs associated with these FinetuningScheduler configurations. The summary compares a full fts_implicit execution to fts_explicit and nofts_baseline scenarios using DDP training with 2 GPUs. The full logs/schedules and detailed system configuration used for all three scenarios are available here and the checkpoints produced in the scenarios here (caution, ~3.5GB).

fts_explicit_accuracy nofts_baseline

Note that the results above may vary to a small degree from the tensorboard summaries generated by this notebook which used DP, 1 GPU and likely when you’re running this, different versions of certain software components (e.g. pytorch, transformers).

FinetuningScheduler expands the space of possible fine-tuning schedules and the composition of more sophisticated schedules can yield marginal fine-tuning performance gains. That stated, it should be emphasized the primary utility of FinetuningScheduler is to grant greater fine-tuning flexibility for model exploration in research. For example, glancing at DeBERTa-v3’s implicit training run, a critical tuning transition point is immediately apparent:

implicit_training_transition

Our val_loss begins a precipitous decline at step 3119 which corresponds to phase 17 in the schedule. Referring to our schedule, in phase 17 we’re beginning tuning the attention parameters of our 10th encoder layer (of 11). Interesting! Though beyond the scope of this tutorial, it might be worth investigating these dynamics further and FinetuningScheduler allows one to do just that quite easily.

Note that though this example is intended to capture a common usage scenario, substantial variation is expected among use cases and models. In summary, FinetuningScheduler provides increased fine-tuning flexibility that can be useful in a variety of contexts from exploring model tuning behavior to maximizing performance.

Footnotes

  1. Howard, J., & Ruder, S. (2018). Fine-tuned Language Models for Text Classification. ArXiv, abs/1801.06146.

  2. Chronopoulou, A., Baziotis, C., & Potamianos, A. (2019). An embarrassingly simple approach for transfer learning from pretrained language models. arXiv preprint arXiv:1902.10547.

  3. Peters, M. E., Ruder, S., & Smith, N. A. (2019). To tune or not to tune? adapting pretrained representations to diverse tasks. arXiv preprint arXiv:1903.05987.

  4. Sivaprasad, P. T., Mai, F., Vogels, T., Jaggi, M., & Fleuret, F. (2020). Optimizer benchmarking needs to account for hyperparameter tuning. In International Conference on Machine Learning (pp. 9036-9045). PMLR.

  5. Mosbach, M., Andriushchenko, M., & Klakow, D. (2020). On the stability of fine-tuning bert: Misconceptions, explanations, and strong baselines. arXiv preprint arXiv:2006.04884.

  6. Loshchilov, I., & Hutter, F. (2016). Sgdr: Stochastic gradient descent with warm restarts. arXiv preprint arXiv:1608.03983.

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Introduction to Pytorch Lightning

  • Author: PL team

  • License: CC BY-SA

  • Generated: 2023-01-05T12:09:29.379466

In this notebook, we’ll go over the basics of lightning by preparing models to train on the MNIST Handwritten Digits dataset.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "torchmetrics>=0.7, <0.12" "seaborn" "ipython[notebook]>=8.0.0, <8.9.0" "pytorch-lightning>=1.4, <1.9" "torchmetrics >=0.11.0" "setuptools==65.6.3" "pandas" "torchvision" "torch>=1.8.1, <1.14.0"
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[2]:
import os

import pandas as pd
import seaborn as sn
import torch
from IPython.core.display import display
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
/tmp/ipykernel_3064/1920170836.py:6: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display
  from IPython.core.display import display

Simplest example

Here’s the simplest most minimal example with just a training loop (no validation, no testing).

Keep in Mind - A LightningModule is a PyTorch nn.Module - it just has a few more helpful features.

[3]:
class MNISTModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

By using the Trainer you automatically get: 1. Tensorboard logging 2. Model checkpointing 3. Training and validation loop 4. early-stopping

[4]:
# Init our model
mnist_model = MNISTModel()

# Init DataLoader from MNIST Dataset
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

# Initialize a trainer
trainer = Trainer(
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    max_epochs=3,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
)

# Train the model ⚡
trainer.fit(mnist_model, train_loader)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /__w/6/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /__w/6/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /__w/6/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /__w/6/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /__w/6/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /__w/6/s/.datasets/MNIST/raw

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /__w/6/s/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 7.9 K
--------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params
0.031     Total estimated model params size (MB)
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=3` reached.

A more complete MNIST Lightning Module Example

That wasn’t so hard was it?

Now that we’ve got our feet wet, let’s dive in a bit deeper and write a more complete LightningModule for MNIST…

This time, we’ll bake in all the dataset specific pieces directly in the LightningModule. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.


Note what the following built-in functions are doing:
  1. prepare_data() 💾

    • This is where we can download the dataset. We point to our desired dataset and ask torchvision’s MNIST dataset class to download if the dataset isn’t found there.

    • Note we do not make any state assignments in this function (i.e. self.something = ...)

  2. setup(stage) ⚙️

    • Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).

    • Setup expects a ‘stage’ arg which is used to separate logic for ‘fit’ and ‘test’.

    • If you don’t mind loading all your datasets at once, you can set up a condition to allow for both ‘fit’ related setup and ‘test’ related setup to run whenever None is passed to stage (or ignore it altogether and exclude any conditionals).

    • Note this runs across all GPUs and it is safe to make state assignments here

  3. x_dataloader() ♻️

    • train_dataloader(), val_dataloader(), and test_dataloader() all return PyTorch DataLoader instances that are created by wrapping their respective datasets that we prepared in setup()

[5]:
class LitMNIST(LightningModule):
    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes),
        )

        self.val_accuracy = Accuracy(task="multiclass", num_classes=10)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
[6]:
model = LitMNIST()
trainer = Trainer(
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    max_epochs=3,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
    logger=CSVLogger(save_dir="logs/"),
)
trainer.fit(model)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: logs/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type               | Params
-----------------------------------------------------
0 | model         | Sequential         | 55.1 K
1 | val_accuracy  | MulticlassAccuracy | 0
2 | test_accuracy | MulticlassAccuracy | 0
-----------------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=3` reached.
Testing

To test a model, call trainer.test(model).

Or, if you’ve just trained a model, you can just call trainer.test() and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss).

[7]:
trainer.test()
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py:134: UserWarning: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.
  rank_zero_warn(
Restoring states from the checkpoint path at logs/lightning_logs/version_0/checkpoints/epoch=2-step=645.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loaded model weights from checkpoint at logs/lightning_logs/version_0/checkpoints/epoch=2-step=645.ckpt
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_acc              0.9228000044822693     │
│         test_loss             0.2596317529678345     │
└───────────────────────────┴───────────────────────────┘
[7]:
[{'test_loss': 0.2596317529678345, 'test_acc': 0.9228000044822693}]
Bonus Tip

You can keep calling trainer.fit(model) as many times as you’d like to continue training

[8]:
trainer.fit(model)
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory logs/lightning_logs/version_0/checkpoints exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type               | Params
-----------------------------------------------------
0 | model         | Sequential         | 55.1 K
1 | val_accuracy  | MulticlassAccuracy | 0
2 | test_accuracy | MulticlassAccuracy | 0
-----------------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=3` reached.

In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!

[9]:

metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv") del metrics["step"] metrics.set_index("epoch", inplace=True) display(metrics.dropna(axis=1, how="all").head()) sn.relplot(data=metrics, kind="line")
val_loss val_acc test_loss test_acc
epoch
0 0.432111 0.8884 NaN NaN
1 0.310814 0.9124 NaN NaN
2 0.264833 0.9224 NaN NaN
2 NaN NaN 0.259632 0.9228
[9]:
<seaborn.axisgrid.FacetGrid at 0x7ff9b037ee20>
_images/notebooks_lightning_examples_mnist-hello-world_16_2.png

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

TPU training with PyTorch Lightning

  • Author: PL team

  • License: CC BY-SA

  • Generated: 2023-01-03T15:41:22.863312

In this notebook, we’ll train a model on TPUs. Updating one Trainer flag is all you need for that. The most up to documentation related to TPU training can be found here.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[ ]:
! pip install --quiet "pytorch-lightning>=1.4" "setuptools==59.5.0" "torch>=1.8" "torchmetrics>=0.7" "torchvision" "ipython[notebook]"
Install Colab TPU compatible PyTorch/TPU wheels and dependencies
[ ]:
! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
[ ]:
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
from torchvision import transforms

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST

BATCH_SIZE = 1024
Defining The MNISTDataModule

Below we define MNISTDataModule. You can learn more about datamodules in docs.

[ ]:
class MNISTDataModule(LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
Defining the LitModel

Below, we define the model LitMNIST.

[ ]:
class LitModel(LightningModule):
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        self.save_hyperparameters()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer
TPU Training

Lightning supports training on a single TPU core or 8 TPU cores.

The Trainer parameter devices defines how many TPU cores to train on (1 or 8) / Single TPU core to train on [1] along with accelerator=‘tpu’.

For Single TPU training, Just pass the TPU core ID [1-8] in a list. Setting devices=[5] will train on TPU core ID 5.

Train on TPU core ID 5 with devices=[5].

[ ]:
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.size(), dm.num_classes)
# Init trainer
trainer = Trainer(
    max_epochs=3,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
    accelerator="tpu",
    devices=[5],
)
# Train
trainer.fit(model, dm)

Train on single TPU core with devices=1.

[ ]:
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.dims, dm.num_classes)
# Init trainer
trainer = Trainer(
    max_epochs=3,
    accelerator="tpu",
    devices=1,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
)
# Train
trainer.fit(model, dm)

Train on 8 TPU cores with accelerator='tpu' and devices=8. You might have to restart the notebook to run it on 8 TPU cores after training on single TPU core.

[ ]:
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.dims, dm.num_classes)
# Init trainer
trainer = Trainer(
    max_epochs=3,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
    accelerator="tpu",
    devices=8,
)
# Train
trainer.fit(model, dm)

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

How to train a Deep Q Network

  • Author: PL team

  • License: CC BY-SA

  • Generated: 2022-04-28T08:05:34.347059

Main takeaways:

  1. RL has the same flow as previous models we have seen, with a few additions

  2. Handle unsupervised learning by using an IterableDataset where the dataset itself is constantly updated during training

  3. Each training step carries has the agent taking an action in the environment and storing the experience in the IterableDataset


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "ipython[notebook]" "seaborn" "torchmetrics>=0.6" "pygame" "gym" "pandas" "pytorch-lightning>=1.4" "torch>=1.6, <1.9"
WARNING: You are using pip version 21.3.1; however, version 22.0.4 is available.
You should consider upgrading via the '/usr/bin/python3.8 -m pip install --upgrade pip' command.
[2]:
import os
from collections import OrderedDict, deque, namedtuple
from typing import Iterator, List, Tuple

import gym
import numpy as np
import pandas as pd
import seaborn as sn
import torch
from IPython.core.display import display
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import CSVLogger
from torch import Tensor, nn
from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
[3]:
class DQN(nn.Module):
    """Simple MLP network."""

    def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
        """
        Args:
            obs_size: observation/state size of the environment
            n_actions: number of discrete actions available in the environment
            hidden_size: size of hidden layers
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions),
        )

    def forward(self, x):
        return self.net(x.float())
Memory
[4]:

# Named tuple for storing experience steps gathered in training Experience = namedtuple( "Experience", field_names=["state", "action", "reward", "done", "new_state"], )
[5]:
class ReplayBuffer:
    """Replay Buffer for storing past experiences allowing the agent to learn from them.

    Args:
        capacity: size of the buffer
    """

    def __init__(self, capacity: int) -> None:
        self.buffer = deque(maxlen=capacity)

    def __len__(self) -> None:
        return len(self.buffer)

    def append(self, experience: Experience) -> None:
        """Add experience to the buffer.

        Args:
            experience: tuple (state, action, reward, done, new_state)
        """
        self.buffer.append(experience)

    def sample(self, batch_size: int) -> Tuple:
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, dones, next_states = zip(*(self.buffer[idx] for idx in indices))

        return (
            np.array(states),
            np.array(actions),
            np.array(rewards, dtype=np.float32),
            np.array(dones, dtype=bool),
            np.array(next_states),
        )
[6]:
class RLDataset(IterableDataset):
    """Iterable Dataset containing the ExperienceBuffer which will be updated with new experiences during training.

    Args:
        buffer: replay buffer
        sample_size: number of experiences to sample at a time
    """

    def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
        self.buffer = buffer
        self.sample_size = sample_size

    def __iter__(self) -> Iterator[Tuple]:
        states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
        for i in range(len(dones)):
            yield states[i], actions[i], rewards[i], dones[i], new_states[i]
Agent
[7]:
class Agent:
    """Base Agent class handeling the interaction with the environment."""

    def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
        """
        Args:
            env: training environment
            replay_buffer: replay buffer storing experiences
        """
        self.env = env
        self.replay_buffer = replay_buffer
        self.reset()
        self.state = self.env.reset()

    def reset(self) -> None:
        """Resents the environment and updates the state."""
        self.state = self.env.reset()

    def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
        """Using the given network, decide what action to carry out using an epsilon-greedy policy.

        Args:
            net: DQN network
            epsilon: value to determine likelihood of taking a random action
            device: current device

        Returns:
            action
        """
        if np.random.random() < epsilon:
            action = self.env.action_space.sample()
        else:
            state = torch.tensor([self.state])

            if device not in ["cpu"]:
                state = state.cuda(device)

            q_values = net(state)
            _, action = torch.max(q_values, dim=1)
            action = int(action.item())

        return action

    @torch.no_grad()
    def play_step(
        self,
        net: nn.Module,
        epsilon: float = 0.0,
        device: str = "cpu",
    ) -> Tuple[float, bool]:
        """Carries out a single interaction step between the agent and the environment.

        Args:
            net: DQN network
            epsilon: value to determine likelihood of taking a random action
            device: current device

        Returns:
            reward, done
        """

        action = self.get_action(net, epsilon, device)

        # do step in the environment
        new_state, reward, done, _ = self.env.step(action)

        exp = Experience(self.state, action, reward, done, new_state)

        self.replay_buffer.append(exp)

        self.state = new_state
        if done:
            self.reset()
        return reward, done
DQN Lightning Module
[8]:
class DQNLightning(LightningModule):
    """Basic DQN Model."""

    def __init__(
        self,
        batch_size: int = 16,
        lr: float = 1e-2,
        env: str = "CartPole-v0",
        gamma: float = 0.99,
        sync_rate: int = 10,
        replay_size: int = 1000,
        warm_start_size: int = 1000,
        eps_last_frame: int = 1000,
        eps_start: float = 1.0,
        eps_end: float = 0.01,
        episode_length: int = 200,
        warm_start_steps: int = 1000,
    ) -> None:
        """
        Args:
            batch_size: size of the batches")
            lr: learning rate
            env: gym environment tag
            gamma: discount factor
            sync_rate: how many frames do we update the target network
            replay_size: capacity of the replay buffer
            warm_start_size: how many samples do we use to fill our buffer at the start of training
            eps_last_frame: what frame should epsilon stop decaying
            eps_start: starting value of epsilon
            eps_end: final value of epsilon
            episode_length: max length of an episode
            warm_start_steps: max episode reward in the environment
        """
        super().__init__()
        self.save_hyperparameters()

        self.env = gym.make(self.hparams.env)
        obs_size = self.env.observation_space.shape[0]
        n_actions = self.env.action_space.n

        self.net = DQN(obs_size, n_actions)
        self.target_net = DQN(obs_size, n_actions)

        self.buffer = ReplayBuffer(self.hparams.replay_size)
        self.agent = Agent(self.env, self.buffer)
        self.total_reward = 0
        self.episode_reward = 0
        self.populate(self.hparams.warm_start_steps)

    def populate(self, steps: int = 1000) -> None:
        """Carries out several random steps through the environment to initially fill up the replay buffer with
        experiences.

        Args:
            steps: number of random steps to populate the buffer with
        """
        for _ in range(steps):
            self.agent.play_step(self.net, epsilon=1.0)

    def forward(self, x: Tensor) -> Tensor:
        """Passes in a state x through the network and gets the q_values of each action as an output.

        Args:
            x: environment state

        Returns:
            q values
        """
        output = self.net(x)
        return output

    def dqn_mse_loss(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
        """Calculates the mse loss using a mini batch from the replay buffer.

        Args:
            batch: current mini batch of replay data

        Returns:
            loss
        """
        states, actions, rewards, dones, next_states = batch

        state_action_values = self.net(states).gather(1, actions.long().unsqueeze(-1)).squeeze(-1)

        with torch.no_grad():
            next_state_values = self.target_net(next_states).max(1)[0]
            next_state_values[dones] = 0.0
            next_state_values = next_state_values.detach()

        expected_state_action_values = next_state_values * self.hparams.gamma + rewards

        return nn.MSELoss()(state_action_values, expected_state_action_values)

    def get_epsilon(self, start: int, end: int, frames: int) -> float:
        if self.global_step > frames:
            return end
        return start - (self.global_step / frames) * (start - end)

    def training_step(self, batch: Tuple[Tensor, Tensor], nb_batch) -> OrderedDict:
        """Carries out a single step through the environment to update the replay buffer. Then calculates loss
        based on the minibatch recieved.

        Args:
            batch: current mini batch of replay data
            nb_batch: batch number

        Returns:
            Training loss and log metrics
        """
        device = self.get_device(batch)
        epsilon = self.get_epsilon(self.hparams.eps_start, self.hparams.eps_end, self.hparams.eps_last_frame)
        self.log("epsilon", epsilon)

        # step through environment with agent
        reward, done = self.agent.play_step(self.net, epsilon, device)
        self.episode_reward += reward
        self.log("episode reward", self.episode_reward)

        # calculates training loss
        loss = self.dqn_mse_loss(batch)

        if done:
            self.total_reward = self.episode_reward
            self.episode_reward = 0

        # Soft update of target network
        if self.global_step % self.hparams.sync_rate == 0:
            self.target_net.load_state_dict(self.net.state_dict())

        self.log_dict(
            {
                "reward": reward,
                "train_loss": loss,
            }
        )
        self.log("total_reward", self.total_reward, prog_bar=True)
        self.log("steps", self.global_step, logger=False, prog_bar=True)

        return loss

    def configure_optimizers(self) -> List[Optimizer]:
        """Initialize Adam optimizer."""
        optimizer = Adam(self.net.parameters(), lr=self.hparams.lr)
        return optimizer

    def __dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences."""
        dataset = RLDataset(self.buffer, self.hparams.episode_length)
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size,
        )
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader."""
        return self.__dataloader()

    def get_device(self, batch) -> str:
        """Retrieve device currently being used by minibatch."""
        return batch[0].device.index if self.on_gpu else "cpu"
Trainer
[9]:

model = DQNLightning() trainer = Trainer( accelerator="auto", devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs max_epochs=150, val_check_interval=50, logger=CSVLogger(save_dir="logs/"), ) trainer.fit(model)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name       | Type | Params
------------------------------------
0 | net        | DQN  | 898
1 | target_net | DQN  | 898
------------------------------------
1.8 K     Trainable params
0         Non-trainable params
1.8 K     Total params
0.007     Total estimated model params size (MB)
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:229: UserWarning: You called `self.log('total_reward', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.
  warning_cache.warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:229: UserWarning: You called `self.log('steps', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.
  warning_cache.warn(
[10]:

metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv") del metrics["step"] metrics.set_index("epoch", inplace=True) display(metrics.dropna(axis=1, how="all").head()) sn.relplot(data=metrics, kind="line")
epsilon episode reward reward train_loss total_reward
epoch
3 0.95149 5.0 1.0 0.189056 22.0
7 0.90199 15.0 1.0 1.432721 12.0
11 0.85249 18.0 1.0 30.838800 14.0
15 0.80299 68.0 1.0 3.394485 14.0
19 0.75349 21.0 1.0 18.886366 15.0
[10]:
<seaborn.axisgrid.FacetGrid at 0x7f02190640d0>
_images/notebooks_lightning_examples_reinforce-learning-DQN_15_2.png

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Finetune Transformers Models with PyTorch Lightning

  • Author: PL team

  • License: CC BY-SA

  • Generated: 2023-01-03T15:49:54.952421

This notebook will use HuggingFace’s datasets library to get data, which will be wrapped in a LightningDataModule. Then, we write a class to perform text classification on any dataset from the GLUE Benchmark. (We just show CoLA and MRPC due to constraint on compute/disk)


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "torchmetrics>=0.7" "torchtext>=0.9" "pytorch-lightning>=1.4" "setuptools==59.5.0" "torch>=1.8" "transformers" "datasets" "scikit-learn" "ipython[notebook]" "scipy"
[2]:
from datetime import datetime
from typing import Optional

import datasets
import torch
from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
from torch.utils.data import DataLoader
from transformers import (
    AdamW,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

Training BERT with Lightning

Lightning DataModule for GLUE
[3]:
class GLUEDataModule(LightningDataModule):
    task_text_field_map = {
        "cola": ["sentence"],
        "sst2": ["sentence"],
        "mrpc": ["sentence1", "sentence2"],
        "qqp": ["question1", "question2"],
        "stsb": ["sentence1", "sentence2"],
        "mnli": ["premise", "hypothesis"],
        "qnli": ["question", "sentence"],
        "rte": ["sentence1", "sentence2"],
        "wnli": ["sentence1", "sentence2"],
        "ax": ["premise", "hypothesis"],
    }

    glue_task_num_labels = {
        "cola": 2,
        "sst2": 2,
        "mrpc": 2,
        "qqp": 2,
        "stsb": 1,
        "mnli": 3,
        "qnli": 2,
        "rte": 2,
        "wnli": 2,
        "ax": 3,
    }

    loader_columns = [
        "datasets_idx",
        "input_ids",
        "token_type_ids",
        "attention_mask",
        "start_positions",
        "end_positions",
        "labels",
    ]

    def __init__(
        self,
        model_name_or_path: str,
        task_name: str = "mrpc",
        max_seq_length: int = 128,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        **kwargs,
    ):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        self.task_name = task_name
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size

        self.text_fields = self.task_text_field_map[task_name]
        self.num_labels = self.glue_task_num_labels[task_name]
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

    def setup(self, stage: str):
        self.dataset = datasets.load_dataset("glue", self.task_name)

        for split in self.dataset.keys():
            self.dataset[split] = self.dataset[split].map(
                self.convert_to_features,
                batched=True,
                remove_columns=["label"],
            )
            self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
            self.dataset[split].set_format(type="torch", columns=self.columns)

        self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]

    def prepare_data(self):
        datasets.load_dataset("glue", self.task_name)
        AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

    def train_dataloader(self):
        return DataLoader(self.dataset["train"], batch_size=self.train_batch_size, shuffle=True)

    def val_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset["validation"], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

    def test_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

    def convert_to_features(self, example_batch, indices=None):
        # Either encode single sentence or sentence pairs
        if len(self.text_fields) > 1:
            texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
        else:
            texts_or_text_pairs = example_batch[self.text_fields[0]]

        # Tokenize the text/text pairs
        features = self.tokenizer.batch_encode_plus(
            texts_or_text_pairs, max_length=self.max_seq_length, pad_to_max_length=True, truncation=True
        )

        # Rename label to labels to make it easier to pass to model forward
        features["labels"] = example_batch["label"]

        return features

You could use this datamodule with standalone PyTorch if you wanted…

[4]:
dm = GLUEDataModule("distilbert-base-uncased")
dm.prepare_data()
dm.setup("fit")
next(iter(dm.train_dataloader()))
Downloading and preparing dataset glue/mrpc to /home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...
Dataset glue downloaded and prepared to /home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.
Found cached dataset glue (/home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2336: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
  warnings.warn(
[4]:
{'input_ids': tensor([[  101,  1996, 12598,  ...,     0,     0,     0],
         [  101,  1000,  1045,  ...,     0,     0,     0],
         [  101,  2610,  2056,  ...,     0,     0,     0],
         ...,
         [  101,  3041,  2023,  ...,     0,     0,     0],
         [  101,  1996, 12368,  ...,     0,     0,     0],
         [  101,  8040,  2278,  ...,     0,     0,     0]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'labels': tensor([0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1,
         1, 0, 0, 1, 0, 1, 0, 1])}
Transformer LightningModule
[5]:
class GLUETransformer(LightningModule):
    def __init__(
        self,
        model_name_or_path: str,
        num_labels: int,
        task_name: str,
        learning_rate: float = 2e-5,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        eval_splits: Optional[list] = None,
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)
        self.metric = datasets.load_metric(
            "glue", self.hparams.task_name, experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
        )

    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs[0]
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self(**batch)
        val_loss, logits = outputs[:2]

        if self.hparams.num_labels > 1:
            preds = torch.argmax(logits, axis=1)
        elif self.hparams.num_labels == 1:
            preds = logits.squeeze()

        labels = batch["labels"]

        return {"loss": val_loss, "preds": preds, "labels": labels}

    def validation_epoch_end(self, outputs):
        if self.hparams.task_name == "mnli":
            for i, output in enumerate(outputs):
                # matched or mismatched
                split = self.hparams.eval_splits[i].split("_")[-1]
                preds = torch.cat([x["preds"] for x in output]).detach().cpu().numpy()
                labels = torch.cat([x["labels"] for x in output]).detach().cpu().numpy()
                loss = torch.stack([x["loss"] for x in output]).mean()
                self.log(f"val_loss_{split}", loss, prog_bar=True)
                split_metrics = {
                    f"{k}_{split}": v for k, v in self.metric.compute(predictions=preds, references=labels).items()
                }
                self.log_dict(split_metrics, prog_bar=True)
            return loss

        preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
        labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.trainer.estimated_stepping_batches,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]

Training

CoLA

See an interactive view of the CoLA dataset in NLP Viewer

[6]:
seed_everything(42)

dm = GLUEDataModule(model_name_or_path="albert-base-v2", task_name="cola")
dm.setup("fit")
model = GLUETransformer(
    model_name_or_path="albert-base-v2",
    num_labels=dm.num_labels,
    eval_splits=dm.eval_splits,
    task_name=dm.task_name,
)

trainer = Trainer(
    max_epochs=1,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
)
trainer.fit(model, datamodule=dm)
Global seed set to 42
Downloading and preparing dataset glue/cola to /home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...
Dataset glue downloaded and prepared to /home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.
Some weights of the model checkpoint at albert-base-v2 were not used when initializing AlbertForSequenceClassification: ['predictions.bias', 'predictions.decoder.bias', 'predictions.decoder.weight', 'predictions.LayerNorm.bias', 'predictions.dense.bias', 'predictions.dense.weight', 'predictions.LayerNorm.weight']
- This IS expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/tmp/ipykernel_2900/3453308743.py:22: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate
  self.metric = datasets.load_metric(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Found cached dataset glue (/home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Missing logger folder: /__w/1/s/lightning_logs
Found cached dataset glue (/home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2336: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
  warnings.warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
Loading `train_dataloader` to estimate number of stepping batches.
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

  | Name  | Type                            | Params
----------------------------------------------------------
0 | model | AlbertForSequenceClassification | 11.7 M
----------------------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.740    Total estimated model params size (MB)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=1` reached.
MRPC

See an interactive view of the MRPC dataset in NLP Viewer

[7]:
seed_everything(42)

dm = GLUEDataModule(
    model_name_or_path="distilbert-base-cased",
    task_name="mrpc",
)
dm.setup("fit")
model = GLUETransformer(
    model_name_or_path="distilbert-base-cased",
    num_labels=dm.num_labels,
    eval_splits=dm.eval_splits,
    task_name=dm.task_name,
)

trainer = Trainer(
    max_epochs=3,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
)
trainer.fit(model, datamodule=dm)
Global seed set to 42
Found cached dataset glue (/home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2336: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
  warnings.warn(
Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'pre_classifier.bias', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Found cached dataset glue (/home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Found cached dataset glue (/home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2336: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
  warnings.warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
Loading `train_dataloader` to estimate number of stepping batches.
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

  | Name  | Type                                | Params
--------------------------------------------------------------
0 | model | DistilBertForSequenceClassification | 65.8 M
--------------------------------------------------------------
65.8 M    Trainable params
0         Non-trainable params
65.8 M    Total params
263.132   Total estimated model params size (MB)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=3` reached.
MNLI
  • The MNLI dataset is huge, so we aren’t going to bother trying to train on it here.

  • We will skip over training and go straight to validation.

See an interactive view of the MRPC dataset in NLP Viewer

[8]:
dm = GLUEDataModule(
    model_name_or_path="distilbert-base-cased",
    task_name="mnli",
)
dm.setup("fit")
model = GLUETransformer(
    model_name_or_path="distilbert-base-cased",
    num_labels=dm.num_labels,
    eval_splits=dm.eval_splits,
    task_name=dm.task_name,
)

trainer = Trainer(
    max_epochs=3,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
)
trainer.validate(model, dm)
Downloading and preparing dataset glue/mnli to /home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...
Dataset glue downloaded and prepared to /home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2336: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
  warnings.warn(
Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'pre_classifier.bias', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Found cached dataset glue (/home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Found cached dataset glue (/home/AzDevOps_azpcontainer/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2336: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
  warnings.warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 1, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃      Validate metric             DataLoader 0               DataLoader 1        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     accuracy_matched          0.3181864619255066         0.3181864619255066     │
│    accuracy_mismatched        0.3182465434074402         0.3182465434074402     │
│     val_loss_matched          1.1081156730651855         1.1081156730651855     │
│    val_loss_mismatched         1.108543872833252          1.108543872833252     │
└───────────────────────────┴───────────────────────────┴───────────────────────────┘
[8]:
[{'val_loss_matched': 1.1081156730651855,
  'accuracy_matched': 0.3181864619255066,
  'val_loss_mismatched': 1.108543872833252,
  'accuracy_mismatched': 0.3182465434074402},
 {'val_loss_matched': 1.1081156730651855,
  'accuracy_matched': 0.3181864619255066,
  'val_loss_mismatched': 1.108543872833252,
  'accuracy_mismatched': 0.3182465434074402}]

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Multi-agent Reinforcement Learning With WarpDrive

This notebook introduces multi-agent reinforcement learning (MARL) with WarpDrive (Lan et al. https://arxiv.org/abs/2108.13976). WarpDrive is a flexible, lightweight, and easy-to-use open-source framework that implements end-to-end deep MARL on GPUs. WarpDrive enables orders-of-magnitude speedups compared to CPU-GPU implementations, using the parallelization capability of GPUs and several design choices to minimize communication overhead. WarpDrive also prioritizes user-friendliness - it has utility functions to easily build MARL environments in CUDA and quality-of-life tools to run end-to-end MARL using just a few lines of code, and is compatible with PyTorch. WarpDrive includes the following resources. code - https://github.com/salesforce/warp-drive documentation - http://opensource.salesforce.com/warp-drive/, and white paper - https://arxiv.org/abs/2108.13976.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "ffmpeg-python" "rl-warp-drive>=1.6.5" "setuptools==59.5.0" "ipython[notebook]" "torch>=1.8" "torch==1.10.*" "torchvision==0.11.*" "torchtext==0.11.*" "torchmetrics>=0.7" "pytorch-lightning>=1.4"

⚠️ PLEASE NOTE: This notebook runs on a GPU runtime. If running on Colab, choose Runtime > Change runtime type from the menu, then select GPU in the ‘Hardware accelerator’ dropdown menu.

Introduction

This tutorial provides a demonstration of a multi-agent Reinforcement Learning (RL) training loop with WarpDrive. WarpDrive is a flexible, lightweight, and easy-to-use RL framework that implements end-to-end deep multi-agent RL on a GPU (Graphics Processing Unit). Using the extreme parallelization capability of GPUs, it enables orders-of-magnitude faster RL compared to common implementations that blend CPU simulations and GPU models. WarpDrive is extremely efficient as it runs simulations across multiple agents and multiple environment replicas all in parallel and completely eliminates the back-and-forth data copying between the CPU and the GPU during every step. As such, WarpDrive - Can simulate 1000s of agents in each environment and thousands of environments in parallel, harnessing the extreme parallelism capability of GPUs. - Eliminates communication between CPU and GPU, and also within the GPU, as read and write operations occur in-place. - Is fully compatible with Pytorch, a highly flexible and very fast deep learning framework. - Implements parallel action sampling on CUDA C, which is ~3x faster than using Pytorch’s sampling methods. - Allows for large-scale distributed training on multiple GPUs.

Below is an overview of WarpDrive’s layout of computational and data structures on a single GPU. image0 Computations are organized into blocks, with multiple threads in each block. Each block runs a simulation environment and each thread simulates an agent in an environment. Blocks can access the shared GPU memory that stores simulation data and neural network policy models. A DataManager and FunctionManager enable defining multi-agent RL GPU-workflows with Python APIs. For more details, please read out white paper.

The Warpdrive framework comprises several utility functions that help easily implement any (OpenAI-)*gym-style* RL environment, and furthermore, provides quality-of-life tools to train it end-to-end using just a few lines of code. You may familiarize yourself with WarpDrive with the help of these tutorials.

We invite everyone to contribute to WarpDrive, including adding new multi-agent environments, proposing new features and reporting issues on our open source repository.

We have integrated WarpDrive with the Pytorch Lightning framework, which greatly reduces the trainer boilerplate code, and improves training modularity and flexibility. It abstracts away most of the engineering pieces of code, so users can focus on research and building models, and iterate on experiments really fast. Pytorch Lightning also provides support for easily running the model on any hardware, performing distributed training, model checkpointing, performance profiling, logging and visualization.

Below, we demonstrate how to use WarpDrive and PytorchLightning together to train a game of Tag where multiple tagger agents are trying to run after and tag multiple other runner agents. Here’s a sample depiction of the game of Tag with 100 runners and 5 taggers. image1

Dependencies

[2]:
import logging

import torch
from example_envs.tag_continuous.tag_continuous import TagContinuous
from pytorch_lightning import Trainer
from warp_drive.env_wrapper import EnvWrapper
from warp_drive.training.pytorch_lightning import CUDACallback, PerfStatsCallback, WarpDriveModule

# Uncomment below for enabling animation visualizations.
# from example_envs.utils.generate_rollout_animation import generate_tag_env_rollout_animation
# from IPython.display import HTML
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  return self._float_to_str(self.smallest_subnormal)
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/usr/local/lib/python3.8/dist-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  return self._float_to_str(self.smallest_subnormal)
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/torchvision/transforms/functional_pil.py:296: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/torchvision/transforms/functional_pil.py:329: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/usr/local/lib/python3.8/dist-packages/comet_ml/monkey_patching.py:19: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
  import imp
/usr/local/lib/python3.8/dist-packages/mlflow/types/schema.py:48: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  binary = (7, np.dtype("bytes"), "BinaryType", np.object)
WARNING:root:Bagua cannot detect bundled NCCL library, Bagua will try to use system NCCL instead. If you encounter any error, please run `import bagua_core; bagua_core.install_deps()` or the `bagua_install_deps.py` script to install bundled libraries.
/usr/local/lib/python3.8/dist-packages/sklearn/utils/multiclass.py:14: DeprecationWarning: Please use `spmatrix` from the `scipy.sparse` namespace, the `scipy.sparse.base` namespace is deprecated.
  from scipy.sparse.base import spmatrix
/usr/local/lib/python3.8/dist-packages/sklearn/utils/optimize.py:18: DeprecationWarning: Please use `line_search_wolfe2` from the `scipy.optimize` namespace, the `scipy.optimize.linesearch` namespace is deprecated.
  from scipy.optimize.linesearch import line_search_wolfe2, line_search_wolfe1
/usr/local/lib/python3.8/dist-packages/sklearn/utils/optimize.py:18: DeprecationWarning: Please use `line_search_wolfe1` from the `scipy.optimize` namespace, the `scipy.optimize.linesearch` namespace is deprecated.
  from scipy.optimize.linesearch import line_search_wolfe2, line_search_wolfe1
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pycuda/compyte/dtypes.py:120: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  reg.get_or_register_dtype("bool", np.bool)
[3]:
assert torch.cuda.device_count() > 0, "This notebook only runs on a GPU!"
[4]:
# Set logger level e.g., DEBUG, INFO, WARNING, ERROR.
logging.getLogger().setLevel(logging.ERROR)

Specify a set of run configurations for your experiments

The run configuration is a dictionary comprising the environment parameters, the trainer and the policy network settings, as well as configurations for saving.

For our experiment, we consider an environment wherein 5 taggers and 100 runners play the game of Tag on a 20 \times 20 plane. The game lasts 200 timesteps. Each agent chooses it’s own acceleration and turn actions at every timestep, and we use mechanics to determine how the agents move over the grid. When a tagger gets close to a runner, the runner is tagged, and is eliminated from the game. For the configuration below, the runners and taggers have the same unit skill levels, or top speeds.

We train the agents using 50 environments or simulations running in parallel. With WarpDrive, each simulation runs on separate GPU blocks.

There are two separate policy networks used for the tagger and runner agents. Each network is a fully-connected model with two layers each of 256 dimensions. We use the Advantage Actor Critic (A2C) algorithm for training. WarpDrive also currently provides the option to use the Proximal Policy Optimization (PPO) algorithm instead.

[5]:
run_config = dict(
    name="tag_continuous",
    # Environment settings.
    env=dict(
        # number of taggers in the environment
        num_taggers=5,
        # number of runners in the environment
        num_runners=100,
        # length of the (square) grid on which the game is played
        grid_length=20.0,
        # episode length in timesteps
        episode_length=200,
        # maximum acceleration
        max_acceleration=0.1,
        # minimum acceleration
        min_acceleration=-0.1,
        # maximum turn (in radians)
        max_turn=2.35,  # 3pi/4 radians
        # minimum turn (in radians)
        min_turn=-2.35,  # -3pi/4 radians
        # number of discretized accelerate actions
        num_acceleration_levels=10,
        # number of discretized turn actions
        num_turn_levels=10,
        # skill level for the tagger
        skill_level_tagger=1.0,
        # skill level for the runner
        skill_level_runner=1.0,
        # each agent sees the full (or partial) information of the world
        use_full_observation=False,
        # flag to indicate if a runner stays in the game after getting tagged
        runner_exits_game_after_tagged=True,
        # number of other agents each agent can see
        # used in the case use_full_observation is False
        num_other_agents_observed=10,
        # positive reward for a tagger upon tagging a runner
        tag_reward_for_tagger=10.0,
        # negative reward for a runner upon getting tagged
        tag_penalty_for_runner=-10.0,
        # reward at the end of the game for a runner that isn't tagged
        end_of_game_reward_for_runner=1.0,
        # distance margin between a tagger and runner
        # to consider the runner as being 'tagged'
        tagging_distance=0.02,
    ),
    # Trainer settings.
    trainer=dict(
        # number of environment replicas (number of GPU blocks used)
        num_envs=50,
        # total batch size used for training per iteration (across all the environments)
        train_batch_size=10000,
        # total number of episodes to run the training for
        # This can be set arbitrarily high!
        num_episodes=500,
    ),
    # Policy network settings.
    policy=dict(
        runner=dict(
            # flag indicating whether the model needs to be trained
            to_train=True,
            # algorithm used to train the policy
            algorithm="A2C",
            # discount rate
            gamma=0.98,
            # learning rate
            lr=0.005,
            # policy model settings
            model=dict(type="fully_connected", fc_dims=[256, 256], model_ckpt_filepath=""),
        ),
        tagger=dict(
            to_train=True,
            algorithm="A2C",
            gamma=0.98,
            lr=0.002,
            model=dict(type="fully_connected", fc_dims=[256, 256], model_ckpt_filepath=""),
        ),
    ),
    # Checkpoint saving setting.
    saving=dict(
        # how often (in iterations) to print the metrics
        metrics_log_freq=10,
        # how often (in iterations) to save the model parameters
        model_params_save_freq=5000,
        # base folder used for saving
        basedir="/tmp",
        # experiment name
        name="continuous_tag",
        # experiment tag
        tag="example",
    ),
)

Instantiate the WarpDrive Module

In order to instantiate the WarpDrive module, we first use an environment wrapper to specify that the environment needs to be run on the GPU (via the use_cuda flag). Also, agents in the environment can share policy models; so we specify a dictionary to map each policy network model to the list of agent ids using that model.

[6]:
# Create a wrapped environment object via the EnvWrapper.
# Ensure that use_cuda is set to True (in order to run on the GPU).
env_wrapper = EnvWrapper(
    TagContinuous(**run_config["env"]),
    num_envs=run_config["trainer"]["num_envs"],
    use_cuda=True,
)

# Agents can share policy models: this dictionary maps policy model names to agent ids.
policy_tag_to_agent_id_map = {
    "tagger": list(env_wrapper.env.taggers),
    "runner": list(env_wrapper.env.runners),
}

wd_module = WarpDriveModule(
    env_wrapper=env_wrapper,
    config=run_config,
    policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,
    verbose=True,
)
Global seed set to 1652830369

Visualizing an episode roll-out before training

We have created a helper function (see below) to visualize an episode rollout. Internally, this function uses the WarpDrive module’s fetch_episode_states API to fetch the data arrays on the GPU for the duration of an entire episode. Specifically, we fetch the state arrays pertaining to agents’ x and y locations on the plane and indicators on which agents are still active in the game. Note that this function may be invoked at any time during training, and it will use the state of the policy models at that time to sample actions and generate the visualization.

The animation below shows a sample realization of the game episode before training, i.e., with randomly chosen agent actions. The 5 taggers are marked in pink, while the 100 blue agents are the runners. Both the taggers and runners move around randomly and about half the runners remain at the end of the episode.

[7]:
# Uncomment below for enabling animation visualizations.
# anim = generate_tag_env_rollout_animation(wd_module, fps=25)
# HTML(anim.to_html5_video())

Create the Lightning Trainer

Next, we create the trainer for training the WarpDrive model. We add the performance stats callbacks to the trainer to view the throughput performance of WarpDrive.

[8]:
log_freq = run_config["saving"]["metrics_log_freq"]

# Define callbacks.
cuda_callback = CUDACallback(module=wd_module)
perf_stats_callback = PerfStatsCallback(
    batch_size=wd_module.training_batch_size,
    num_iters=wd_module.num_iters,
    log_freq=log_freq,
)

# Instantiate the PytorchLightning trainer with the callbacks.
# Also, set the number of gpus to 1, since this notebook uses just a single GPU.
num_gpus = 1
num_episodes = run_config["trainer"]["num_episodes"]
episode_length = run_config["env"]["episode_length"]
training_batch_size = run_config["trainer"]["train_batch_size"]
num_epochs = num_episodes * episode_length / training_batch_size

trainer = Trainer(
    accelerator="gpu",
    devices=num_gpus,
    callbacks=[cuda_callback, perf_stats_callback],
    max_epochs=num_epochs,
    log_every_n_steps=1,
    reload_dataloaders_every_n_epochs=1,
)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[9]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

Train the WarpDrive Module

Finally, we invoke training.

Note: please scroll up to the tensorboard cell to visualize the curves during training.

[10]:
trainer.fit(wd_module)
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:376: LightningDeprecationWarning: The `Callback.on_batch_start` hook was deprecated in v1.6 and will be removed in v1.8. Please use `Callback.on_train_batch_start` instead.
  rank_zero_deprecation(
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:376: LightningDeprecationWarning: The `Callback.on_batch_end` hook was deprecated in v1.6 and will be removed in v1.8. Please use `Callback.on_train_batch_end` instead.
  rank_zero_deprecation(
Missing logger folder: /__w/1/s/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name | Type | Params
------------------------------
------------------------------
0         Trainable params
0         Non-trainable params
0         Total params
0.000     Total estimated model params size (MB)
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
========================================
Metrics for policy 'runner'
========================================
VF loss coefficient                     :    0.01000
Entropy coefficient                     :    0.05000
Total loss                              :   -1.51269
Policy loss                             :   -1.31748
Value function loss                     :    4.30106
Mean rewards                            :   -0.02525
Max. rewards                            :    1.00000
Min. rewards                            :  -10.00000
Mean value function                     :   -0.86170
Mean advantages                         :   -0.27768
Mean (norm.) advantages                 :   -0.27768
Mean (discounted) returns               :   -1.13938
Mean normalized returns                 :   -1.13938
Mean entropy                            :    4.76451
Variance explained by the value function:    0.11032
Std. of action_0 over agents            :    3.04816
Std. of action_0 over envs              :    3.04446
Std. of action_0 over time              :    3.04757
Std. of action_1 over agents            :    3.23549
Std. of action_1 over envs              :    3.23271
Std. of action_1 over time              :    3.23722
Current timestep                        : 90000.00000
Gradient norm                           :    0.05845
Mean episodic reward                    : -408.38889
[Device 0]: Saving the results to the file '/tmp/continuous_tag/example/1652830363/results.json'
[Device 0]: Saving the 'runner' torch model to the file: '/tmp/continuous_tag/example/1652830363/runner_90000.state_dict'.
[Device 0]: Saving the 'tagger' torch model to the file: '/tmp/continuous_tag/example/1652830363/tagger_80000.state_dict'.
========================================
Metrics for policy 'tagger'
========================================
VF loss coefficient                     :    0.01000
Entropy coefficient                     :    0.05000
Total loss                              :   79.46014
Policy loss                             :   75.07774
Value function loss                     :  460.96414
Mean rewards                            :    0.53500
Max. rewards                            :   20.00000
Min. rewards                            :    0.00000
Mean value function                     :    3.43005
Mean advantages                         :   16.50640
Mean (norm.) advantages                 :   16.50640
Mean (discounted) returns               :   19.93644
Mean normalized returns                 :   19.93644
Mean entropy                            :    4.54485
Variance explained by the value function:   -0.00764
Std. of action_0 over agents            :    3.04688
Std. of action_0 over envs              :    3.19368
Std. of action_0 over time              :    3.19806
Std. of action_1 over agents            :    2.74155
Std. of action_1 over envs              :    2.85016
Std. of action_1 over time              :    2.85594
Current timestep                        : 90000.00000
Gradient norm                           :    1.21257
Mean episodic reward                    :  449.24444
[Device 0]: Saving the results to the file '/tmp/continuous_tag/example/1652830363/results.json'
[Device 0]: Saving the 'runner' torch model to the file: '/tmp/continuous_tag/example/1652830363/runner_90000.state_dict'.
[Device 0]: Saving the 'tagger' torch model to the file: '/tmp/continuous_tag/example/1652830363/tagger_90000.state_dict'.
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:229: UserWarning: You called `self.log('Current timestep_runner', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.
  warning_cache.warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:229: UserWarning: You called `self.log('Current timestep_tagger', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.
  warning_cache.warn(
========================================
Metrics for policy 'runner'
========================================
VF loss coefficient                     :    0.01000
Entropy coefficient                     :    0.05000
Total loss                              :   -1.06076
Policy loss                             :   -0.86573
Value function loss                     :    4.28389
Mean rewards                            :   -0.02681
Max. rewards                            :    1.00000
Min. rewards                            :  -10.00000
Mean value function                     :   -1.03110
Mean advantages                         :   -0.18345
Mean (norm.) advantages                 :   -0.18345
Mean (discounted) returns               :   -1.21455
Mean normalized returns                 :   -1.21455
Mean entropy                            :    4.75726
Variance explained by the value function:    0.13849
Std. of action_0 over agents            :    3.08665
Std. of action_0 over envs              :    3.08295
Std. of action_0 over time              :    3.08616
Std. of action_1 over agents            :    3.21539
Std. of action_1 over envs              :    3.21178
Std. of action_1 over time              :    3.21630
Current timestep                        : 100000.00000
Gradient norm                           :    0.05899
Mean episodic reward                    : -536.14000
[Device 0]: Saving the results to the file '/tmp/continuous_tag/example/1652830363/results.json'
========================================
Metrics for policy 'tagger'
========================================
VF loss coefficient                     :    0.01000
Entropy coefficient                     :    0.05000
Total loss                              :   77.55455
Policy loss                             :   72.94509
Value function loss                     :  482.91556
Mean rewards                            :    0.56020
Max. rewards                            :   20.00000
Min. rewards                            :    0.00000
Mean value function                     :    4.44337
Mean advantages                         :   16.58761
Mean (norm.) advantages                 :   16.58761
Mean (discounted) returns               :   21.03099
Mean normalized returns                 :   21.03099
Mean entropy                            :    4.39390
Variance explained by the value function:   -0.00993
Std. of action_0 over agents            :    2.94368
Std. of action_0 over envs              :    3.11596
Std. of action_0 over time              :    3.12263
Std. of action_1 over agents            :    2.66070
Std. of action_1 over envs              :    2.78366
Std. of action_1 over time              :    2.79009
Current timestep                        : 100000.00000
Gradient norm                           :    1.13135
Mean episodic reward                    :  560.20000
[Device 0]: Saving the results to the file '/tmp/continuous_tag/example/1652830363/results.json'
========================================
Speed performance stats
========================================
Iteration                               : 10 / 10
Mean training time per iter (ms)        :     131.28
Mean steps per sec (training time)      :   76172.00


Training is complete!

Visualize an episode-rollout after training

[11]:
# Uncomment below for enabling animation visualizations.
# anim = generate_tag_env_rollout_animation(wd_module, fps=25)
# HTML(anim.to_html5_video())

Note: In the configuration above, we have set the trainer to only train on 500 rollout episodes, but you can increase the num_episodes configuration parameter to train further. As more training happens, the runners learn to escape the taggers, and the taggers learn to chase after the runner. Sometimes, the taggers also collaborate to team-tag runners. A good number of episodes to train on (for the configuration we have used) is 2M or higher.

[12]:
# Finally, close the WarpDrive module to clear up the CUDA memory heap
wd_module.graceful_close()
[Device 0]: Trainer exits gracefully

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Contributor Covenant Code of Conduct

Our Pledge

In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.

Our Standards

Examples of behavior that contributes to creating a positive environment include:

  • Using welcoming and inclusive language

  • Being respectful of differing viewpoints and experiences

  • Gracefully accepting constructive criticism

  • Focusing on what is best for the community

  • Showing empathy towards other community members

Examples of unacceptable behavior by participants include:

  • The use of sexualized language or imagery and unwelcome sexual attention or advances

  • Trolling, insulting/derogatory comments, and personal or political attacks

  • Public or private harassment

  • Publishing others’ private information, such as a physical or electronic address, without explicit permission

  • Other conduct which could reasonably be considered inappropriate in a professional setting

Our Responsibilities

Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.

Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.

Scope

This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.

Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at community@lightning.ai. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.

Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership.

Attribution

This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html

For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq

Contributing

Welcome to the PyTorch Lightning community! We’re building the most advanced research platform on the planet to implement the latest, best practices and integrations that the amazing PyTorch team and other research organization rolls out!

If you are new to open source, check out this blog to get started with your first Open Source contribution.

Main Core Value: One less thing to remember

Simplify the API as much as possible from the user perspective. Any additions or improvements should minimize the things the user needs to remember.

For example: One benefit of the validation_step is that the user doesn’t have to remember to set the model to .eval(). This helps users avoid all sorts of subtle errors.

Lightning Design Principles

We encourage all sorts of contributions you’re interested in adding! When coding for Lightning, please follow these principles.

No PyTorch Interference

We don’t want to add any abstractions on top of pure PyTorch. This gives researchers all the control they need without having to learn yet another framework.

Simple Internal Code

It’s useful for users to look at the code and understand very quickly what’s happening. Many users won’t be engineers. Thus we need to value clear, simple code over condensed ninja moves. While that’s super cool, this isn’t the project for that :)

Force User Decisions To Best Practices

There are 1,000 ways to do something. However, eventually one popular solution becomes standard practice, and everyone follows. We try to find the best way to solve a particular problem, and then force our users to use it for readability and simplicity. A good example is accumulated gradients. There are many different ways to implement it, we just pick one and force users to use it. A bad forced decision would be to make users use a specific library to do something.

When something becomes a best practice, we add it to the framework. This is usually something like bits of code in utils or in the model file that everyone keeps adding over and over again across projects. When this happens, bring that code inside the trainer and add a flag for it.

Simple External API

What makes sense to you may not make sense to others. When creating an issue with an API change suggestion, please validate that it makes sense for others. Treat code changes the way you treat a startup: validate that it’s a needed feature, then add if it makes sense for many people.

Backward-compatible API

We all hate updating our deep learning packages because we don’t want to refactor a bunch of stuff. In Lightning, we make sure every change we make which could break an API is backward compatible with good deprecation warnings.

You shouldn’t be afraid to upgrade Lightning :)

Gain User Trust

As a researcher, you can’t have any part of your code going wrong. So, make thorough tests to ensure that every implementation of a new trick or subtle change is correct.

Interoperability

Have a favorite feature from other libraries like fast.ai or transformers? Those should just work with lightning as well. Grab your favorite model or learning rate scheduler from your favorite library and run it in Lightning.


Contribution Types

We are always open to contributions of new features or bug fixes.

A lot of good work has already been done in project mechanics (requirements.txt, setup.py, pep8, badges, ci, etc…) so we’re in a good state there thanks to all the early contributors (even pre-beta release)!

Bug Fixes:
  1. If you find a bug please submit a GitHub issue.

    • Make sure the title explains the issue.

    • Describe your setup, what you are trying to do, expected vs. actual behaviour. Please add configs and code samples.

    • Add details on how to reproduce the issue - a minimal test case is always best, colab is also great. Note, that the sample code shall be minimal and if needed with publicly available data.

  2. Try to fix it or recommend a solution. We highly recommend to use test-driven approach:

    • Convert your minimal code example to a unit/integration test with assert on expected results.

    • Start by debugging the issue… You can run just this particular test in your IDE and draft a fix.

    • Verify that your test case fails on the master branch and only passes with the fix applied.

  3. Submit a PR!

Note, even if you do not find the solution, sending a PR with a test covering the issue is a valid contribution, and we can help you or finish it with you :]

New Features:
  1. Submit a GitHub issue - describe what is the motivation of such feature (adding the use case, or an example is helpful).

  2. Determine the feature scope with us.

  3. Submit a PR! We recommend test driven approach to adding new features as well:

    • Write a test for the functionality you want to add.

    • Write the functional code until the test passes.

  4. Add/update the relevant tests!

Test cases:

Want to keep Lightning healthy? Love seeing those green tests? So do we! How to we keep it that way? We write tests! We value tests contribution even more than new features.

Most of the tests in PyTorch Lightning train a random BoringModel under various trainer conditions (ddp, amp, etc…). Want to add a new test case and not sure how? Talk to us!


Guidelines

Developments scripts

To build the documentation locally, simply execute the following commands from project root (only for Unix):

  • make clean cleans repo from temp/generated files

  • make docs builds documentation under docs/build/html

  • make test runs all project’s tests with coverage

Original code

All added or edited code shall be the own original work of the particular contributor. If you use some third-party implementation, all such blocks/functions/modules shall be properly referred and if possible also agreed by code’s author. For example - This code is inspired from http://.... In case you adding new dependencies, make sure that they are compatible with the actual PyTorch Lightning license (ie. dependencies should be at least as permissive as the PyTorch Lightning license).

Coding Style
  1. Use f-strings for output formation (except logging when we stay with lazy logging.info("Hello %s!", name).

  2. You can use pre-commit to make sure your code style is correct.

Documentation

To learn about development of docs, check out the docs README.md.

Testing

To learn about tests, check out the tests README.md.

Pull Request

We welcome any useful contribution! For your convenience here’s a recommended workflow:

  1. Think about what you want to do - fix a bug, repair docs, etc. If you want to implement a new feature or enhance an existing one.

    • Start by opening a GitHub issue to explain the feature and the motivation. In the case of features, ask yourself first - Is this NECESSARY for Lightning? There are some PRs that are just purely about adding engineering complexity which has no place in Lightning.

    • Core contributors will take a look (it might take some time - we are often overloaded with issues!) and discuss it.

    • Once an agreement was reached - start coding.

  2. Start your work locally.

    • Create a branch and prepare your changes.

    • Tip: do not work on your master branch directly, it may become complicated when you need to rebase.

    • Tip: give your PR a good name! It will be useful later when you may work on multiple tasks/PRs.

  3. Test your code!

    • It is always good practice to start coding by creating a test case, verifying it breaks with current behaviour, and passes with your new changes.

    • Make sure your new tests cover all different edge cases.

    • Make sure all exceptions raised are tested.

    • Make sure all warnings raised are tested.

  4. If your PR is not ready for reviews, but you want to run it on our CI, open a “Draft PR” to let us know you don’t need feedback yet.

  5. If any of the existing tests fail in your PR on our CI, refer to the following READMEs to identify what’s failing and try to address it.

  6. When you feel ready for integrating your work, mark your PR “Ready for review”.

    • Your code should be readable and follow the project’s design principles.

    • Make sure all tests are passing and any new code is tested for (coverage!).

    • Make sure you link the GitHub issue to your PR.

    • Make sure any docs for that piece of code are updated, or added.

    • The code should be elegant and simple. No over-engineering or hard-to-read code.

    Do your best but don’t sweat about perfection! We do code-review to find any missed items. If you need help, don’t hesitate to ping the core team on the PR.

  7. Use tags in PR name for the following cases:

    • [blocked by #] if your work is dependent on other PRs.

    • [wip] when you start to re-edit your work, mark it so no one will accidentally merge it in meantime.

Question & Answer
How can I help/contribute?

All types of contributions are welcome - reporting bugs, fixing documentation, adding test cases, solving issues, and preparing bug fixes. To get started with code contributions, look for issues marked with the label good first issue or chose something close to your domain with the label help wanted. Before coding, make sure that the issue description is clear and comment on the issue so that we can assign it to you (or simply self-assign if you can).

Is there a recommendation for branch names?

We recommend you follow this convention <type>/<issue-id>_<short-name> where the types are: bugfix, feature, docs, or tests (but if you are using your own fork that’s optional).

How to rebase my PR?

We recommend creating a PR in a separate branch other than master, especially if you plan to submit several changes and do not want to wait until the first one is resolved (we can work on them in parallel).

First, make sure you have set upstream by running:

git remote add upstream https://github.com/Lightning-AI/lightning.git

You’ll know its set up right if you run git remote -v and see something similar to this:

origin  https://github.com/{YOUR_USERNAME}/pytorch-lightning.git (fetch)
origin  https://github.com/{YOUR_USERNAME}/pytorch-lightning.git (push)
upstream        https://github.com/Lightning-AI/lightning.git (fetch)
upstream        https://github.com/Lightning-AI/lightning.git (push)

Checkout your feature branch and rebase it with upstream’s master before pushing up your feature branch:

git fetch --all --prune
git rebase upstream/master
# follow git instructions to resolve conflicts
git push -f
How to add new tests?

We are using pytest in PyTorch Lightning.

Here are tutorials:

Here is the process to create a new test

    1. Optional: Follow tutorials !

    1. Find a file in tests/ which match what you want to test. If none, create one.

    1. Use this template to get started !

    1. Use BoringModel and derivates to test out your code.

# TEST SHOULD BE IN YOUR FILE: tests/.../test_file.py
# TEST CODE TEMPLATE


# [OPTIONAL] pytest decorator
# @RunIf(min_cuda_gpus=1)
def test_explain_what_is_being_tested(tmpdir):
    """
    Test description about text reason to be
    """

    class ExtendedModel(BoringModel):
        ...

    model = ExtendedModel()

    # BoringModel is a functional model. You might want to set methods to None to test your behaviour
    # Example: model.training_step_end = None

    trainer = Trainer(default_root_dir=tmpdir, ...)  # will save everything within a tmpdir generated for this test
    trainer.fit(model)
    trainer.test()  # [OPTIONAL]

    # assert the behaviour is correct.
    assert ...

run our/your test with

python -m pytest tests/.../test_file.py::test_explain_what_is_being_tested -v --capture=no
How to fix PR with mixed base and target branches?

Sometimes you start your PR as a bug-fix but it turns out to be more of a feature (or the other way around). Do not panic, the solution is very straightforward and quite simple. All you need to do are these two steps in arbitrary order:

  • Ask someone from Core to change the base/target branch to the correct one

  • Rebase or cherry-pick your commits onto the correct base branch…

Let’s show how to deal with the git… the sample case is moving a PR from master to release/1.2-dev assuming my branch name is my-branch and the last true master commit is ccc111 and your first commit is mmm222.

  • Cherry-picking way

    git checkout my-branch
    # create a local backup of your branch
    git branch my-branch-backup
    # reset your branch to the correct base
    git reset release/1.2-dev --hard
    # ACTION: this step is much easier to do with IDE
    #  so open one and cherry-pick your last commits from `my-branch-backup`
    #  resolve all eventual conflict as the new base may contain different code
    # when all done, push back to the open PR
    git push -f
    
  • Rebasing way, see more about rebase onto usage

    git checkout my-branch
    # rebase your commits on the correct branch
    git rebase --onto release/1.2-dev ccc111
    # if there is no collision you shall see just success
    #  eventually you would need to resolve collision and in such case follow the instruction in terminal
    # when all done, push back to the open PR
    git push -f
    
How to run an app on the cloud with a local version of lightning

The lightning cloud uses the latest release by default. However, you might want to run your app with some local changes you’ve made to the lightning framework. To use your local version of lightning on the cloud, set the following environment variable:

git clone https://github.com/Lightning-AI/lightning.git
cd lightning
pip install -e .
export PACKAGE_LIGHTNING=1  # <- this is the magic to use your version (not mainstream PyPI)!
lightning run app app.py --cloud

By seting PACKAGE_LIGHTNING=1, lightning packages the lightning source code in your local directory in addition to your app source code and uploads them to the cloud.

Bonus Workflow Tip

If you don’t want to remember all the commands above every time you want to push some code/setup a Lightning Dev environment on a new VM, you can set up bash aliases for some common commands. You can add these to one of your ~/.bashrc, ~/.zshrc, or ~/.bash_aliases files.

NOTE: Once you edit one of these files, remember to source it or restart your shell. (ex. source ~/.bashrc if you added these to your ~/.bashrc file).

plclone (){
    git clone https://github.com/{YOUR_USERNAME}/pytorch-lightning.git
    cd pytorch-lightning
    git remote add upstream https://github.com/Lightning-AI/lightning.git
    # This is just here to print out info about your remote upstream/origin
    git remote -v
}

plfetch (){
    git fetch --all --prune
    git checkout master
    git merge upstream/master
}

# Rebase your branch with upstream's master
# plrebase <your-branch-name>
plrebase (){
    git checkout $@
    git rebase master
}

Now, you can:

  • clone your fork and set up upstream by running plclone from your terminal

  • fetch upstream and update your local master branch with it by running plfetch

  • rebase your feature branch (after running plfetch) by running plrebase your-branch-name

How to Become a Core Contributor

Thanks for your interest in joining the Lightning team! We’re a rapidly growing project which is poised to become the go-to framework for DL researchers!

As a core maintainer, you will have a strong say in the direction of the project. Big changes will require a majority of maintainers to agree. Our development is fully open, so you can still raise your voice just by commenting on issues and pull requests! Doing so is a big step in becoming part of core.

Code of Conduct

First and foremost, you’ll be evaluated against these core values. Any code we commit or feature we add needs to align with those core values. Any collaboration and communication must adhere to our code of conduct.

The Bar for Joining the Team

Lightning is being used to solve really hard problems at the top AI labs in the world. As such, the bar for adding team members is extremely high. Candidates must have solid engineering skills, have a good eye for user experience, and must be a power user of Lightning and PyTorch.

With that said, the Lightning team will be diverse and a reflection of an inclusive AI community. You don’t have to be an engineer to contribute! Scientists with great usability intuition and PyTorch ninja skills are welcomed!

Responsibilities:

Here, we describe general expectations from core contributors:

Github Issues
  • Our community is the main motivation for our work. Help them have an amazing experience. Issues range from answering questions from new people getting into deep learning to helping researchers doing something esoteric. They often require some sort of bug fix, document clarification, or new functionality to be scoped out. You can help them solve their issues and guide them to completion.

  • Weigh in on discussions in a timely fashion. Most importantly, on the RFCs (request for comments) that will shape the future of Lightning. There are some big decisions which the project must make. For these, we expect core contributors to have something meaningful to add, especially if it’s their area of expertise.

  • Propose your own RFCs that align with the API design goals for Lightning.

  • Identify opportunities from an issue or bug that can solve other related issues or make the framework more flexible.

  • Don’t make users feel like they don’t know what they’re doing. We’re here to help and to make everyone’s experience delightful.

  • Help out with critical bugs. Nobody likes bugs so you’ll be a hero if you fix them!

Pull Requests (PRs)
  • Pull requests are the evolutionary mechanism of Lightning, so quality is extremely important. Make sure contributors adhere to the guidelines described in the contributing section.

  • Some PRs are from people who want to get involved and try to add something unnecessary. We do want their help though! So don’t approve the PR, but direct them to a Github issue that they might be interested in helping with instead!

  • Provide strong and valuable feedback during reviews. This is expected both when reviewing community PRs as well as PRs from other core contributors. Even if you are not part of core yet, you can still review and approve PRs. This will show us your abilities.

Diversity

Lightning should reflect the broader community it serves. As such we should have scientists/researchers from different fields contributing!

Community

We have an active Slack community, where questions are asked daily. This is a great way to show off your Lightning and PyTorch knowledge, and help out others. There’s also GitHub discussions.

Applying

There are no precise targets for becoming a core contributor. In the past, community members have become core after fitting the previous expectations consistently. We are on the lookout for new people to join, however, if you feel like you meet the expectations already and we haven’t reached out to you yet, feel free to ping us privately on Slack!.

Employment

You can also become a Lightning AI employee or intern and work on Lightning. To get started, you can email careers@lightning.ai with your resume or check out our open job postings.

Lightning Governance

This document describes governance processes we follow in developing PyTorch Lightning.

Persons of Interest

BDFL

Role: All final decisions related to Lightning.

Maintainers
Emeritus Maintainers
Alumni

Project Management and Decision Making

The decision what goes into a release is governed by the maintainers of lightning.pytorch. Whenever possible, discussion happens publicly on GitHub and includes the whole community. For controversial changes, it is mandatory to seek consultation from BDFL for a final decision. When a consensus is reached, maintainers assign milestones and labels to the issue and/or pull request and start tracking the development. It is possible that priorities change over time.

Commits to the project are exclusively to be added by pull requests on GitHub and anyone in the community is welcome to review them. However, reviews submitted by code owners have higher weight and it is necessary to get the approval of code owners before a pull request can be merged. Additional requirements may apply case by case.

Versioning Policy

PyTorch Lightning follows its own versioning policy but not semantic versioning (SemVer).

Versioning

A Lightning release number is in the format of MAJOR.MINOR.PATCH.

  • A patch release contains only bug fixes. Since it introduces no breaking changes, we recommend users always update the package to the latest version within the minor version whenever possible.

  • A minor release, unlike SemVer, contains backwards-incompatible changes, such as API changes and removals, as well as new features, deprecations and all bugfixes since the last release.

With every release, we publish a changelog where we list additions, removals, deprecations, changed functionality and fixes.

API Stability

In Lightning, all API and features are marked as either stable or experimental.

Experimental API

Experimental APIs are labelled as experimental or beta in the documentation and/or in the release note and are considered unstable and should not be used in production.

For experimental features, any of the following may be true:

  • The feature has unstable dependencies.

  • The API may change without notice in future versions.

  • The performance of the feature has not been verified.

  • The docs for this feature are under active development.

Stable API

Everything not specifically labelled as experimental is stable.

For stable APIs, all of the following are true:

  • The API is not expected to change.

  • If anything does change, we show a deprecation warning before applying the breaking change following the rule described below.

API Evolution

Lightning’s development is driven by research and best practices in a rapidly developing field of AI and machine learning. Change is inevitable and when it happens, the Lightning team is committed to minimizing user friction and maximizing ease of transition from one version to the next. We take backwards compatibility and reproducibility very seriously.

For API removal, renaming or other forms of backwards-incompatible changes, the procedure is:

  1. A deprecation process is initiated at a minor version X, producing a deprecation warning at runtime and in the documentation.

  2. The deprecated API remains unchanged during the deprecation phase for two minor versions.

  3. The breaking change takes effect at a minor version X+2.

  4. From version X+2 onward, the deprecation warning gets converted into a helpful error, which will remain as long as possible.

The X+2 rule is a recommendation and not a strict requirement. Shorter or longer deprecation cycles may apply to some cases. In the past, DDP2 was removed without a deprecation process because the feature was broken and unusable beyond fixing as discussed in #12584. Also, #10410 is an example that a longer deprecation applied to. We deprecated the accelerator arguments, such as Trainer(gpus=...), in 1.7, however, because the APIs were so core that they would impact almost all use cases, we decided not to introduce the breaking change until 2.0.

Python Support

PyTorch Lightning follows NEP 29 which PyTorch also follows (#74203).

PyTorch Support

PyTorch Lightning supports the latest four minor versions of PyTorch at the time of release. For example, PyTorch Lightning 1.8 supports PyTorch 1.10, 1.11, 1.12 and 1.13.

Changelog

All notable changes to this project will be documented in this file.

The format is based on Keep a Changelog.

[1.9.1] - 2023-02-10

[1.9.1] - Fixed
  • Fixed an unintended limitation for calling save_hyperparameters on mixin classes that don’t subclass LightningModule/LightningDataModule (#16369)

  • Fixed an issue with MLFlowLogger logging the wrong keys with .log_hyperparams() (#16418)

  • Fixed logging more than 100 parameters with MLFlowLogger and long values are truncated (#16451)

  • Fixed strict availability check for torch_xla requirement (#16476)

  • Fixed an issue where PL would wrap DataLoaders with XLA’s MpDeviceLoader more than once (#16571)

  • Fixed the batch_sampler reference for DataLoaders wrapped with XLA’s MpDeviceLoader (#16571)

  • Fixed an import error when torch.distributed is not available (#16658)

[1.9.0] - 2023-01-17

[1.9.0] - Added
  • Added support for native logging of MetricCollection with enabled compute groups (#15580)

  • Added support for custom artifact names in pl.loggers.WandbLogger (#16173)

  • Added support for DDP with LRFinder (#15304)

  • Added utilities to migrate checkpoints from one Lightning version to another (#15237)

  • Added support to upgrade all checkpoints in a folder using the pl.utilities.upgrade_checkpoint script (#15333)

  • Add an axes argument ax to the .lr_find().plot() to enable writing to a user-defined axes in a matplotlib figure (#15652)

  • Added log_model parameter to MLFlowLogger (#9187)

  • Added a check to validate that wrapped FSDP models are used while initializing optimizers (#15301)

  • Added a warning when self.log(..., logger=True) is called without a configured logger (#15814)

  • Added support for colossalai 0.1.11 (#15888)

  • Added LightningCLI support for optimizer and learning schedulers via callable type dependency injection (#15869)

  • Added support for activation checkpointing for the DDPFullyShardedNativeStrategy strategy (#15826)

  • Added the option to set DDPFullyShardedNativeStrategy(cpu_offload=True|False) via bool instead of needing to pass a configuration object (#15832)

  • Added info message for Ampere CUDA GPU users to enable tf32 matmul precision (#16037)

  • Added support for returning optimizer-like classes in LightningModule.configure_optimizers (#16189)

[1.9.0] - Changed
  • Drop PyTorch 1.9 support (#15347)

  • Switch from tensorboard to tensorboardx in TensorBoardLogger (#15728)

  • From now on, Lightning Trainer and LightningModule.load_from_checkpoint automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning (#15237)

  • Trainer.{validate,test,predict}(ckpt_path=...) no longer restores the Trainer.global_step and trainer.current_epoch value from the checkpoints - From now on, only Trainer.fit will restore this value (#15532)

  • The ModelCheckpoint.save_on_train_epoch_end attribute is now computed dynamically every epoch, accounting for changes to the validation dataloaders (#15300)

  • The Trainer now raises an error if it is given multiple stateful callbacks of the same time with colliding state keys (#15634)

  • MLFlowLogger now logs hyperparameters and metrics in batched API calls (#15915)

  • Overriding the on_train_batch_{start,end} hooks in conjunction with taking a dataloader_iter in the training_step no longer errors out and instead shows a warning (#16062)

  • Move tensorboardX to extra dependencies. Use the CSVLogger by default (#16349)

[1.9.0] - Deprecated
  • Deprecated description, env_prefix and env_parse parameters in LightningCLI.__init__ in favour of giving them through parser_kwargs (#15651)

  • Deprecated pytorch_lightning.profiler in favor of pytorch_lightning.profilers (#16059)

  • Deprecated Trainer(auto_select_gpus=...) in favor of pytorch_lightning.accelerators.find_usable_cuda_devices (#16147)

  • Deprecated pytorch_lightning.tuner.auto_gpu_select.{pick_single_gpu,pick_multiple_gpus} in favor of pytorch_lightning.accelerators.find_usable_cuda_devices (#16147)

  • nvidia/apex deprecation (#16039)

    • Deprecated pytorch_lightning.plugins.NativeMixedPrecisionPlugin in favor of pytorch_lightning.plugins.MixedPrecisionPlugin

    • Deprecated the LightningModule.optimizer_step(using_native_amp=...) argument

    • Deprecated the Trainer(amp_backend=...) argument

    • Deprecated the Trainer.amp_backend property

    • Deprecated the Trainer(amp_level=...) argument

    • Deprecated the pytorch_lightning.plugins.ApexMixedPrecisionPlugin class

    • Deprecates the pytorch_lightning.utilities.enums.AMPType enum

    • Deprecates the DeepSpeedPrecisionPlugin(amp_type=..., amp_level=...) arguments

  • horovod deprecation (#16141)

    • Deprecated Trainer(strategy="horovod")

    • Deprecated the HorovodStrategy class

  • Deprecated pytorch_lightning.lite.LightningLite in favor of lightning.fabric.Fabric (#16314)

  • FairScale deprecation (in favor of PyTorch’s FSDP implementation) (#16353)

    • Deprecated the pytorch_lightning.overrides.fairscale.LightningShardedDataParallel class

    • Deprecated the pytorch_lightning.plugins.precision.fully_sharded_native_amp.FullyShardedNativeMixedPrecisionPlugin class

    • Deprecated the pytorch_lightning.plugins.precision.sharded_native_amp.ShardedNativeMixedPrecisionPlugin class

    • Deprecated the pytorch_lightning.strategies.fully_sharded.DDPFullyShardedStrategy class

    • Deprecated the pytorch_lightning.strategies.sharded.DDPShardedStrategy class

    • Deprecated the pytorch_lightning.strategies.sharded_spawn.DDPSpawnShardedStrategy class

[1.9.0] - Removed
  • Removed deprecated pytorch_lightning.utilities.memory.get_gpu_memory_map in favor of pytorch_lightning.accelerators.cuda.get_nvidia_gpu_stats (#15617)

  • Temporarily removed support for Hydra multi-run (#15737)

  • Removed deprecated pytorch_lightning.profiler.base.AbstractProfiler in favor of pytorch_lightning.profilers.profiler.Profiler (#15637)

  • Removed deprecated pytorch_lightning.profiler.base.BaseProfiler in favor of pytorch_lightning.profilers.profiler.Profiler (#15637)

  • Removed deprecated code in pytorch_lightning.utilities.meta (#16038)

  • Removed the deprecated LightningDeepSpeedModule (#16041)

  • Removed the deprecated pytorch_lightning.accelerators.GPUAccelerator in favor of pytorch_lightning.accelerators.CUDAAccelerator (#16050)

  • Removed the deprecated pytorch_lightning.profiler.* classes in favor of pytorch_lightning.profilers (#16059)

  • Removed the deprecated pytorch_lightning.utilities.cli module in favor of pytorch_lightning.cli (#16116)

  • Removed the deprecated pytorch_lightning.loggers.base module in favor of pytorch_lightning.loggers.logger (#16120)

  • Removed the deprecated pytorch_lightning.loops.base module in favor of pytorch_lightning.loops.loop (#16142)

  • Removed the deprecated pytorch_lightning.core.lightning module in favor of pytorch_lightning.core.module (#16318)

  • Removed the deprecated pytorch_lightning.callbacks.base module in favor of pytorch_lightning.callbacks.callback (#16319)

  • Removed the deprecated Trainer.reset_train_val_dataloaders() in favor of Trainer.reset_{train,val}_dataloader (#16131)

  • Removed support for LightningCLI(seed_everything_default=None) (#16131)

  • Removed support in LightningLite for FairScale’s sharded training (strategy='ddp_sharded'|'ddp_sharded_spawn'). Use Fully-Sharded Data Parallel instead (strategy='fsdp') (#16329)

[1.9.0] - Fixed
  • Enhanced reduce_boolean_decision to accommodate any-analogous semantics expected by the EarlyStopping callback (#15253)

  • Fixed the incorrect optimizer step synchronization when running across multiple TPU devices (#16020)

  • Fixed a type error when dividing the chunk size in the ColossalAI strategy (#16212)

  • Fixed bug where the interval key of the scheduler would be ignored during manual optimization, making the LearningRateMonitor callback fail to log the learning rate (#16308)

  • Fixed an issue with MLFlowLogger not finalizing correctly when status code ‘finished’ was passed (#16340)

[1.8.6] - 2022-12-21

  • minor cleaning

[1.8.5] - 2022-12-15

  • Add function to remove checkpoint to allow override for extended classes (#16067)

[1.8.4] - 2022-12-08

[1.8.4] - Changed
[1.8.4] - Fixed
  • Fixed issue with unsupported torch.inference_mode() on hpu backends (#15918)

  • Fixed LRScheduler import for PyTorch 2.0 (#15940)

  • Fixed fit_loop.restarting to be False for lr finder (#15620)

  • Fixed torch.jit.script-ing a LightningModule causing an unintended error message about deprecated use_amp property (#15947)

  • Fixed the XLAProfiler not recording anything due to mismatching of action names (#15885)

[1.8.3] - 2022-11-22

[1.8.3] - Changed
  • Temporarily removed support for Hydra multi-run (#15737)

  • Switch from tensorboard to tensorboardx in TensorBoardLogger (#15728)

[1.8.2] - 2022-11-17

[1.8.2] - Fixed
  • Make sure save_dir can be empty str (#15638)

  • Fixed the automatic fallback from Trainer(strategy="ddp_spawn", ...) to Trainer(strategy="ddp", ...) when on an LSF cluster (#15103)

[1.8.1] - 2022-11-10

[1.8.1] - Added
  • Added back the accidentally removed pytorch_lightning.utilities.distributed.rank_zero_only function (#15536)

[1.8.1] - Deprecated
  • Deprecated pytorch_lightning.utilities.distributed.rank_zero_only in favor of pytorch_lightning.utilities.rank_zero_only (#15536)

[1.8.1] - Fixed
  • Fixed TensorBoardLogger not validating the input array type when logging the model graph (#15323)

  • Fixed an attribute error in ColossalAIStrategy at import time when torch.distributed is not available (#15535)

  • Fixed an issue when calling fs.listdir with file URI instead of path in CheckpointConnector (#15413)

  • Fixed an issue with the BaseFinetuning callback not setting the track_running_stats attribute for batch normaliztion layers (#15063)

  • Fixed an issue with WandbLogger(log_model=True|'all) raising an error and not being able to serialize tensors in the metadata (#15544)

  • Fixed the gradient unscaling logic when using Trainer(precision=16) and fused optimizers such as Adam(..., fused=True) (#15544)

  • Fixed model state transfer in multiprocessing launcher when running multi-node (#15567)

  • Fixed manual optimization raising AttributeError with Bagua Strategy (#12534)

  • Fixed the import of pytorch_lightning causing a warning ‘Redirects are currently not supported in Windows or MacOs’ (#15610)

[1.8.0] - 2022-11-01

[1.8.0] - Added
  • Added support for requeueing slurm array jobs (#15040)

  • Added native AMP support for ddp_fork (and associated alias strategies) with CUDA GPUs (#14983)

  • Added BatchSizeFinder callback (#11089)

  • Added LearningRateFinder callback (#13802)

  • Tuner now supports a new method argument which will determine when to run the BatchSizeFinder: one of fit, validate, test or predict (#11089)

  • Added prefix to log message in seed_everything with rank info (#14031)

  • Added support for auto wrapping for DDPFullyShardedNativeStrategy (#14252)

  • Added support for passing extra init-parameters to the LightningDataModule.from_datasets (#14185)

  • Added support for saving sharded optimizer state dict outside of DDPShardedStrategy (#14208)

  • Added support for auto wrapping for DDPFullyShardedStrategy (#14383)

  • Integrate the lightning_utilities package ( #14475, #14537, #14556, #14558, #14575, #14620)

  • Added args parameter to LightningCLI to ease running from within Python (#14596)

  • Added WandbLogger.download_artifact and WandbLogger.use_artifact for managing artifacts with Weights and Biases (#14551)

  • Added an option to configure the signal SLURM sends when a job is preempted or requeued (#14626)

  • Added a warning when the model passed to LightningLite.setup() does not have all parameters on the same device (#14822)

  • The CometLogger now flags the Comet Experiments as being created from Lightning for analytics purposes (#14906)

  • Introduce ckpt_path="hpc" keyword for checkpoint loading (#14911)

  • Added a more descriptive error message when attempting to fork processes with pre-initialized CUDA context (#14709)

  • Added support for custom parameters in subclasses of SaveConfigCallback (#14998)

  • Added inference_mode flag to Trainer to let users enable/disable inference mode during evaluation (#15034)

  • Added LightningLite.no_backward_sync for control over efficient gradient accumulation with distributed strategies (#14966)

  • Added a sanity check that scripts are executed with the srun command in SLURM and that environment variables are not conflicting (#15011)

  • Added an error message when attempting to launch processes with python -i and an interactive-incompatible strategy (#15293)

[1.8.0] - Changed
  • The Trainer.{fit,validate,test,predict,tune} methods now raise a useful error message if the input is not a LightningModule (#13892)

  • Raised a MisconfigurationException if batch transfer hooks are overriden with IPUAccelerator (#13961)

  • Replaced the unwrapping logic in strategies with direct access to unwrapped LightningModule (#13738)

  • Enabled on_before_batch_transfer for DPStrategy and IPUAccelerator (#14023)

  • When resuming training with Apex enabled, the Trainer will now raise an error (#14341)

  • Included torch.cuda rng state to the aggregate _collect_rng_states() and _set_rng_states() (#14384)

  • Changed trainer.should_stop to not stop in between an epoch and run until min_steps/min_epochs only (#13890)

  • The pyDeprecate dependency is no longer installed (#14472)

  • When using multiple loggers, by default checkpoints and profiler output now get saved to the log dir of the first logger in the list (#14325)

  • In Lightning Lite, state-dict access to the module wrapper now gets passed through to the original module reference (#14629)

  • Removed fall-back to LightningEnvironment when number of SLURM tasks does not correspond to number of processes in Trainer (#14300)

  • Aligned DDP and DDPSpawn strategies in setting up the environment (#11073)

  • Integrated the Lite Precision plugins into the PL Precision plugins - the base class in PL now extends the lightning_lite.precision.Precision base class (#14798)

    • The PrecisionPlugin.backward signature changed: The closure_loss argument was renamed to tensor

    • The PrecisionPlugin.{pre_,post_}backward signature changed: The closure_loss argument was renamed to tensor and moved as the first argument

    • The PrecisionPlugin.optimizer_step signature changed: The model, optimizer_idx and closure arguments need to be passed as keyword arguments now

  • Trainer queries the CUDA devices through NVML if available to avoid initializing CUDA before forking, which eliminates the need for the PL_DISABLE_FORK environment variable introduced in v1.7.4 (#14631)

  • The MLFlowLogger.finalize() now sets the status to FAILED when an exception occurred in Trainer, and sets the status to FINISHED on successful completion (#12292)

  • It is no longer needed to call model.double() when using precision=64 in Lightning Lite (#14827)

  • HPC checkpoints are now loaded automatically only in slurm environment when no specific value for ckpt_path has been set (#14911)

  • The Callback.on_load_checkpoint now gets the full checkpoint dictionary and the callback_state argument was renamed checkpoint (#14835)

  • Moved the warning about saving nn.Module in save_hyperparameters() to before the deepcopy (#15132)

  • To avoid issues with forking processes, from PyTorch 1.13 and higher, Lightning will directly use the PyTorch NVML-based check for torch.cuda.device_count and from PyTorch 2.0 and higher, Lightning will configure PyTorch to use a NVML-based check for torch.cuda.is_available. (#15110, #15133)

  • The NeptuneLogger now uses neptune.init_run instead of the deprecated neptune.init to initialize a run (#15393)

[1.8.0] - Deprecated
  • Deprecated LightningDeepSpeedModule (#14000)

  • Deprecated amp_level from Trainer in favour of passing it explictly via precision plugin (#13898)

  • Deprecated the calls to pytorch_lightning.utiltiies.meta functions in favor of built-in https://github.com/pytorch/torchdistx support (#13868)

  • Deprecated the unwrap_lightning_module and unwrap_lightning_module_sharded utility functions in favor of accessing the unwrapped LightningModule on the strategy directly (#13738)

  • Deprecated the pl_module argument in LightningParallelModule, LightningDistributedModule, LightningShardedDataParallel, LightningBaguaModule and LightningDeepSpeedModule wrapper classes (#13738)

  • Deprecated the on_colab_kaggle function (#14247)

  • Deprecated the internal pl.core.mixins.DeviceDtypeModuleMixin class (#14511, #14548)

  • Deprecated all functions in pytorch_lightning.utilities.xla_device (#14514, #14550)

    • Deprecated the internal inner_f function

    • Deprecated the internal pl_multi_process function

    • Deprecated the internal XLADeviceUtils.xla_available staticmethod

    • Deprecated the XLADeviceUtils.tpu_device_exists staticmethod in favor of pytorch_lightning.accelerators.TPUAccelerator.is_available()

  • Deprecated pytorch_lightning.utilities.distributed.tpu_distributed in favor of lightning_lite.accelerators.tpu.tpu_distributed (#14550)

  • Deprecated all functions in pytorch_lightning.utilities.cloud_io in favor of lightning_lite.utilities.cloud_io (#14515)

  • Deprecated the functions in pytorch_lightning.utilities.apply_func in favor of lightning_utilities.core.apply_func (#14516, #14537)

  • Deprecated all functions in pytorch_lightning.utilities.device_parser (#14492, #14753)

    • Deprecated the pytorch_lightning.utilities.device_parser.determine_root_gpu_device in favor of lightning_lite.utilities.device_parser.determine_root_gpu_device

    • Deprecated the pytorch_lightning.utilities.device_parser.parse_gpu_ids in favor of lightning_lite.utilities.device_parser.parse_gpu_ids

    • Deprecated the pytorch_lightning.utilities.device_parser.is_cuda_available in favor of lightning_lite.accelerators.cuda.is_cuda_available

    • Deprecated the pytorch_lightning.utilities.device_parser.num_cuda_devices in favor of lightning_lite.accelerators.cuda.num_cuda_devices

    • Deprecated the pytorch_lightning.utilities.device_parser.parse_cpu_cores in favor of lightning_lite.accelerators.cpu.parse_cpu_cores

    • Deprecated the pytorch_lightning.utilities.device_parser.parse_tpu_cores in favor of lightning_lite.accelerators.tpu.parse_tpu_cores

    • Deprecated the pytorch_lightning.utilities.device_parser.parse_hpus in favor of pytorch_lightning.accelerators.hpu.parse_hpus

  • Deprecated duplicate SaveConfigCallback parameters in LightningCLI.__init__: save_config_kwargs, save_config_overwrite and save_config_multifile. New save_config_kwargs parameter should be used instead (#14998)

  • Deprecated TrainerFn.TUNING, RunningStage.TUNING and trainer.tuning property (#15100)

  • Deprecated custom pl.utilities.distributed.AllGatherGrad implementation in favor of PyTorch’s (#15364)

[1.8.0] - Removed
  • Removed the deprecated Trainer.training_type_plugin property in favor of Trainer.strategy (#14011)

  • Removed all deprecated training type plugins (#14011)

  • Removed the deprecated DDP2Strategy (#14026)

  • Removed the deprecated DistributedType and DeviceType enum classes (#14045)

  • Removed deprecated support for passing the rank_zero_warn warning category positionally (#14470)

  • Removed the legacy and unused Trainer.get_deprecated_arg_names() (#14415)

  • Removed the deprecated on_train_batch_end(outputs) format when multiple optimizers are used and TBPTT is enabled (#14373)

  • Removed the deprecated training_epoch_end(outputs) format when multiple optimizers are used and TBPTT is enabled (#14373)

  • Removed the experimental pytorch_lightning.utiltiies.meta functions in favor of built-in https://github.com/pytorch/torchdistx support (#13868)

  • Removed the deprecated LoggerCollection; Trainer.logger and LightningModule.logger now returns the first logger when more than one gets passed to the Trainer (#14283)

  • Removed the deprecated the trainer.lr_schedulers (#14408)

  • Removed the deprecated LightningModule.{on_hpc_load,on_hpc_save} hooks in favor of the general purpose hooks LightningModule.{on_load_checkpoint,on_save_checkpoint} (#14315)

  • Removed deprecated support for old torchtext versions (#14375)

  • Removed deprecated support for the old neptune-client API in the NeptuneLogger (#14727)

  • Removed the deprecated weights_save_path Trainer argumnent and Trainer.weights_save_path property (#14424)

  • Removed the deprecated (#14471)

    • pytorch_lightning.utilities.distributed.rank_zero_only in favor of pytorch_lightning.utilities.rank_zero.rank_zero_only

    • pytorch_lightning.utilities.distributed.rank_zero_debug in favor of pytorch_lightning.utilities.rank_zero.rank_zero_debug

    • pytorch_lightning.utilities.distributed.rank_zero_info in favor of pytorch_lightning.utilities.rank_zero.rank_zero_info

    • pytorch_lightning.utilities.warnings.rank_zero_warn in favor of pytorch_lightning.utilities.rank_zero.rank_zero_warn

    • pytorch_lightning.utilities.warnings.rank_zero_deprecation in favor of pytorch_lightning.utilities.rank_zero.rank_zero_deprecation

    • pytorch_lightning.utilities.warnings.LightningDeprecationWarning in favor of pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning

  • Removed deprecated Trainer.num_processes attribute in favour of Trainer.num_devices (#14423)

  • Removed the deprecated Trainer.data_parallel_device_ids hook in favour of Trainer.device_ids (#14422)

  • Removed the deprecated class TrainerCallbackHookMixin (#14401)

  • Removed the deprecated BaseProfiler and AbstractProfiler classes (#14404)

  • Removed the deprecated way to set the distributed backend via the environment variable PL_TORCH_DISTRIBUTED_BACKEND, in favor of setting the process_group_backend in the strategy constructor (#14693)

  • Removed deprecated callback hooks (#14834)

    • Callback.on_configure_sharded_model in favor of Callback.setup

    • Callback.on_before_accelerator_backend_setup in favor of Callback.setup

    • Callback.on_batch_start in favor of Callback.on_train_batch_start

    • Callback.on_batch_end in favor of Callback.on_train_batch_end

    • Callback.on_epoch_start in favor of Callback.on_{train,validation,test}_epoch_start

    • Callback.on_epoch_end in favor of Callback.on_{train,validation,test}_epoch_end

    • Callback.on_pretrain_routine_{start,end} in favor of Callback.on_fit_start

  • Removed the deprecated device attributes Trainer.{devices,gpus,num_gpus,ipus,tpu_cores} in favor of the accelerator-agnostic Trainer.num_devices (#14829)

  • Removed the deprecated LightningIPUModule (#14830)

  • Removed the deprecated Logger.agg_and_log_metrics hook in favour of Logger.log_metrics and the agg_key_funcs and agg_default_func arguments. (#14840)

  • Removed the deprecated precision plugin checkpoint hooks PrecisionPlugin.on_load_checkpoint and PrecisionPlugin.on_save_checkpoint (#14833)

  • Removed the deprecated Trainer.root_gpu attribute in favor of Trainer.strategy.root_device (#14829)

  • Removed the deprecated Trainer.use_amp and LightningModule.use_amp attributes (#14832)

  • Removed the deprecated callback hooks Callback.on_init_start and Callback.on_init_end (#14867)

  • Removed the deprecated Trainer.run_stage in favor of Trainer.{fit,validate,test,predict} (#14870)

  • Removed the deprecated SimpleProfiler.profile_iterable and AdvancedProfiler.profile_iterable attributes (#14864)

  • Removed the deprecated Trainer.verbose_evaluate (#14884)

  • Removed the deprecated Trainer.should_rank_save_checkpoint (#14885)

  • Removed the deprecated TrainerOptimizersMixin (#14887)

  • Removed the deprecated Trainer.lightning_optimizers (#14889)

  • Removed the deprecated TrainerDataLoadingMixin (#14888)

  • Removed the deprecated Trainer.call_hook in favor of Trainer._call_callback_hooks, Trainer._call_lightning_module_hook, Trainer._call_ttp_hook, and Trainer._call_accelerator_hook (#14869)

  • Removed the deprecated Trainer.{validated,tested,predicted}_ckpt_path (#14897)

  • Removed the deprecated device_stats_monitor_prefix_metric_keys (#14890)

  • Removed the deprecated LightningDataModule.on_save/load_checkpoint hooks (#14909)

  • Removed support for returning a value in Callback.on_save_checkpoint in favor of implementing Callback.state_dict (#14835)

[1.8.0] - Fixed
  • Fixed an issue with LightningLite.setup() not setting the .device attribute correctly on the returned wrapper (#14822)

  • Fixed an attribute error when running the tuner together with the StochasticWeightAveraging callback (#14836)

  • Fixed MissingFieldException in offline mode for the NeptuneLogger() (#14919)

  • Fixed wandb save_dir is overridden by None dir when using CLI (#14878)

  • Fixed a missing call to LightningDataModule.load_state_dict hook while restoring checkpoint using LightningDataModule.load_from_checkpoint (#14883)

  • Fixed torchscript error with containers of LightningModules (#14904)

  • Fixed reloading of the last checkpoint on run restart (#14907)

  • SaveConfigCallback instances should only save the config once to allow having the overwrite=False safeguard when using LightningCLI(..., run=False) (#14927)

  • Fixed an issue with terminating the trainer profiler when a StopIteration exception is raised while using an IterableDataset (#14940)

  • Do not update on-plateau schedulers when reloading from an end-of-epoch checkpoint (#14702)

  • Fixed Trainer support for PyTorch built without distributed support (#14971)

  • Fixed batch normalization statistics calculation in StochasticWeightAveraging callback (#14866)

  • Avoided initializing optimizers during deepspeed inference (#14944)

  • Fixed LightningCLI parse_env and description in subcommands (#15138)

  • Fixed an exception that would occur when creating a multiprocessing.Pool after importing Lightning (#15292)

  • Fixed a pickling error when using RichProgressBar together with checkpointing (#15319)

  • Fixed the RichProgressBar crashing when used with distributed strategies (#15376)

  • Fixed an issue with RichProgressBar not resetting the internal state for the sanity check progress (#15377)

  • Fixed an issue with DataLoader re-instantiation when the attribute is an array and the default value of the corresponding argument changed (#15409)

[1.7.7] - 2022-09-22

[1.7.7] - Fixed
  • Fixed the availability check for the neptune-client package (#14714)

  • Break HPU Graphs into two parts (forward + backward as one and optimizer as another) for better performance (#14656)

  • Fixed torchscript error with ensembles of LightningModules (#14657, #14724)

  • Fixed an issue with TensorBoardLogger.finalize creating a new experiment when none was created during the Trainer’s execution (#14762)

  • Fixed TypeError on import when torch.distributed is not available (#14809)

[1.7.6] - 2022-09-13

[1.7.6] - Changed
  • Improved the error messaging when passing Trainer.method(model, x_dataloader=None) with no module-method implementations available (#14614)

[1.7.6] - Fixed
  • Reset the dataloaders on OOM failure in batch size finder to use the last successful batch size (#14372)

  • Fixed an issue to keep downscaling the batch size in case there hasn’t been even a single successful optimal batch size with mode="power" (#14372)

  • Fixed an issue where self.log-ing a tensor would create a user warning from PyTorch about cloning tensors (#14599)

  • Fixed compatibility when torch.distributed is not available (#14454)

[1.7.5] - 2022-09-06

[1.7.5] - Fixed
  • Squeezed tensor values when logging with LightningModule.log (#14489)

  • Fixed WandbLogger save_dir is not set after creation (#14326)

  • Fixed Trainer.estimated_stepping_batches when maximum number of epochs is not set (#14317)

[1.7.4] - 2022-08-31

[1.7.4] - Added
  • Added an environment variable PL_DISABLE_FORK that can be used to disable all forking in the Trainer (#14319)

[1.7.4] - Fixed
  • Fixed LightningDataModule hparams parsing (#12806)

  • Reset epoch progress with batch size scaler (#13846)

  • Fixed restoring the trainer after using lr_find() so that the correct LR schedule is used for the actual training (#14113)

  • Fixed incorrect values after transferring data to an MPS device (#14368)

[1.7.3] - 2022-08-25

[1.7.3] - Fixed
  • Fixed an assertion error when using a ReduceOnPlateau scheduler with the Horovod strategy (#14215)

  • Fixed an AttributeError when accessing LightningModule.logger and the Trainer has multiple loggers (#14234)

  • Added back support for logging in the configure_gradient_clipping hook after unintended removal in v1.7.2 (#14298)

  • Fixed wrong num padding for RichProgressBar (#14296)

  • Fixed an issue to avoid the impact of sanity check on reload_dataloaders_every_n_epochs for validation (#13964)

[1.7.2] - 2022-08-17

[1.7.2] - Added
  • Added FullyShardedNativeNativeMixedPrecisionPlugin to handle precision for DDPFullyShardedNativeStrategy (#14092)

  • Added profiling to these hooks: on_before_batch_transfer, transfer_batch_to_device, on_after_batch_transfer, configure_gradient_clipping, clip_gradients (#14069)

[1.7.2] - Changed
  • The WandbLogger.name property no longer returns the name of the experiment, and instead returns the project’s name (#14145)

  • The default project name in WandbLogger is now “lightning_logs” (#14145)

  • Updated compatibility for LightningLite to run with the latest DeepSpeed 0.7.0 (13967)

[1.7.2] - Fixed
  • Fixed a bug that caused spurious AttributeError when multiple DataLoader classes are imported (#14117)

  • Fixed epoch-end logging results not being reset after the end of the epoch (#14061)

  • Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) (#9938)

  • Fixed the device placement when LightningModule.cuda() gets called without specifying a device index and the current cuda device was not 0 (#14128)

  • Avoided false positive warning about using sync_dist when using torchmetrics (#14143)

  • Avoid metadata.entry_points deprecation warning on Python 3.10 (#14052)

  • Fixed epoch-end logging results not being reset after the end of the epoch (#14061)

  • Avoid raising the sampler warning if num_replicas=1 (#14097)

  • Fixed saving hyperparameters in a composition where the parent class is not a LightningModule or LightningDataModule (#14151)

  • Avoided requiring the FairScale package to use precision with the fsdp native strategy (#14092)

  • Fixed an issue in which the default name for a run in WandbLogger would be set to the project name instead of a randomly generated string (#14145)

  • Fixed not preserving set attributes on DataLoader and BatchSampler when instantiated inside *_dataloader hooks (#14212)

[1.7.1] - 2022-08-09

[1.7.1] - Fixed
  • Casted only floating point tensors to fp16 with IPUs (#13983)

  • Casted tensors to fp16 before moving them to device with DeepSpeedStrategy (#14000)

  • Fixed the NeptuneLogger dependency being unrecognized (#13988)

  • Fixed an issue where users would be warned about unset max_epochs even when fast_dev_run was set (#13262)

  • Fixed MPS device being unrecognized (#13992)

  • Fixed incorrect precision="mixed" being used with DeepSpeedStrategy and IPUStrategy (#14041)

  • Fixed dtype inference during gradient norm computation (#14051)

  • Fixed a bug that caused ddp_find_unused_parameters to be set False, whereas the intended default is True (#14095)

[1.7.0] - 2022-08-02

[1.7.0] - Added
  • Added ServableModule and its associated callback called ServableModuleValidator to ensure the model can served (#13614)

  • Converted validation loop config warnings to PossibleUserWarning (#13377)

  • Added a flag named log_rank_zero_only to EarlyStopping to disable logging to non-zero rank processes (#13233)

  • Added support for reloading the last checkpoint saved by passing ckpt_path="last" (#12816)

  • Added LightningDataModule.load_from_checkpoint to support loading datamodules directly from checkpoint (#12550)

  • Added a friendly error message when attempting to call Trainer.save_checkpoint() without a model attached (#12772)

  • Added a friendly error message when attempting to use DeepSpeedStrategy on unsupported accelerators (#12699)

  • Enabled torch.inference_mode for evaluation and prediction (#12715)

  • Added support for setting val_check_interval to a value higher than the amount of training batches when check_val_every_n_epoch=None (#11993)

  • Include the pytorch_lightning version as a header in the CLI config files (#12532)

  • Added support for Callback registration through entry points (#12739)

  • Added support for Trainer(deterministic="warn") to warn instead of fail when a non-deterministic operation is encountered (#12588)

  • Added profiling to the loops’ dataloader __next__ calls (#12124)

  • Hivemind Strategy

    • Added CollaborativeStrategy (#12842)

    • Renamed CollaborativeStrategy to HivemindStrategy (#13388)

    • Removed unnecessary endpoint logic, renamed collaborative to hivemind (#13392)

  • Include a version suffix for new “last” checkpoints of later runs in the same directory (#12902)

  • Show a better error message when a Metric that does not return a Tensor is logged (#13164)

  • Added missing predict_dataset argument in LightningDataModule.from_datasets to create predict dataloaders (#12942)

  • Added class name prefix to metrics logged by DeviceStatsMonitor (#12228)

  • Automatically wrap custom samplers under a distributed environment by using DistributedSamplerWrapper (#12959)

  • Added profiling of LightningDataModule hooks (#12971)

  • Added Native FSDP Strategy (#12447)

  • Added breaking of lazy graph across training, validation, test and predict steps when training with habana accelerators to ensure better performance (#12938)

  • Added Checkpoint class to inherit from (#13024)

  • Added CPU metric tracking to DeviceStatsMonitor (#11795)

  • Added teardown() method to Accelerator (#11935)

  • Added support for using custom Trainers that don’t include callbacks using the CLI (#13138)

  • Added a timeout argument to DDPStrategy and DDPSpawnStrategy. (#13244, #13383)

  • Added XLAEnvironment cluster environment plugin (#11330)

  • Added logging messages to notify when FitLoop stopping conditions are met (#9749)

  • Added support for calling unknown methods with DummyLogger (#13224

  • Added support for recursively setting the Trainer reference for ensembles of LightningModules (#13638

  • Added Apple Silicon Support via MPSAccelerator (#13123)

  • Added support for DDP Fork (#13405)

  • Added support for async checkpointing (#13658)

  • Added support for HPU Device stats monitor (#13819)

[1.7.0] - Changed
  • accelerator="gpu" now automatically selects an available GPU backend (CUDA and MPS currently) (#13642)

  • Enable validation during overfitting (#12527)

  • Added dataclass support to extract_batch_size (#12573)

  • Changed checkpoints save path in the case of one logger and user-provided weights_save_path from weights_save_path/name/version/checkpoints to weights_save_path/checkpoints (#12372)

  • Changed checkpoints save path in the case of multiple loggers and user-provided weights_save_path from weights_save_path/name1_name2/version1_version2/checkpoints to weights_save_path/checkpoints (#12372)

  • Marked swa_lrs argument in StochasticWeightAveraging callback as required (#12556)

  • LightningCLI’s shorthand notation changed to use jsonargparse native feature (#12614)

  • LightningCLI changed to use jsonargparse native support for list append (#13129)

  • Changed seed_everything_default argument in the LightningCLI to type Union[bool, int]. If set to True a seed is automatically generated for the parser argument --seed_everything. (#12822, #13110)

  • Make positional arguments required for classes passed into the add_argparse_args function. (#12504)

  • Raise an error if there are insufficient training batches when using a float value of limit_train_batches (#12885)

  • DataLoader instantiated inside a *_dataloader hook will not set the passed arguments as attributes anymore (#12981)

  • When a multi-element tensor is logged, an error is now raised instead of silently taking the mean of all elements (#13164)

  • The WandbLogger will now use the run name in the logs folder if it is provided, and otherwise the project name (#12604)

  • Enabled using any Sampler in distributed environment in Lite (#13646)

  • Raised a warning instead of forcing sync_dist=True on epoch end (13364)

  • Updated val_check_interval(int) to consider total train batches processed instead of _batches_that_stepped for validation check during training (#12832

  • Updated Habana Accelerator’s auto_device_count, is_available & get_device_name methods based on the latest torch habana package (#13423)

  • Disallowed using BatchSampler when running on multiple IPUs (#13854)

[1.7.0] - Deprecated
  • Deprecated pytorch_lightning.accelerators.gpu.GPUAccelerator in favor of pytorch_lightning.accelerators.cuda.CUDAAccelerator (#13636)

  • Deprecated pytorch_lightning.loggers.base.LightningLoggerBase in favor of pytorch_lightning.loggers.logger.Logger, and deprecated pytorch_lightning.loggers.base in favor of pytorch_lightning.loggers.logger (#120148)

  • Deprecated pytorch_lightning.callbacks.base.Callback in favor of pytorch_lightning.callbacks.callback.Callback (#13031)

  • Deprecated num_processes, gpus, tpu_cores, and ipus from the Trainer constructor in favor of using the accelerator and devices arguments (#11040)

  • Deprecated setting LightningCLI(seed_everything_default=None) in favor of False (#12804).

  • Deprecated pytorch_lightning.core.lightning.LightningModule in favor of pytorch_lightning.core.module.LightningModule (#12740)

  • Deprecated pytorch_lightning.loops.base.Loop in favor of pytorch_lightning.loops.loop.Loop (#13043)

  • Deprecated Trainer.reset_train_val_dataloaders() in favor of Trainer.reset_{train,val}_dataloader (#12184)

  • Deprecated LightningCLI’s registries in favor of importing the respective package (#13221)

  • Deprecated public utilities in pytorch_lightning.utilities.cli.LightningCLI in favor of equivalent copies in pytorch_lightning.cli.LightningCLI (#13767)

  • Deprecated pytorch_lightning.profiler.* in favor of pytorch_lightning.profilers (#12308)

[1.7.0] - Removed
  • Removed deprecated IndexBatchSamplerWrapper.batch_indices (#13565)

  • Removed the deprecated LightningModule.add_to_queue and LightningModule.get_from_queue method (#13600)

  • Removed deprecated pytorch_lightning.core.decorators.parameter_validation from decorators (#13514)

  • Removed the deprecated Logger.close method (#13149)

  • Removed the deprecated weights_summary argument from the Trainer constructor (#13070)

  • Removed the deprecated flush_logs_every_n_steps argument from the Trainer constructor (#13074)

  • Removed the deprecated process_position argument from the Trainer constructor (13071)

  • Removed the deprecated checkpoint_callback argument from the Trainer constructor (#13027)

  • Removed the deprecated on_{train,val,test,predict}_dataloader hooks from the LightningModule and LightningDataModule (#13033)

  • Removed the deprecated TestTubeLogger (#12859)

  • Removed the deprecated pytorch_lightning.core.memory.LayerSummary and pytorch_lightning.core.memory.ModelSummary (#12593)

  • Removed the deprecated summarize method from the LightningModule (#12559)

  • Removed the deprecated model_size property from the LightningModule class (#12641)

  • Removed the deprecated stochastic_weight_avg argument from the Trainer constructor (#12535)

  • Removed the deprecated progress_bar_refresh_rate argument from the Trainer constructor (#12514)

  • Removed the deprecated prepare_data_per_node argument from the Trainer constructor (#12536)

  • Removed the deprecated pytorch_lightning.core.memory.{get_gpu_memory_map,get_memory_profile} (#12659)

  • Removed the deprecated terminate_on_nan argument from the Trainer constructor (#12553)

  • Removed the deprecated XLAStatsMonitor callback (#12688)

  • Remove deprecated pytorch_lightning.callbacks.progress.progress (#12658)

  • Removed the deprecated dim and size arguments from the LightningDataModule constructor(#12780)

  • Removed the deprecated train_transforms argument from the LightningDataModule constructor(#12662)

  • Removed the deprecated log_gpu_memory argument from the Trainer constructor (#12657)

  • Removed the deprecated automatic logging of GPU stats by the logger connector (#12657)

  • Removed deprecated GPUStatsMonitor callback (#12554)

  • Removed support for passing strategy names or strategy instances to the accelerator Trainer argument (#12696)

  • Removed support for passing strategy names or strategy instances to the plugins Trainer argument (#12700)

  • Removed the deprecated val_transforms argument from the LightningDataModule constructor (#12763)

  • Removed the deprecated test_transforms argument from the LightningDataModule constructor (#12773)

  • Removed deprecated Trainer(max_steps=None) (#13591)

  • Removed deprecated dataloader_idx argument from on_train_batch_start/end hooks Callback and LightningModule (#12769, #12977)

  • Removed deprecated get_progress_bar_dict property from LightningModule (#12839)

  • Removed sanity check for multi-optimizer support with habana backends (#13217)

  • Removed the need to explicitly load habana module (#13338)

  • Removed the deprecated Strategy.post_dispatch() hook (#13461)

  • Removed deprecated pytorch_lightning.callbacks.lr_monitor.LearningRateMonitor.lr_sch_names (#13353)

  • Removed deprecated Trainer.slurm_job_id in favor of SLURMEnvironment.job_id (#13459)

  • Removed support for the DDP2Strategy (#12705)

  • Removed deprecated LightningDistributed (#13549)

  • Removed deprecated ClusterEnvironment properties master_address and master_port in favor of main_address and main_port (#13458)

  • Removed deprecated ClusterEnvironment methods KubeflowEnvironment.is_using_kubelfow(), LSFEnvironment.is_using_lsf() and TorchElasticEnvironment.is_using_torchelastic() in favor of the detect() method (#13458)

  • Removed deprecated Callback.on_keyboard_interrupt (#13438)

  • Removed deprecated LightningModule.on_post_move_to_device (#13548)

  • Removed TPUSpawnStrategy.{tpu_local_core_rank,tpu_global_core_rank} attributes in favor of TPUSpawnStrategy.{local_rank,global_rank} (#11163)

  • Removed SingleTPUStrategy.{tpu_local_core_rank,tpu_global_core_rank} attributes in favor of SingleTPUStrategy.{local_rank,global_rank}(#11163)

[1.7.0] - Fixed
  • Improved support for custom DataLoaders when instantiated in *_dataloader hook (#12981)

  • Allowed custom BatchSamplers when instantiated in *_dataloader hook #13640)

  • Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad (#13014)

  • The model wrapper returned by LightningLite.setup() now properly supports pass-through when looking up attributes (#12597)

  • Fixed issue where the CLI fails with certain torch objects (#13153)

  • Fixed LightningCLI signature parameter resolving for some lightning classes (#13283)

  • Fixed Model Summary when using DeepSpeed Stage 3 (#13427)

  • Fixed pytorch_lightning.utilities.distributed.gather_all_tensors to handle tensors of different dimensions (#12630)

  • Fixed the input validation for the accelerator Trainer argument when passed as a string (#13417)

  • Fixed Trainer.predict(return_predictions=False) to track prediction’s batch_indices (#13629)

  • Fixed and issue that prevented setting a custom CheckpointIO plugin with strategies (#13785)

  • Fixed main progress bar counter when val_check_interval=int and check_val_every_n_epoch=None (#12832

  • Improved support for custom ReduceLROnPlateau scheduler if reduce_on_plateau is set by the user in scheduler config (#13838)

  • Used global_step while restoring logging step for old checkpoints (#13645)

  • When training with precision=16 on IPU, the cast has been moved off the IPU onto the host, making the copies from host to IPU cheaper (#13880)

  • Fixed error handling in learning rate finder when not enough data points are available to give a good suggestion (#13845)

  • Fixed an issue that caused the learning rate finder to set the model’s learning rate to None when no suggestion was possible (#13845)

  • Fixed an issue causing deterministic algorighms and other globals to get reset in spawned processes (#13921)

  • Fixed default amp_level for DeepSpeedPrecisionPlugin to O2 (#13897)

  • Fixed Python 3.10 compatibility for truncated back-propagation through time (TBPTT) (#13973)

  • Fixed TQDMProgressBar reset and update to show correct time estimation (2/2) (#13962)

[1.6.5] - 2022-07-13

[1.6.5] - Fixed
  • Fixed estimated_stepping_batches requiring distributed comms in configure_optimizers for the DeepSpeedStrategy (#13350)

  • Fixed bug with Python version check that prevented use with development versions of Python (#13420)

  • The loops now call .set_epoch() also on batch samplers if the dataloader has one wrapped in a distributed sampler (#13396)

  • Fixed the restoration of log step during restart (#13467)

[1.6.4] - 2022-06-01

[1.6.4] - Added
  • Added all DDP params to be exposed through hpu parallel strategy (#13067)

[1.6.4] - Changed
  • Keep torch.backends.cudnn.benchmark=False by default (unlike in v1.6.{0-3}) after speed and memory problems depending on the data used. Please consider tuning Trainer(benchmark) manually. (#13154)

  • Prevent modification of torch.backends.cudnn.benchmark when Trainer(benchmark=...) is not set (#13154)

[1.6.4] - Fixed
  • Fixed an issue causing zero-division error for empty dataloaders (#12885)

  • Fixed mismatching default values for the types of some arguments in the DeepSpeed and Fully-Sharded strategies which made the CLI unable to use them (#12989)

  • Avoid redundant callback restore warning while tuning (#13026)

  • Fixed Trainer(precision=64) during evaluation which now uses the wrapped precision module (#12983)

  • Fixed an issue to use wrapped LightningModule for evaluation during trainer.fit for BaguaStrategy (#12983)

  • Fixed an issue wrt unnecessary usage of habana mixed precision package for fp32 types (#13028)

  • Fixed the number of references of LightningModule so it can be deleted (#12897)

  • Fixed materialize_module setting a module’s child recursively (#12870)

  • Fixed issue where the CLI could not pass a Profiler to the Trainer (#13084)

  • Fixed torchelastic detection with non-distributed installations (#13142)

  • Fixed logging’s step values when multiple dataloaders are used during evaluation (#12184)

  • Fixed epoch logging on train epoch end (#13025)

  • Fixed DDPStrategy and DDPSpawnStrategy to initialize optimizers only after moving the module to the device (#11952)

[1.6.3] - 2022-05-03

[1.6.3] - Fixed
  • Use only a single instance of rich.console.Console throughout codebase (#12886)

  • Fixed an issue to ensure all the checkpoint states are saved in a common filepath with DeepspeedStrategy (#12887)

  • Fixed trainer.logger deprecation message (#12671)

  • Fixed an issue where sharded grad scaler is passed in when using BF16 with the ShardedStrategy (#12915)

  • Fixed an issue wrt recursive invocation of DDP configuration in hpu parallel plugin (#12912)

  • Fixed printing of ragged dictionaries in Trainer.validate and Trainer.test (#12857)

  • Fixed threading support for legacy loading of checkpoints (#12814)

  • Fixed pickling of KFoldLoop (#12441)

  • Stopped optimizer_zero_grad from being called after IPU execution (#12913)

  • Fixed fuse_modules to be qat-aware for torch>=1.11 (#12891)

  • Enforced eval shuffle warning only for default samplers in DataLoader (#12653)

  • Enable mixed precision in DDPFullyShardedStrategy when precision=16 (#12965)

  • Fixed TQDMProgressBar reset and update to show correct time estimation (1/2) (#12889)

  • Fixed fit loop restart logic to enable resume using the checkpoint (#12821)

[1.6.2] - 2022-04-27

[1.6.2] - Fixed
  • Fixed ImportError when torch.distributed is not available. (#12794)

  • When using custom DataLoaders in LightningDataModule, multiple inheritance is resolved properly (#12716)

  • Fixed encoding issues on terminals that do not support unicode characters (#12828)

  • Fixed support for ModelCheckpoint monitors with dots (#12783)

[1.6.1] - 2022-04-13

[1.6.1] - Changed
  • Support strategy argument being case insensitive (#12528)

[1.6.1] - Fixed
  • Run main progress bar updates independent of val progress bar updates in TQDMProgressBar (#12563)

  • Avoid calling average_parameters multiple times per optimizer step (#12452)

  • Properly pass some Logger’s parent’s arguments to super().__init__() (#12609)

  • Fixed an issue where incorrect type warnings appear when the overridden LightningLite.run method accepts user-defined arguments (#12629)

  • Fixed rank_zero_only decorator in LSF environments (#12587)

  • Don’t raise a warning when nn.Module is not saved under hparams (#12669)

  • Raise MisconfigurationException when the accelerator is available but the user passes invalid ([]/0/"0") values to the devices flag (#12708)

  • Support auto_select_gpus with the accelerator and devices API (#12608)

[1.6.0] - 2022-03-29

[1.6.0] - Added
  • Allow logging to an existing run ID in MLflow with MLFlowLogger (#12290)

  • Enable gradient accumulation using Horovod’s backward_passes_per_step (#11911)

  • Add new DETAIL log level to provide useful logs for improving monitoring and debugging of batch jobs (#11008)

  • Added a flag SLURMEnvironment(auto_requeue=True|False) to control whether Lightning handles the requeuing (#10601)

  • Fault Tolerant Manual

    • Add _Stateful protocol to detect if classes are stateful (#10646)

    • Add _FaultTolerantMode enum used to track different supported fault tolerant modes (#10645)

    • Add a _rotate_worker_indices utility to reload the state according the latest worker (#10647)

    • Add stateful workers (#10674)

    • Add an utility to collect the states across processes (#10639)

    • Add logic to reload the states across data loading components (#10699)

    • Cleanup some fault tolerant utilities (#10703)

    • Enable Fault Tolerant Manual Training (#10707)

    • Broadcast the _terminate_gracefully to all processes and add support for DDP (#10638)

  • Added support for re-instantiation of custom (subclasses of) DataLoaders returned in the *_dataloader() methods, i.e., automatic replacement of samplers now works with custom types of DataLoader (#10680)

  • Added a function to validate if fault tolerant training is supported. (#10465)

  • Added a private callback to manage the creation and deletion of fault-tolerance checkpoints (#11862)

  • Show a better error message when a custom DataLoader implementation is not well implemented and we need to reconstruct it (#10719)

  • Show a better error message when frozen dataclass is used as a batch (#10927)

  • Save the Loop’s state by default in the checkpoint (#10784)

  • Added Loop.replace to easily switch one loop for another (#10324)

  • Added support for --lr_scheduler=ReduceLROnPlateau to the LightningCLI (#10860)

  • Added LightningCLI.configure_optimizers to override the configure_optimizers return value (#10860)

  • Added LightningCLI(auto_registry) flag to register all subclasses of the registerable components automatically (#12108)

  • Added a warning that shows when max_epochs in the Trainer is not set (#10700)

  • Added support for returning a single Callback from LightningModule.configure_callbacks without wrapping it into a list (#11060)

  • Added console_kwargs for RichProgressBar to initialize inner Console (#10875)

  • Added support for shorthand notation to instantiate loggers with the LightningCLI (#11533)

  • Added a LOGGER_REGISTRY instance to register custom loggers to the LightningCLI (#11533)

  • Added info message when the Trainer arguments limit_*_batches, overfit_batches, or val_check_interval are set to 1 or 1.0 (#11950)

  • Added a PrecisionPlugin.teardown method (#10990)

  • Added LightningModule.lr_scheduler_step (#10249)

  • Added support for no pre-fetching to DataFetcher (#11606)

  • Added support for optimizer step progress tracking with manual optimization (#11848)

  • Return the output of the optimizer.step. This can be useful for LightningLite users, manual optimization users, or users overriding LightningModule.optimizer_step (#11711)

  • Teardown the active loop and strategy on exception (#11620)

  • Added a MisconfigurationException if user provided opt_idx in scheduler config doesn’t match with actual optimizer index of its respective optimizer (#11247)

  • Added a loggers property to Trainer which returns a list of loggers provided by the user (#11683)

  • Added a loggers property to LightningModule which retrieves the loggers property from Trainer (#11683)

  • Added support for DDP when using a CombinedLoader for the training data (#11648)

  • Added a warning when using DistributedSampler during validation/testing (#11479)

  • Added support for Bagua training strategy (#11146)

  • Added support for manually returning a poptorch.DataLoader in a *_dataloader hook (#12116)

  • Added rank_zero module to centralize utilities (#11747)

  • Added a _Stateful support for LightningDataModule (#11637)

  • Added _Stateful support for PrecisionPlugin (#11638)

  • Added Accelerator.is_available to check device availability (#11797)

  • Enabled static type-checking on the signature of Trainer (#11888)

  • Added utility functions for moving optimizers to devices (#11758)

  • Added a warning when saving an instance of nn.Module with save_hyperparameters() (#12068)

  • Added estimated_stepping_batches property to Trainer (#11599)

  • Added support for pluggable Accelerators (#12030)

  • Added profiling for on_load_checkpoint/on_save_checkpoint callback and LightningModule hooks (#12149)

  • Added LayerSync and NativeSyncBatchNorm plugins (#11754)

  • Added optional storage_options argument to Trainer.save_checkpoint() to pass to custom CheckpointIO implementations (#11891)

  • Added support to explicitly specify the process group backend for parallel strategies (#11745)

  • Added device_ids and num_devices property to Trainer (#12151)

  • Added Callback.state_dict() and Callback.load_state_dict() methods (#12232)

  • Added AcceleratorRegistry (#12180)

  • Added support for Habana Accelerator (HPU) (#11808)

  • Added support for dataclasses in apply_to_collections (#11889)

[1.6.0] - Changed
  • Drop PyTorch 1.7 support (#12191), (#12432)

  • Make benchmark flag optional and set its value based on the deterministic flag (#11944)

  • Implemented a new native and rich format in _print_results method of the EvaluationLoop (#11332)

  • Do not print an empty table at the end of the EvaluationLoop (#12427)

  • Set the prog_bar flag to False in LightningModule.log_grad_norm (#11472)

  • Raised exception in init_dist_connection() when torch distributed is not available (#10418)

  • The monitor argument in the EarlyStopping callback is no longer optional (#10328)

  • Do not fail if batch size could not be inferred for logging when using DeepSpeed (#10438)

  • Raised MisconfigurationException when enable_progress_bar=False and a progress bar instance has been passed in the callback list (#10520)

  • Moved trainer.connectors.env_vars_connector._defaults_from_env_vars to utilities.argsparse._defaults_from_env_vars (#10501)

  • Changes in LightningCLI required for the new major release of jsonargparse v4.0.0 (#10426)

  • Renamed refresh_rate_per_second parameter to refresh_rate for RichProgressBar signature (#10497)

  • Moved ownership of the PrecisionPlugin into TrainingTypePlugin and updated all references (#10570)

  • Fault Tolerant relies on signal.SIGTERM to gracefully exit instead of signal.SIGUSR1 (#10605)

  • Loop.restarting=... now sets the value recursively for all subloops (#11442)

  • Raised an error if the batch_size cannot be inferred from the current batch if it contained a string or was a custom batch object (#10541)

  • The validation loop is now disabled when overfit_batches > 0 is set in the Trainer (#9709)

  • Moved optimizer related logics from Accelerator to TrainingTypePlugin (#10596)

  • Moved ownership of the lightning optimizers from the Trainer to the Strategy (#11444)

  • Moved ownership of the data fetchers from the DataConnector to the Loops (#11621)

  • Moved batch_to_device method from Accelerator to TrainingTypePlugin (#10649)

  • The DDPSpawnPlugin no longer overrides the post_dispatch plugin hook (#10034)

  • Integrate the progress bar implementation with progress tracking (#11213)

  • The LightningModule.{add_to_queue,get_from_queue} hooks no longer get a torch.multiprocessing.SimpleQueue and instead receive a list based queue (#10034)

  • Changed training_step, validation_step, test_step and predict_step method signatures in Accelerator and updated input from caller side (#10908)

  • Changed the name of the temporary checkpoint that the DDPSpawnPlugin and related plugins save (#10934)

  • LoggerCollection returns only unique logger names and versions (#10976)

  • Redesigned process creation for spawn-based plugins (DDPSpawnPlugin, TPUSpawnPlugin, etc.) (#10896)

    • All spawn-based plugins now spawn processes immediately upon calling Trainer.{fit,validate,test,predict}

    • The hooks/callbacks prepare_data, setup, configure_sharded_model and teardown now run under initialized process group for spawn-based plugins just like their non-spawn counterparts

    • Some configuration errors that were previously raised as MisconfigurationExceptions will now be raised as ProcessRaisedException (torch>=1.8) or as Exception (torch<1.8)

    • Removed the TrainingTypePlugin.pre_dispatch() method and merged it with TrainingTypePlugin.setup() (#11137)

  • Changed profiler to index and display the names of the hooks with a new pattern []. (#11026)

  • Changed batch_to_device entry in profiling from stage-specific to generic, to match profiling of other hooks (#11031)

  • Changed the info message for finalizing ddp-spawn worker processes to a debug-level message (#10864)

  • Removed duplicated file extension when uploading model checkpoints with NeptuneLogger (#11015)

  • Removed __getstate__ and __setstate__ of RichProgressBar (#11100)

  • The DDPPlugin and DDPSpawnPlugin and their subclasses now remove the SyncBatchNorm wrappers in teardown() to enable proper support at inference after fitting (#11078)

  • Moved ownership of the Accelerator instance to the TrainingTypePlugin; all training-type plugins now take an optional parameter accelerator (#11022)

  • Renamed the TrainingTypePlugin to Strategy (#11120)

    • Renamed the ParallelPlugin to ParallelStrategy (#11123)

    • Renamed the DataParallelPlugin to DataParallelStrategy (#11183)

    • Renamed the DDPPlugin to DDPStrategy (#11142)

    • Renamed the DDP2Plugin to DDP2Strategy (#11185)

    • Renamed the DDPShardedPlugin to DDPShardedStrategy (#11186)

    • Renamed the DDPFullyShardedPlugin to DDPFullyShardedStrategy (#11143)

    • Renamed the DDPSpawnPlugin to DDPSpawnStrategy (#11145)

    • Renamed the DDPSpawnShardedPlugin to DDPSpawnShardedStrategy (#11210)

    • Renamed the DeepSpeedPlugin to DeepSpeedStrategy (#11194)

    • Renamed the HorovodPlugin to HorovodStrategy (#11195)

    • Renamed the TPUSpawnPlugin to TPUSpawnStrategy (#11190)

    • Renamed the IPUPlugin to IPUStrategy (#11193)

    • Renamed the SingleDevicePlugin to SingleDeviceStrategy (#11182)

    • Renamed the SingleTPUPlugin to SingleTPUStrategy (#11182)

    • Renamed the TrainingTypePluginsRegistry to StrategyRegistry (#11233)

  • Marked the ResultCollection, ResultMetric, and ResultMetricCollection classes as protected (#11130)

  • Marked trainer.checkpoint_connector as protected (#11550)

  • The epoch start/end hooks are now called by the FitLoop instead of the TrainingEpochLoop (#11201)

  • DeepSpeed does not require lightning module zero 3 partitioning (#10655)

  • Moved Strategy classes to the strategies directory (#11226)

  • Renamed training_type_plugin file to strategy (#11239)

  • Changed DeviceStatsMonitor to group metrics based on the logger’s group_separator (#11254)

  • Raised UserWarning if evaluation is triggered with best ckpt and trainer is configured with multiple checkpoint callbacks (#11274)

  • Trainer.logged_metrics now always contains scalar tensors, even when a Python scalar was logged (#11270)

  • The tuner now uses the checkpoint connector to copy and restore its state (#11518)

  • Changed MisconfigurationException to ModuleNotFoundError when rich isn’t available (#11360)

  • The trainer.current_epoch value is now increased by 1 during and after on_train_end (#8578)

  • The trainer.global_step value now accounts for multiple optimizers and TBPTT splits (#11805)

  • The trainer.global_step value is now increased right after the optimizer.step() call which will impact users who access it during an intra-training validation hook (#11805)

  • The filename of checkpoints created with ModelCheckpoint(filename='{step}') is different compared to previous versions. A checkpoint saved after 1 step will be named step=1.ckpt instead of step=0.ckpt (#11805)

  • Inherit from ABC for Accelerator: Users need to implement auto_device_count (#11521)

  • Changed parallel_devices property in ParallelStrategy to be lazy initialized (#11572)

  • Updated TQDMProgressBar to run a separate progress bar for each eval dataloader (#11657)

  • Sorted SimpleProfiler(extended=False) summary based on mean duration for each hook (#11671)

  • Avoid enforcing shuffle=False for eval dataloaders (#11575)

  • When using DP (data-parallel), Lightning will no longer automatically reduce all tensors returned in training_step; it will only reduce the loss unless training_step_end is overridden (#11594)

  • When using DP (data-parallel), the training_epoch_end hook will no longer receive reduced outputs from training_step and instead get the full tensor of results from all GPUs (#11594)

  • Changed default logger name to lightning_logs for consistency (#11762)

  • Rewrote accelerator_connector (#11448)

  • When manual optimization is used with DDP, we no longer force find_unused_parameters=True (#12425)

  • Disable loading dataloades if corresponding limit_batches=0 (#11576)

  • Removed is_global_zero check in training_epoch_loop before logger.save. If you have a custom logger that implements save the Trainer will now call save on all ranks by default. To change this behavior add @rank_zero_only to your save implementation (#12134)

  • Disabled tuner with distributed strategies (#12179)

  • Marked trainer.logger_connector as protected (#12195)

  • Move Strategy.process_dataloader function call from fit/evaluation/predict_loop.py to data_connector.py (#12251)

  • ModelCheckpoint(save_last=True, every_n_epochs=N) now saves a “last” checkpoint every epoch (disregarding every_n_epochs) instead of only once at the end of training (#12418)

  • The strategies that support sync_batchnorm now only apply it when fitting (#11919)

  • Avoided fallback on CPU if no devices are provided for other accelerators (#12410)

  • Modified supporters.py so that in the accumulator element (for loss) is created directly on the device (#12430)

  • Removed EarlyStopping.on_save_checkpoint and EarlyStopping.on_load_checkpoint in favor of EarlyStopping.state_dict and EarlyStopping.load_state_dict (#11887)

  • Removed BaseFinetuning.on_save_checkpoint and BaseFinetuning.on_load_checkpoint in favor of BaseFinetuning.state_dict and BaseFinetuning.load_state_dict (#11887)

  • Removed BackboneFinetuning.on_save_checkpoint and BackboneFinetuning.on_load_checkpoint in favor of BackboneFinetuning.state_dict and BackboneFinetuning.load_state_dict (#11887)

  • Removed ModelCheckpoint.on_save_checkpoint and ModelCheckpoint.on_load_checkpoint in favor of ModelCheckpoint.state_dict and ModelCheckpoint.load_state_dict (#11887)

  • Removed Timer.on_save_checkpoint and Timer.on_load_checkpoint in favor of Timer.state_dict and Timer.load_state_dict (#11887)

  • Replaced PostLocalSGDOptimizer with a dedicated model averaging component (#12378)

[1.6.0] - Deprecated
  • Deprecated training_type_plugin property in favor of strategy in Trainer and updated the references (#11141)

  • Deprecated Trainer.{validated,tested,predicted}_ckpt_path and replaced with read-only property Trainer.ckpt_path set when checkpoints loaded via Trainer.{fit,validate,test,predict} (#11696)

  • Deprecated ClusterEnvironment.master_{address,port} in favor of ClusterEnvironment.main_{address,port} (#10103)

  • Deprecated DistributedType in favor of _StrategyType (#10505)

  • Deprecated the precision_plugin constructor argument from Accelerator (#10570)

  • Deprecated DeviceType in favor of _AcceleratorType (#10503)

  • Deprecated the property Trainer.slurm_job_id in favor of the new SLURMEnvironment.job_id() method (#10622)

  • Deprecated the access to the attribute IndexBatchSamplerWrapper.batch_indices in favor of IndexBatchSamplerWrapper.seen_batch_indices (#10870)

  • Deprecated on_init_start and on_init_end callback hooks (#10940)

  • Deprecated Trainer.call_hook in favor of Trainer._call_callback_hooks, Trainer._call_lightning_module_hook, Trainer._call_ttp_hook, and Trainer._call_accelerator_hook (#10979)

  • Deprecated TrainingTypePlugin.post_dispatch in favor of TrainingTypePlugin.teardown (#10939)

  • Deprecated ModelIO.on_hpc_{save/load} in favor of CheckpointHooks.on_{save/load}_checkpoint (#10911)

  • Deprecated Trainer.run_stage in favor of Trainer.{fit,validate,test,predict} (#11000)

  • Deprecated Trainer.lr_schedulers in favor of Trainer.lr_scheduler_configs which returns a list of dataclasses instead of dictionaries (#11443)

  • Deprecated Trainer.verbose_evaluate in favor of EvaluationLoop(verbose=...) (#10931)

  • Deprecated Trainer.should_rank_save_checkpoint Trainer property (#11068)

  • Deprecated Trainer.lightning_optimizers (#11444)

  • Deprecated TrainerOptimizersMixin and moved functionality to core/optimizer.py(#11155)

  • Deprecated the on_train_batch_end(outputs) format when multiple optimizers are used and TBPTT is enabled (#12182)

  • Deprecated the training_epoch_end(outputs) format when multiple optimizers are used and TBPTT is enabled (#12182)

  • Deprecated TrainerCallbackHookMixin (#11148)

  • Deprecated TrainerDataLoadingMixin and moved functionality to Trainer and DataConnector (#11282)

  • Deprecated function pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys (#11254)

  • Deprecated Callback.on_epoch_start hook in favour of Callback.on_{train/val/test}_epoch_start (#11578)

  • Deprecated Callback.on_epoch_end hook in favour of Callback.on_{train/val/test}_epoch_end (#11578)

  • Deprecated LightningModule.on_epoch_start hook in favor of LightningModule.on_{train/val/test}_epoch_start (#11578)

  • Deprecated LightningModule.on_epoch_end hook in favor of LightningModule.on_{train/val/test}_epoch_end (#11578)

  • Deprecated on_before_accelerator_backend_setup callback hook in favour of setup (#11568)

  • Deprecated on_batch_start and on_batch_end callback hooks in favor of on_train_batch_start and on_train_batch_end (#11577)

  • Deprecated on_configure_sharded_model callback hook in favor of setup (#11627)

  • Deprecated pytorch_lightning.utilities.distributed.rank_zero_only in favor of pytorch_lightning.utilities.rank_zero.rank_zero_only (#11747)

  • Deprecated pytorch_lightning.utilities.distributed.rank_zero_debug in favor of pytorch_lightning.utilities.rank_zero.rank_zero_debug (#11747)

  • Deprecated pytorch_lightning.utilities.distributed.rank_zero_info in favor of pytorch_lightning.utilities.rank_zero.rank_zero_info (#11747)

  • Deprecated pytorch_lightning.utilities.warnings.rank_zero_warn in favor of pytorch_lightning.utilities.rank_zero.rank_zero_warn (#11747)

  • Deprecated pytorch_lightning.utilities.warnings.rank_zero_deprecation in favor of pytorch_lightning.utilities.rank_zero.rank_zero_deprecation (#11747)

  • Deprecated pytorch_lightning.utilities.warnings.LightningDeprecationWarning in favor of pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning (#11747)

  • Deprecated on_pretrain_routine_start and on_pretrain_routine_end callback hooks in favor of on_fit_start (#11794)

  • Deprecated LightningModule.on_pretrain_routine_start and LightningModule.on_pretrain_routine_end hooks in favor of on_fit_start (#12122)

  • Deprecated agg_key_funcs and agg_default_func parameters from LightningLoggerBase (#11871)

  • Deprecated LightningLoggerBase.update_agg_funcs (#11871)

  • Deprecated LightningLoggerBase.agg_and_log_metrics in favor of LightningLoggerBase.log_metrics (#11832)

  • Deprecated passing weights_save_path to the Trainer constructor in favor of adding the ModelCheckpoint callback with dirpath directly to the list of callbacks (#12084)

  • Deprecated pytorch_lightning.profiler.AbstractProfiler in favor of pytorch_lightning.profiler.Profiler (#12106)

  • Deprecated pytorch_lightning.profiler.BaseProfiler in favor of pytorch_lightning.profiler.Profiler (#12150)

  • Deprecated BaseProfiler.profile_iterable (#12102)

  • Deprecated LoggerCollection in favor of trainer.loggers (#12147)

  • Deprecated PrecisionPlugin.on_{save,load}_checkpoint in favor of PrecisionPlugin.{state_dict,load_state_dict} (#11978)

  • Deprecated LightningDataModule.on_save/load_checkpoint in favor of state_dict/load_state_dict (#11893)

  • Deprecated Trainer.use_amp in favor of Trainer.amp_backend (#12312)

  • Deprecated LightingModule.use_amp in favor of Trainer.amp_backend (#12315)

  • Deprecated specifying the process group backend through the environment variable PL_TORCH_DISTRIBUTED_BACKEND (#11745)

  • Deprecated ParallelPlugin.torch_distributed_backend in favor of DDPStrategy.process_group_backend property (#11745)

  • Deprecated ModelCheckpoint.save_checkpoint in favor of Trainer.save_checkpoint (#12456)

  • Deprecated Trainer.devices in favor of Trainer.num_devices and Trainer.device_ids (#12151)

  • Deprecated Trainer.root_gpu in favor of Trainer.strategy.root_device.index when GPU is used (#12262)

  • Deprecated Trainer.num_gpus in favor of Trainer.num_devices when GPU is used (#12384)

  • Deprecated Trainer.ipus in favor of Trainer.num_devices when IPU is used (#12386)

  • Deprecated Trainer.num_processes in favor of Trainer.num_devices (#12388)

  • Deprecated Trainer.data_parallel_device_ids in favor of Trainer.device_ids (#12072)

  • Deprecated returning state from Callback.on_save_checkpoint in favor of returning state in Callback.state_dict for checkpointing (#11887)

  • Deprecated passing only the callback state to Callback.on_load_checkpoint(callback_state) in favor of passing the callback state to Callback.load_state_dict and in 1.8, passing the entire checkpoint dictionary to Callback.on_load_checkpoint(checkpoint) (#11887)

  • Deprecated Trainer.gpus in favor of Trainer.device_ids or Trainer.num_devices (#12436)

  • Deprecated Trainer.tpu_cores in favor of Trainer.num_devices (#12437)

[1.6.0] - Removed
  • Removed deprecated parameter method in pytorch_lightning.utilities.model_helpers.is_overridden (#10507)

  • Remove deprecated method ClusterEnvironment.creates_children (#10339)

  • Removed deprecated TrainerModelHooksMixin.is_function_implemented and TrainerModelHooksMixin.has_arg (#10322)

  • Removed deprecated pytorch_lightning.utilities.device_dtype_mixin.DeviceDtypeModuleMixin in favor of pytorch_lightning.core.mixins.device_dtype_mixin.DeviceDtypeModuleMixin (#10442)

  • Removed deprecated LightningModule.loaded_optimizer_states_dict property (#10346)

  • Removed deprecated Trainer.fit(train_dataloader=), Trainer.validate(val_dataloaders=), and Trainer.test(test_dataloader=) (#10325)

  • Removed deprecated has_prepared_data, has_setup_fit, has_setup_validate, has_setup_test, has_setup_predict, has_teardown_fit, has_teardown_validate, has_teardown_test and has_teardown_predict datamodule lifecycle properties (#10350)

  • Removed deprecated every_n_val_epochs parameter of ModelCheckpoint (#10366)

  • Removed deprecated import pytorch_lightning.profiler.profilers in favor of import pytorch_lightning.profiler (#10443)

  • Removed deprecated property configure_slurm_dpp from accelerator connector (#10370)

  • Removed deprecated arguments num_nodes and sync_batchnorm from DDPPlugin, DDPSpawnPlugin, DeepSpeedPlugin (#10357)

  • Removed deprecated property is_slurm_managing_tasks from AcceleratorConnector (#10353)

  • Removed deprecated LightningModule.log(tbptt_reduce_fx, tbptt_reduce_token, sync_dist_op) (#10423)

  • Removed deprecated Plugin.task_idx (#10441)

  • Removed deprecated method master_params from PrecisionPlugin (#10372)

  • Removed the automatic detachment of “extras” returned from training_step. For example, return {'loss': ..., 'foo': foo.detach()} will now be necessary if foo has gradients which you do not want to store (#10424)

  • Removed deprecated passthrough methods and properties from Accelerator base class:

  • Removed deprecated signature for transfer_batch_to_device hook. The new argument dataloader_idx is now required (#10480)

  • Removed deprecated utilities.distributed.rank_zero_{warn/deprecation} (#10451)

  • Removed deprecated mode argument from ModelSummary class (#10449)

  • Removed deprecated Trainer.train_loop property in favor of Trainer.fit_loop (#10482)

  • Removed deprecated Trainer.train_loop property in favor of Trainer.fit_loop (#10482)

  • Removed deprecated disable_validation property from Trainer (#10450)

  • Removed deprecated CheckpointConnector.hpc_load property in favor of CheckpointConnector.restore (#10525)

  • Removed deprecated reload_dataloaders_every_epoch from Trainer in favour of reload_dataloaders_every_n_epochs (#10481)

  • Removed the precision_plugin attribute from Accelerator in favor of its equivalent attribute precision_plugin in the TrainingTypePlugin (#10570)

  • Removed DeepSpeedPlugin.{precision,amp_type,amp_level} properties (#10657)

  • Removed patching of on_before_batch_transfer, transfer_batch_to_device and on_after_batch_transfer hooks in LightningModule (#10603)

  • Removed argument return_result from the DDPSpawnPlugin.spawn() method (#10867)

  • Removed the property TrainingTypePlugin.results and corresponding properties in subclasses (#10034)

  • Removed the mp_queue attribute from DDPSpawnPlugin and TPUSpawnPlugin (#10034)

  • Removed unnecessary _move_optimizer_state method overrides from TPUSpawnPlugin and SingleTPUPlugin (#10849)

  • Removed should_rank_save_checkpoint property from TrainingTypePlugin (#11070)

  • Removed model_sharded_context method from Accelerator (#10886)

  • Removed method pre_dispatch from the PrecisionPlugin (#10887)

  • Removed method setup_optimizers_in_pre_dispatch from the strategies and achieve the same logic in setup and pre_dispatch methods (#10906)

  • Removed methods pre_dispatch, dispatch and post_dispatch from the Accelerator (#10885)

  • Removed method training_step, test_step, validation_step and predict_step from the Accelerator (#10890)

  • Removed TrainingTypePlugin.start_{training,evaluating,predicting} hooks and the same in all subclasses (#10989, #10896)

  • Removed Accelerator.on_train_start (#10999)

  • Removed support for Python 3.6 (#11117)

  • Removed Strategy.init_optimizers in favor of Strategy.setup_optimizers (#11236)

  • Removed profile("training_step_and_backward") in Closure class since we already profile calls training_step and backward (#11222)

  • Removed Strategy.optimizer_zero_grad (#11246)

  • Removed Strategy.on_gpu (#11537)

  • Removed Strategy.on_tpu property (#11536)

  • Removed the abstract property LightningLoggerBase.experiment (#11603)

  • Removed FitLoop.current_epoch getter and setter (#11562)

  • Removed access to _short_id in NeptuneLogger (#11517)

  • Removed log_text and log_image from the LightningLoggerBase API (#11857)

  • Removed calls to profile("model_forward") in favor of profiling training_step (#12032)

  • Removed get_mp_spawn_kwargs from DDPSpawnStrategy and TPUSpawnStrategy in favor of configuration in the _SpawnLauncher (#11966)

  • Removed _aggregate_metrics, _reduce_agg_metrics, and _finalize_agg_metrics from LightningLoggerBase (#12053)

  • Removed the AcceleratorConnector.device_type property (#12081)

  • Removed AcceleratorConnector.num_nodes (#12107)

  • Removed AcceleratorConnector.has_ipu property (#12111)

  • Removed AcceleratorConnector.use_ipu property (#12110)

  • Removed AcceleratorConnector.has_tpu property (#12109)

  • Removed AcceleratorConnector.use_dp property (#12112)

  • Removed configure_sync_batchnorm from ParallelStrategy and all other strategies that inherit from it (#11754)

  • Removed public attribute sync_batchnorm from strategies (#11754)

  • Removed AcceleratorConnector.root_gpu property (#12262)

  • Removed AcceleratorConnector.tpu_id property (#12387)

  • Removed AcceleratorConnector.num_gpus property (#12384)

  • Removed AcceleratorConnector.num_ipus property (#12386)

  • Removed AcceleratorConnector.num_processes property (#12388)

  • Removed AcceleratorConnector.parallel_device_ids property (#12072)

  • Removed AcceleratorConnector.devices property (#12435)

  • Removed AcceleratorConnector.parallel_devices property (#12075)

  • Removed AcceleratorConnector.tpu_cores property (#12437)

[1.6.0] - Fixed
  • Fixed an issue where ModelCheckpoint could delete last checkpoint from the old directory when dirpath has changed during resumed training (#12225)

  • Fixed an issue where ModelCheckpoint could delete older checkpoints when dirpath has changed during resumed training (#12045)

  • Fixed an issue where HorovodStrategy.teardown() did not complete gracefully if an exception was thrown during callback setup #11752

  • Fixed security vulnerabilities CVE-2020-1747 and CVE-2020-14343 caused by the PyYAML dependency (#11099)

  • Fixed security vulnerability “CWE-94: Improper Control of Generation of Code (Code Injection)” (#12212)

  • Fixed logging on {test,validation}_epoch_end with multiple dataloaders (#11132)

  • Reset the validation progress tracking state after sanity checking (#11218)

  • Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped (#11119)

  • Fixed an issue with the TPUSpawnPlugin handling the XLA_USE_BF16 environment variable incorrectly (#10990)

  • Fixed wrong typehint for Trainer.lightning_optimizers (#11155)

  • Fixed the lr-scheduler state not being dumped to checkpoint when using the deepspeed strategy (#11307)

  • Fixed bug that forced overriding configure_optimizers with the CLI (#11672)

  • Fixed type promotion when tensors of higher category than float are logged (#11401)

  • Fixed SimpleProfiler summary (#11414)

  • No longer set a DistributedSampler to the poptorch.DataLoader when IPUs are used (#12114)

  • Fixed bug where progress bar was not being disabled when not in rank zero during predict (#11377)

  • Fixed the mid-epoch warning call while resuming training (#11556)

  • Fixed LightningModule.{un,}toggle_model when only 1 optimizer is used (#12088)

  • Fixed an issue in RichProgressbar to display the metrics logged only on main progress bar (#11690)

  • Fixed RichProgressBar progress when refresh rate does not evenly divide the total counter (#11668)

  • Fixed RichProgressBar progress validation bar total when using multiple validation runs within a single training epoch (#11668)

  • Configure native Deepspeed schedulers with interval=’step’ (#11788), (#12031)

  • Update RichProgressBarTheme styles after detecting light theme on colab (#10993)

  • Fixed passing _ddp_params_and_buffers_to_ignore (#11949)

  • Fixed an AttributeError when calling save_hyperparameters and no parameters need saving (#11827)

  • Fixed environment variable priority for global rank determination (#11406)

  • Fixed an issue that caused the Trainer to produce identical results on subsequent runs without explicit re-seeding (#11870)

  • Fixed an issue that caused the Tuner to affect the random state (#11870)

  • Fixed to avoid common hook warning if no hook is overridden (#12131)

  • Fixed deepspeed keeping old sub-folders in same ckpt path (#12194)

  • Fixed returning logged metrics instead of callback metrics during evaluation (#12224)

  • Fixed the case where logger=None is passed to the Trainer (#12249)

  • Fixed bug where the global step tracked by ModelCheckpoint was still set even if no checkpoint was saved (#12418)

  • Fixed bug where ModelCheckpoint was overriding the epoch and step logged values (#12418)

  • Fixed bug where monitoring the default epoch and step values with ModelCheckpoint would fail (#12418)

  • Fixed initializing optimizers unnecessarily in DDPFullyShardedStrategy (#12267)

  • Fixed check for horovod module (#12377)

  • Fixed logging to loggers with multiple eval dataloaders (#12454)

  • Fixed an issue with resuming from a checkpoint trained with QAT (#11346)

[1.5.10] - 2022-02-08

[1.5.10] - Fixed
  • Fixed an issue to avoid validation loop run on restart (#11552)

  • The RichProgressBar now correctly shows the on_epoch logged values on train epoch end (#11689)

  • Fixed an issue to make the step argument in WandbLogger.log_image work (#11716)

  • Fixed restore_optimizers for mapping states (#11757)

  • With DPStrategy, the batch is not explicitly moved to the device (#11780)

  • Fixed an issue to avoid val bar disappear after trainer.validate() (#11700)

  • Fixed supporting remote filesystems with Trainer.weights_save_path for fault-tolerant training (#11776)

  • Fixed check for available modules (#11526)

  • Fixed bug where the path for “last” checkpoints was not getting saved correctly which caused newer runs to not remove the previous “last” checkpoint (#11481)

  • Fixed bug where the path for best checkpoints was not getting saved correctly when no metric was monitored which caused newer runs to not use the best checkpoint (#11481)

[1.5.9] - 2022-01-20

[1.5.9] - Fixed
  • Pinned sphinx-autodoc-typehints with <v1.15 (#11400)

  • Skipped testing with PyTorch 1.7 and Python 3.9 on Ubuntu (#11217)

  • Fixed type promotion when tensors of higher category than float are logged (#11401)

  • Fixed the format of the configuration saved automatically by the CLI’s SaveConfigCallback (#11532)

[1.5.9] - Changed
  • Changed LSFEnvironment to use LSB_DJOB_RANKFILE environment variable instead of LSB_HOSTS for determining node rank and main address (#10825)

  • Disabled sampler replacement when using IterableDataset (#11507)

[1.5.8] - 2022-01-05

[1.5.8] - Fixed
  • Fixed LightningCLI race condition while saving the config (#11199)

  • Fixed the default value used with log(reduce_fx=min|max) (#11310)

  • Fixed data fetcher selection (#11294)

  • Fixed a race condition that could result in incorrect (zero) values being observed in prediction writer callbacks (#11288)

  • Fixed dataloaders not getting reloaded the correct amount of times when setting reload_dataloaders_every_n_epochs and check_val_every_n_epoch (#10948)

  • Fixed deepspeed strategy not restoring the lr-scheduler states when lr-scheduler(s) are configured through LightningModule.configure_optimizer (#11322)

[1.5.7] - 2021-12-21

[1.5.7] - Fixed
  • Fixed NeptuneLogger when using DDP (#11030)

  • Fixed a bug to disable logging hyperparameters in logger if there are no hparams (#11105)

  • Avoid the deprecated onnx.export(example_outputs=...) in torch 1.10 (#11116)

  • Fixed an issue when torch-scripting a LightningModule after training with Trainer(sync_batchnorm=True) (#11078)

  • Fixed an AttributeError occurring when using a CombinedLoader (multiple dataloaders) for prediction (#11111)

  • Fixed bug where Trainer(track_grad_norm=..., logger=False) would fail (#11114)

  • Fixed an incorrect warning being produced by the model summary when using bf16 precision on CPU (#11161)

[1.5.7] - Changed
  • DeepSpeed does not require lightning module zero 3 partitioning (#10655)

  • The ModelCheckpoint callback now saves and restores attributes best_k_models, kth_best_model_path, kth_value, and last_model_path (#10995)

[1.5.6] - 2021-12-15

[1.5.6] - Fixed
  • Fixed a bug where the DeepSpeedPlugin arguments cpu_checkpointing and contiguous_memory_optimization were not being forwarded to deepspeed correctly (#10874)

  • Fixed an issue with NeptuneLogger causing checkpoints to be uploaded with a duplicated file extension (#11015)

  • Fixed support for logging within callbacks returned from LightningModule (#10991)

  • Fixed running sanity check with RichProgressBar (#10913)

  • Fixed support for CombinedLoader while checking for warning raised with eval dataloaders (#10994)

  • The TQDM progress bar now correctly shows the on_epoch logged values on train epoch end (#11069)

  • Fixed bug where the TQDM updated the training progress bar during trainer.validate (#11069)

[1.5.5] - 2021-12-07

[1.5.5] - Fixed
  • Disabled batch_size extraction for torchmetric instances because they accumulate the metrics internally (#10815)

  • Fixed an issue with SignalConnector not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled (#10611)

  • Fixed SignalConnector._has_already_handler check for callable type (#10483)

  • Fixed an issue to return the results for each dataloader separately instead of duplicating them for each (#10810)

  • Improved exception message if rich version is less than 10.2.2 (#10839)

  • Fixed uploading best model checkpoint in NeptuneLogger (#10369)

  • Fixed early schedule reset logic in PyTorch profiler that was causing data leak (#10837)

  • Fixed a bug that caused incorrect batch indices to be passed to the BasePredictionWriter hooks when using a dataloader with num_workers > 0 (#10870)

  • Fixed an issue with item assignment on the logger on rank > 0 for those who support it (#10917)

  • Fixed importing torch_xla.debug for torch-xla<1.8 (#10836)

  • Fixed an issue with DDPSpawnPlugin and related plugins leaving a temporary checkpoint behind (#10934)

  • Fixed a TypeError occurring in the SingalConnector.teardown() method (#10961)

[1.5.4] - 2021-11-30

[1.5.4] - Fixed
  • Fixed support for --key.help=class with the LightningCLI (#10767)

  • Fixed _compare_version for python packages (#10762)

  • Fixed TensorBoardLogger SummaryWriter not close before spawning the processes (#10777)

  • Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer (#10746)

  • Fixed the default logging level for batch hooks associated with training from on_step=False, on_epoch=True to on_step=True, on_epoch=False (#10756)

[1.5.4] - Removed

[1.5.3] - 2021-11-24

[1.5.3] - Fixed
  • Fixed ShardedTensor state dict hook registration to check if torch distributed is available (#10621)

  • Fixed an issue with self.log not respecting a tensor’s dtype when applying computations (#10076)

  • Fixed LigtningLite _wrap_init popping unexisting keys from DataLoader signature parameters (#10613)

  • Fixed signals being registered within threads (#10610)

  • Fixed an issue that caused Lightning to extract the batch size even though it was set by the user in LightningModule.log (#10408)

  • Fixed Trainer(move_metrics_to_cpu=True) not moving the evaluation logged results to CPU (#10631)

  • Fixed the {validation,test}_step outputs getting moved to CPU with Trainer(move_metrics_to_cpu=True) (#10631)

  • Fixed an issue with collecting logged test results with multiple dataloaders (#10522)

[1.5.2] - 2021-11-16

[1.5.2] - Fixed
  • Fixed CombinedLoader and max_size_cycle didn’t receive a DistributedSampler (#10374)

  • Fixed an issue where class or init-only variables of dataclasses were passed to the dataclass constructor in utilities.apply_to_collection (#9702)

  • Fixed isinstance not working with init_meta_context, materialized model not being moved to the device (#10493)

  • Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure(#10463)

  • Squeeze the early stopping monitor to remove empty tensor dimensions (#10461)

  • Fixed sampler replacement logic with overfit_batches to only replace the sample when SequentialSampler is not used (#10486)

  • Fixed scripting causing false positive deprecation warnings (#10470, #10555)

  • Do not fail if batch size could not be inferred for logging when using DeepSpeed (#10438)

  • Fixed propagation of device and dtype information to submodules of LightningLite when they inherit from DeviceDtypeModuleMixin (#10559)

[1.5.1] - 2021-11-09

[1.5.1] - Fixed
  • Fixed apply_to_collection(defaultdict) (#10316)

  • Fixed failure when DataLoader(batch_size=None) is passed (#10345)

  • Fixed interception of __init__ arguments for sub-classed DataLoader re-instantiation in Lite (#10334)

  • Fixed issue with pickling CSVLogger after a call to CSVLogger.save (#10388)

  • Fixed an import error being caused by PostLocalSGD when torch.distributed not available (#10359)

  • Fixed the logging with on_step=True in epoch-level hooks causing unintended side-effects. Logging with on_step=True in epoch-level hooks will now correctly raise an error (#10409)

  • Fixed deadlocks for distributed training with RichProgressBar (#10428)

  • Fixed an issue where the model wrapper in Lite converted non-floating point tensors to float (#10429)

  • Fixed an issue with inferring the dataset type in fault-tolerant training (#10432)

  • Fixed dataloader workers with persistent_workers being deleted on every iteration (#10434)

[1.5.0] - 2021-11-02

[1.5.0] - Added
  • Added support for monitoring the learning rate without schedulers in LearningRateMonitor (#9786)

  • Added registration of ShardedTensor state dict hooks in LightningModule.__init__ if the PyTorch version supports ShardedTensor (#8944)

  • Added error handling including calling of on_keyboard_interrupt() and on_exception() for all entrypoints (fit, validate, test, predict) (#8819)

  • Added a flavor of training_step that takes dataloader_iter as an argument (#8807)

  • Added a state_key property to the Callback base class (#6886)

  • Added progress tracking to loops:

    • Integrated TrainingEpochLoop.total_batch_idx (#8598)

    • Added BatchProgress and integrated TrainingEpochLoop.is_last_batch (#9657)

    • Avoid optional Tracker attributes (#9320)

    • Reset current progress counters when restarting an epoch loop that had already finished (#9371)

    • Call reset_on_restart in the loop’s reset hook instead of when loading a checkpoint (#9561)

    • Use completed over processed in reset_on_restart (#9656)

    • Renamed reset_on_epoch to reset_on_run (#9658)

  • Added batch_size and rank_zero_only arguments for log_dict to match log (#8628)

  • Added a check for unique GPU ids (#8666)

  • Added ResultCollection state_dict to the Loop state_dict and added support for distributed reload (#8641)

  • Added DeepSpeed collate checkpoint utility function (#8701)

  • Added a handles_accumulate_grad_batches property to the training type plugins (#8856)

  • Added a warning to WandbLogger when reusing a wandb run (#8714)

  • Added log_graph argument for watch method of WandbLogger (#8662)

  • LightningCLI additions:

    • Added LightningCLI(run=False|True) to choose whether to run a Trainer subcommand (#8751)

    • Added support to call any trainer function from the LightningCLI via subcommands (#7508)

    • Allow easy trainer re-instantiation (#7508)

    • Automatically register all optimizers and learning rate schedulers (#9565)

    • Allow registering custom optimizers and learning rate schedulers without subclassing the CLI (#9565)

    • Support shorthand notation to instantiate optimizers and learning rate schedulers (#9565)

    • Support passing lists of callbacks via command line (#8815)

    • Support shorthand notation to instantiate models (#9588)

    • Support shorthand notation to instantiate datamodules (#10011)

    • Added multifile option to LightningCLI to enable/disable config saving to preserve multiple files structure (#9073)

  • Fault-tolerant training:

    • Added FastForwardSampler and CaptureIterableDataset injection to data loading utilities (#8366)

    • Added DataFetcher to control fetching flow (#8890)

    • Added SharedCycleIteratorState to prevent infinite loop (#8889)

    • Added CaptureMapDataset for state management in map-style datasets (#8891)

    • Added Fault Tolerant Training to DataFetcher (#8891)

    • Replaced old prefetch iterator with new DataFetcher in training loop (#8953)

    • Added partial support for global random state fault-tolerance in map-style datasets (#8950)

    • Converted state to tuple explicitly when setting Python random state (#9401)

    • Added support for restarting an optimizer loop (multiple optimizers) (#9537)

    • Added support for restarting within Evaluation Loop (#9563)

    • Added mechanism to detect that a signal has been sent so the Trainer can gracefully exit (#9566)

    • Added support for skipping ahead to validation during the auto-restart of fitting (#9681)

    • Added support for auto-restart if a fault-tolerant checkpoint is available (#9722)

  • Checkpoint saving and loading extensibility:

    • Added CheckpointIO plugin to expose checkpoint IO from training type plugin (#8743)

    • Refactored CheckpointConnector to offload validation logic to the CheckpointIO plugin (#9045)

    • Added remove_checkpoint to CheckpointIO plugin by moving the responsibility out of the ModelCheckpoint callback (#9373)

    • Added XLACheckpointIO plugin (#9972)

  • Loop customization:

    • Added Closure and AbstractClosure classes (#8642)

    • Refactored TrainingBatchLoop and extracted OptimizerLoop, splitting off automatic optimization into its own loop (#9191)

    • Removed TrainingBatchLoop.backward(); manual optimization now calls directly into Accelerator.backward() and automatic optimization handles backward in new OptimizerLoop (#9265)

    • Extracted ManualOptimization logic from TrainingBatchLoop into its own separate loop class (#9266)

    • Added OutputResult and ManualResult classes (#9437, #9424)

    • Marked OptimizerLoop.backward as protected (#9514)

    • Marked FitLoop.should_accumulate as protected (#9515)

    • Marked several methods in PredictionLoop as protected: on_predict_start, on_predict_epoch_end, on_predict_end, on_predict_model_eval (#9516)

    • Marked several methods in EvaluationLoop as protected: get_max_batches, on_evaluation_model_eval, on_evaluation_model_train, on_evaluation_start, on_evaluation_epoch_start, on_evaluation_epoch_end, on_evaluation_end, reload_evaluation_dataloaders (#9516)

    • Marked several methods in EvaluationEpochLoop as protected: on_evaluation_batch_start, evaluation_step, evaluation_step_end (#9516)

    • Added yielding_training_step example (#9983)

  • Added support for saving and loading state of multiple callbacks of the same type (#7187)

  • Added DeepSpeed Stage 1 support (#8974)

  • Added Python dataclass support for LightningDataModule (#8272)

  • Added sanitization of tensors when they get logged as hyperparameters in TensorBoardLogger (#9031)

  • Added InterBatchParallelDataFetcher (#9020)

  • Added DataLoaderIterDataFetcher (#9020)

  • Added DataFetcher within Fit / Evaluation Loop (#9047)

  • Added a friendly error message when DDP attempts to spawn new distributed processes with rank > 0 (#9005)

  • Added Rich integration:

    • Added Rich progress bar (#8929, #9559)

    • Added Support for iterable datasets (#9734)

    • Added RichModelSummary callback (#9546)

    • Added configure_columns method to RichProgressBar (#10288)

    • Added leave argument to RichProgressBar (#10301)

  • Added input validation logic for precision (#9080)

  • Added support for CPU AMP autocast (#9084)

  • Added on_exception callback hook (#9183)

  • Added a warning to DeepSpeed when inferring batch size (#9221)

  • Added ModelSummary callback (#9344)

  • Added log_images, log_text and log_table to WandbLogger (#9545)

  • Added PL_RECONCILE_PROCESS environment variable to enable process reconciliation regardless of cluster environment settings (#9389)

  • Added get_device_stats to the Accelerator interface and added its implementation for GPU and TPU (#9586)

  • Added a warning when an unknown key is encountered in the optimizer configuration, and when OneCycleLR is used with "interval": "epoch" (#9666)

  • Added DeviceStatsMonitor callback (#9712)

  • Added enable_progress_bar to the Trainer constructor (#9664)

  • Added pl_legacy_patch load utility for loading old checkpoints that have pickled legacy Lightning attributes (#9166)

  • Added support for torch.use_deterministic_algorithms (#9121)

  • Added automatic parameters tying for TPUs (#9525)

  • Added support for torch.autograd.set_detect_anomaly through Trainer constructor argument detect_anomaly (#9848)

  • Added enable_model_summary flag to Trainer (#9699)

  • Added strategy argument to Trainer (#8597)

  • Added init_meta_context, materialize_module utilities (#9920)

  • Added TPUPrecisionPlugin (#10020)

  • Added torch.bfloat16 support:

    • Added bfloat16 support for Lightning Trainer (#9049)

    • Renamed TPUHalfPrecisionPlugin to TPUBf16PrecisionPlugin (#10026)

    • Default to precision=bf16 on CPU when precision=16 is passed (#10033)

    • Added support for torch.autocast (#10053)

  • Added kfold example for loop customization (#9965)

  • LightningLite:

    • Added PrecisionPlugin.forward_context, making it the default implementation for all {train,val,test,predict}_step_context() methods (#9988)

    • Added DDPSpawnPlugin.spawn() for spawning new processes of a given function (#10018, #10022)

    • Added TrainingTypePlugin.{_setup_model, _setup_optimizer} methods (#9994, #10064)

    • Implemented DataParallelPlugin._setup_model (#10010)

    • Implemented DeepSpeedPlugin._setup_model_and_optimizers (#10009, #10064)

    • Implemented {DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_model_and_optimizers (#10028, #10064)

    • Added optional model argument to the optimizer_step methods in accelerators and plugins (#10023)

    • Updated precision attributes in DeepSpeedPlugin (#10164)

    • Added the ability to return a result from rank 0 in DDPSpawnPlugin.spawn (#10162)

    • Added pytorch_lightning.lite package (#10175)

    • Added LightningLite documentation (#10043)

    • Added LightningLite examples (#9987)

    • Make the _LiteDataLoader an iterator and add supports for custom dataloader (#10279)

  • Added use_omegaconf argument to save_hparams_to_yaml plugin (#9170)

  • Added ckpt_path argument for Trainer.fit() (#10061)

  • Added auto_device_count method to Accelerators (#10222)

  • Added support for devices="auto" (#10264)

  • Added a filename argument in ModelCheckpoint.format_checkpoint_name (#9818)

  • Added support for empty gpus list to run on CPU (#10246)

  • Added a warning if multiple batch sizes are found from ambiguous batch (#10247)

[1.5.0] - Changed
  • Trainer now raises a MisconfigurationException when its methods are called with ckpt_path="best" but a checkpoint callback isn’t configured (#9841)

  • Setting Trainer(accelerator="ddp_cpu") now does not spawn a subprocess if num_processes is kept 1 along with num_nodes > 1 (#9603)

  • Module imports are now catching ModuleNotFoundError instead of ImportError (#9867)

  • pytorch_lightning.loggers.neptune.NeptuneLogger is now consistent with the new neptune-client API; the old neptune-client API is supported by NeptuneClient from the neptune-contrib repo (#6867)

  • Parsing of enums type hyperparameters to be saved in the haprams.yaml file by TensorBoard and CSV loggers has been fixed and made in line with how OmegaConf parses it (#9170)

  • Parsing of the gpus Trainer argument has changed: gpus="n" (str) no longer selects the GPU index n and instead selects the first n devices (#8770)

  • iteration_count and other index attributes in the loops has been replaced with progress dataclasses (#8477)

  • The trainer.lightning_module reference is now properly set at the very beginning of a run (#8536)

  • The model weights now get loaded in all cases when the checkpoint path gets provided in validate/test/predict, regardless of whether the model instance is provided or not (#8352)

  • The Trainer functions reset_{train,val,test,predict}_dataloader, reset_train_val_dataloaders, and request_dataloader model argument is now optional (#8536)

  • Saved checkpoints will no longer use the type of a Callback as the key to avoid issues with unpickling (#6886)

  • Improved string conversion for ResultCollection (#8622)

  • LightningCLI changes:

    • LightningCLI.init_parser now returns the parser instance (#8721)

    • LightningCLI.add_core_arguments_to_parser, LightningCLI.parse_arguments now take a parser argument (#8721)

    • LightningCLI.instantiate_trainer now takes a config and a list of callbacks (#8721)

    • Split LightningCLI.add_core_arguments_to_parser into LightningCLI.add_default_arguments_to_parser + LightningCLI.add_core_arguments_to_parser (#8721)

  • The accelerator and training type plugin setup hooks no longer have a model argument (#8536)

  • The accelerator and training type plugin update_global_step hook has been removed (#8856)

  • The coverage of self.log-ing in any LightningModule or Callback hook has been improved (#8498)

  • self.log-ing without a Trainer reference now raises a warning instead of an exception (#9733)

  • Removed restrictions in the Trainer that loggers can only log from rank 0; the existing logger behavior has not changed (#8608)

  • Trainer.request_dataloader now takes a RunningStage enum instance (#8858)

  • Changed rank_zero_warn to NotImplementedError in the {train, val, test, predict}_dataloader hooks that Lightning(Data)Module uses (#9161)

  • Moved block_ddp_sync_behaviour out of TrainingBatchLoop to loop utilities (#9192)

  • Executing the optimizer_closure is now required when overriding the optimizer_step hook (#9360)

  • Changed logging of LightningModule and LightningDataModule hyperparameters to raise an exception only if there are colliding keys with different values (#9496)

  • seed_everything now fails when an invalid seed value is passed instead of selecting a random seed (#8787)

  • The Trainer now calls TrainingTypePlugin collective APIs directly instead of going through the Accelerator reference (#9677, #9901)

  • The tuner now uses a unique filename to save a temporary checkpoint (#9682)

  • Changed HorovodPlugin.all_gather to return a torch.Tensor instead of a list (#9696)

  • Changed Trainer connectors to be protected attributes:

    • Configuration Validator (#9779)

  • The current_epoch and global_step attributes now get restored irrespective of the Trainer task (#9413)

  • Trainer now raises an exception when requesting amp_level with native amp_backend (#9755)

  • Update the logic to check for accumulation steps with deepspeed (#9826)

  • pytorch_lightning.utilities.grads.grad_norm now raises an exception if parameter norm_type <= 0 (#9765)

  • Updated error message for interactive incompatible plugins (#9896)

  • Moved the optimizer_step and clip_gradients hook from the Accelerator and TrainingTypePlugin into the PrecisionPlugin (#10143, #10029)

  • NativeMixedPrecisionPlugin and its subclasses now take an optional GradScaler instance (#10055)

  • Trainer is now raising a MisconfigurationException instead of a warning if Trainer.{validate/test} is missing required methods (#10016)

  • Changed default value of the max_steps Trainer argument from None to -1 (#9460)

  • LightningModule now raises an error when calling log(on_step=False, on_epoch=False) (#10227)

  • Quantization aware training observers are now disabled by default during validating/testing/predicting stages (#8540)

  • Raised MisconfigurationException when total length of dataloader across ranks is zero, and give warning when total length is non-zero, but only local rank length is zero. (#9827)

  • Changed the model size calculation using ByteCounter (#10123)

  • Enabled on_load_checkpoint for LightningDataModule for all trainer_fn (#10238)

  • Allowed separate config files for parameters with class type when LightningCLI is in subclass_mode=False (#10286)

[1.5.0] - Deprecated
  • Deprecated Trainer argument terminate_on_nan in favor of detect_anomaly(#9175)

  • Deprecated Trainer.terminate_on_nan public attribute access (#9849)

  • Deprecated LightningModule.summarize() in favor of pytorch_lightning.utilities.model_summary.summarize() (#8513)

  • Deprecated LightningModule.model_size (#8343)

  • Deprecated DataModule properties: train_transforms, val_transforms, test_transforms, size, dims (#8851)

  • Deprecated add_to_queue, get_from_queue from LightningModule in favor of corresponding methods in the DDPSpawnPlugin (#9118)

  • Deprecated LightningModule.get_progress_bar_dict and Trainer.progress_bar_dict in favor of pytorch_lightning.callbacks.progress.base.get_standard_metrics and ProgressBarBase.get_metrics (#8985)

  • Deprecated prepare_data_per_node flag on Trainer and set it as a property of DataHooks, accessible in the LightningModule and LightningDataModule (#8958)

  • Deprecated the TestTubeLogger (#9065)

  • Deprecated on_{train/val/test/predict}_dataloader() from LightningModule and LightningDataModule (#9098)

  • Deprecated on_keyboard_interrupt callback hook in favor of new on_exception hook (#9260)

  • Deprecated passing process_position to the Trainer constructor in favor of adding the ProgressBar callback with process_position directly to the list of callbacks (#9222)

  • Deprecated passing flush_logs_every_n_steps as a Trainer argument, instead pass it to the logger init if supported (#9366)

  • Deprecated LightningLoggerBase.close, LoggerCollection.close in favor of LightningLoggerBase.finalize, LoggerCollection.finalize (#9422)

  • Deprecated passing progress_bar_refresh_rate to the Trainer constructor in favor of adding the ProgressBar callback with refresh_rate directly to the list of callbacks, or passing enable_progress_bar=False to disable the progress bar (#9616)

  • Deprecated LightningDistributed and moved the broadcast logic to DDPPlugin and DDPSpawnPlugin directly (#9691)

  • Deprecated passing stochastic_weight_avg to the Trainer constructor in favor of adding the StochasticWeightAveraging callback directly to the list of callbacks (#8989)

  • Deprecated Accelerator collective API barrier, broadcast, and all_gather in favor of calling the TrainingTypePlugin collective API directly (#9677)

  • Deprecated checkpoint_callback from the Trainer constructor in favor of enable_checkpointing (#9754)

  • Deprecated the LightningModule.on_post_move_to_device method (#9525)

  • Deprecated pytorch_lightning.core.decorators.parameter_validation in favor of pytorch_lightning.utilities.parameter_tying.set_shared_parameters (#9525)

  • Deprecated passing weights_summary to the Trainer constructor in favor of adding the ModelSummary callback with max_depth directly to the list of callbacks (#9699)

  • Deprecated log_gpu_memory, gpu_metrics, and util funcs in favor of DeviceStatsMonitor callback (#9921)

  • Deprecated GPUStatsMonitor and XLAStatsMonitor in favor of DeviceStatsMonitor callback (#9924)

  • Deprecated setting Trainer(max_steps=None); To turn off the limit, set Trainer(max_steps=-1) (default) (#9460)

  • Deprecated access to the AcceleratorConnector.is_slurm_managing_tasks attribute and marked it as protected (#10101)

  • Deprecated access to the AcceleratorConnector.configure_slurm_ddp method and marked it as protected (#10101)

  • Deprecated passing resume_from_checkpoint to the Trainer constructor in favor of trainer.fit(ckpt_path=) (#10061)

  • Deprecated ClusterEnvironment.creates_children() in favor of ClusterEnvironment.creates_processes_externally (property) (#10106)

  • Deprecated PrecisionPlugin.master_params() in favor of PrecisionPlugin.main_params() (#10105)

  • Deprecated lr_sch_names from LearningRateMonitor (#10066)

  • Deprecated ProgressBar callback in favor of TQDMProgressBar (#10134)

[1.5.0] - Removed
  • Removed deprecated metrics (#8586)

  • Removed the deprecated outputs argument in both the LightningModule.on_train_epoch_end and Callback.on_train_epoch_end hooks (#8587)

  • Removed the deprecated TrainerLoggingMixin class (#8609)

  • Removed the deprecated TrainerTrainingTricksMixin class (#8679)

  • Removed the deprecated optimizer_idx from training_step as an accepted argument in manual optimization (#8576)

  • Removed support for the deprecated on_save_checkpoint signature. The hook now takes a checkpoint positional parameter (#8697)

  • Removed support for the deprecated on_load_checkpoint signature. The hook now takes a pl_module positional parameter (#8697)

  • Removed the deprecated save_function property in ModelCheckpoint (#8680)

  • Removed the deprecated model argument from ModelCheckpoint.save_checkpoint (#8688)

  • Removed the deprecated sync_step argument from WandbLogger (#8763)

  • Removed the deprecated Trainer.truncated_bptt_steps in favor of LightningModule.truncated_bptt_steps (#8826)

  • Removed LightningModule.write_predictions and LightningModule.write_predictions_dict (#8850)

  • Removed on_reset_*_dataloader hooks in TrainingType Plugins and Accelerators (#8858)

  • Removed deprecated GradInformation module in favor of pytorch_lightning.utilities.grads (#8831)

  • Removed TrainingTypePlugin.on_save and Accelerator.on_save (#9023)

  • Removed {Accelerator,TrainingTypePlugin,PrecisionPlugin}.post_optimizer_step (#9746)

  • Removed deprecated connect_precision_plugin and connect_training_type_plugin from Accelerator (#9019)

  • Removed on_train_epoch_end from Accelerator (#9035)

  • Removed InterBatchProcessor in favor of DataLoaderIterDataFetcher (#9052)

  • Removed Plugin in base_plugin.py in favor of accessing TrainingTypePlugin and PrecisionPlugin directly instead (#9066)

  • Removed teardown from ParallelPlugin (#8943)

  • Removed deprecated profiled_functions argument from PyTorchProfiler (#9178)

  • Removed deprecated pytorch_lighting.utilities.argparse_utils module (#9166)

  • Removed deprecated property Trainer.running_sanity_check in favor of Trainer.sanity_checking (#9209)

  • Removed deprecated BaseProfiler.output_filename arg from it and its descendants in favor of dirpath and filename (#9214)

  • Removed deprecated property ModelCheckpoint.period in favor of ModelCheckpoint.every_n_epochs (#9213)

  • Removed deprecated auto_move_data decorator (#9231)

  • Removed deprecated property LightningModule.datamodule in favor of Trainer.datamodule (#9233)

  • Removed deprecated properties DeepSpeedPlugin.cpu_offload* in favor of offload_optimizer, offload_parameters and pin_memory (#9244)

  • Removed deprecated property AcceleratorConnector.is_using_torchelastic in favor of TorchElasticEnvironment.is_using_torchelastic() (#9729)

  • Removed pytorch_lightning.utilities.debugging.InternalDebugger (#9680)

  • Removed call_configure_sharded_model_hook property from Accelerator and TrainingTypePlugin (#9612)

  • Removed TrainerProperties mixin and moved property definitions directly into Trainer (#9495)

  • Removed a redundant warning with ModelCheckpoint(monitor=None) callback (#9875)

  • Remove epoch from trainer.logged_metrics (#9904)

  • Remove deprecated distributed_backend from Trainer (#10017)

  • Removed process_idx from the {DDPSpawnPlugin,TPUSpawnPlugin}.new_process methods (#10022)

  • Removed automatic patching of {train,val,test,predict}_dataloader() on the LightningModule (#9764)

  • Removed pytorch_lightning.trainer.connectors.OptimizerConnector (#10120)

[1.5.0] - Fixed
  • Fixed ImageNet evaluation in example (#10179)

  • Fixed an issue with logger outputs not being finalized correctly after prediction runs (#8685)

  • Fixed move_metrics_to_cpu moving the loss to CPU while training on device (#9308)

  • Fixed incorrect main progress bar indicator when resuming training mid-epoch (#9310)

  • Fixed an issue with freeing memory of datafetchers during teardown (#9387)

  • Fixed a bug where the training step output needed to be deepcopy-ed (#9349)

  • Fixed an issue with freeing memory allocated by the data iterators in Loop.on_run_end (#9386, #9915)

  • Fixed BasePredictionWriter not returning the batch indices in a non-distributed setting (#9432)

  • Fixed an error when running in XLA environments with no TPU attached (#9572)

  • Fixed check on torchmetrics logged whose compute() output is a multielement tensor (#9582)

  • Fixed gradient accumulation for DDPShardedPlugin (#9122)

  • Fixed missing DeepSpeed distributed call (#9540)

  • Fixed an issue with wrapped LightningModule during evaluation; The LightningModule no longer gets wrapped with data-parallel modules when not fitting in DDPPlugin, DDPSpawnPlugin, DDPShardedPlugin, DDPSpawnShardedPlugin (#9096)

  • Fixed trainer.accumulate_grad_batches to be an int on init. The default value for it is now None inside Trainer (#9652)

  • Fixed broadcast in DDPPlugin and DDPSpawnPlugin to respect the src input (#9691)

  • Fixed self.log(on_epoch=True, reduce_fx=sum)) for the on_batch_start and on_train_batch_start hooks (#9791)

  • Fixed self.log(on_epoch=True) for the on_batch_start and on_train_batch_start hooks (#9780)

  • Fixed restoring training state during Trainer.fit only (#9413)

  • Fixed DeepSpeed and Lightning both calling the scheduler (#9788)

  • Fixed missing arguments when saving hyperparameters from the parent class but not from the child class (#9800)

  • Fixed DeepSpeed GPU device IDs (#9847)

  • Reset val_dataloader in tuner/batch_size_scaling (#9857)

  • Fixed use of LightningCLI in computer_vision_fine_tuning.py example (#9934)

  • Fixed issue with non-init dataclass fields in apply_to_collection (#9963)

  • Reset val_dataloader in tuner/batch_size_scaling for binsearch (#9975)

  • Fixed logic to check for spawn in dataloader TrainerDataLoadingMixin._worker_check (#9902)

  • Fixed train_dataloader getting loaded twice when resuming from a checkpoint during Trainer.fit() (#9671)

  • Fixed LearningRateMonitor logging with multiple param groups optimizer with no scheduler (#10044)

  • Fixed undesired side effects being caused by Trainer patching dataloader methods on the LightningModule (#9764)

  • Fixed gradients not being unscaled when clipping or logging the gradient norm (#9287)

  • Fixed on_before_optimizer_step getting called before the optimizer closure (including backward) has run (#10167)

  • Fixed monitor value in ModelCheckpoint getting moved to the wrong device in a special case where it becomes NaN (#10118)

  • Fixed creation of dirpath in BaseProfiler if it doesn’t exist (#10073)

  • Fixed incorrect handling of sigterm (#10189)

  • Fixed bug where log(on_step=True, on_epoch=True, sync_dist=True) wouldn’t reduce the value on step (#10227)

  • Fixed an issue with pl.utilities.seed.reset_seed converting the PL_SEED_WORKERS environment variable to bool (#10099)

  • Fixed iterating over a logger collection when fast_dev_run > 0 (#10232)

  • Fixed batch_size in ResultCollection not being reset to 1 on epoch end (#10242)

  • Fixed distrib_type not being set when training plugin instances are being passed to the Trainer (#10251)

[1.4.9] - 2021-09-30

  • Fixed lr_find to generate same results on multiple calls (#9704)

  • Fixed reset metrics on validation epoch end (#9717)

  • Fixed input validation for gradient_clip_val, gradient_clip_algorithm, track_grad_norm and terminate_on_nan Trainer arguments (#9595)

  • Reset metrics before each task starts (#9410)

[1.4.8] - 2021-09-22

  • Fixed error reporting in DDP process reconciliation when processes are launched by an external agent (#9389)

  • Added PL_RECONCILE_PROCESS environment variable to enable process reconciliation regardless of cluster environment settings (#9389)

  • Fixed add_argparse_args raising TypeError when args are typed as typing.Generic in Python 3.6 (#9554)

  • Fixed back-compatibility for saving hyperparameters from a single container and inferring its argument name by reverting #9125 (#9642)

[1.4.7] - 2021-09-14

  • Fixed logging of nan parameters (#9364)

  • Fixed replace_sampler missing the batch size under specific conditions (#9367)

  • Pass init args to ShardedDataParallel (#9483)

  • Fixed collision of user argument when using ShardedDDP (#9512)

  • Fixed DeepSpeed crash for RNNs (#9489)

[1.4.6] - 2021-09-07

  • Fixed an issues with export to ONNX format when a model has multiple inputs (#8800)

  • Removed deprecation warnings being called for on_{task}_dataloader (#9279)

  • Fixed save/load/resume from checkpoint for DeepSpeed Plugin ( #8397, #8644, #8627)

  • Fixed EarlyStopping running on train epoch end when check_val_every_n_epoch>1 is set (#9156)

  • Fixed an issue with logger outputs not being finalized correctly after prediction runs (#8333)

  • Fixed the Apex and DeepSpeed plugin closure running after the on_before_optimizer_step hook (#9288)

  • Fixed the Native AMP plugin closure not running with manual optimization (#9288)

  • Fixed bug where data-loading functions where not getting the correct running stage passed (#8858)

  • Fixed intra-epoch evaluation outputs staying in memory when the respective *_epoch_end hook wasn’t overridden (#9261)

  • Fixed error handling in DDP process reconciliation when _sync_dir was not initialized (#9267)

  • Fixed PyTorch Profiler not enabled for manual optimization (#9316)

  • Fixed inspection of other args when a container is specified in save_hyperparameters (#9125)

  • Fixed signature of Timer.on_train_epoch_end and StochasticWeightAveraging.on_train_epoch_end to prevent unwanted deprecation warnings (#9347)

[1.4.5] - 2021-08-31

  • Fixed reduction using self.log(sync_dict=True, reduce_fx={mean,max}) (#9142)

  • Fixed not setting a default value for max_epochs if max_time was specified on the Trainer constructor (#9072)

  • Fixed the CometLogger, no longer modifies the metrics in place. Instead creates a copy of metrics before performing any operations (#9150)

  • Fixed DDP “CUDA error: initialization error” due to a copy instead of deepcopy on ResultCollection (#9239)

[1.4.4] - 2021-08-24

  • Fixed a bug in the binary search mode of auto batch size scaling where exception was raised if the first trainer run resulted in OOM (#8954)

  • Fixed a bug causing logging with log_gpu_memory='min_max' not working (#9013)

[1.4.3] - 2021-08-17

  • Fixed plateau scheduler stepping on incomplete epoch (#8861)

  • Fixed infinite loop with CycleIterator and multiple loaders (#8889)

  • Fixed StochasticWeightAveraging with a list of learning rates not applying them to each param group (#8747)

  • Restore original loaders if replaced by entrypoint (#8885)

  • Fixed lost reference to _Metadata object in ResultMetricCollection (#8932)

  • Ensure the existence of DDPPlugin._sync_dir in reconciliate_processes (#8939)

[1.4.2] - 2021-08-10

  • Fixed recursive call for apply_to_collection(include_none=False) (#8719)

  • Fixed truncated backprop through time enablement when set as a property on the LightningModule and not the Trainer (#8804)

  • Fixed comments and exception message for metrics_to_scalars (#8782)

  • Fixed typo error in LightningLoggerBase.after_save_checkpoint docstring (#8737)

[1.4.1] - 2021-08-03

  • Fixed trainer.fit_loop.split_idx always returning None (#8601)

  • Fixed references for ResultCollection.extra (#8622)

  • Fixed reference issues during epoch end result collection (#8621)

  • Fixed horovod auto-detection when horovod is not installed and the launcher is mpirun (#8610)

  • Fixed an issue with training_step outputs not getting collected correctly for training_epoch_end (#8613)

  • Fixed distributed types support for CPUs (#8667)

  • Fixed a deadlock issue with DDP and torchelastic (#8655)

  • Fixed accelerator=ddp choice for CPU (#8645)

[1.4.0] - 2021-07-27

[1.4.0] - Added
  • Added extract_batch_size utility and corresponding tests to extract batch dimension from multiple batch types (#8357)

  • Added support for named parameter groups in LearningRateMonitor (#7987)

  • Added dataclass support for pytorch_lightning.utilities.apply_to_collection (#7935)

  • Added support to LightningModule.to_torchscript for saving to custom filesystems with fsspec (#7617)

  • Added KubeflowEnvironment for use with the PyTorchJob operator in Kubeflow

  • Added LightningCLI support for config files on object stores (#7521)

  • Added ModelPruning(prune_on_train_epoch_end=True|False) to choose when to apply pruning (#7704)

  • Added support for checkpointing based on a provided time interval during training (#7515)

  • Progress tracking

    • Added dataclasses for progress tracking (#6603, #7574, #8140, #8362)

    • Add {,load_}state_dict to the progress tracking dataclasses (#8140)

    • Connect the progress tracking dataclasses to the loops (#8244, #8362)

    • Do not reset the progress tracking dataclasses total counters (#8475)

  • Added support for passing a LightningDataModule positionally as the second argument to trainer.{validate,test,predict} (#7431)

  • Added argument trainer.predict(ckpt_path) (#7430)

  • Added clip_grad_by_value support for TPUs (#7025)

  • Added support for passing any class to is_overridden (#7918)

  • Added sub_dir parameter to TensorBoardLogger (#6195)

  • Added correct dataloader_idx to batch transfer hooks (#6241)

  • Added include_none=bool argument to apply_to_collection (#7769)

  • Added apply_to_collections to apply a function to two zipped collections (#7769)

  • Added ddp_fully_sharded support (#7487)

  • Added should_rank_save_checkpoint property to Training Plugins (#7684)

  • Added log_grad_norm hook to LightningModule to customize the logging of gradient norms (#7873)

  • Added save_config_filename init argument to LightningCLI to ease resolving name conflicts (#7741)

  • Added save_config_overwrite init argument to LightningCLI to ease overwriting existing config files (#8059)

  • Added reset dataloader hooks to Training Plugins and Accelerators (#7861)

  • Added trainer stage hooks for Training Plugins and Accelerators (#7864)

  • Added the on_before_optimizer_step hook (#8048)

  • Added IPU Accelerator (#7867)

  • Fault-tolerant training

    • Added {,load_}state_dict to ResultCollection (#7948)

    • Added {,load_}state_dict to Loops (#8197)

    • Added FastForwardSampler and CaptureIterableDataset (#8307)

    • Set Loop.restarting=False at the end of the first iteration (#8362)

    • Save the loops state with the checkpoint (opt-in) (#8362)

    • Save a checkpoint to restore the state on exception (opt-in) (#8362)

    • Added state_dict and load_state_dict utilities for CombinedLoader + utilities for dataloader (#8364)

  • Added rank_zero_only to LightningModule.log function (#7966)

  • Added metric_attribute to LightningModule.log function (#7966)

  • Added a warning if Trainer(log_every_n_steps) is a value too high for the training dataloader (#7734)

  • Added LightningCLI support for argument links applied on instantiation (#7895)

  • Added LightningCLI support for configurable callbacks that should always be present (#7964)

  • Added DeepSpeed Infinity Support, and updated to DeepSpeed 0.4.0 (#7234)

  • Added support for torch.nn.UninitializedParameter in ModelSummary (#7642)

  • Added support LightningModule.save_hyperparameters when LightningModule is a dataclass (#7992)

  • Added support for overriding optimizer_zero_grad and optimizer_step when using accumulate_grad_batches (#7980)

  • Added logger boolean flag to save_hyperparameters (#7960)

  • Added support for calling scripts using the module syntax (python -m package.script) (#8073)

  • Added support for optimizers and learning rate schedulers to LightningCLI (#8093)

  • Added XLA Profiler (#8014)

  • Added PrecisionPlugin.{pre,post}_backward (#8328)

  • Added on_load_checkpoint and on_save_checkpoint hooks to the PrecisionPlugin base class (#7831)

  • Added max_depth parameter in ModelSummary (#8062)

  • Added XLAStatsMonitor callback (#8235)

  • Added restore function and restarting attribute to base Loop (#8247)

  • Added support for save_hyperparameters in LightningDataModule (#3792)

  • Added the ModelCheckpoint(save_on_train_epoch_end) to choose when to run the saving logic (#8389)

  • Added LSFEnvironment for distributed training with the LSF resource manager jsrun (#5102)

  • Added support for accelerator='cpu'|'gpu'|'tpu'|'ipu'|'auto' (#7808)

  • Added tpu_spawn_debug to plugin registry (#7933)

  • Enabled traditional/manual launching of DDP processes through LOCAL_RANK and NODE_RANK environment variable assignments (#7480)

  • Added quantize_on_fit_end argument to QuantizationAwareTraining (#8464)

  • Added experimental support for loop specialization (#8226)

  • Added support for devices flag to Trainer (#8440)

  • Added private prevent_trainer_and_dataloaders_deepcopy context manager on the LightningModule (#8472)

  • Added support for providing callables to the Lightning CLI instead of types (#8400)

[1.4.0] - Changed
  • Decoupled device parsing logic from Accelerator connector to Trainer (#8180)

  • Changed the Trainer’s checkpoint_callback argument to allow only boolean values (#7539)

  • Log epoch metrics before the on_evaluation_end hook (#7272)

  • Explicitly disallow calling self.log(on_epoch=False) during epoch-only or single-call hooks (#7874)

  • Changed these Trainer methods to be protected: call_setup_hook, call_configure_sharded_model, pre_dispatch, dispatch, post_dispatch, call_teardown_hook, run_train, run_sanity_check, run_evaluate, run_evaluation, run_predict, track_output_for_epoch_end

  • Changed metrics_to_scalars to work with any collection or value (#7888)

  • Changed clip_grad_norm to use torch.nn.utils.clip_grad_norm_ (#7025)

  • Validation is now always run inside the training epoch scope (#7357)

  • ModelCheckpoint now runs at the end of the training epoch by default (#8389)

  • EarlyStopping now runs at the end of the training epoch by default (#8286)

  • Refactored Loops

    • Moved attributes global_step, current_epoch, max/min_steps, max/min_epochs, batch_idx, and total_batch_idx to TrainLoop (#7437)

    • Refactored result handling in training loop (#7506)

    • Moved attributes hiddens and split_idx to TrainLoop (#7507)

    • Refactored the logic around manual and automatic optimization inside the optimizer loop (#7526)

    • Simplified “should run validation” logic (#7682)

    • Simplified logic for updating the learning rate for schedulers (#7682)

    • Removed the on_epoch guard from the “should stop” validation check (#7701)

    • Refactored internal loop interface; added new classes FitLoop, TrainingEpochLoop, TrainingBatchLoop (#7871, #8077)

    • Removed pytorch_lightning/trainer/training_loop.py (#7985)

    • Refactored evaluation loop interface; added new classes DataLoaderLoop, EvaluationLoop, EvaluationEpochLoop (#7990, #8077)

    • Removed pytorch_lightning/trainer/evaluation_loop.py (#8056)

    • Restricted public access to several internal functions (#8024)

    • Refactored trainer _run_* functions and separate evaluation loops (#8065)

    • Refactored prediction loop interface; added new classes PredictionLoop, PredictionEpochLoop (#7700, #8077)

    • Removed pytorch_lightning/trainer/predict_loop.py (#8094)

    • Moved result teardown to the loops (#8245)

    • Improve Loop API to better handle children state_dict and progress (#8334)

  • Refactored logging

    • Renamed and moved core/step_result.py to trainer/connectors/logger_connector/result.py (#7736)

    • Dramatically simplify the LoggerConnector (#7882)

    • trainer.{logged,progress_bar,callback}_metrics are now updated on-demand (#7882)

    • Completely overhaul the Result object in favor of ResultMetric (#7882)

    • Improve epoch-level reduction time and overall memory usage (#7882)

    • Allow passing self.log(batch_size=...) (#7891)

    • Each of the training loops now keeps its own results collection (#7891)

    • Remove EpochResultStore and HookResultStore in favor of ResultCollection (#7909)

    • Remove MetricsHolder (#7909)

  • Moved ignore_scalar_return_in_dp warning suppression to the DataParallelPlugin class (#7421)

  • Changed the behaviour when logging evaluation step metrics to no longer append /epoch_* to the metric name (#7351)

  • Raised ValueError when a None value is self.log-ed (#7771)

  • Changed resolve_training_type_plugins to allow setting num_nodes and sync_batchnorm from Trainer setting (#7026)

  • Default seed_everything(workers=True) in the LightningCLI (#7504)

  • Changed model.state_dict() in CheckpointConnector to allow training_type_plugin to customize the model’s state_dict() (#7474)

  • MLflowLogger now uses the env variable MLFLOW_TRACKING_URI as default tracking URI (#7457)

  • Changed Trainer arg and functionality from reload_dataloaders_every_epoch to reload_dataloaders_every_n_epochs (#5043)

  • Changed WandbLogger(log_model={True/'all'}) to log models as artifacts (#6231)

  • MLFlowLogger now accepts run_name as an constructor argument (#7622)

  • Changed teardown() in Accelerator to allow training_type_plugin to customize teardown logic (#7579)

  • Trainer.fit now raises an error when using manual optimization with unsupported features such as gradient_clip_val or accumulate_grad_batches (#7788)

  • Accelerator hooks are called regardless if LightningModule overrides the same hooks (#7826)

  • Moved profilers to their own file (#7822)

  • The on_after_backward hook is now called on accumulating iterations. Use the on_before_optimizer_step hook to mimic the old behaviour (#8328)

  • The mixed precision loss is no longer unscaled before the on_after_backward hook. Use the on_before_optimizer_step hook to mimic the old behaviour (#8328)

  • The TrainingTypePlugin.{pre,post}_backward hooks no longer take the optimizer, opt_idx, should_accumulate arguments (#8328)

  • The PrecisionPlugin.backward hooks no longer returns a value (#8328)

  • The PrecisionPlugin.backward hooks no longer takes a should_accumulate argument (#8328)

  • Added the on_before_backward hook (#7865)

  • LightningCLI now aborts with a clearer message if config already exists and disables save config during fast_dev_run(#7963)

  • Saved the LightningCLI config on setup and only on the main process (#8017)

  • Dropped the LightningCLI ArgumentParser when pickling (#8017)

  • Skip broadcast if distributed not initialized for the spawn plugins (#8017)

  • Trainer(resume_from_checkpoint=...) now restores the model directly after LightningModule.setup(), which is before LightningModule.configure_sharded_model() (#7652)

  • Moved torch.cuda.set_device() to enable collective calls earlier in setup (#8312)

  • Used XLA utility API to move data to CPU (Single TPU core) (#8078)

  • Improved error messages in replace_sampler when the DataLoader attributes are not included in the signature or the signature is missing optional arguments (#8519)

  • Moved DeviceDtypeModuleMixin and HyperparametersMixin mixin to core (#8396)

  • Return the default_root_dir as the log_dir when the logger is a LoggerCollection (#8187)

[1.4.0] - Deprecated
  • Deprecated LightningModule.loaded_optimizer_states_dict (#8229)

  • Standardized the dataloaders arguments of trainer.{fit,valdiate,test,tune} (#7431)

  • Deprecated DataModule properties: has_prepared_data, has_setup_fit, has_setup_validate, has_setup_test, has_setup_predict, has_teardown_fit, has_teardown_validate, has_teardown_test, has_teardown_predict (#7657)

  • Deprecated TrainerModelHooksMixin in favor of pytorch_lightning.utilities.signature_utils (#7422)

  • Deprecated num_nodes and sync_batchnorm arguments in DDPPlugin and DDPSpawnPlugin (#7026)

  • Deprecated self.log(sync_dist_op) in favor of self.log(reduce_fx). (#7891)

  • Deprecated is_overridden(model=...) in favor of is_overridden(instance=...) (#7918)

  • Deprecated automatically detaching returned extras with grads (#7994)

  • Deprecated default value of monitor argument in EarlyStopping callback to enforce monitor as a required argument (#7907)

  • Deprecated importing rank_zero_{warn,deprecation} directly from pytorch_lightning.utilities.distributed (#8085)

  • Deprecated the use of CheckpointConnector.hpc_load() in favor of CheckpointConnector.restore() (#7652)

  • Deprecated ModelCheckpoint(every_n_val_epochs) in favor of ModelCheckpoint(every_n_epochs) (#8383)

  • Deprecated DDPPlugin.task_idx in favor of DDPPlugin.local_rank (#8203)

  • Deprecated the Trainer.train_loop property in favor of Trainer.fit_loop (#8025)

  • Deprecated the Trainer.disable_validation property in favor of not Trainer.enable_validation (#8291)

  • Deprecated mode parameter in ModelSummary in favor of max_depth (#8062)

  • Deprecated reload_dataloaders_every_epoch argument of Trainer in favor of reload_dataloaders_every_n_epochs (#5043)

  • Deprecated distributed_backend argument for Trainer (#8575)

[1.4.0] - Removed
  • Dropped official support/testing for PyTorch <1.6 (#8288)

  • Removed ProfilerConnector (#7654)

  • Pruned deprecated classif. metrics from pytorch_lightning.metrics.functional.classification (#7499)

  • Removed deprecated data parallel classes LightningDataParallel and LightningDistributedDataParallel from pytorch_lightning.overrides.data_parallel (#7510)

  • Removed deprecated trainer attributes - get_model and accelerator_backend (#7502)

  • Removed support for automatically monitoring the val_loss key with ModelCheckpoint. Pass your monitor of choice to the ModelCheckpoint instance instead (#8293)

  • Removed support for self.log(tbptt_reduce_fx) and self.log(tbptt_pad_token). Please, open a discussion explaining your use-case if you relied on these. (#7644)

  • Removed deprecated utils modules model_utils, warning_utils, xla_device_utils and partially argparse_utils (#7503)

  • Removed RPCPlugin and RPCSequentialPlugin. If you were successfully using these plugins, please open a GitHub discussion about your use case (#8101)

  • Removed deprecated trainer attributes - on_cpu, on_tpu, use_tpu, on_gpu, use_dp, use_ddp, use_ddp2, use_horovod, use_single_gpu (#7501)

  • Removed deprecated optimizer argument in LightningModule.manual_backward(); Toggling optimizers in manual optimization should be done using LightningModule.{un}toggle_optimizer() (#8287)

  • Removed DeepSpeed FP16 Exception as FP32 is now supported (#8462)

  • Removed environment variable PL_EXP_VERSION from DDP subprocesses (7403)

[1.4.0] - Fixed
  • Fixed the GPUStatsMonitor callbacks to use the correct GPU IDs if CUDA_VISIBLE_DEVICES set (#8260)

  • Fixed lr_scheduler checkpointed state by calling update_lr_schedulers before saving checkpoints (#7877)

  • Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled (#7685)

  • Fixed dev debugger memory growing due to tracking events even when disabled (#7875)

  • Fixed None loss keys getting added in training_epoch_end when using manual optimization and not returning a loss (#7772)

  • Fixed a bug where precision=64 with accelerator='ddp_spawn' would throw a pickle error (#6924)

  • Do not override the existing epoch value in logged_metrics when already logged by the user (#7982)

  • Support for manual optimization with DeepSpeed (#7970)

  • Fixed dataloader_idx argument value when predicting with only one DataLoader (#7941)

  • Fixed passing the stage argument of Callback.{setup,teardown} as a keyword (#7973)

  • Fixed metrics generated during validation sanity checking are cleaned on end (#8171)

  • Fixed log_gpu_memory metrics not being added to logging when nothing else is logged (#8174)

  • Fixed a bug where calling log with a Metric instance would raise an error if it was a nested attribute of the model (#8181)

  • Fixed a bug where using precision=64 would cause buffers with complex dtype to be cast to real (#8208)

  • Fixed is_overridden returning true for wrapped functions with no changes (#8296)

  • Fixed a bug where truncated_bptt_steps would throw an AttributeError when the target RNN has multiple hidden states (#8145)

  • Fixed self.optimizers() not returning a single optimizer if it had been wrapped (#8326)

  • Fixed the on_after_backward hook not getting called when using manual optimization and no plugins (#8328)

  • Fixed the LightningModule.backward hook only getting called with the apex plugin when using manual optimization (#8328)

  • Fixed moving batch to device before sending it to the on_*_batch_start/on_*_batch_end callbacks and model hooks (#7378)

  • Fixed passing a custom DDPPlugin when choosing accelerator="ddp_cpu" for the accelerator (#6208)

  • Fixed missing call to LightningModule.untoggle_optimizer in training loop when running gradient accumulation with multiple optimizers (#8284)

  • Fixed hash of LightningEnum to work with value instead of name (#8421).

  • Fixed a bug where an extra checkpoint was saved at the end of training if the val_check_interval did not align with the number of training batches (#7724)

  • Fixed hash of LightningEnum to work with value instead of name(#8421).

  • Fixed move_data_to_device to return the batch if the object to function didn’t return self (#8433)

  • Fixed progress bar updates for Pod Training (#8258)

  • Fixed clearing dataloader references before attaching new dataloaders in consecutive `Trainer.{fit,validate,test,predict}´ runs (#8442)

  • Fixed memory leaks on GPU by moving optimizer_states, ResultCollection.extra, ResultMetric attributes, and LoggerConnector metrics to cpu. Also, delete the DDP wrapper on teardown (#8490)

  • Fixed SWA callback using LightningModule prevent_trainer_and_dataloaders_deepcopy to avoid OOM (#8472)

  • Fixed ModelPruning callback on_save_checkpoint to avoid making a deepcopy potentially leading to OOM (#8472)

  • Fixed the sampler replacement logic for DataLoaders which do not define all DataLoader attributes as __init__ parameters (#8519)

  • Fixed DeepSpeed Windows support (#8488)

  • Fixed DeepSpeed not properly setting the trainer lr_schedulers attribute (#8527)

  • Fixed experiment version and log-dir divergence in DDP when using multiple Trainer instances in sequence (7403)

  • Enabled manual optimization for TPUs (#8458)

  • Fixed accumulate_grad_batches not been recomputed during model reload (#5334)

  • Fixed a TypeError when wrapping optimizers in the HorovodPlugin and running Trainer.test (#7840)

  • Fixed BackboneFinetuning restoration (#8501)

  • Fixed lr_scheduler with metric (e.g. torch.optim.lr_scheduler.ReduceLROnPlateau) when using automatic_optimization = False (#7643)

  • Fixed DeepSpeed breaking with no schedulers (#8580)

[1.3.8] - 2021-07-01

[1.3.8] - Fixed
  • Fixed a sync deadlock when checkpointing a LightningModule that uses a torchmetrics 0.4 Metric (#8218)

  • Fixed compatibility TorchMetrics v0.4 (#8206)

  • Added torchelastic check when sanitizing GPUs (#8095)

  • Fixed a DDP info message that was never shown (#8111)

  • Fixed metrics deprecation message at module import level (#8163)

  • Fixed a bug where an infinite recursion would be triggered when using the BaseFinetuning callback on a model that contains a ModuleDict (#8170)

  • Added a mechanism to detect deadlock for DDP when only 1 process trigger an Exception. The mechanism will kill the processes when it happens (#8167)

  • Fixed NCCL error when selecting non-consecutive device ids (#8165)

  • Fixed SWA to also work with IterableDataset (#8172)

[1.3.7] - 2021-06-22

[1.3.7] - Fixed
  • Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error (#7975)

  • Fixed deprecation messages not showing due to incorrect stacklevel (#8002, #8005)

  • Fixed setting a DistributedSampler when using a distributed plugin in a custom accelerator (#7814)

  • Improved PyTorchProfiler chrome traces names (#8009)

  • Fixed moving the best score to device in EarlyStopping callback for TPU devices (#7959)

  • Fixes access to callback_metrics in ddp_spawn (#7916)

[1.3.6] - 2021-06-15

[1.3.6] - Fixed
  • Fixed logs overwriting issue for remote filesystems (#7889)

  • Fixed DataModule.prepare_data could only be called on the global rank 0 process (#7945)

  • Fixed setting worker_init_fn to seed dataloaders correctly when using DDP (#7942)

  • Fixed BaseFinetuning callback to properly handle parent modules w/ parameters (#7931)

[1.3.5] - 2021-06-08

[1.3.5] - Added
  • Added warning to Training Step output (#7779)

[1.3.5] - Fixed
  • Fixed LearningRateMonitor and BackboneFinetuning (#7835)

  • Minor improvements to apply_to_collection and type signature of log_dict (#7851)

  • Fixed docker versions (#7834)

  • Fixed sharded training check for fp16 precision (#7825)

  • Fixed support for torch Module type hints in LightningCLI (#7807)

[1.3.5] - Changed
  • Move training_output validation to after train_step_end (#7868)

[1.3.4] - 2021-06-01

[1.3.4] - Fixed
  • Fixed info message when max training time reached (#7780)

  • Fixed missing __len__ method to IndexBatchSamplerWrapper (#7681)

[1.3.3] - 2021-05-27

[1.3.3] - Changed
  • Changed calling of untoggle_optimizer(opt_idx) out of the closure function (#7563)

[1.3.3] - Fixed
  • Fixed ProgressBar pickling after calling trainer.predict (#7608)

  • Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 (#7592)

  • Fixed dataloaders are not reset when tuning the model (#7566)

  • Fixed print errors in ProgressBar when trainer.fit is not called (#7674)

  • Fixed global step update when the epoch is skipped (#7677)

  • Fixed training loop total batch counter when accumulate grad batches was enabled (#7692)

[1.3.2] - 2021-05-18

[1.3.2] - Changed
  • DataModules now avoid duplicate {setup,teardown,prepare_data} calls for the same stage (#7238)

[1.3.2] - Fixed
  • Fixed parsing of multiple training dataloaders (#7433)

  • Fixed recursive passing of wrong_type keyword argument in pytorch_lightning.utilities.apply_to_collection (#7433)

  • Fixed setting correct DistribType for ddp_cpu (spawn) backend (#7492)

  • Fixed incorrect number of calls to LR scheduler when check_val_every_n_epoch > 1 (#7032)

[1.3.1] - 2021-05-11

[1.3.1] - Fixed
  • Fixed DeepSpeed with IterableDatasets (#7362)

  • Fixed Trainer.current_epoch not getting restored after tuning (#7434)

  • Fixed local rank displayed in console log (#7395)

[1.3.0] - 2021-05-06

[1.3.0] - Added
  • Added support for the EarlyStopping callback to run at the end of the training epoch (#6944)

  • Added synchronization points before and after setup hooks are run (#7202)

  • Added a teardown hook to ClusterEnvironment (#6942)

  • Added utils for metrics to scalar conversions (#7180)

  • Added utils for NaN/Inf detection for gradients and parameters (#6834)

  • Added more explicit exception message when trying to execute trainer.test() or trainer.validate() with fast_dev_run=True (#6667)

  • Added LightningCLI class to provide simple reproducibility with minimum boilerplate training CLI ( #4492, #6862, #7156, #7299)

  • Added gradient_clip_algorithm argument to Trainer for gradient clipping by value (#6123).

  • Added a way to print to terminal without breaking up the progress bar (#5470)

  • Added support to checkpoint after training steps in ModelCheckpoint callback (#6146)

  • Added TrainerStatus.{INITIALIZING,RUNNING,FINISHED,INTERRUPTED} (#7173)

  • Added Trainer.validate() method to perform one evaluation epoch over the validation set (#4948)

  • Added LightningEnvironment for Lightning-specific DDP (#5915)

  • Added teardown() hook to LightningDataModule (#4673)

  • Added auto_insert_metric_name parameter to ModelCheckpoint (#6277)

  • Added arg to self.log that enables users to give custom names when dealing with multiple dataloaders (#6274)

  • Added teardown method to BaseProfiler to enable subclasses defining post-profiling steps outside of __del__ (#6370)

  • Added setup method to BaseProfiler to enable subclasses defining pre-profiling steps for every process (#6633)

  • Added no return warning to predict (#6139)

  • Added Trainer.predict config validation (#6543)

  • Added AbstractProfiler interface (#6621)

  • Added support for including module names for forward in the autograd trace of PyTorchProfiler (#6349)

  • Added support for the PyTorch 1.8.1 autograd profiler (#6618)

  • Added outputs parameter to callback’s on_validation_epoch_end & on_test_epoch_end hooks (#6120)

  • Added configure_sharded_model hook (#6679)

  • Added support for precision=64, enabling training with double precision (#6595)

  • Added support for DDP communication hooks (#6736)

  • Added artifact_location argument to MLFlowLogger which will be passed to the MlflowClient.create_experiment call (#6677)

  • Added model parameter to precision plugins’ clip_gradients signature ( #6764, #7231)

  • Added is_last_batch attribute to Trainer (#6825)

  • Added LightningModule.lr_schedulers() for manual optimization (#6567)

  • Added MpModelWrapper in TPU Spawn (#7045)

  • Added max_time Trainer argument to limit training time (#6823)

  • Added on_predict_{batch,epoch}_{start,end} hooks (#7141)

  • Added new EarlyStopping parameters stopping_threshold and divergence_threshold (#6868)

  • Added debug flag to TPU Training Plugins (PT_XLA_DEBUG) (#7219)

  • Added new UnrepeatedDistributedSampler and IndexBatchSamplerWrapper for tracking distributed predictions (#7215)

  • Added trainer.predict(return_predictions=None|False|True) (#7215)

  • Added BasePredictionWriter callback to implement prediction saving (#7127)

  • Added trainer.tune(scale_batch_size_kwargs, lr_find_kwargs) arguments to configure the tuning algorithms (#7258)

  • Added tpu_distributed check for TPU Spawn barrier (#7241)

  • Added device updates to TPU Spawn for Pod training (#7243)

  • Added warning when missing Callback and using resume_from_checkpoint (#7254)

  • DeepSpeed single file saving (#6900)

  • Added Training type Plugins Registry ( #6982, #7063, #7214, #7224 )

  • Add ignore param to save_hyperparameters (#6056)

[1.3.0] - Changed
  • Changed LightningModule.truncated_bptt_steps to be property (#7323)

  • Changed EarlyStopping callback from by default running EarlyStopping.on_validation_end if only training is run. Set check_on_train_epoch_end to run the callback at the end of the train epoch instead of at the end of the validation epoch (#7069)

  • Renamed pytorch_lightning.callbacks.swa to pytorch_lightning.callbacks.stochastic_weight_avg (#6259)

  • Refactor RunningStage and TrainerState usage ( #4945, #7173)

    • Added RunningStage.SANITY_CHECKING

    • Added TrainerFn.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}

    • Changed trainer.evaluating to return True if validating or testing

  • Changed setup() and teardown() stage argument to take any of {fit,validate,test,predict} (#6386)

  • Changed profilers to save separate report files per state and rank (#6621)

  • The trainer no longer tries to save a checkpoint on exception or run callback’s on_train_end functions (#6864)

  • Changed PyTorchProfiler to use torch.autograd.profiler.record_function to record functions (#6349)

  • Disabled lr_scheduler.step() in manual optimization (#6825)

  • Changed warnings and recommendations for dataloaders in ddp_spawn (#6762)

  • pl.seed_everything will now also set the seed on the DistributedSampler (#7024)

  • Changed default setting for communication of multi-node training using DDPShardedPlugin (#6937)

  • trainer.tune() now returns the tuning result (#7258)

  • LightningModule.from_datasets() now accepts IterableDataset instances as training datasets. (#7503)

  • Changed resume_from_checkpoint warning to an error when the checkpoint file does not exist (#7075)

  • Automatically set sync_batchnorm for training_type_plugin (#6536)

  • Allowed training type plugin to delay optimizer creation (#6331)

  • Removed ModelSummary validation from train loop on_trainer_init (#6610)

  • Moved save_function to accelerator (#6689)

  • Updated DeepSpeed ZeRO (#6546, #6752, #6142, #6321)

  • Improved verbose logging for EarlyStopping callback (#6811)

  • Run ddp_spawn dataloader checks on Windows (#6930)

  • Updated mlflow with using resolve_tags (#6746)

  • Moved save_hyperparameters to its own function (#7119)

  • Replaced _DataModuleWrapper with __new__ (#7289)

  • Reset current_fx properties on lightning module in teardown (#7247)

  • Auto-set DataLoader.worker_init_fn with seed_everything (#6960)

  • Remove model.trainer call inside of dataloading mixin (#7317)

  • Split profilers module (#6261)

  • Ensure accelerator is valid if running interactively (#5970)

  • Disabled batch transfer in DP mode (#6098)

[1.3.0] - Deprecated
  • Deprecated outputs in both LightningModule.on_train_epoch_end and Callback.on_train_epoch_end hooks (#7339)

  • Deprecated Trainer.truncated_bptt_steps in favor of LightningModule.truncated_bptt_steps (#7323)

  • Deprecated outputs in both LightningModule.on_train_epoch_end and Callback.on_train_epoch_end hooks (#7339)

  • Deprecated LightningModule.grad_norm in favor of pytorch_lightning.utilities.grads.grad_norm (#7292)

  • Deprecated the save_function property from the ModelCheckpoint callback (#7201)

  • Deprecated LightningModule.write_predictions and LightningModule.write_predictions_dict (#7066)

  • Deprecated TrainerLoggingMixin in favor of a separate utilities module for metric handling (#7180)

  • Deprecated TrainerTrainingTricksMixin in favor of a separate utilities module for NaN/Inf detection for gradients and parameters (#6834)

  • period has been deprecated in favor of every_n_val_epochs in the ModelCheckpoint callback (#6146)

  • Deprecated trainer.running_sanity_check in favor of trainer.sanity_checking (#4945)

  • Deprecated Profiler(output_filename) in favor of dirpath and filename (#6621)

  • Deprecated PyTorchProfiler(profiled_functions) in favor of record_functions (#6349)

  • Deprecated @auto_move_data in favor of trainer.predict (#6993)

  • Deprecated Callback.on_load_checkpoint(checkpoint) in favor of Callback.on_load_checkpoint(trainer, pl_module, checkpoint) (#7253)

  • Deprecated metrics in favor of torchmetrics ( #6505, #6530, #6540, #6547, #6515, #6572, #6573, #6584, #6636, #6637, #6649, #6659, #7131, )

  • Deprecated the LightningModule.datamodule getter and setter methods; access them through Trainer.datamodule instead (#7168)

  • Deprecated the use of Trainer(gpus="i") (string) for selecting the i-th GPU; from v1.5 this will set the number of GPUs instead of the index (#6388)

[1.3.0] - Removed
  • Removed the exp_save_path property from the LightningModule (#7266)

  • Removed training loop explicitly calling EarlyStopping.on_validation_end if no validation is run (#7069)

  • Removed automatic_optimization as a property from the training loop in favor of LightningModule.automatic_optimization (#7130)

  • Removed evaluation loop legacy returns for *_epoch_end hooks (#6973)

  • Removed support for passing a bool value to profiler argument of Trainer (#6164)

  • Removed no return warning from val/test step (#6139)

  • Removed passing a ModelCheckpoint instance to Trainer(checkpoint_callback) (#6166)

  • Removed deprecated Trainer argument enable_pl_optimizer and automatic_optimization (#6163)

  • Removed deprecated metrics (#6161)

    • from pytorch_lightning.metrics.functional.classification removed to_onehot, to_categorical, get_num_classes, roc, multiclass_roc, average_precision, precision_recall_curve, multiclass_precision_recall_curve

    • from pytorch_lightning.metrics.functional.reduction removed reduce, class_reduce

  • Removed deprecated ModelCheckpoint arguments prefix, mode="auto" (#6162)

  • Removed mode='auto' from EarlyStopping (#6167)

  • Removed epoch and step arguments from ModelCheckpoint.format_checkpoint_name(), these are now included in the metrics argument (#7344)

  • Removed legacy references for magic keys in the Result object (#6016)

  • Removed deprecated LightningModule hparams setter (#6207)

  • Removed legacy code to log or include metrics in the progress bar by returning them in a dict with the "log"/"progress_bar" magic keys. Use self.log instead (#6734)

  • Removed trainer.fit() return value of 1. It has no return now (#7237)

  • Removed logger_connector legacy code (#6733)

  • Removed unused mixin attributes (#6487)

[1.3.0] - Fixed
  • Fixed NaN errors in progress bars when training with iterable datasets with no length defined (#7306)

  • Fixed attaching train and validation dataloaders when reload_dataloaders_every_epoch=True and num_sanity_val_steps=0 (#7207)

  • Added a barrier in the accelerator teardown to synchronize processes before execution finishes (#6814)

  • Fixed multi-node DDP sub-process launch by using local_rank instead of global_rank for main process assertion (#7061)

  • Fixed incorrect removal of WORLD_SIZE environment variable in DDP training when launching with torch distributed/torchelastic (#6942)

  • Made the Plugin.reduce method more consistent across all Plugins to reflect a mean-reduction by default (#6011)

  • Move lightning module to correct device type when using LightningDistributedWrapper (#6070)

  • Do not print top-k verbose log with ModelCheckpoint(monitor=None) (#6109)

  • Fixed ModelCheckpoint(save_top_k=0, save_last=True) not saving the last checkpoint (#6136)

  • Fixed .teardown(stage='fit') and .on_fit_{start,end}() getting called during trainer.test (#6386)

  • Fixed LightningModule all_gather on cpu tensors (#6416)

  • Fixed torch distributed not available in setup hook for DDP (#6506)

  • Fixed trainer.tuner.{lr_find,scale_batch_size} not setting the Trainer state properly (#7258)

  • Fixed bug where the learning rate schedulers did not follow the optimizer frequencies (#4868)

  • Fixed pickle error checker to now check for pickle.PickleError to catch all pickle errors (#6917)

  • Fixed a bug where the outputs object passed to LightningModule.training_epoch_end was different from the object passed to the on_train_end_epoch hook (#6969)

  • Fixed a bug where the outputs passed to train_batch_end would be lists even when using a single optimizer and no truncated backprop through time steps (#6969)

  • Fixed bug for trainer error handling which would cause hang for distributed training (#6864)

  • Fixed self.device not returning the correct device in replicas of data-parallel (#6414)

  • Fixed lr_find trying beyond num_training steps and suggesting a too high learning rate (#7076)

  • Fixed logger creating incorrect version folder in DDP with repeated Trainer.fit calls (#7077)

  • Fixed metric objects passed directly to self.log not being reset correctly (#7055)

  • Fixed CombinedLoader in distributed settings for validation / testing (#7102)

  • Fixed the save_dir in WandbLogger when the run was initiated externally (#7106)

  • Fixed num_sanity_val_steps affecting reproducibility of training data shuffling (#7014)

  • Fixed resetting device after fitting/evaluating/predicting (#7188)

  • Fixed bug where trainer.tuner.scale_batch_size(max_trials=0) would not return the correct batch size result (#7262)

  • Fixed metrics not being properly logged with precision=16 and manual_optimization (#7228)

  • Fixed BaseFinetuning properly reloading optimizer_states when using resume_from_checkpoint (#6891)

  • Fixed parameters_to_ignore not properly set to DDPWrapper (#7239)

  • Fixed parsing of fast_dev_run=True with the built-in ArgumentParser (#7240)

  • Fixed handling an IterableDataset that fails to produce a batch at the beginning of an epoch (#7294)

  • Fixed LightningModule.save_hyperparameters() when attempting to save an empty container (#7268)

  • Fixed apex not properly instantiated when running with ddp (#7274)

  • Fixed optimizer state not moved to GPU (#7277)

  • Fixed custom init args for WandbLogger (#6989)

  • Fixed a bug where an error would be raised if the train dataloader sometimes produced None for a batch (#7342)

  • Fixed examples ( #6600, #6638, #7096, #7246, #6357, #6476, #6294, #6373, #6088, #7398 )

  • Resolved schedule step bug for PyTorch Profiler (#6674, #6681)

  • Updated logic for checking TPUs availability (#6767)

  • Resolve TPU miss rendezvous (#6781)

  • Fixed auto-scaling mode when calling tune method on trainer (#7321)

  • Fixed finetuning complex models correctly unfreezes (#6880)

  • Ensure we set the eval/train flag correctly on accelerator model (#6877)

  • Set better defaults for rank_zero_only.rank when training is launched with SLURM and torchelastic (#6802)

  • Fixed matching the number of outputs of backward with forward for AllGatherGrad (#6625)

  • Fixed the gradient_clip_algorithm has no effect (#6928)

  • Fixed CUDA OOM detection and handling (#6934)

  • Fixed unfreeze_and_add_param_group expects modules rather than module (#6822)

  • Fixed DPP + SyncBN when move on device (#6838)

  • Fixed missing arguments in lr_find call (#6784)

  • Fixed set_default_tensor_type to torch.DoubleTensor with precision=64 (#7108)

  • Fixed NeptuneLogger.log_text(step=None) (#7194)

  • Fixed importing torchtext batch (#6365, #6323, #6211)

[1.2.9] - 2021-04-20

[1.2.9] - Fixed
  • Fixed the order to call for world ranks & the root_device property in TPUSpawnPlugin (#7074)

  • Fixed multi-gpu join for Horovod (#6954)

  • Fixed parsing for pre-release package versions (#6999)

[1.2.8] - 2021-04-14

[1.2.8] - Added
  • Added TPUSpawn + IterableDataset error message (#6875)

[1.2.8] - Fixed
  • Fixed process rank not being available right away after Trainer instantiation (#6941)

  • Fixed sync_dist for tpus (#6950)

  • Fixed AttributeError for require_backward_grad_sync when running manual optimization with sharded plugin (#6915)

  • Fixed --gpus default for parser returned by Trainer.add_argparse_args (#6898)

  • Fixed TPU Spawn all gather (#6896)

  • Fixed EarlyStopping logic when min_epochs or min_steps requirement is not met (#6705)

  • Fixed csv extension check (#6436)

  • Fixed checkpoint issue when using Horovod distributed backend (#6958)

  • Fixed tensorboard exception raising (#6901)

  • Fixed setting the eval/train flag correctly on accelerator model (#6983)

  • Fixed DDP_SPAWN compatibility with bug_report_model.py (#6892)

  • Fixed bug where BaseFinetuning.flatten_modules() was duplicating leaf node parameters (#6879)

  • Set better defaults for rank_zero_only.rank when training is launched with SLURM and torchelastic:

    • Support SLURM and torchelastic global rank environment variables (#5715)

    • Remove hardcoding of local rank in accelerator connector (#6878)

[1.2.7] - 2021-04-06

[1.2.7] - Fixed
  • Fixed resolve a bug with omegaconf and xm.save (#6741)

  • Fixed an issue with IterableDataset when len is not defined (#6828)

  • Sanitize None params during pruning (#6836)

  • Enforce an epoch scheduler interval when using SWA (#6588)

  • Fixed TPU Colab hang issue, post training (#6816)

  • Fixed a bug where TensorBoardLogger would give a warning and not log correctly to a symbolic link save_dir (#6730)

  • Fixed bug where predict could not be used when progress_bar_refresh_rate=0 (#6884)

[1.2.6] - 2021-03-30

[1.2.6] - Changed
  • Changed the behavior of on_epoch_start to run at the beginning of validation & test epoch (#6498)

[1.2.6] - Removed
  • Removed legacy code to include step dictionary returns in callback_metrics. Use self.log_dict instead. (#6682)

[1.2.6] - Fixed
  • Fixed DummyLogger.log_hyperparams raising a TypeError when running with fast_dev_run=True (#6398)

  • Fixed error on TPUs when there was no ModelCheckpoint (#6654)

  • Fixed trainer.test freeze on TPUs (#6654)

  • Fixed a bug where gradients were disabled after calling Trainer.predict (#6657)

  • Fixed bug where no TPUs were detected in a TPU pod env (#6719)

[1.2.5] - 2021-03-23

[1.2.5] - Changed
  • Update Gradient Clipping for the TPU Accelerator (#6576)

  • Refactored setup for typing friendly (#6590)

[1.2.5] - Fixed
  • Fixed a bug where all_gather would not work correctly with tpu_cores=8 (#6587)

  • Fixed comparing required versions (#6434)

  • Fixed duplicate logs appearing in console when using the python logging module (#6275)

  • Added Autocast in validation, test and predict modes for Native AMP (#6565)

[1.2.4] - 2021-03-16

[1.2.4] - Changed
  • Changed the default of find_unused_parameters back to True in DDP and DDP Spawn (#6438)

[1.2.4] - Fixed
  • Expose DeepSpeed loss parameters to allow users to fix loss instability (#6115)

  • Fixed DP reduction with collection (#6324)

  • Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size (#4688)

  • Fixed broadcast to use PyTorch broadcast_object_list and add reduce_decision (#6410)

  • Fixed logger creating directory structure too early in DDP (#6380)

  • Fixed DeepSpeed additional memory use on rank 0 when default device not set early enough (#6460)

  • Fixed an issue with Tuner.scale_batch_size not finding the batch size attribute in the datamodule (#5968)

  • Fixed an exception in the layer summary when the model contains torch.jit scripted submodules (#6511)

  • Fixed when Train loop config was run during Trainer.predict (#6541)

[1.2.3] - 2021-03-09

[1.2.3] - Fixed
  • Fixed ModelPruning(make_pruning_permanent=True) pruning buffers getting removed when saved during training (#6073)

  • Fixed when _stable_1d_sort to work when n >= N (#6177)

  • Fixed AttributeError when logger=None on TPU (#6221)

  • Fixed PyTorch Profiler with emit_nvtx (#6260)

  • Fixed trainer.test from best_path hangs after calling trainer.fit (#6272)

  • Fixed SingleTPU calling all_gather (#6296)

  • Ensure we check DeepSpeed/Sharded in multi-node DDP (#6297

  • Check LightningOptimizer doesn’t delete optimizer hooks (#6305

  • Resolve memory leak for evaluation (#6326

  • Ensure that clip gradients is only called if the value is greater than 0 (#6330

  • Fixed Trainer not resetting lightning_optimizers when calling Trainer.fit() multiple times (#6372)

[1.2.2] - 2021-03-02

[1.2.2] - Added
  • Added checkpoint parameter to callback’s on_save_checkpoint hook (#6072)

[1.2.2] - Changed
  • Changed the order of backward, step, zero_grad to zero_grad, backward, step (#6147)

  • Changed default for DeepSpeed CPU Offload to False, due to prohibitively slow speeds at smaller scale (#6262)

[1.2.2] - Fixed
  • Fixed epoch level schedulers not being called when val_check_interval < 1.0 (#6075)

  • Fixed multiple early stopping callbacks (#6197)

  • Fixed incorrect usage of detach(), cpu(), to() (#6216)

  • Fixed LBFGS optimizer support which didn’t converge in automatic optimization (#6147)

  • Prevent WandbLogger from dropping values (#5931)

  • Fixed error thrown when using valid distributed mode in multi node (#6297

[1.2.1] - 2021-02-23

[1.2.1] - Fixed
  • Fixed incorrect yield logic for the amp autocast context manager (#6080)

  • Fixed priority of plugin/accelerator when setting distributed mode (#6089)

  • Fixed error message for AMP + CPU incompatibility (#6107)

  • Disabled batch transfer in DP mode (#6093)

[1.2.0] - 2021-02-18

[1.2.0] - Added
  • Added DataType, AverageMethod and MDMCAverageMethod enum in metrics (#5657)

  • Added support for summarized model total params size in megabytes (#5590)

  • Added support for multiple train loaders (#1959)

  • Added Accuracy metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the top_k parameter (#4838)

  • Added Accuracy metric now enables the computation of subset accuracy for multi-label or multi-dimensional multi-class inputs with the subset_accuracy parameter (#4838)

  • Added HammingDistance metric to compute the hamming distance (loss) (#4838)

  • Added max_fpr parameter to auroc metric for computing partial auroc metric (#3790)

  • Added StatScores metric to compute the number of true positives, false positives, true negatives and false negatives (#4839)

  • Added R2Score metric (#5241)

  • Added LambdaCallback (#5347)

  • Added BackboneLambdaFinetuningCallback (#5377)

  • Accelerator all_gather supports collection (#5221)

  • Added image_gradients functional metric to compute the image gradients of a given input image. (#5056)

  • Added MetricCollection (#4318)

  • Added .clone() method to metrics (#4318)

  • Added IoU class interface (#4704)

  • Support to tie weights after moving model to TPU via on_post_move_to_device hook

  • Added missing val/test hooks in LightningModule (#5467)

  • The Recall and Precision metrics (and their functional counterparts recall and precision) can now be generalized to Recall@K and Precision@K with the use of top_k parameter (#4842)

  • Added ModelPruning Callback (#5618, #5825, #6045)

  • Added PyTorchProfiler (#5560)

  • Added compositional metrics (#5464)

  • Added Trainer method predict(...) for high performance predictions (#5579)

  • Added on_before_batch_transfer and on_after_batch_transfer data hooks (#3671)

  • Added AUC/AUROC class interface (#5479)

  • Added PredictLoop object (#5752)

  • Added QuantizationAwareTraining callback (#5706, #6040)

  • Added LightningModule.configure_callbacks to enable the definition of model-specific callbacks (#5621)

  • Added dim to PSNR metric for mean-squared-error reduction (#5957)

  • Added promxial policy optimization template to pl_examples (#5394)

  • Added log_graph to CometLogger (#5295)

  • Added possibility for nested loaders (#5404)

  • Added sync_step to Wandb logger (#5351)

  • Added StochasticWeightAveraging callback (#5640)

  • Added LightningDataModule.from_datasets(...) (#5133)

  • Added PL_TORCH_DISTRIBUTED_BACKEND env variable to select backend (#5981)

  • Added Trainer flag to activate Stochastic Weight Averaging (SWA) Trainer(stochastic_weight_avg=True) (#6038)

  • Added DeepSpeed integration (#5954, #6042)

[1.2.0] - Changed
  • Changed stat_scores metric now calculates stat scores over all classes and gains new parameters, in line with the new StatScores metric (#4839)

  • Changed computer_vision_fine_tunning example to use BackboneLambdaFinetuningCallback (#5377)

  • Changed automatic casting for LoggerConnector metrics (#5218)

  • Changed iou [func] to allow float input (#4704)

  • Metric compute() method will no longer automatically call reset() (#5409)

  • Set PyTorch 1.4 as min requirements, also for testing and examples torchvision>=0.5 and torchtext>=0.5 (#5418)

  • Changed callbacks argument in Trainer to allow Callback input (#5446)

  • Changed the default of find_unused_parameters to False in DDP (#5185)

  • Changed ModelCheckpoint version suffixes to start at 1 (#5008)

  • Progress bar metrics tensors are now converted to float (#5692)

  • Changed the default value for the progress_bar_refresh_rate Trainer argument in Google COLAB notebooks to 20 (#5516)

  • Extended support for purely iteration-based training (#5726)

  • Made LightningModule.global_rank, LightningModule.local_rank and LightningModule.logger read-only properties (#5730)

  • Forced ModelCheckpoint callbacks to run after all others to guarantee all states are saved to the checkpoint (#5731)

  • Refactored Accelerators and Plugins:

    • Added base classes for plugins (#5715)

    • Added parallel plugins for DP, DDP, DDPSpawn, DDP2 and Horovod (#5714)

    • Precision Plugins (#5718)

    • Added new Accelerators for CPU, GPU and TPU (#5719)

    • Added RPC and Sharded plugins (#5732)

    • Added missing LightningModule-wrapper logic to new plugins and accelerator (#5734)

    • Moved device-specific teardown logic from training loop to accelerator (#5973)

    • Moved accelerator_connector.py to the connectors subfolder (#6033)

    • Trainer only references accelerator (#6039)

    • Made parallel devices optional across all plugins (#6051)

    • Cleaning (#5948, #5949, #5950)

  • Enabled self.log in callbacks (#5094)

  • Renamed xxx_AVAILABLE as protected (#5082)

  • Unified module names in Utils (#5199)

  • Separated utils: imports & enums (#5256 #5874)

  • Refactor: clean trainer device & distributed getters (#5300)

  • Simplified training phase as LightningEnum (#5419)

  • Updated metrics to use LightningEnum (#5689)

  • Changed the seq of on_train_batch_end, on_batch_end & on_train_epoch_end, on_epoch_end hooks (#5688)

  • Refactored setup_training and remove test_mode (#5388)

  • Disabled training with zero num_training_batches when insufficient limit_train_batches (#5703)

  • Refactored EpochResultStore (#5522)

  • Update lr_finder to check for attribute if not running fast_dev_run (#5990)

  • LightningOptimizer manual optimizer is more flexible and expose toggle_model (#5771)

  • MlflowLogger limit parameter value length to 250 char (#5893)

  • Re-introduced fix for Hydra directory sync with multiple process (#5993)

[1.2.0] - Deprecated
  • Function stat_scores_multiple_classes is deprecated in favor of stat_scores (#4839)

  • Moved accelerators and plugins to its legacy pkg (#5645)

  • Deprecated LightningDistributedDataParallel in favor of new wrapper module LightningDistributedModule (#5185)

  • Deprecated LightningDataParallel in favor of new wrapper module LightningParallelModule (#5670)

  • Renamed utils modules (#5199)

    • argparse_utils >> argparse

    • model_utils >> model_helpers

    • warning_utils >> warnings

    • xla_device_utils >> xla_device

  • Deprecated using 'val_loss' to set the ModelCheckpoint monitor (#6012)

  • Deprecated .get_model() with explicit .lightning_module property (#6035)

  • Deprecated Trainer attribute accelerator_backend in favor of accelerator (#6034)

[1.2.0] - Removed
  • Removed deprecated checkpoint argument filepath (#5321)

  • Removed deprecated Fbeta, f1_score and fbeta_score metrics (#5322)

  • Removed deprecated TrainResult (#5323)

  • Removed deprecated EvalResult (#5633)

  • Removed LoggerStages (#5673)

[1.2.0] - Fixed
  • Fixed distributed setting and ddp_cpu only with num_processes>1 (#5297)

  • Fixed num_workers for Windows example (#5375)

  • Fixed loading yaml (#5619)

  • Fixed support custom DataLoader with DDP if they can be re-instantiated (#5745)

  • Fixed repeated .fit() calls ignore max_steps iteration bound (#5936)

  • Fixed throwing MisconfigurationError on unknown mode (#5255)

  • Resolve bug with Finetuning (#5744)

  • Fixed ModelCheckpoint race condition in file existence check (#5155)

  • Fixed some compatibility with PyTorch 1.8 (#5864)

  • Fixed forward cache (#5895)

  • Fixed recursive detach of tensors to CPU (#6007)

  • Fixed passing wrong strings for scheduler interval doesn’t throw an error (#5923)

  • Fixed wrong requires_grad state after return None with multiple optimizers (#5738)

  • Fixed add on_epoch_end hook at the end of validation, test epoch (#5986)

  • Fixed missing process_dataloader call for TPUSpawn when in distributed mode (#6015)

  • Fixed progress bar flickering by appending 0 to floats/strings (#6009)

  • Fixed synchronization issues with TPU training (#6027)

  • Fixed hparams.yaml saved twice when using TensorBoardLogger (#5953)

  • Fixed basic examples (#5912, #5985)

  • Fixed fairscale compatible with PT 1.8 (#5996)

  • Ensured process_dataloader is called when tpu_cores > 1 to use Parallel DataLoader (#6015)

  • Attempted SLURM auto resume call when non-shell call fails (#6002)

  • Fixed wrapping optimizers upon assignment (#6006)

  • Fixed allowing hashing of metrics with lists in their state (#5939)

[1.1.8] - 2021-02-08

[1.1.8] - Fixed
  • Separate epoch validation from step validation (#5208)

  • Fixed toggle_optimizers not handling all optimizer parameters (#5775)

[1.1.7] - 2021-02-03

[1.1.7] - Fixed
  • Fixed TensorBoardLogger not closing SummaryWriter on finalize (#5696)

  • Fixed filtering of pytorch “unsqueeze” warning when using DP (#5622)

  • Fixed num_classes argument in F1 metric (#5663)

  • Fixed log_dir property (#5537)

  • Fixed a race condition in ModelCheckpoint when checking if a checkpoint file exists (#5144)

  • Remove unnecessary intermediate layers in Dockerfiles (#5697)

  • Fixed auto learning rate ordering (#5638)

[1.1.6] - 2021-01-26

[1.1.6] - Changed
  • Increased TPU check timeout from 20s to 100s (#5598)

  • Ignored step param in Neptune logger’s log_metric method (#5510)

  • Pass batch outputs to on_train_batch_end instead of epoch_end outputs (#4369)

[1.1.6] - Fixed
  • Fixed toggle_optimizer to reset requires_grad state (#5574)

  • Fixed FileNotFoundError for best checkpoint when using DDP with Hydra (#5629)

  • Fixed an error when logging a progress bar metric with a reserved name (#5620)

  • Fixed Metric’s state_dict not included when child modules (#5614)

  • Fixed Neptune logger creating multiple experiments when GPUs > 1 (#3256)

  • Fixed duplicate logs appearing in console when using the python logging module (#5509)

  • Fixed tensor printing in trainer.test() (#5138)

  • Fixed not using dataloader when hparams present (#4559)

[1.1.5] - 2021-01-19

[1.1.5] - Fixed
  • Fixed a visual bug in the progress bar display initialization (#4579)

  • Fixed logging on_train_batch_end in a callback with multiple optimizers (#5521)

  • Fixed reinit_scheduler_properties with correct optimizer (#5519)

  • Fixed val_check_interval with fast_dev_run (#5540)

[1.1.4] - 2021-01-12

[1.1.4] - Added
  • Add automatic optimization property setter to lightning module (#5169)

[1.1.4] - Changed
  • Changed deprecated enable_pl_optimizer=True (#5244)

[1.1.4] - Fixed
  • Fixed transfer_batch_to_device for DDP with len(devices_ids) == 1 (#5195)

  • Logging only on not should_accumulate() during training (#5417)

  • Resolve interpolation bug with Hydra (#5406)

  • Check environ before selecting a seed to prevent warning message (#4743)

  • Fixed signature mismatch in model_to_device of DDPCPUHPCAccelerator (#5505)

[1.1.3] - 2021-01-05

[1.1.3] - Added
  • Added a check for optimizer attached to lr_scheduler (#5338)

  • Added support for passing non-existing filepaths to resume_from_checkpoint (#4402)

[1.1.3] - Changed
  • Skip restore from resume_from_checkpoint while testing (#5161)

  • Allowed log_momentum for adaptive optimizers in LearningRateMonitor (#5333)

  • Disabled checkpointing, earlystopping and logging with fast_dev_run (#5277)

  • Distributed group defaults to WORLD if None (#5125)

[1.1.3] - Fixed
  • Fixed trainer.test returning non-test metrics (#5214)

  • Fixed metric state reset (#5273)

  • Fixed --num-nodes on DDPSequentialPlugin (#5327)

  • Fixed invalid value for weights_summary (#5296)

  • Fixed Trainer.test not using the latest best_model_path (#5161)

  • Fixed existence check for hparams not using underlying filesystem (#5250)

  • Fixed LightningOptimizer AMP bug (#5191)

  • Fixed casted key to string in _flatten_dict (#5354)

[1.1.2] - 2020-12-23

[1.1.2] - Added
  • Support number for logging with sync_dist=True (#5080)

  • Added offset logging step when resuming for Wandb logger (#5050)

[1.1.2] - Removed
  • enable_pl_optimizer=False by default to temporarily fix AMP issues (#5163)

[1.1.2] - Fixed
  • Metric reduction with Logging (#5150)

  • Remove nan loss in manual optimization (#5121)

  • Un-balanced logging properly supported (#5119)

  • Fix hanging in DDP HPC accelerators (#5157)

  • Fix reset TensorRunningAccum (#5106)

  • Updated DALIClassificationLoader to not use deprecated arguments (#4925)

  • Corrected call to torch.no_grad (#5124)

[1.1.1] - 2020-12-15

[1.1.1] - Added
  • Add a notebook example to reach a quick baseline of ~94% accuracy on CIFAR10 using Resnet in Lightning (#4818)

[1.1.1] - Changed
  • Simplify accelerator steps (#5015)

  • Refactor load in checkpoint connector (#4593)

  • Fixed the saved filename in ModelCheckpoint when it already exists (#4861)

[1.1.1] - Removed
  • Drop duplicate metrics (#5014)

  • Remove beta arg from F1 class and functional (#5076)

[1.1.1] - Fixed
  • Fixed trainer by default None in DDPAccelerator (#4915)

  • Fixed LightningOptimizer to expose optimizer attributes (#5095)

  • Do not warn when the name key is used in the lr_scheduler dict (#5057)

  • Check if optimizer supports closure (#4981)

  • Add deprecated metric utility functions back to functional ( #5067, #5068)

  • Allow any input in to_onnx and to_torchscript (#4378)

  • Fixed DDPHPCAccelerator hangs in DDP construction by calling init_device (#5157)

[1.1.0] - 2020-12-09

[1.1.0] - Added
  • Added “monitor” key to saved ModelCheckpoints (#4383)

  • Added ConfusionMatrix class interface (#4348)

  • Added multiclass AUROC metric (#4236)

  • Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience (#3807)

  • Added optimizer hooks in callbacks (#4379)

  • Added option to log momentum (#4384)

  • Added current_score to ModelCheckpoint.on_save_checkpoint (#4721)

  • Added logging using self.log in train and evaluation for epoch end hooks ( #4552, #4495, #4439, #4684, #4913)

  • Added ability for DDP plugin to modify optimizer state saving (#4675)

  • Added prefix argument in loggers (#4557)

  • Added printing of total num of params, trainable and non-trainable params in ModelSummary (#4521)

  • Added PrecisionRecallCurve, ROC, AveragePrecision class metric (#4549)

  • Added custom Apex and NativeAMP as Precision plugins (#4355)

  • Added DALI MNIST example (#3721)

  • Added sharded plugin for DDP for multi-gpu training memory optimizations ( #4639, #4686, #4737, #4773)

  • Added experiment_id to the NeptuneLogger (#3462)

  • Added PyTorch Geometric integration example with Lightning (#4568)

  • Added all_gather method to LightningModule which allows gradient based tensor synchronizations for use-cases such as negative sampling. (#5012)

  • Enabled self.log in most functions (#4969)

  • Added changeable extension variable for ModelCheckpoint (#4977)

[1.1.0] - Changed
  • Tuner algorithms will be skipped if fast_dev_run=True (#3903)

  • WandbLogger does not force wandb reinit arg to True anymore and creates a run only when needed (#4648)

  • Changed automatic_optimization to be a model attribute (#4602)

  • Changed Simple Profiler report to order by percentage time spent + num calls (#4880)

  • Simplify optimization Logic (#4984)

  • Classification metrics overhaul (#4837)

  • Updated fast_dev_run to accept integer representing num_batches (#4629)

  • Refactored optimizer (#4658)

[1.1.0] - Deprecated
  • Deprecated prefix argument in ModelCheckpoint (#4765)

  • Deprecated the old way of assigning hyper-parameters through self.hparams = ... (#4813)

  • Deprecated mode='auto' from ModelCheckpoint and EarlyStopping (#4695)

[1.1.0] - Removed
  • Removed reorder parameter of the auc metric (#5004)

  • Removed multiclass_roc and multiclass_precision_recall_curve, use roc and precision_recall_curve instead (#4549)

[1.1.0] - Fixed
  • Added feature to move tensors to CPU before saving (#4309)

  • Fixed LoggerConnector to have logged metrics on root device in DP (#4138)

  • Auto convert tensors to contiguous format when gather_all (#4907)

  • Fixed PYTHONPATH for ddp test model (#4528)

  • Fixed allowing logger to support indexing (#4595)

  • Fixed DDP and manual_optimization (#4976)

[1.0.8] - 2020-11-24

[1.0.8] - Added
  • Added casting to python types for numpy scalars when logging hparams (#4647)

  • Added warning when progress bar refresh rate is less than 20 on Google Colab to prevent crashing (#4654)

  • Added F1 class metric (#4656)

[1.0.8] - Changed
  • Consistently use step=trainer.global_step in LearningRateMonitor independently of logging_interval (#4376)

  • Metric states are no longer as default added to state_dict (#4685)

  • Renamed class metric Fbeta >> FBeta (#4656)

  • Model summary: add 1 decimal place (#4745)

  • Do not override PYTHONWARNINGS (#4700)

  • Changed init_ddp_connection moved from DDP to DDPPlugin (#4407)

[1.0.8] - Fixed
  • Fixed checkpoint hparams dict casting when omegaconf is available (#4770)

  • Fixed incomplete progress bars when total batches not divisible by refresh rate (#4577)

  • Updated SSIM metric (#4566)

  • Fixed batch_arg_name - add batch_arg_name to all calls to _adjust_batch_sizebug (#4812)

  • Fixed torchtext data to GPU (#4785)

  • Fixed a crash bug in MLFlow logger (#4716)

[1.0.7] - 2020-11-17

[1.0.7] - Added
  • Added lambda closure to manual_optimizer_step (#4618)

[1.0.7] - Changed
  • Change Metrics persistent default mode to False (#4685)

  • LoggerConnector log_metrics will use total_batch_idx instead of global_step when logging on training step (#4738)

[1.0.7] - Fixed
  • Prevent crash if sync_dist=True on CPU (#4626)

  • Fixed average pbar Metrics (#4534)

  • Fixed setup callback hook to correctly pass the LightningModule through (#4608)

  • Allowing decorate model init with saving hparams inside (#4662)

  • Fixed split_idx set by LoggerConnector in on_trainer_init to Trainer (#4697)

[1.0.6] - 2020-11-11

[1.0.6] - Added
  • Added metrics aggregation in Horovod and fixed early stopping (#3775)

  • Added manual_optimizer_step which work with AMP Native and accumulated_grad_batches (#4485)

  • Added persistent(mode) method to metrics, to enable and disable metric states being added to state_dict (#4482)

  • Added congratulations at the end of our notebooks (#4555)

  • Added parameters move_metrics_to_cpu in Trainer to disable gpu leak (#4592)

[1.0.6] - Changed
[1.0.6] - Fixed
  • Fixed feature-lack in hpc_load (#4526)

  • Fixed metrics states being overridden in DDP mode (#4482)

  • Fixed lightning_getattr, lightning_hasattr not finding the correct attributes in datamodule (#4347)

  • Fixed automatic optimization AMP by manual_optimization_step (#4485)

  • Replace MisconfigurationException with warning in ModelCheckpoint Callback (#4560)

  • Fixed logged keys in mlflow logger (#4412)

  • Fixed is_picklable by catching AttributeError (#4508)

  • Fixed multi test dataloaders dict AttributeError error (#4480)

  • Fixed show progress bar only for progress_rank 0 on DDP_SLURM (#4437)

[1.0.5] - 2020-11-03

[1.0.5] - Added
  • Added PyTorch 1.7 Stable support (#3821)

  • Added timeout for tpu_device_exists to ensure process does not hang indefinitely (#4340)

[1.0.5] - Changed
  • W&B log in sync with Trainer step (#4405)

  • Hook on_after_backward is called only when optimizer_step is being called (#4439)

  • Moved track_and_norm_grad into training loop and called only when optimizer_step is being called (#4439)

  • Changed type checker with explicit cast of ref_model object (#4457)

  • Changed distributed_backend -> accelerator (#4429)

[1.0.5] - Deprecated
  • Deprecated passing ModelCheckpoint instance to checkpoint_callback Trainer argument (#4336)

[1.0.5] - Fixed
  • Disable saving checkpoints if not trained (#4372)

  • Fixed error using auto_select_gpus=True with gpus=-1 (#4209)

  • Disabled training when limit_train_batches=0 (#4371)

  • Fixed that metrics do not store computational graph for all seen data (#4313)

  • Fixed AMP unscale for on_after_backward (#4439)

  • Fixed TorchScript export when module includes Metrics (#4428)

  • Fixed TorchScript trace method’s data to device and docstring (#4360)

  • Fixed CSV logger warning (#4419)

  • Fixed skip DDP parameter sync (#4301)

  • Fixed WandbLogger _sanitize_callable function (#4422)

  • Fixed AMP Native _unscale gradient (#4441)

[1.0.4] - 2020-10-27

[1.0.4] - Added
  • Added dirpath and filename parameter in ModelCheckpoint (#4213)

  • Added plugins docs and DDPPlugin to customize ddp across all accelerators (#4258)

  • Added strict option to the scheduler dictionary (#3586)

  • Added fsspec support for profilers (#4162)

  • Added autogenerated helptext to Trainer.add_argparse_args (#4344)

  • Added support for string values in Trainer’s profiler parameter (#3656)

  • Added optimizer_closure to optimizer.step when supported (#4190)

  • Added unification of regression metrics (#4166)

  • Added checkpoint load from Bytes (#4314)

[1.0.4] - Changed
  • Improved error messages for invalid configure_optimizers returns (#3587)

  • Allow changing the logged step value in validation_step (#4130)

  • Allow setting replace_sampler_ddp=True with a distributed sampler already added (#4273)

  • Fixed sanitized parameters for WandbLogger.log_hyperparams (#4320)

[1.0.4] - Deprecated
  • Deprecated filepath in ModelCheckpoint (#4213)

  • Deprecated reorder parameter of the auc metric (#4237)

  • Deprecated bool values in Trainer’s profiler parameter (#3656)

[1.0.4] - Fixed
  • Fixed setting device ids in DDP (#4297)

  • Fixed synchronization of best model path in ddp_accelerator (#4323)

  • Fixed WandbLogger not uploading checkpoint artifacts at the end of training (#4341)

  • Fixed FBeta computation (#4183)

  • Fixed accumulation across batches has completed before breaking training loop (#4278)

  • Fixed ModelCheckpoint don’t increase current_epoch and global_step when not training (#4291)

  • Fixed COMET_EXPERIMENT_KEY environment variable usage in comet logger (#4230)

[1.0.3] - 2020-10-20

[1.0.3] - Added
  • Added persistent flag to Metric.add_state (#4195)

[1.0.3] - Changed
  • Used checkpoint_connector.hpc_save in SLURM (#4217)

  • Moved base req. to root (#4219)

[1.0.3] - Fixed
  • Fixed hparams assign in init (#4189)

  • Fixed overwrite check for model hooks (#4010)

[1.0.2] - 2020-10-15

[1.0.2] - Added
  • Added trace functionality to the function to_torchscript (#4142)

[1.0.2] - Changed
  • Called on_load_checkpoint before loading state_dict (#4057)

[1.0.2] - Removed
  • Removed duplicate metric vs step log for train loop (#4173)

[1.0.2] - Fixed
  • Fixed the self.log problem in validation_step() (#4169)

  • Fixed hparams saving - save the state when save_hyperparameters() is called [in __init__] (#4163)

  • Fixed runtime failure while exporting hparams to yaml (#4158)

[1.0.1] - 2020-10-14

[1.0.1] - Added
  • Added getstate/setstate method for torch.save serialization (#4127)

[1.0.0] - 2020-10-13

[1.0.0] - Added
  • Added Explained Variance Metric + metric fix (#4013)

  • Added Metric <-> Lightning Module integration tests (#4008)

  • Added parsing OS env vars in Trainer (#4022)

  • Added classification metrics (#4043)

  • Updated explained variance metric (#4024)

  • Enabled plugins (#4041)

  • Enabled custom clusters (#4048)

  • Enabled passing in custom accelerators (#4050)

  • Added LightningModule.toggle_optimizer (#4058)

  • Added LightningModule.manual_backward (#4063)

  • Added output argument to *_batch_end hooks (#3965, #3966)

  • Added output argument to *_epoch_end hooks (#3967)

[1.0.0] - Changed
[1.0.0] - Removed
  • Removed support for EvalResult and TrainResult (#3968)

  • Removed deprecated trainer flags: overfit_pct, log_save_interval, row_log_interval (#3969)

  • Removed deprecated early_stop_callback (#3982)

  • Removed deprecated model hooks (#3980)

  • Removed deprecated callbacks (#3979)

  • Removed trainer argument in LightningModule.backward #4056)

[1.0.0] - Fixed
  • Fixed current_epoch property update to reflect true epoch number inside LightningDataModule, when reload_dataloaders_every_epoch=True. (#3974)

  • Fixed to print scaler value in progress bar (#4053)

  • Fixed mismatch between docstring and code regarding when on_load_checkpoint hook is called (#3996)

[0.10.0] - 2020-10-07

[0.10.0] - Added
  • Added new Metrics API. (#3868, #3921)

  • Enable PyTorch 1.7 compatibility (#3541)

  • Added LightningModule.to_torchscript to support exporting as ScriptModule (#3258)

  • Added warning when dropping unpicklable hparams (#2874)

  • Added EMB similarity (#3349)

  • Added ModelCheckpoint.to_yaml method (#3048)

  • Allow ModelCheckpoint monitor to be None, meaning it will always save (#3630)

  • Disabled optimizers setup during testing (#3059)

  • Added support for datamodules to save and load checkpoints when training (#3563)

  • Added support for datamodule in learning rate finder (#3425)

  • Added gradient clip test for native AMP (#3754)

  • Added dist lib to enable syncing anything across devices (#3762)

  • Added broadcast to TPUBackend (#3814)

  • Added XLADeviceUtils class to check XLA device type (#3274)

[0.10.0] - Changed
  • Refactored accelerator backends:

    • moved TPU xxx_step to backend (#3118)

    • refactored DDP backend forward (#3119)

    • refactored GPU backend __step (#3120)

    • refactored Horovod backend (#3121, #3122)

    • remove obscure forward call in eval + CPU backend ___step (#3123)

    • reduced all simplified forward (#3126)

    • added hook base method (#3127)

    • refactor eval loop to use hooks - use test_mode for if so we can split later (#3129)

    • moved ___step_end hooks (#3130)

    • training forward refactor (#3134)

    • training AMP scaling refactor (#3135)

    • eval step scaling factor (#3136)

    • add eval loop object to streamline eval loop (#3138)

    • refactored dataloader process hook (#3139)

    • refactored inner eval loop (#3141)

    • final inner eval loop hooks (#3154)

    • clean up hooks in run_evaluation (#3156)

    • clean up data reset (#3161)

    • expand eval loop out (#3165)

    • moved hooks around in eval loop (#3195)

    • remove _evaluate fx (#3197)

    • Trainer.fit hook clean up (#3198)

    • DDPs train hooks (#3203)

    • refactor DDP backend (#3204, #3207, #3208, #3209, #3210)

    • reduced accelerator selection (#3211)

    • group prepare data hook (#3212)

    • added data connector (#3285)

    • modular is_overridden (#3290)

    • adding Trainer.tune() (#3293)

    • move run_pretrain_routine -> setup_training (#3294)

    • move train outside of setup training (#3297)

    • move prepare_data to data connector (#3307)

    • moved accelerator router (#3309)

    • train loop refactor - moving train loop to own object (#3310, #3312, #3313, #3314)

    • duplicate data interface definition up into DataHooks class (#3344)

    • inner train loop (#3359, #3361, #3362, #3363, #3365, #3366, #3367, #3368, #3369, #3370, #3371, #3372, #3373, #3374, #3375, #3376, #3385, #3388, #3397)

    • all logging related calls in a connector (#3395)

    • device parser (#3400, #3405)

    • added model connector (#3407)

    • moved eval loop logging to loggers (#3408)

    • moved eval loop (#3412#3408)

    • trainer/separate argparse (#3421, #3428, #3432)

    • move lr_finder (#3434)

    • organize args (##3435, #3442, #3447, #3448, #3449, #3456)

    • move specific accelerator code (#3457)

    • group connectors (#3472)

    • accelerator connector methods x/n (#3469, #3470, #3474)

    • merge backends x/n (#3476, #3477, #3478, #3480, #3482)

    • apex plugin (#3502)

    • precision plugins (#3504)

    • Result - make monitor default to checkpoint_on to simplify (#3571)

    • reference to the Trainer on the LightningDataModule (#3684)

    • add .log to lightning module (#3686, #3699, #3701, #3704, #3715)

    • enable tracking original metric when step and epoch are both true (#3685)

    • deprecated results obj, added support for simpler comms (#3681)

    • move backends back to individual files (#3712)

    • fixes logging for eval steps (#3763)

    • decoupled DDP, DDP spawn (#3733, #3766, #3767, #3774, #3802, #3806, #3817, #3819, #3927)

    • remove weight loading hack for ddp_cpu (#3808)

    • separate torchelastic from DDP (#3810)

    • separate SLURM from DDP (#3809)

    • decoupled DDP2 (#3816)

    • bug fix with logging val epoch end + monitor (#3812)

    • callback system and init DDP (#3836)

    • adding compute environments (#3837, #3842)

    • epoch can now log independently (#3843)

    • test selecting the correct backend. temp backends while slurm and TorchElastic are decoupled (#3848)

    • fixed init_slurm_connection causing hostname errors (#3856)

    • moves init apex from LM to apex connector (#3923)

    • moves sync bn to each backend (#3925)

    • moves configure ddp to each backend (#3924)

  • Deprecation warning (#3844)

  • Changed LearningRateLogger to LearningRateMonitor (#3251)

  • Used fsspec instead of gfile for all IO (#3320)

    • Swapped torch.load for fsspec load in DDP spawn backend (#3787)

    • Swapped torch.load for fsspec load in cloud_io loading (#3692)

    • Added support for to_disk() to use remote filepaths with fsspec (#3930)

    • Updated model_checkpoint’s to_yaml to use fsspec open (#3801)

    • Fixed fsspec is inconsistent when doing fs.ls (#3805)

  • Refactor GPUStatsMonitor to improve training speed (#3257)

  • Changed IoU score behavior for classes absent in target and pred (#3098)

  • Changed IoU remove_bg bool to ignore_index optional int (#3098)

  • Changed defaults of save_top_k and save_last to None in ModelCheckpoint (#3680)

  • row_log_interval and log_save_interval are now based on training loop’s global_step instead of epoch-internal batch index (#3667)

  • Silenced some warnings. verified ddp refactors (#3483)

  • Cleaning up stale logger tests (#3490)

  • Allow ModelCheckpoint monitor to be None (#3633)

  • Enable None model checkpoint default (#3669)

  • Skipped best_model_path if checkpoint_callback is None (#2962)

  • Used raise .. from .. to explicitly chain exceptions (#3750)

  • Mocking loggers (#3596, #3617, #3851, #3859, #3884, #3853, #3910, #3889, #3926)

  • Write predictions in LightningModule instead of EvalResult #3882

[0.10.0] - Deprecated
  • Deprecated TrainResult and EvalResult, use self.log and self.write from the LightningModule to log metrics and write predictions. training_step can now only return a scalar (for the loss) or a dictionary with anything you want. (#3681)

  • Deprecate early_stop_callback Trainer argument (#3845)

  • Rename Trainer arguments row_log_interval >> log_every_n_steps and log_save_interval >> flush_logs_every_n_steps (#3748)

[0.10.0] - Removed
  • Removed experimental Metric API (#3943, #3949, #3946), listed changes before final removal:

    • Added EmbeddingSimilarity metric (#3349, #3358)

    • Added hooks to metric module interface (#2528)

    • Added error when AUROC metric is used for multiclass problems (#3350)

    • Fixed ModelCheckpoint with save_top_k=-1 option not tracking the best models when a monitor metric is available (#3735)

    • Fixed counter-intuitive error being thrown in Accuracy metric for zero target tensor (#3764)

    • Fixed aggregation of metrics (#3517)

    • Fixed Metric aggregation (#3321)

    • Fixed RMSLE metric (#3188)

    • Renamed reduction to class_reduction in classification metrics (#3322)

    • Changed class_reduction similar to sklearn for classification metrics (#3322)

    • Renaming of precision recall metric (#3308)

[0.10.0] - Fixed
  • Fixed on_train_batch_start hook to end epoch early (#3700)

  • Fixed num_sanity_val_steps is clipped to limit_val_batches (#2917)

  • Fixed ONNX model save on GPU (#3145)

  • Fixed GpuUsageLogger to work on different platforms (#3008)

  • Fixed auto-scale batch size not dumping auto_lr_find parameter (#3151)

  • Fixed batch_outputs with optimizer frequencies (#3229)

  • Fixed setting batch size in LightningModule.datamodule when using auto_scale_batch_size (#3266)

  • Fixed Horovod distributed backend compatibility with native AMP (#3404)

  • Fixed batch size auto scaling exceeding the size of the dataset (#3271)

  • Fixed getting experiment_id from MLFlow only once instead of each training loop (#3394)

  • Fixed overfit_batches which now correctly disables shuffling for the training loader. (#3501)

  • Fixed gradient norm tracking for row_log_interval > 1 (#3489)

  • Fixed ModelCheckpoint name formatting (#3164)

  • Fixed example implementation of AutoEncoder (#3190)

  • Fixed invalid paths when remote logging with TensorBoard (#3236)

  • Fixed change t() to transpose() as XLA devices do not support .t() on 1-dim tensor (#3252)

  • Fixed (weights only) checkpoints loading without PL (#3287)

  • Fixed gather_all_tensors cross GPUs in DDP (#3319)

  • Fixed CometML save dir (#3419)

  • Fixed forward key metrics (#3467)

  • Fixed normalize mode at confusion matrix (replace NaNs with zeros) (#3465)

  • Fixed global step increment in training loop when training_epoch_end hook is used (#3673)

  • Fixed dataloader shuffling not getting turned off with overfit_batches > 0 and distributed_backend = "ddp" (#3534)

  • Fixed determinism in DDPSpawnBackend when using seed_everything in main process (#3335)

  • Fixed ModelCheckpoint period to actually save every period epochs (#3630)

  • Fixed val_progress_bar total with num_sanity_val_steps (#3751)

  • Fixed Tuner dump: add current_epoch to dumped_params (#3261)

  • Fixed current_epoch and global_step properties mismatch between Trainer and LightningModule (#3785)

  • Fixed learning rate scheduler for optimizers with internal state (#3897)

  • Fixed tbptt_reduce_fx when non-floating tensors are logged (#3796)

  • Fixed model checkpoint frequency (#3852)

  • Fixed logging non-tensor scalar with result breaks subsequent epoch aggregation (#3855)

  • Fixed TrainerEvaluationLoopMixin activates model.train() at the end (#3858)

  • Fixed overfit_batches when using with multiple val/test_dataloaders (#3857)

  • Fixed enables training_step to return None (#3862)

  • Fixed init nan for checkpointing (#3863)

  • Fixed for load_from_checkpoint (#2776)

  • Fixes incorrect batch_sizes when Dataloader returns a dict with multiple tensors (#3668)

  • Fixed unexpected signature for validation_step (#3947)

[0.9.0] - 2020-08-20

[0.9.0] - Added
  • Added SyncBN for DDP (#2801, #2838)

  • Added basic CSVLogger (#2721)

  • Added SSIM metrics (#2671)

  • Added BLEU metrics (#2535)

  • Added support to export a model to ONNX format (#2596)

  • Added support for Trainer(num_sanity_val_steps=-1) to check all validation data before training (#2246)

  • Added struct. output:

    • tests for val loop flow (#2605)

    • EvalResult support for train and val. loop (#2615, #2651)

    • weighted average in results obj (#2930)

    • fix result obj DP auto reduce (#3013)

  • Added class LightningDataModule (#2668)

  • Added support for PyTorch 1.6 (#2745)

  • Added call DataModule hooks implicitly in trainer (#2755)

  • Added support for Mean in DDP Sync (#2568)

  • Added remaining sklearn metrics: AveragePrecision, BalancedAccuracy, CohenKappaScore, DCG, Hamming, Hinge, Jaccard, MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError, MedianAbsoluteError, R2Score, MeanPoissonDeviance, MeanGammaDeviance, MeanTweedieDeviance, ExplainedVariance (#2562)

  • Added support for limit_{mode}_batches (int) to work with infinite dataloader (IterableDataset) (#2840)

  • Added support returning python scalars in DP (#1935)

  • Added support to Tensorboard logger for OmegaConf hparams (#2846)

  • Added tracking of basic states in Trainer (#2541)

  • Tracks all outputs including TBPTT and multiple optimizers (#2890)

  • Added GPU Usage Logger (#2932)

  • Added strict=False for load_from_checkpoint (#2819)

  • Added saving test predictions on multiple GPUs (#2926)

  • Auto log the computational graph for loggers that support this (#3003)

  • Added warning when changing monitor and using results obj (#3014)

  • Added a hook transfer_batch_to_device to the LightningDataModule (#3038)

[0.9.0] - Changed
  • Truncated long version numbers in progress bar (#2594)

  • Enabling val/test loop disabling (#2692)

  • Refactored into accelerator module:

    • GPU training (#2704)

    • TPU training (#2708)

    • DDP(2) backend (#2796)

    • Retrieve last logged val from result by key (#3049)

  • Using .comet.config file for CometLogger (#1913)

  • Updated hooks arguments - breaking for setup and teardown (#2850)

  • Using gfile to support remote directories (#2164)

  • Moved optimizer creation after device placement for DDP backends (#2904)

  • Support **DictConfig for hparam serialization (#2519)

  • Removed callback metrics from test results obj (#2994)

  • Re-enabled naming metrics in ckpt name (#3060)

  • Changed progress bar epoch counting to start from 0 (#3061)

[0.9.0] - Deprecated
  • Deprecated Trainer attribute ckpt_path, which will now be set by weights_save_path (#2681)

[0.9.0] - Removed
  • Removed deprecated: (#2760)

    • core decorator data_loader

    • Module hook on_sanity_check_start and loading load_from_metrics

    • package pytorch_lightning.logging

    • Trainer arguments: show_progress_bar, num_tpu_cores, use_amp, print_nan_grads

    • LR Finder argument num_accumulation_steps

[0.9.0] - Fixed
  • Fixed accumulate_grad_batches for last batch (#2853)

  • Fixed setup call while testing (#2624)

  • Fixed local rank zero casting (#2640)

  • Fixed single scalar return from training (#2587)

  • Fixed Horovod backend to scale LR schedlers with the optimizer (#2626)

  • Fixed dtype and device properties not getting updated in submodules (#2657)

  • Fixed fast_dev_run to run for all dataloaders (#2581)

  • Fixed save_dir in loggers getting ignored by default value of weights_save_path when user did not specify weights_save_path (#2681)

  • Fixed weights_save_path getting ignored when logger=False is passed to Trainer (#2681)

  • Fixed TPU multi-core and Float16 (#2632)

  • Fixed test metrics not being logged with LoggerCollection (#2723)

  • Fixed data transfer to device when using torchtext.data.Field and include_lengths is True (#2689)

  • Fixed shuffle argument for distributed sampler (#2789)

  • Fixed logging interval (#2694)

  • Fixed loss value in the progress bar is wrong when accumulate_grad_batches > 1 (#2738)

  • Fixed correct CWD for ddp sub-processes when using Hydra (#2719)

  • Fixed selecting GPUs using CUDA_VISIBLE_DEVICES (#2739)

  • Fixed false num_classes warning in metrics (#2781)

  • Fixed shell injection vulnerability in subprocess call (#2786)

  • Fixed LR finder and hparams compatibility (#2821)

  • Fixed ModelCheckpoint not saving the latest information when save_last=True (#2881)

  • Fixed ImageNet example: learning rate scheduler, number of workers and batch size when using DDP (#2889)

  • Fixed apex gradient clipping (#2829)

  • Fixed save apex scaler states (#2828)

  • Fixed a model loading issue with inheritance and variable positional arguments (#2911)

  • Fixed passing non_blocking=True when transferring a batch object that does not support it (#2910)

  • Fixed checkpointing to remote file paths (#2925)

  • Fixed adding val step argument to metrics (#2986)

  • Fixed an issue that caused Trainer.test() to stall in ddp mode (#2997)

  • Fixed gathering of results with tensors of varying shape (#3020)

  • Fixed batch size auto-scaling feature to set the new value on the correct model attribute (#3043)

  • Fixed automatic batch scaling not working with half precision (#3045)

  • Fixed setting device to root gpu (#3042)

[0.8.5] - 2020-07-09

[0.8.5] - Added
  • Added a PSNR metric: peak signal-to-noise ratio (#2483)

  • Added functional regression metrics (#2492)

[0.8.5] - Removed
  • Removed auto val reduce (#2462)

[0.8.5] - Fixed
  • Flattening Wandb Hyperparameters (#2459)

  • Fixed using the same DDP python interpreter and actually running (#2482)

  • Fixed model summary input type conversion for models that have input dtype different from model parameters (#2510)

  • Made TensorBoardLogger and CometLogger pickleable (#2518)

  • Fixed a problem with MLflowLogger creating multiple run folders (#2502)

  • Fixed global_step increment (#2455)

  • Fixed TPU hanging example (#2488)

  • Fixed argparse default value bug (#2526)

  • Fixed Dice and IoU to avoid NaN by adding small eps (#2545)

  • Fixed accumulate gradients schedule at epoch 0 (continued) (#2513)

  • Fixed Trainer .fit() returning last not best weights in “ddp_spawn” (#2565)

  • Fixed passing (do not pass) TPU weights back on test (#2566)

  • Fixed DDP tests and .test() (#2512, #2570)

[0.8.4] - 2020-07-01

[0.8.4] - Added
  • Added reduce ddp results on eval (#2434)

  • Added a warning when an IterableDataset has __len__ defined (#2437)

[0.8.4] - Changed
  • Enabled no returns from eval (#2446)

[0.8.4] - Fixed
  • Fixes train outputs (#2428)

  • Fixes Conda dependencies (#2412)

  • Fixed Apex scaling with decoupled backward (#2433)

  • Fixed crashing or wrong displaying progressbar because of missing ipywidgets (#2417)

  • Fixed TPU saving dir (fc26078e, 04e68f02)

  • Fixed logging on rank 0 only (#2425)

[0.8.3] - 2020-06-29

[0.8.3] - Fixed

[0.8.2] - 2020-06-28

[0.8.2] - Added
  • Added TorchText support for moving data to GPU (#2379)

[0.8.2] - Changed
  • Changed epoch indexing from 0 instead of 1 (#2289)

  • Refactor Model backward (#2276)

  • Refactored training_batch + tests to verify correctness (#2327, #2328)

  • Refactored training loop (#2336)

  • Made optimization steps for hooks (#2363)

  • Changed default apex level to ‘O2’ (#2362)

[0.8.2] - Removed
  • Moved TrainsLogger to Bolts (#2384)

[0.8.2] - Fixed
  • Fixed parsing TPU arguments and TPU tests (#2094)

  • Fixed number batches in case of multiple dataloaders and limit_{*}_batches (#1920, #2226)

  • Fixed an issue with forward hooks not being removed after model summary (#2298)

  • Fix for load_from_checkpoint() not working with absolute path on Windows (#2294)

  • Fixed an issue how _has_len handles NotImplementedError e.g. raised by torchtext.data.Iterator (#2293), (#2307)

  • Fixed average_precision metric (#2319)

  • Fixed ROC metric for CUDA tensors (#2304)

  • Fixed lost compatibility with custom datatypes implementing .to (#2335)

  • Fixed loading model with kwargs (#2387)

  • Fixed sum(0) for trainer.num_val_batches (#2268)

  • Fixed checking if the parameters are a DictConfig Object (#2216)

  • Fixed SLURM weights saving (#2341)

  • Fixed swaps LR scheduler order (#2356)

  • Fixed adding tensorboard hparams logging test (#2342)

  • Fixed use model ref for tear down (#2360)

  • Fixed logger crash on DDP (#2388)

  • Fixed several issues with early stopping and checkpoint callbacks (#1504, #2391)

  • Fixed loading past checkpoints from v0.7.x (#2405)

  • Fixed loading model without arguments (#2403)

  • Fixed Windows compatibility issue (#2358)

[0.8.1] - 2020-06-19

[0.8.1] - Fixed
  • Fixed the load_from_checkpoint path detected as URL bug (#2244)

  • Fixed hooks - added barrier (#2245, #2257, #2260)

  • Fixed hparams - remove frame inspection on self.hparams (#2253)

  • Fixed setup and on fit calls (#2252)

  • Fixed GPU template (#2255)

[0.8.0] - 2020-06-18

[0.8.0] - Added
  • Added overfit_batches, limit_{val|test}_batches flags (overfit now uses training set for all three) (#2213)

  • Added metrics

  • Allow dataloaders without sampler field present (#1907)

  • Added option save_last to save the model at the end of every epoch in ModelCheckpoint (#1908)

  • Early stopping checks on_validation_end (#1458)

  • Speed up single-core TPU training by loading data using ParallelLoader (#2033)

  • Added a model hook transfer_batch_to_device that enables moving custom data structures to the target device (#1756)

  • Added black formatter for the code with code-checker on pull (#1610)

  • Added back the slow spawn ddp implementation as ddp_spawn (#2115)

  • Added loading checkpoints from URLs (#1667)

  • Added a callback method on_keyboard_interrupt for handling KeyboardInterrupt events during training (#2134)

  • Added a decorator auto_move_data that moves data to the correct device when using the LightningModule for inference (#1905)

  • Added ckpt_path option to LightningModule.test(...) to load particular checkpoint (#2190)

  • Added setup and teardown hooks for model (#2229)

[0.8.0] - Changed
  • Allow user to select individual TPU core to train on (#1729)

  • Removed non-finite values from loss in LRFinder (#1862)

  • Allow passing model hyperparameters as complete kwarg list (#1896)

  • Renamed ModelCheckpoint’s attributes best to best_model_score and kth_best_model to kth_best_model_path (#1799)

  • Re-Enable Logger’s ImportErrors (#1938)

  • Changed the default value of the Trainer argument weights_summary from full to top (#2029)

  • Raise an error when lightning replaces an existing sampler (#2020)

  • Enabled prepare_data from correct processes - clarify local vs global rank (#2166)

  • Remove explicit flush from tensorboard logger (#2126)

  • Changed epoch indexing from 1 instead of 0 (#2206)

[0.8.0] - Deprecated
  • Deprecated flags: (#2213)

    • overfit_pct in favour of overfit_batches

    • val_percent_check in favour of limit_val_batches

    • test_percent_check in favour of limit_test_batches

  • Deprecated ModelCheckpoint’s attributes best and kth_best_model (#1799)

  • Dropped official support/testing for older PyTorch versions <1.3 (#1917)

  • Deprecated Trainer proc_rank in favour of global_rank (#2166, #2269)

[0.8.0] - Removed
  • Removed unintended Trainer argument progress_bar_callback, the callback should be passed in by Trainer(callbacks=[...]) instead (#1855)

  • Removed obsolete self._device in Trainer (#1849)

  • Removed deprecated API (#2073)

    • Packages: pytorch_lightning.pt_overrides, pytorch_lightning.root_module

    • Modules: pytorch_lightning.logging.comet_logger, pytorch_lightning.logging.mlflow_logger, pytorch_lightning.logging.test_tube_logger, pytorch_lightning.overrides.override_data_parallel, pytorch_lightning.core.model_saving, pytorch_lightning.core.root_module

    • Trainer arguments: add_row_log_interval, default_save_path, gradient_clip, nb_gpu_nodes, max_nb_epochs, min_nb_epochs, nb_sanity_val_steps

    • Trainer attributes: nb_gpu_nodes, num_gpu_nodes, gradient_clip, max_nb_epochs, min_nb_epochs, nb_sanity_val_steps, default_save_path, tng_tqdm_dic

[0.8.0] - Fixed
  • Run graceful training teardown on interpreter exit (#1631)

  • Fixed user warning when apex was used together with learning rate schedulers (#1873)

  • Fixed multiple calls of EarlyStopping callback (#1863)

  • Fixed an issue with Trainer.from_argparse_args when passing in unknown Trainer args (#1932)

  • Fixed bug related to logger not being reset correctly for model after tuner algorithms (#1933)

  • Fixed root node resolution for SLURM cluster with dash in host name (#1954)

  • Fixed LearningRateLogger in multi-scheduler setting (#1944)

  • Fixed test configuration check and testing (#1804)

  • Fixed an issue with Trainer constructor silently ignoring unknown/misspelled arguments (#1820)

  • Fixed save_weights_only in ModelCheckpoint (#1780)

  • Allow use of same WandbLogger instance for multiple training loops (#2055)

  • Fixed an issue with _auto_collect_arguments collecting local variables that are not constructor arguments and not working for signatures that have the instance not named self (#2048)

  • Fixed mistake in parameters’ grad norm tracking (#2012)

  • Fixed CPU and hanging GPU crash (#2118)

  • Fixed an issue with the model summary and example_input_array depending on a specific ordering of the submodules in a LightningModule (#1773)

  • Fixed Tpu logging (#2230)

  • Fixed Pid port + duplicate rank_zero logging (#2140, #2231)

[0.7.6] - 2020-05-16

[0.7.6] - Added
  • Added callback for logging learning rates (#1498)

  • Added transfer learning example (for a binary classification task in computer vision) (#1564)

  • Added type hints in Trainer.fit() and Trainer.test() to reflect that also a list of dataloaders can be passed in (#1723).

  • Added auto scaling of batch size (#1638)

  • The progress bar metrics now also get updated in training_epoch_end (#1724)

  • Enable NeptuneLogger to work with distributed_backend=ddp (#1753)

  • Added option to provide seed to random generators to ensure reproducibility (#1572)

  • Added override for hparams in load_from_ckpt (#1797)

  • Added support multi-node distributed execution under torchelastic (#1811, #1818)

  • Added using store_true for bool args (#1822, #1842)

  • Added dummy logger for internally disabling logging for some features (#1836)

[0.7.6] - Changed
  • Enable non-blocking for device transfers to GPU (#1843)

  • Replace mata_tags.csv with hparams.yaml (#1271)

  • Reduction when batch_size < num_gpus (#1609)

  • Updated LightningTemplateModel to look more like Colab example (#1577)

  • Don’t convert namedtuple to tuple when transferring the batch to target device (#1589)

  • Allow passing hparams as keyword argument to LightningModule when loading from checkpoint (#1639)

  • Args should come after the last positional argument (#1807)

  • Made ddp the default if no backend specified with multiple GPUs (#1789)

[0.7.6] - Deprecated
  • Deprecated tags_csv in favor of hparams_file (#1271)

[0.7.6] - Fixed
  • Fixed broken link in PR template (#1675)

  • Fixed ModelCheckpoint not None checking filepath (#1654)

  • Trainer now calls on_load_checkpoint() when resuming from a checkpoint (#1666)

  • Fixed sampler logic for ddp with iterable dataset (#1734)

  • Fixed _reset_eval_dataloader() for IterableDataset (#1560)

  • Fixed Horovod distributed backend to set the root_gpu property (#1669)

  • Fixed wandb logger global_step affects other loggers (#1492)

  • Fixed disabling progress bar on non-zero ranks using Horovod backend (#1709)

  • Fixed bugs that prevent lr finder to be used together with early stopping and validation dataloaders (#1676)

  • Fixed a bug in Trainer that prepended the checkpoint path with version_ when it shouldn’t (#1748)

  • Fixed lr key name in case of param groups in LearningRateLogger (#1719)

  • Fixed accumulation parameter and suggestion method for learning rate finder (#1801)

  • Fixed num processes wasn’t being set properly and auto sampler was ddp failing (#1819)

  • Fixed bugs in semantic segmentation example (#1824)

  • Fixed saving native AMP scaler state (#1777)

  • Fixed native amp + ddp (#1788)

  • Fixed hparam logging with metrics (#1647)

[0.7.5] - 2020-04-27

[0.7.5] - Changed
  • Allow logging of metrics together with hparams (#1630)

[0.7.5] - Removed
  • Removed Warning from trainer loop (#1634)

[0.7.5] - Fixed
  • Fixed ModelCheckpoint not being fixable (#1632)

  • Fixed CPU DDP breaking change and DDP change (#1635)

  • Tested pickling (#1636)

[0.7.4] - 2020-04-26

[0.7.4] - Added
  • Added flag replace_sampler_ddp to manually disable sampler replacement in DDP (#1513)

  • Added auto_select_gpus flag to trainer that enables automatic selection of available GPUs on exclusive mode systems.

  • Added learning rate finder (#1347)

  • Added support for DDP mode in clusters without SLURM (#1387)

  • Added test_dataloaders parameter to Trainer.test() (#1434)

  • Added terminate_on_nan flag to trainer that performs a NaN check with each training iteration when set to True (#1475)

  • Added speed parity tests (max 1 sec difference per epoch)(#1482)

  • Added ddp_cpu backend for testing ddp without GPUs (#1158)

  • Added Horovod support as a distributed backend Trainer(distributed_backend='horovod') (#1529)

  • Added support for 8 core distributed training on Kaggle TPU’s (#1568)

  • Added support for native AMP (#1561, #1580)

[0.7.4] - Changed
  • Changed the default behaviour to no longer include a NaN check with each training iteration (#1475)

  • Decoupled the progress bar from trainer` it is a callback now and can be customized or even be replaced entirely (#1450).

  • Changed lr schedule step interval behavior to update every backwards pass instead of every forwards pass (#1477)

  • Defines shared proc. rank, remove rank from instances (e.g. loggers) (#1408)

  • Updated semantic segmentation example with custom U-Net and logging (#1371)

  • Disabled val and test shuffling (#1600)

[0.7.4] - Deprecated
  • Deprecated training_tqdm_dict in favor of progress_bar_dict (#1450).

[0.7.4] - Removed
  • Removed test_dataloaders parameter from Trainer.fit() (#1434)

[0.7.4] - Fixed
  • Added the possibility to pass nested metrics dictionaries to loggers (#1582)

  • Fixed memory leak from opt return (#1528)

  • Fixed saving checkpoint before deleting old ones (#1453)

  • Fixed loggers - flushing last logged metrics even before continue, e.g. trainer.test() results (#1459)

  • Fixed optimizer configuration when configure_optimizers returns dict without lr_scheduler (#1443)

  • Fixed LightningModule - mixing hparams and arguments in LightningModule.__init__() crashes load_from_checkpoint() (#1505)

  • Added a missing call to the on_before_zero_grad model hook (#1493).

  • Allow use of sweeps with WandbLogger (#1512)

  • Fixed a bug that caused the callbacks Trainer argument to reference a global variable (#1534).

  • Fixed a bug that set all boolean CLI arguments from Trainer.add_argparse_args always to True (#1571)

  • Fixed do not copy the batch when training on a single GPU (#1576, #1579)

  • Fixed soft checkpoint removing on DDP (#1408)

  • Fixed automatic parser bug (#1585)

  • Fixed bool conversion from string (#1606)

[0.7.3] - 2020-04-09

[0.7.3] - Added
  • Added rank_zero_warn for warning only in rank 0 (#1428)

[0.7.3] - Fixed
  • Fixed default DistributedSampler for DDP training (#1425)

  • Fixed workers warning not on windows (#1430)

  • Fixed returning tuple from run_training_batch (#1431)

  • Fixed gradient clipping (#1438)

  • Fixed pretty print (#1441)

[0.7.2] - 2020-04-07

[0.7.2] - Added
  • Added same step loggers’ metrics aggregation (#1278)

  • Added parity test between a vanilla MNIST model and lightning model (#1284)

  • Added parity test between a vanilla RNN model and lightning model (#1351)

  • Added Reinforcement Learning - Deep Q-network (DQN) lightning example (#1232)

  • Added support for hierarchical dict (#1152)

  • Added TrainsLogger class (#1122)

  • Added type hints to pytorch_lightning.core (#946)

  • Added support for IterableDataset in validation and testing (#1104)

  • Added support for non-primitive types in hparams for TensorboardLogger (#1130)

  • Added a check that stops the training when loss or weights contain NaN or inf values. (#1097)

  • Added support for IterableDataset when val_check_interval=1.0 (default), this will trigger validation at the end of each epoch. (#1283)

  • Added summary method to Profilers. (#1259)

  • Added informative errors if user defined dataloader has zero length (#1280)

  • Added testing for python 3.8 (#915)

  • Added model configuration checking (#1199)

  • Added support for optimizer frequencies through LightningModule.configure_optimizers() (#1269)

  • Added option to run without an optimizer by returning None from configure_optimizers. (#1279)

  • Added a warning when the number of data loader workers is small. (#1378)

[0.7.2] - Changed
  • Changed (renamed and refatored) TensorRunningMean -> TensorRunningAccum: running accumulations were generalized. (#1278)

  • Changed progress_bar_refresh_rate trainer flag to disable progress bar when set to 0. (#1108)

  • Enhanced load_from_checkpoint to also forward params to the model (#1307)

  • Updated references to self.forward() to instead use the __call__ interface. (#1211)

  • Changed default behaviour of configure_optimizers to use no optimizer rather than Adam. (#1279)

  • Allow to upload models on W&B (#1339)

  • On DP and DDP2 unsqueeze is automated now (#1319)

  • Did not always create a DataLoader during reinstantiation, but the same type as before (if subclass of DataLoader) (#1346)

  • Did not interfere with a default sampler (#1318)

  • Remove default Adam optimizer (#1317)

  • Give warnings for unimplemented required lightning methods (#1317)

  • Made evaluate method private >> Trainer._evaluate(...). (#1260)

  • Simplify the PL examples structure (shallower and more readable) (#1247)

  • Changed min max gpu memory to be on their own plots (#1358)

  • Remove .item which causes sync issues (#1254)

  • Changed smoothing in TQDM to decrease variability of time remaining between training / eval (#1194)

  • Change default logger to dedicated one (#1064)

[0.7.2] - Deprecated
  • Deprecated Trainer argument print_nan_grads (#1097)

  • Deprecated Trainer argument show_progress_bar (#1108)

[0.7.2] - Removed
  • Removed test for no test dataloader in .fit (#1495)

  • Removed duplicated module pytorch_lightning.utilities.arg_parse for loading CLI arguments (#1167)

  • Removed wandb logger’s finalize method (#1193)

  • Dropped torchvision dependency in tests and added own MNIST dataset class instead (#986)

[0.7.2] - Fixed
  • Fixed model_checkpoint when saving all models (#1359)

  • Trainer.add_argparse_args classmethod fixed. Now it adds a type for the arguments (#1147)

  • Fixed bug related to type checking of ReduceLROnPlateau lr schedulers(#1126)

  • Fixed a bug to ensure lightning checkpoints to be backward compatible (#1132)

  • Fixed a bug that created an extra dataloader with active reload_dataloaders_every_epoch (#1196)

  • Fixed all warnings and errors in the docs build process (#1191)

  • Fixed an issue where val_percent_check=0 would not disable validation (#1251)

  • Fixed average of incomplete TensorRunningMean (#1309)

  • Fixed WandbLogger.watch with wandb.init() (#1311)

  • Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented (#1235).

  • Fixed a bug that would cause trainer.test() to run on the validation set when overloading validation_epoch_end and test_end (#1353)

  • Fixed WandbLogger.watch - use of the watch method without importing wandb (#1311)

  • Fixed WandbLogger to be used with ‘ddp’ - allow reinits in sub-processes (#1149, #1360)

  • Made training_epoch_end behave like validation_epoch_end (#1357)

  • Fixed fast_dev_run running validation twice (#1365)

  • Fixed pickle error from quick patch __code__ (#1352)

  • Fixed memory leak on GPU0 (#1094, #1349)

  • Fixed checkpointing interval (#1272)

  • Fixed validation and training loops run the partial dataset (#1192)

  • Fixed running on_validation_end only on main process in DDP (#1125)

  • Fixed load_spawn_weights only in proc rank 0 (#1385)

  • Fixes using deprecated use_amp attribute (#1145)

  • Fixed Tensorboard logger error: lightning_logs directory not exists in multi-node DDP on nodes with rank != 0 (#1377)

  • Fixed Unimplemented backend XLA error on TPU (#1387)

[0.7.1] - 2020-03-07

[0.7.1] - Fixed
  • Fixes print issues and data_loader (#1080)

[0.7.0] - 2020-03-06

[0.7.0] - Added
  • Added automatic sampler setup. Depending on DDP or TPU, lightning configures the sampler correctly (user needs to do nothing) (#926)

  • Added reload_dataloaders_every_epoch=False flag for trainer. Some users require reloading data every epoch (#926)

  • Added progress_bar_refresh_rate=50 flag for trainer. Throttle refresh rate on notebooks (#926)

  • Updated governance docs

  • Added a check to ensure that the metric used for early stopping exists before training commences (#542)

  • Added optimizer_idx argument to backward hook (#733)

  • Added entity argument to WandbLogger to be passed to wandb.init (#783)

  • Added a tool for profiling training runs (#782)

  • Improved flexibility for naming of TensorBoard logs, can now set version to a str to just save to that directory, and use name='' to prevent experiment-name directory (#804)

  • Added option to specify step key when logging metrics (#808)

  • Added train_dataloader, val_dataloader and test_dataloader arguments to Trainer.fit(), for alternative data parsing (#759)

  • Added Tensor Processing Unit (TPU) support (#868)

  • Added semantic segmentation example (#751,#876, #881)

  • Split callbacks in multiple files (#849)

  • Support for user defined callbacks (#889 and #950)

  • Added support for multiple loggers to be passed to Trainer as an iterable (e.g. list, tuple, etc.) (#903)

  • Added support for step-based learning rate scheduling (#941)

  • Added support for logging hparams as dict (#1029)

  • Checkpoint and early stopping now work without val. step (#1041)

  • Support graceful training cleanup after Keyboard Interrupt (#856, #1019)

  • Added type hints for function arguments (#912, )

  • Added default argparser for Trainer (#952, #1023)

  • Added TPU gradient clipping (#963)

  • Added max/min number of steps in Trainer (#728)

[0.7.0] - Changed
  • Improved NeptuneLogger by adding close_after_fit argument to allow logging after training(#908)

  • Changed default TQDM to use tqdm.auto for prettier outputs in IPython notebooks (#752)

  • Changed pytorch_lightning.logging to pytorch_lightning.loggers (#767)

  • Moved the default tqdm_dict definition from Trainer to LightningModule, so it can be overridden by the user (#749)

  • Moved functionality of LightningModule.load_from_metrics into LightningModule.load_from_checkpoint (#995)

  • Changed Checkpoint path parameter from filepath to dirpath (#1016)

  • Freezed models hparams as Namespace property (#1029)

  • Dropped logging config in package init (#1015)

  • Renames model steps (#1051)

    • training_end >> training_epoch_end

    • validation_end >> validation_epoch_end

    • test_end >> test_epoch_end

  • Refactor dataloading, supports infinite dataloader (#955)

  • Create single file in TensorBoardLogger (#777)

[0.7.0] - Deprecated
  • Deprecated pytorch_lightning.logging (#767)

  • Deprecated LightningModule.load_from_metrics in favour of LightningModule.load_from_checkpoint (#995, #1079)

  • Deprecated @data_loader decorator (#926)

  • Deprecated model steps training_end, validation_end and test_end (#1051, #1056)

[0.7.0] - Removed
  • Removed dependency on pandas (#736)

  • Removed dependency on torchvision (#797)

  • Removed dependency on scikit-learn (#801)

[0.7.0] - Fixed
  • Fixed a bug where early stopping on_end_epoch would be called inconsistently when check_val_every_n_epoch == 0 (#743)

  • Fixed a bug where the model checkpointer didn’t write to the same directory as the logger (#771)

  • Fixed a bug where the TensorBoardLogger class would create an additional empty log file during fitting (#777)

  • Fixed a bug where global_step was advanced incorrectly when using accumulate_grad_batches > 1 (#832)

  • Fixed a bug when calling self.logger.experiment with multiple loggers (#1009)

  • Fixed a bug when calling logger.append_tags on a NeptuneLogger with a single tag (#1009)

  • Fixed sending back data from .spawn by saving and loading the trained model in/out of the process (#1017

  • Fixed port collision on DDP (#1010)

  • Fixed/tested pass overrides (#918)

  • Fixed comet logger to log after train (#892)

  • Remove deprecated args to learning rate step function (#890)

[0.6.0] - 2020-01-21

[0.6.0] - Added
  • Added support for resuming from a specific checkpoint via resume_from_checkpoint argument (#516)

  • Added support for ReduceLROnPlateau scheduler (#320)

  • Added support for Apex mode O2 in conjunction with Data Parallel (#493)

  • Added option (save_top_k) to save the top k models in the ModelCheckpoint class (#128)

  • Added on_train_start and on_train_end hooks to ModelHooks (#598)

  • Added TensorBoardLogger (#607)

  • Added support for weight summary of model with multiple inputs (#543)

  • Added map_location argument to load_from_metrics and load_from_checkpoint (#625)

  • Added option to disable validation by setting val_percent_check=0 (#649)

  • Added NeptuneLogger class (#648)

  • Added WandbLogger class (#627)

[0.6.0] - Changed
  • Changed the default progress bar to print to stdout instead of stderr (#531)

  • Renamed step_idx to step, epoch_idx to epoch, max_num_epochs to max_epochs and min_num_epochs to min_epochs (#589)

  • Renamed total_batch_nb to total_batches, nb_val_batches to num_val_batches, nb_training_batches to num_training_batches, max_nb_epochs to max_epochs, min_nb_epochs to min_epochs, nb_test_batches to num_test_batches, and nb_val_batches to num_val_batches (#567)

  • Changed gradient logging to use parameter names instead of indexes (#660)

  • Changed the default logger to TensorBoardLogger (#609)

  • Changed the directory for tensorboard logging to be the same as model checkpointing (#706)

[0.6.0] - Deprecated
  • Deprecated max_nb_epochs and min_nb_epochs (#567)

  • Deprecated the on_sanity_check_start hook in ModelHooks (#598)

[0.6.0] - Removed
  • Removed the save_best_only argument from ModelCheckpoint, use save_top_k=1 instead (#128)

[0.6.0] - Fixed
  • Fixed a bug which occurred when using Adagrad with cuda (#554)

  • Fixed a bug where training would be on the GPU despite setting gpus=0 or gpus=[] (#561)

  • Fixed an error with print_nan_gradients when some parameters do not require gradient (#579)

  • Fixed a bug where the progress bar would show an incorrect number of total steps during the validation sanity check when using multiple validation data loaders (#597)

  • Fixed support for PyTorch 1.1.0 (#552)

  • Fixed an issue with early stopping when using a val_check_interval < 1.0 in Trainer (#492)

  • Fixed bugs relating to the CometLogger object that would cause it to not work properly (#481)

  • Fixed a bug that would occur when returning -1 from on_batch_start following an early exit or when the batch was None (#509)

  • Fixed a potential race condition with several processes trying to create checkpoint directories (#530)

  • Fixed a bug where batch ‘segments’ would remain on the GPU when using truncated_bptt > 1 (#532)

  • Fixed a bug when using IterableDataset (#547)

  • Fixed a bug where .item was called on non-tensor objects (#602)

  • Fixed a bug where Trainer.train would crash on an uninitialized variable if the trainer was run after resuming from a checkpoint that was already at max_epochs (#608)

  • Fixed a bug where early stopping would begin two epochs early (#617)

  • Fixed a bug where num_training_batches and num_test_batches would sometimes be rounded down to zero (#649)

  • Fixed a bug where an additional batch would be processed when manually setting num_training_batches (#653)

  • Fixed a bug when batches did not have a .copy method (#701)

  • Fixed a bug when using log_gpu_memory=True in Python 3.6 (#715)

  • Fixed a bug where checkpoint writing could exit before completion, giving incomplete checkpoints (#689)

  • Fixed a bug where on_train_end was not called when ealy stopping (#723)

[0.5.3] - 2019-11-06

[0.5.3] - Added
  • Added option to disable default logger, checkpointer, and early stopping by passing logger=False, checkpoint_callback=False and early_stop_callback=False respectively

  • Added CometLogger for use with Comet.ml

  • Added val_check_interval argument to Trainer allowing validition to be performed at every given number of batches

  • Added functionality to save and load hyperparameters using the standard checkpoint mechanism

  • Added call to torch.cuda.empty_cache before training starts

  • Added option for user to override the call t backward

  • Added support for truncated backprop through time via the truncated_bptt_steps argument in Trainer

  • Added option to operate on all outputs from training_step in DDP2

  • Added a hook for modifying DDP init

  • Added a hook for modifying Apex

[0.5.3] - Changed
  • Changed experiment version to be padded with zeros (e.g. /dir/version_9 becomes /dir/version_0009)

  • Changed callback metrics to include any metrics given in logs or progress bar

  • Changed the default for save_best_only in ModelCheckpoint to True

  • Added tng_data_loader for backwards compatibility

  • Renamed MLFlowLogger.client to MLFlowLogger.experiment for consistency

  • Moved global_step increment to happen after the batch has been processed

  • Changed weights restore to first attempt HPC weights before restoring normally, preventing both weights being restored and running out of memory

  • Changed progress bar functionality to add multiple progress bars for train/val/test

  • Changed calls to print to use logging instead

[0.5.3] - Deprecated
  • Deprecated tng_dataloader

[0.5.3] - Fixed
  • Fixed an issue where the number of batches was off by one during training

  • Fixed a bug that occurred when setting a ckeckpoint callback and early_stop_callback=False

  • Fixed an error when importing CometLogger

  • Fixed a bug where the gpus argument had some unexpected behaviour

  • Fixed a bug where the computed total number of batches was sometimes incorrect

  • Fixed a bug where the progress bar would sometimes not show the total number of batches in test mode

  • Fixed a bug when using the log_gpu_memory='min_max' option in Trainer

  • Fixed a bug where checkpointing would sometimes erase the current directory

[0.5.2] - 2019-10-10

[0.5.2] - Added
  • Added weights_summary argument to Trainer to be set to full (full summary), top (just top level modules) or other

  • Added tags argument to MLFlowLogger

[0.5.2] - Changed
  • Changed default for amp_level to O1

[0.5.2] - Removed
  • Removed the print_weights_summary argument from Trainer

[0.5.2] - Fixed
  • Fixed a bug where logs were not written properly

  • Fixed a bug where logger.finalize wasn’t called after training is complete

  • Fixed callback metric errors in DDP

  • Fixed a bug where TestTubeLogger didn’t log to the correct directory

[0.5.1] - 2019-10-05

[0.5.1] - Added
  • Added the LightningLoggerBase class for experiment loggers

  • Added MLFlowLogger for logging with mlflow

  • Added TestTubeLogger for logging with test_tube

  • Added a different implementation of DDP (distributed_backed='ddp2') where every node has one model using all GPUs

  • Added support for optimisers which require a closure (e.g. LBFGS)

  • Added automatic MASTER_PORT default for DDP when not set manually

  • Added new GPU memory logging options 'min_max' (log only the min/max utilization) and 'all' (log all the GPU memory)

[0.5.1] - Changed
  • Changed schedulers to always be called with the current epoch

  • Changed test_tube to an optional dependency

  • Changed data loaders to internally use a getter instead of a python property

  • Disabled auto GPU loading when restoring weights to prevent out of memory errors

  • Changed logging, early stopping and checkpointing to occur by default

[0.5.1] - Fixed
  • Fixed a bug with samplers that do not specify set_epoch

  • Fixed a bug when using the MLFlowLogger with unsupported data types, this will now raise a warning

  • Fixed a bug where gradient norms were always zero using track_grad_norm

  • Fixed a bug which causes a crash when logging memory

[0.5.0] - 2019-09-26

[0.5.0] - Changed
  • Changed data_batch argument to batch throughout

  • Changed batch_i argument to batch_idx throughout

  • Changed tng_dataloader method to train_dataloader

  • Changed on_tng_metrics method to on_training_metrics

  • Changed gradient_clip argument to gradient_clip_val

  • Changed add_log_row_interval to row_log_interval

[0.5.0] - Fixed
  • Fixed a bug with tensorboard logging in multi-gpu setup

[0.4.9] - 2019-09-16

[0.4.9] - Added
  • Added the flag log_gpu_memory to Trainer to deactivate logging of GPU memory utilization

  • Added SLURM resubmit functionality (port from test-tube)

  • Added optional weight_save_path to trainer to remove the need for a checkpoint_callback when using cluster training

  • Added option to use single gpu per node with DistributedDataParallel

[0.4.9] - Changed
  • Changed functionality of validation_end and test_end with multiple dataloaders to be given all of the dataloaders at once rather than in separate calls

  • Changed print_nan_grads to only print the parameter value and gradients when they contain NaN

  • Changed gpu API to take integers as well (e.g. gpus=2 instead of gpus=[0, 1])

  • All models now loaded on to CPU to avoid device and out of memory issues in PyTorch

[0.4.9] - Fixed
  • Fixed a bug where data types that implement .to but not .cuda would not be properly moved onto the GPU

  • Fixed a bug where data would not be re-shuffled every epoch when using a DistributedSampler

[0.4.8] - 2019-08-31

[0.4.8] - Added
  • Added test_step and test_end methods, used when Trainer.test is called

  • Added GradientAccumulationScheduler callback which can be used to schedule changes to the number of accumulation batches

  • Added option to skip the validation sanity check by setting nb_sanity_val_steps = 0

[0.4.8] - Fixed
  • Fixed a bug when setting nb_sanity_val_steps = 0

[0.4.7] - 2019-08-24

[0.4.7] - Changed
  • Changed the default val_check_interval to 1.0

  • Changed defaults for nb_val_batches, nb_tng_batches and nb_test_batches to 0

[0.4.7] - Fixed
  • Fixed a bug where the full validation set as used despite setting val_percent_check

  • Fixed a bug where an Exception was thrown when using a data set containing a single batch

  • Fixed a bug where an Exception was thrown if no val_dataloader was given

  • Fixed a bug where tuples were not properly transferred to the GPU

  • Fixed a bug where data of a non standard type was not properly handled by the trainer

  • Fixed a bug when loading data as a tuple

  • Fixed a bug where AttributeError could be suppressed by the Trainer

[0.4.6] - 2019-08-15

[0.4.6] - Added
  • Added support for data to be given as a dict or list with a single gpu

  • Added support for configure_optimizers to return a single optimizer, two list (optimizers and schedulers), or a single list

[0.4.6] - Fixed
  • Fixed a bug where returning just an optimizer list (i.e. without schedulers) from configure_optimizers would throw an Exception

[0.4.5] - 2019-08-13

[0.4.5] - Added
  • Added optimizer_step method that can be overridden to change the standard optimizer behaviour

[0.4.4] - 2019-08-12

[0.4.4] - Added
  • Added supoort for multiple validation dataloaders

  • Added support for latest test-tube logger (optimised for torch==1.2.0)

[0.4.4] - Changed
  • validation_step and val_dataloader are now optional

  • lr_scheduler is now activated after epoch

[0.4.4] - Fixed
  • Fixed a bug where a warning would show when using lr_scheduler in torch>1.1.0

  • Fixed a bug where an Exception would be thrown if using torch.DistributedDataParallel without using a DistributedSampler, this now throws a Warning instead

[0.4.3] - 2019-08-10

[0.4.3] - Fixed
  • Fixed a bug where accumulate gradients would scale the loss incorrectly

[0.4.2] - 2019-08-08

[0.4.2] - Changed
  • Changed install requirement to torch==1.2.0

[0.4.1] - 2019-08-08

[0.4.1] - Changed
  • Changed install requirement to torch==1.1.0

[0.4.0] - 2019-08-08

[0.4.0] - Added
  • Added 16-bit support for a single GPU

  • Added support for training continuation (preserves epoch, global step etc.)

[0.4.0] - Changed
  • Changed training_step and validation_step, outputs will no longer be automatically reduced

[0.4.0] - Removed
  • Removed need for Experiment object in Trainer

[0.4.0] - Fixed
  • Fixed issues with reducing outputs from generative models (such as images and text)

[0.3.6] - 2019-07-25

[0.3.6] - Added
  • Added a decorator to do lazy data loading internally

[0.3.6] - Fixed
  • Fixed a bug where Experiment object was not process safe, potentially causing logs to be overwritten

[0.3.5] - 2019-07-25

[0.3.4] - 2019-07-22

[0.3.3] - 2019-07-22

[0.3.2] - 2019-07-21

[0.3.1] - 2019-07-21

[0.2.x] - 2019-07-09

[0.1.x] - 2019-06-DD


© Copyright Copyright (c) 2018-2023, Lightning AI et al... Revision 3d69e466.

Built with Sphinx using a theme provided by Read the Docs.

Get Started

Level Up

Core API

API Reference

Common Workflows

Glossary

Hands-on Examples

Community

Read the Docs v: 1.9.1
Versions
latest
stable
1.9.1
1.9.0
1.8.6
1.8.5
1.8.4
1.8.3
1.8.2
1.8.1
1.8.0
1.7.7
1.7.6
1.7.5
1.7.4
1.7.3
1.7.2
1.7.1
1.7.0
1.6.5
1.6.4
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.