Shortcuts

Graphics Processing Unit (GPU)

Single GPU Training

Make sure you’re running on a machine with at least one GPU. There’s no need to specify any NVIDIA flags as Lightning will do it for you.

trainer = Trainer(accelerator="gpu", devices=1)

Multi GPU Training

Lightning supports multiple ways of doing distributed training.


Model Parallel Training

Check out the Model Parallel Guide documentation.


Preparing your code

To train on CPU/GPU/TPU without changing your code, we need to build a few good habits :)

Delete .cuda() or .to() calls

Delete any calls to .cuda() or .to(device).

# before lightning
def forward(self, x):
    x = x.cuda(0)
    layer_1.cuda(0)
    x_hat = layer_1(x)


# after lightning
def forward(self, x):
    x_hat = layer_1(x)

Init tensors using type_as and register_buffer

When you need to create a new tensor, use type_as. This will make your code scale to any arbitrary number of GPUs or TPUs with Lightning.

# before lightning
def forward(self, x):
    z = torch.Tensor(2, 3)
    z = z.cuda(0)


# with lightning
def forward(self, x):
    z = torch.Tensor(2, 3)
    z = z.type_as(x)

The LightningModule knows what device it is on. You can access the reference via self.device. Sometimes it is necessary to store tensors as module attributes. However, if they are not parameters they will remain on the CPU even if the module gets moved to a new device. To prevent that and remain device agnostic, register the tensor as a buffer in your modules’s __init__ method with register_buffer().

class LitModel(LightningModule):
    def __init__(self):
        ...
        self.register_buffer("sigma", torch.eye(3))
        # you can now access self.sigma anywhere in your module

Remove samplers

DistributedSampler is automatically handled by Lightning.

See replace_sampler_ddp for more information.

Synchronize validation and test logging

When running in distributed mode, we have to ensure that the validation and test step logging calls are synchronized across processes. This is done by adding sync_dist=True to all self.log calls in the validation and test step. This ensures that each GPU worker has the same behaviour when tracking model checkpoints, which is important for later downstream tasks such as testing the best checkpoint across all workers. The sync_dist option can also be used in logging calls during the step methods, but be aware that this can lead to significant communication overhead and slow down your training.

Note if you use any built in metrics or custom metrics that use TorchMetrics, these do not need to be updated and are automatically handled for you.

def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    # Add sync_dist=True to sync logging across all GPU workers (may have performance impact)
    self.log("validation_loss", loss, on_step=True, on_epoch=True, sync_dist=True)


def test_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    # Add sync_dist=True to sync logging across all GPU workers (may have performance impact)
    self.log("test_loss", loss, on_step=True, on_epoch=True, sync_dist=True)

It is possible to perform some computation manually and log the reduced result on rank 0 as follows:

def test_step(self, batch, batch_idx):
    x, y = batch
    tensors = self(x)
    return tensors


def test_epoch_end(self, outputs):
    mean = torch.mean(self.all_gather(outputs))

    # When logging only on rank 0, don't forget to add
    # ``rank_zero_only=True`` to avoid deadlocks on synchronization.
    if self.trainer.is_global_zero:
        self.log("my_reduced_metric", mean, rank_zero_only=True)

Make models pickleable

It’s very likely your code is already pickleable, in that case no change in necessary. However, if you run a distributed model and get the following error:

self._launch(process_obj)
File "/net/software/local/python/3.6.5/lib/python3.6/multiprocessing/popen_spawn_posix.py", line 47,
in _launch reduction.dump(process_obj, fp)
File "/net/software/local/python/3.6.5/lib/python3.6/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <function <lambda> at 0x2b599e088ae8>:
attribute lookup <lambda> on __main__ failed

This means something in your model definition, transforms, optimizer, dataloader or callbacks cannot be pickled, and the following code will fail:

import pickle

pickle.dump(some_object)

This is a limitation of using multiple processes for distributed training within PyTorch. To fix this issue, find your piece of code that cannot be pickled. The end of the stacktrace is usually helpful. ie: in the stacktrace example here, there seems to be a lambda function somewhere in the code which cannot be pickled.

self._launch(process_obj)
File "/net/software/local/python/3.6.5/lib/python3.6/multiprocessing/popen_spawn_posix.py", line 47,
in _launch reduction.dump(process_obj, fp)
File "/net/software/local/python/3.6.5/lib/python3.6/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
_pickle.PicklingError: Can't pickle [THIS IS THE THING TO FIND AND DELETE]:
attribute lookup <lambda> on __main__ failed

Select GPU devices

You can select the GPU devices using ranges, a list of indices or a string containing a comma separated list of GPU ids:

# DEFAULT (int) specifies how many GPUs to use per node
Trainer(accelerator="gpu", devices=k)

# Above is equivalent to
Trainer(accelerator="gpu", devices=list(range(k)))

# Specify which GPUs to use (don't use when running on cluster)
Trainer(accelerator="gpu", devices=[0, 1])

# Equivalent using a string
Trainer(accelerator="gpu", devices="0, 1")

# To use all available GPUs put -1 or '-1'
# equivalent to list(range(torch.cuda.device_count()))
Trainer(accelerator="gpu", devices=-1)

The table below lists examples of possible input formats and how they are interpreted by Lightning.

devices

Type

Parsed

Meaning

3

int

[0, 1, 2]

first 3 GPUs

-1

int

[0, 1, 2, …]

all available GPUs

[0]

list

[0]

GPU 0

[1, 3]

list

[1, 3]

GPUs 1 and 3

“3”

str

[0, 1, 2]

first 3 GPUs

“1, 3”

str

[1, 3]

GPUs 1 and 3

“-1”

str

[0, 1, 2, …]

all available GPUs

Note

When specifying number of devices as an integer devices=k, setting the trainer flag auto_select_gpus=True will automatically help you find k GPUs that are not occupied by other processes. This is especially useful when GPUs are configured to be in “exclusive mode”, such that only one process at a time can access them. For more details see the trainer guide.

Select torch distributed backend

By default, Lightning will select the nccl backend over gloo when running on GPUs. Find more information about PyTorch’s supported backends here.

Lightning allows explicitly specifying the backend via the process_group_backend constructor argument on the relevant Strategy classes. By default, Lightning will select the appropriate process group backend based on the hardware used.

from pytorch_lightning.strategies import DDPStrategy

# Explicitly specify the process group backend if you choose to
ddp = DDPStrategy(process_group_backend="nccl")

# Configure the strategy on the Trainer
trainer = Trainer(strategy=ddp, accelerator="gpu", devices=8)

Distributed modes

Lightning allows multiple ways of training

  • Data Parallel (strategy='dp') (multiple-gpus, 1 machine)

  • DistributedDataParallel (strategy='ddp') (multiple-gpus across many machines (python script based)).

  • DistributedDataParallel (strategy='ddp_spawn') (multiple-gpus across many machines (spawn based)).

  • DistributedDataParallel 2 (strategy='ddp2') (DP in a machine, DDP across machines).

  • Horovod (strategy='horovod') (multi-machine, multi-gpu, configured at runtime)

  • Bagua (strategy='bagua') (multiple-gpus across many machines with advanced training algorithms)

  • TPUs (accelerator="tpu", devices=8|x) (tpu or TPU pod)

Note

If you request multiple GPUs or nodes without setting a mode, DDP Spawn will be automatically used.

For a deeper understanding of what Lightning is doing, feel free to read this guide.

Data Parallel

DataParallel (DP) splits a batch across k GPUs. That is, if you have a batch of 32 and use DP with 2 GPUs, each GPU will process 16 samples, after which the root node will aggregate the results.

Warning

DP use is discouraged by PyTorch and Lightning. State is not maintained on the replicas created by the DataParallel wrapper and you may see errors or misbehavior if you assign state to the module in the forward() or *_step() methods. For the same reason we cannot fully support Manual Optimization with DP. Use DDP which is more stable and at least 3x faster.

Warning

DP only supports scattering and gathering primitive collections of tensors like lists, dicts, etc. Therefore the transfer_batch_to_device() hook does not apply in this mode and if you have overridden it, it will not be called.

# train on 2 GPUs (using DP mode)
trainer = Trainer(accelerator="gpu", devices=2, strategy="dp")

Distributed Data Parallel

DistributedDataParallel (DDP) works as follows:

  1. Each GPU across each node gets its own process.

  2. Each GPU gets visibility into a subset of the overall dataset. It will only ever see that subset.

  3. Each process inits the model.

  4. Each process performs a full forward and backward pass in parallel.

  5. The gradients are synced and averaged across all processes.

  6. Each process updates its optimizer.

# train on 8 GPUs (same machine (ie: node))
trainer = Trainer(accelerator="gpu", devices=8, strategy="ddp")

# train on 32 GPUs (4 nodes)
trainer = Trainer(accelerator="gpu", devices=8, strategy="ddp", num_nodes=4)

This Lightning implementation of DDP calls your script under the hood multiple times with the correct environment variables:

# example for 3 GPUs DDP
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=0 python my_file.py --accelerator 'gpu' --devices 3 --etc
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=1 LOCAL_RANK=0 python my_file.py --accelerator 'gpu' --devices 3 --etc
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=2 LOCAL_RANK=0 python my_file.py --accelerator 'gpu' --devices 3 --etc

We use DDP this way because ddp_spawn has a few limitations (due to Python and PyTorch):

  1. Since .spawn() trains the model in subprocesses, the model on the main process does not get updated.

  2. Dataloader(num_workers=N), where N is large, bottlenecks training with DDP… ie: it will be VERY slow or won’t work at all. This is a PyTorch limitation.

  3. Forces everything to be picklable.

There are cases in which it is NOT possible to use DDP. Examples are:

  • Jupyter Notebook, Google COLAB, Kaggle, etc.

  • You have a nested script without a root package

In these situations you should use dp or ddp_spawn instead.

Distributed Data Parallel 2

In certain cases, it’s advantageous to use all batches on the same machine instead of a subset. For instance, you might want to compute a NCE loss where it pays to have more negative samples.

In this case, we can use DDP2 which behaves like DP in a machine and DDP across nodes. DDP2 does the following:

  1. Copies a subset of the data to each node.

  2. Inits a model on each node.

  3. Runs a forward and backward pass using DP.

  4. Syncs gradients across nodes.

  5. Applies the optimizer updates.

# train on 32 GPUs (4 nodes)
trainer = Trainer(accelerator="gpu", devices=8, strategy="ddp2", num_nodes=4)

Distributed Data Parallel Spawn

ddp_spawn is exactly like ddp except that it uses .spawn to start the training processes.

Warning

It is STRONGLY recommended to use DDP for speed and performance.

mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))

If your script does not support being called from the command line (ie: it is nested without a root project module) you can use the following method:

# train on 8 GPUs (same machine (ie: node))
trainer = Trainer(accelerator="gpu", devices=8, strategy="ddp_spawn")

We STRONGLY discourage this use because it has limitations (due to Python and PyTorch):

  1. The model you pass in will not update. Please save a checkpoint and restore from there.

  2. Set Dataloader(num_workers=0) or it will bottleneck training.

ddp is MUCH faster than ddp_spawn. We recommend you

  1. Install a top-level module for your project using setup.py

# setup.py
#!/usr/bin/env python

from setuptools import setup, find_packages

setup(
    name="src",
    version="0.0.1",
    description="Describe Your Cool Project",
    author="",
    author_email="",
    url="https://github.com/YourSeed",  # REPLACE WITH YOUR OWN GITHUB PROJECT LINK
    install_requires=["pytorch-lightning"],
    packages=find_packages(),
)
  1. Setup your project like so:

/project
    /src
        some_file.py
        /or_a_folder
    setup.py
  1. Install as a root-level package

cd /project
pip install -e .

You can then call your scripts anywhere

cd /project/src
python some_file.py --accelerator 'gpu' --devices 8 --strategy 'ddp'

Horovod

Horovod allows the same training script to be used for single-GPU, multi-GPU, and multi-node training.

Like Distributed Data Parallel, every process in Horovod operates on a single GPU with a fixed subset of the data. Gradients are averaged across all GPUs in parallel during the backward pass, then synchronously applied before beginning the next step.

The number of worker processes is configured by a driver application (horovodrun or mpirun). In the training script, Horovod will detect the number of workers from the environment, and automatically scale the learning rate to compensate for the increased total batch size.

Horovod can be configured in the training script to run with any number of GPUs / processes as follows:

# train Horovod on GPU (number of GPUs / machines provided on command-line)
trainer = Trainer(strategy="horovod", accelerator="gpu", devices=1)

# train Horovod on CPU (number of processes / machines provided on command-line)
trainer = Trainer(strategy="horovod")

When starting the training job, the driver application will then be used to specify the total number of worker processes:

# run training with 4 GPUs on a single machine
horovodrun -np 4 python train.py

# run training with 8 GPUs on two machines (4 GPUs each)
horovodrun -np 8 -H hostname1:4,hostname2:4 python train.py

See the official Horovod documentation for details on installation and performance tuning.

Bagua

Bagua is a deep learning training acceleration framework which supports multiple advanced distributed training algorithms including:

  • Gradient AllReduce for centralized synchronous communication, where gradients are averaged among all workers.

  • Decentralized SGD for decentralized synchronous communication, where each worker exchanges data with one or a few specific workers.

  • ByteGrad and QAdam for low precision communication, where data is compressed into low precision before communication.

  • Asynchronous Model Average for asynchronous communication, where workers are not required to be synchronized in the same iteration in a lock-step style.

By default, Bagua uses Gradient AllReduce algorithm, which is also the algorithm implemented in Distributed Data Parallel and Horovod, but Bagua can usually produce a higher training throughput due to its backend written in Rust.

# train on 4 GPUs (using Bagua mode)
trainer = Trainer(strategy="bagua", accelerator="gpu", devices=4)

By specifying the algorithm in the BaguaStrategy, you can select more advanced training algorithms featured by Bagua:

# train on 4 GPUs, using Bagua Gradient AllReduce algorithm
trainer = Trainer(
    strategy=BaguaStrategy(algorithm="gradient_allreduce"),
    accelerator="gpu",
    devices=4,
)

# train on 4 GPUs, using Bagua ByteGrad algorithm
trainer = Trainer(
    strategy=BaguaStrategy(algorithm="bytegrad"),
    accelerator="gpu",
    devices=4,
)

# train on 4 GPUs, using Bagua Decentralized SGD
trainer = Trainer(
    strategy=BaguaStrategy(algorithm="decentralized"),
    accelerator="gpu",
    devices=4,
)

# train on 4 GPUs, using Bagua Low Precision Decentralized SGD
trainer = Trainer(
    strategy=BaguaStrategy(algorithm="low_precision_decentralized"),
    accelerator="gpu",
    devices=4,
)

# train on 4 GPUs, using Asynchronous Model Average algorithm, with a synchronization interval of 100ms
trainer = Trainer(
    strategy=BaguaStrategy(algorithm="async", sync_interval_ms=100),
    accelerator="gpu",
    devices=4,
)

To use QAdam, we need to initialize QAdamOptimizer first:

from pytorch_lightning.strategies import BaguaStrategy
from bagua.torch_api.algorithms.q_adam import QAdamOptimizer


class MyModel(pl.LightningModule):
    ...

    def configure_optimizers(self):
        # initialize QAdam Optimizer
        return QAdamOptimizer(self.parameters(), lr=0.05, warmup_steps=100)


model = MyModel()
trainer = Trainer(
    accelerator="gpu",
    devices=4,
    strategy=BaguaStrategy(algorithm="qadam"),
)
trainer.fit(model)

Bagua relies on its own launcher to schedule jobs. Below, find examples using bagua.distributed.launch which follows torch.distributed.launch API:

# start training with 8 GPUs on a single node
python -m bagua.distributed.launch --nproc_per_node=8 train.py

If the ssh service is available with passwordless login on each node, you can launch the distributed job on a single node with baguarun which has a similar syntax as mpirun. When staring the job, baguarun will automatically spawn new processes on each of your training node provided by --host_list option and each node in it is described as an ip address followed by a ssh port.

# Run on node1 (or node2) to start training on two nodes (node1 and node2), 8 GPUs per node
baguarun --host_list hostname1:ssh_port1,hostname2:ssh_port2 --nproc_per_node=8 --master_port=port1 train.py

Note

You can also start training in the same way as Distributed Data Parallel. However, system optimizations like Bagua-Net and Performance autotuning can only be enabled through bagua launcher. It is worth noting that with Bagua-Net, Distributed Data Parallel can also achieve better performance without modifying the training script.

See Bagua Tutorials for more details on installation and advanced features.

DP/DDP2 caveats

In DP and DDP2 each GPU within a machine sees a portion of a batch. DP and ddp2 roughly do the following:

def distributed_forward(batch, model):
    batch = torch.Tensor(32, 8)
    gpu_0_batch = batch[:8]
    gpu_1_batch = batch[8:16]
    gpu_2_batch = batch[16:24]
    gpu_3_batch = batch[24:]

    y_0 = model_copy_gpu_0(gpu_0_batch)
    y_1 = model_copy_gpu_1(gpu_1_batch)
    y_2 = model_copy_gpu_2(gpu_2_batch)
    y_3 = model_copy_gpu_3(gpu_3_batch)

    return [y_0, y_1, y_2, y_3]

So, when Lightning calls any of the training_step, validation_step, test_step you will only be operating on one of those pieces.

# the batch here is a portion of the FULL batch
def training_step(self, batch, batch_idx):
    y_0 = batch

For most metrics, this doesn’t really matter. However, if you want to add something to your computational graph (like softmax) using all batch parts you can use the training_step_end step.

def training_step_end(self, outputs):
    # only use when  on dp
    outputs = torch.cat(outputs, dim=1)
    softmax = softmax(outputs, dim=1)
    out = softmax.mean()
    return out

In pseudocode, the full sequence is:

# get data
batch = next(dataloader)

# copy model and data to each gpu
batch_splits = split_batch(batch, num_gpus)
models = copy_model_to_gpus(model)

# in parallel, operate on each batch chunk
all_results = []
for gpu_num in gpus:
    batch_split = batch_splits[gpu_num]
    gpu_model = models[gpu_num]
    out = gpu_model(batch_split)
    all_results.append(out)

# use the full batch for something like softmax
full_out = model.training_step_end(all_results)

To illustrate why this is needed, let’s look at DataParallel

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(batch)

    # on dp or ddp2 if we did softmax now it would be wrong
    # because batch is actually a piece of the full batch
    return y_hat


def training_step_end(self, step_output):
    # step_output has outputs of each part of the batch

    # do softmax here
    outputs = torch.cat(outputs, dim=1)
    softmax = softmax(outputs, dim=1)
    out = softmax.mean()

    return out

If training_step_end is defined it will be called regardless of TPU, DP, DDP, etc… which means it will behave the same regardless of the backend.

Validation and test step have the same option when using DP.

def validation_step_end(self, step_output):
    ...


def test_step_end(self, step_output):
    ...

Distributed and 16-bit precision

Due to an issue with Apex and DataParallel (PyTorch and NVIDIA issue), Lightning does not allow 16-bit and DP training. We tried to get this to work, but it’s an issue on their end.

Below are the possible configurations we support.

1 GPU

1+ GPUs

DP

DDP

16-bit

command

Y

Trainer(accelerator=”gpu”, devices=1)

Y

Y

Trainer(accelerator=”gpu”, devices=1, precision=16)

Y

Y

Trainer(accelerator=”gpu”, devices=k, strategy=’dp’)

Y

Y

Trainer(accelerator=”gpu”, devices=k, strategy=’ddp’)

Y

Y

Y

Trainer(accelerator=”gpu”, devices=k, strategy=’ddp’, precision=16)

Implement Your Own Distributed (DDP) training

If you need your own way to init PyTorch DDP you can override pytorch_lightning.strategies.ddp.DDPStrategy.init_dist_connection().

If you also need to use your own DDP implementation, override pytorch_lightning.strategies.ddp.DDPStrategy.configure_ddp().

Batch size

When using distributed training make sure to modify your learning rate according to your effective batch size.

Let’s say you have a batch size of 7 in your dataloader.

class LitModel(LightningModule):
    def train_dataloader(self):
        return Dataset(..., batch_size=7)

In DDP, DDP_SPAWN, Deepspeed, DDP_SHARDED, or Horovod your effective batch size will be 7 * devices * num_nodes.

# effective batch size = 7 * 8
Trainer(accelerator="gpu", devices=8, strategy="ddp")
Trainer(accelerator="gpu", devices=8, strategy="ddp_spawn")
Trainer(accelerator="gpu", devices=8, strategy="ddp_sharded")
Trainer(accelerator="gpu", devices=8, strategy="horovod")

# effective batch size = 7 * 8 * 10
Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy="ddp")
Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy="ddp_spawn")
Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy="ddp_sharded")
Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy="horovod")

In DDP2 or DP, your effective batch size will be 7 * num_nodes. The reason is that the full batch is visible to all GPUs on the node when using DDP2.

# effective batch size = 7
Trainer(accelerator="gpu", devices=8, strategy="ddp2")
Trainer(accelerator="gpu", devices=8, strategy="dp")

# effective batch size = 7 * 10
Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy="ddp2")
Trainer(accelerator="gpu", devices=8, strategy="dp")

Note

Huge batch sizes are actually really bad for convergence. Check out: Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour


Torch Distributed Elastic

Lightning supports the use of Torch Distributed Elastic to enable fault-tolerant and elastic distributed job scheduling. To use it, specify the ‘ddp’ or ‘ddp2’ backend and the number of GPUs you want to use in the trainer.

Trainer(accelerator="gpu", devices=8, strategy="ddp")

To launch a fault-tolerant job, run the following on all nodes.

python -m torch.distributed.run
        --nnodes=NUM_NODES
        --nproc_per_node=TRAINERS_PER_NODE
        --rdzv_id=JOB_ID
        --rdzv_backend=c10d
        --rdzv_endpoint=HOST_NODE_ADDR
        YOUR_LIGHTNING_TRAINING_SCRIPT.py (--arg1 ... train script args...)

To launch an elastic job, run the following on at least MIN_SIZE nodes and at most MAX_SIZE nodes.

python -m torch.distributed.run
        --nnodes=MIN_SIZE:MAX_SIZE
        --nproc_per_node=TRAINERS_PER_NODE
        --rdzv_id=JOB_ID
        --rdzv_backend=c10d
        --rdzv_endpoint=HOST_NODE_ADDR
        YOUR_LIGHTNING_TRAINING_SCRIPT.py (--arg1 ... train script args...)

See the official Torch Distributed Elastic documentation for details on installation and more use cases.


Jupyter Notebooks

Unfortunately any ddp_ is not supported in jupyter notebooks. Please use dp for multiple GPUs. This is a known Jupyter issue. If you feel like taking a stab at adding this support, feel free to submit a PR!


Pickle Errors

Multi-GPU training sometimes requires your model to be pickled. If you run into an issue with pickling try the following to figure out the issue

import pickle

model = YourModel()
pickle.dumps(model)

However, if you use ddp the pickling requirement is not there and you should be fine. If you use ddp_spawn the pickling requirement remains. This is a limitation of Python.