TPU training (Intermediate)

Audience: Users looking to use cloud TPUs.

Warning

This is an experimental feature.


DistributedSamplers

Lightning automatically inserts the correct samplers - no need to do this yourself!

Usually, with TPUs (and DDP), you would need to define a DistributedSampler to move the right chunk of data to the appropriate TPU. As mentioned, this is not needed in Lightning

Note

Don’t add distributedSamplers. Lightning does this automatically

If for some reason you still need to, this is how to construct the sampler for TPU use

import torch_xla.core.xla_model as xm


def train_dataloader(self):
    dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())

    # required for TPU support
    sampler = None
    if use_tpu:
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True
        )

    loader = DataLoader(dataset, sampler=sampler, batch_size=32)

    return loader

Configure the number of TPU cores in the trainer. You can only choose 1 or 8. To use a full TPU pod skip to the TPU pod section.

import lightning as L

my_model = MyLightningModule()
trainer = L.Trainer(accelerator="tpu", devices=8)
trainer.fit(my_model)

That’s it! Your model will train on all 8 TPU cores.


16 bit precision

Lightning also supports training in 16-bit precision with TPUs. By default, TPU training will use 32-bit precision. To enable it, do

import lightning as L

my_model = MyLightningModule()
trainer = L.Trainer(accelerator="tpu", precision="16-true")
trainer.fit(my_model)

Under the hood the xla library will use the bfloat16 type.