# Barlow Twins Tutorial¶

This notebook describes the self-supervised learning method Barlow Twins. Barlow Twins differs from other recently proposed algorithms as it doesn’t fall under the category of either contrastive learning, or methods like knowledge distillation or clustering. The simplicity of the loss function and its effectiveness in comparison to the current state of the art makes Barlow Twins an interesting case study.

Open in

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

## Setup¶

This notebook requires some packages besides pytorch-lightning.

[1]:

# ! pip install --quiet "torch>=1.6, <1.9" "torchmetrics>=0.3" "torchvision" "matplotlib" "pytorch-lightning>=1.3"


## Barlow Twins¶

Barlow Twins finds itself in unique place amongst the current state-of-the-art self-supervised learning methods. It does not fall under the existing categories of contrastive learning, knowledge distillation or clustering based methods. Instead, it creates its own category of redundancy reductionand achieves competitive performance with a simple yet effective loss function. In this tutorial, we look at coding up a small version of Barlow Twins algorithm using PyTorch Lightning.

[2]:

from functools import partial
from typing import Sequence, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as VisionF
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.metrics.functional import accuracy
from torchvision.datasets import CIFAR10
from torchvision.models.resnet import resnet18
from torchvision.utils import make_grid

batch_size = 32
num_workers = 0  # to run notebook on CPU
max_epochs = 200
z_dim = 128


### Transforms¶

We first define the data augmentation pipeline used in Barlow Twins. Here, we use pipeline proposed in SimCLR, which generates two copies/views of an input image by applying the following transformations in a sequence.

First it takes a random crop of the image and resizes it to a fixed pre-specified size. Then, it applies a left-to-right random flip with a probability of 0.5. This step is followed by a composition of color jitter, conversion to grayscale with a probability of 0.2 and the application of a Gaussian blur filter. Finally, we normalize the image and convert it to a tensor.

Within this transform, we add a third view for our online finetuner, which we explain later on. But, to explain things quickly here, we add a another transform to perform perform test our encoder on a downstream classification task.

[3]:

class BarlowTwinsTransform:
def __init__(self, train=True, input_height=224, gaussian_blur=True, jitter_strength=1.0, normalize=None):

self.input_height = input_height
self.gaussian_blur = gaussian_blur
self.jitter_strength = jitter_strength
self.normalize = normalize
self.train = train

color_jitter = transforms.ColorJitter(
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.2 * self.jitter_strength,
)

color_transform = [transforms.RandomApply([color_jitter], p=0.8), transforms.RandomGrayscale(p=0.2)]

if self.gaussian_blur:
kernel_size = int(0.1 * self.input_height)
if kernel_size % 2 == 0:
kernel_size += 1

color_transform.append(transforms.RandomApply([transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5))

self.color_transform = transforms.Compose(color_transform)

if normalize is None:
self.final_transform = transforms.ToTensor()
else:
self.final_transform = transforms.Compose([transforms.ToTensor(), normalize])

self.transform = transforms.Compose(
[
transforms.RandomResizedCrop(self.input_height),
transforms.RandomHorizontalFlip(p=0.5),
self.color_transform,
self.final_transform,
]
)

self.finetune_transform = None
if self.train:
self.finetune_transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
)
else:
self.finetune_transform = transforms.ToTensor()

def __call__(self, sample):
return self.transform(sample), self.transform(sample), self.finetune_transform(sample)


### Dataset¶

We select CIFAR10 as the dataset to demonstrate the pre-training process for Barlow Twins. CIFAR10 images are 32x32 in size and we do not apply a Gaussian blur transformation on them. In this step, we create the training and validation dataloaders for CIFAR10.

[4]:

def cifar10_normalization():
normalize = transforms.Normalize(
mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
)
return normalize

train_transform = BarlowTwinsTransform(
train=True, input_height=32, gaussian_blur=False, jitter_strength=0.5, normalize=cifar10_normalization()
)

val_transform = BarlowTwinsTransform(
train=False, input_height=32, gaussian_blur=False, jitter_strength=0.5, normalize=cifar10_normalization()
)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz

Extracting ./cifar-10-python.tar.gz to .


### Plot images¶

To see how the CIFAR10 images look after the data augmentation pipeline, we load a few images from the dataloader and plot them here.

[5]:

for batch in val_loader:
(img1, img2, _), label = batch
break

img_grid = make_grid(img1, normalize=True)

def show(imgs):
if not isinstance(imgs, list):
imgs = [imgs]
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = img.detach()
img = VisionF.to_pil_image(img)
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

show(img_grid)


### Barlow Twins Loss¶

Here we define the loss function for Barlow Twins. It first normalizes the D dimensinonal vectors from the projection head and then computes the DxD cross-correlation matrix between the normalized vectors of the 2 views of each image.

Then it splits this cross-correlation matrix into two parts. The first part, the diagonal of this matrix is brought closer to 1, which pushes up the cosine similarity between the latent vectors of two views of each image, thus making the backbone invariant to the transformations applied to the views. The second part of the loss pushes the non-diagonal elements of the cross-corrlelation matrix closes to 0. This reduces the redundancy between the different dimensions of the latent vector.

[6]:

class BarlowTwinsLoss(nn.Module):
def __init__(self, batch_size, lambda_coeff=5e-3, z_dim=128):
super().__init__()

self.z_dim = z_dim
self.batch_size = batch_size
self.lambda_coeff = lambda_coeff

def off_diagonal_ele(self, x):
# return a flattened view of the off-diagonal elements of a square matrix
n, m = x.shape
assert n == m
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

def forward(self, z1, z2):
# N x D, where N is the batch size and D is output dim of projection head
z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0)
z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0)

cross_corr = torch.matmul(z1_norm.T, z2_norm) / self.batch_size

off_diag = self.off_diagonal_ele(cross_corr).pow_(2).sum()

return on_diag + self.lambda_coeff * off_diag


### Backbone¶

This is a standard Resnet backbone that we pre-train using the Barlow Twins method. To accommodate the 32x32 CIFAR10 images, we replace the first 7x7 convolution of the Resnet backbone by a 3x3 filter. We also remove the first Maxpool layer from the network for CIFAR10 images.

[7]:

encoder = resnet18()

# for CIFAR10, replace the first 7x7 conv with smaller 3x3 conv and remove the first maxpool
encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
encoder.maxpool = nn.MaxPool2d(kernel_size=1, stride=1)

# replace classification fc layer of Resnet to obtain representations from the backbone
encoder.fc = nn.Identity()


Unlike SimCLR and BYOL, the downstream performance of Barlow Twins greatly benefits from having a larger projection head after the backbone network. The paper utilizes a 3 layer MLP with 8192 hidden dimensions and 8192 as the output dimenion of the projection head. For the purposes of the tutorial, we use a smaller projection head. But, it is imperative to mention here that in practice, Barlow Twins needs to be trained using a bigger projection head as it is highly sensitive to its architecture and output dimensionality.

[8]:

class ProjectionHead(nn.Module):
def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128):
super().__init__()

nn.Linear(input_dim, hidden_dim, bias=True),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim, bias=False),
)

def forward(self, x):


### Learning rate warmup¶

For the purposes of this tutorial, we keep things simple and use a linear warmup schedule with Adam optimizer. In our previous experiments we have found that linear warmup part is much more important for the final performance of a model than the cosine decay component of the schedule.

[9]:

def fn(warmup_steps, step):
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
else:
return 1.0

def linear_warmup_decay(warmup_steps):
return partial(fn, warmup_steps)


### Barlow Twins Lightning Module¶

We keep the LightningModule for Barlow Twins neat and simple. It takes in an backbone encoder and initializes the projection head and the loss function. We configure the optimizer and the learning rate scheduler in the configure_optimizers method.

[10]:

class BarlowTwins(LightningModule):
def __init__(
self,
encoder,
encoder_out_dim,
num_training_samples,
batch_size,
lambda_coeff=5e-3,
z_dim=128,
learning_rate=1e-4,
warmup_epochs=10,
max_epochs=200,
):
super().__init__()

self.encoder = encoder
self.loss_fn = BarlowTwinsLoss(batch_size=batch_size, lambda_coeff=lambda_coeff, z_dim=z_dim)

self.learning_rate = learning_rate
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs

self.train_iters_per_epoch = num_training_samples // batch_size

def forward(self, x):
return self.encoder(x)

def shared_step(self, batch):
(x1, x2, _), _ = batch

return self.loss_fn(z1, z2)

def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)

self.log("train_loss", loss.item(), on_step=True, on_epoch=False)
return loss

def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch)

self.log("val_loss", loss, on_step=False, on_epoch=True)
return loss

def configure_optimizers(self):

warmup_steps = self.train_iters_per_epoch * self.warmup_epochs

scheduler = {
"scheduler": torch.optim.lr_scheduler.LambdaLR(
optimizer,
linear_warmup_decay(warmup_steps),
),
"interval": "step",
"frequency": 1,
}

return [optimizer], [scheduler]


### Evaluation¶

We define a callback which appends a linear layer on top of the encoder and trains the classification evaluation head in an online manner. We make sure not to backpropagate the gradients back to the encoder while tuning the linear layer. This technique was used in SimCLR as well and they showed that the final downstream classification peformance is pretty much similar to the results on online finetuning as the training progresses.

[11]:

class OnlineFineTuner(Callback):
def __init__(
self,
encoder_output_dim: int,
num_classes: int,
) -> None:
super().__init__()

self.optimizer: torch.optim.Optimizer

self.encoder_output_dim = encoder_output_dim
self.num_classes = num_classes

def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:

# add linear_eval layer and optimizer
pl_module.online_finetuner = nn.Linear(self.encoder_output_dim, self.num_classes).to(pl_module.device)

def extract_online_finetuning_view(
self, batch: Sequence, device: Union[str, torch.device]
) -> Tuple[torch.Tensor, torch.Tensor]:
(_, _, finetune_view), y = batch
finetune_view = finetune_view.to(device)
y = y.to(device)

return finetune_view, y

def on_train_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Sequence,
batch: Sequence,
batch_idx: int,
) -> None:
x, y = self.extract_online_finetuning_view(batch, pl_module.device)

feats = pl_module(x)

feats = feats.detach()
preds = pl_module.online_finetuner(feats)
loss = F.cross_entropy(preds, y)

loss.backward()
self.optimizer.step()

acc = accuracy(F.softmax(preds, dim=1), y)
pl_module.log("online_train_acc", acc, on_step=True, on_epoch=False)
pl_module.log("online_train_loss", loss, on_step=True, on_epoch=False)

def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Sequence,
batch: Sequence,
batch_idx: int,
) -> None:
x, y = self.extract_online_finetuning_view(batch, pl_module.device)

feats = pl_module(x)

feats = feats.detach()
preds = pl_module.online_finetuner(feats)
loss = F.cross_entropy(preds, y)

acc = accuracy(F.softmax(preds, dim=1), y)
pl_module.log("online_val_acc", acc, on_step=False, on_epoch=True, sync_dist=True)
pl_module.log("online_val_loss", loss, on_step=False, on_epoch=True, sync_dist=True)


Finally, we define the trainer for training the model. We pass in the train_loader and val_loader we had initialized earlier to the fit function.

[12]:

encoder_out_dim = 512

model = BarlowTwins(
encoder=encoder,
encoder_out_dim=encoder_out_dim,
num_training_samples=len(train_dataset),
batch_size=batch_size,
z_dim=z_dim,
)

online_finetuner = OnlineFineTuner(encoder_output_dim=encoder_out_dim, num_classes=10)
checkpoint_callback = ModelCheckpoint(every_n_val_epochs=100, save_top_k=-1, save_last=True)

trainer = Trainer(
max_epochs=max_epochs,
gpus=torch.cuda.device_count(),
precision=16 if torch.cuda.device_count() > 0 else 32,
callbacks=[online_finetuner, checkpoint_callback],
)

# uncomment this to train the model
# this is done for the tutorial so that the notebook compiles

/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:240: LightningDeprecationWarning: ModelCheckpoint(every_n_val_epochs) is deprecated in v1.4 and will be removed in v1.6. Please use every_n_epochs instead.
rank_zero_deprecation(
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:432: UserWarning: ModelCheckpoint(save_last=True, save_top_k=None, monitor=None) is a redundant configuration. You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None).
rank_zero_warn(
ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None) will duplicate the last checkpoint saved.
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:746: UserWarning: You requested multiple GPUs but did not specify a backend, e.g. Trainer(accelerator="dp"|"ddp"|"ddp2"). Setting accelerator="ddp_spawn" for you.
rank_zero_warn(
Using native 16bit precision.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


### Using the trained encoder for downstream tasks¶

Once the encoder is pretrained on CIFAR10, we can use it to get image embeddings and use them further downstream on tasks like classification, detection, segmentation etc.

In this tutorial, we did not completely train our encoder for 100s of epochs using the Barlow Twins pretraining method. So, we will load the pretrained encoder weights from a checkpoint and show the image embeddings obtained from that.

To create this checkpoint, the encoder was pretrained for 200 epochs, and obtained a online finetune accuracy of x% on CIFAR-10.

[13]:

# ckpt_model = torch.load('')  # upload checkpoint to aws
# encoder = ckpt_model.encoder
encoder = model.encoder

downstream_dataset = CIFAR10(root=".", train=False, transform=transforms.ToTensor())

img, label = batch
print(encoder(img).shape)
break

torch.Size([4, 512])


## Congratulations - Time to Join the Community!¶

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

### Star Lightning on GitHub¶

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel