Shortcuts

TPU support

Lightning supports running on TPUs. At this moment, TPUs are only available on Google Cloud (GCP). For more information on TPUs watch this video.


Live demo

Check out this Google Colab to see how to train MNIST on TPUs.


TPU Terminology

A TPU is a Tensor processing unit. Each TPU has 8 cores where each core is optimized for 128x128 matrix multiplies. In general, a single TPU is about as fast as 5 V100 GPUs!

A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores! You can request a full pod from Google cloud or a “slice” which gives you some subset of those 2048 cores.


How to access TPUs

To access TPUs there are two main ways.

  1. Using google colab.

  2. Using Google Cloud (GCP).


Colab TPUs

Colab is like a jupyter notebook with a free GPU or TPU hosted on GCP.

To get a TPU on colab, follow these steps:

  1. Go to https://colab.research.google.com/.

  2. Click “new notebook” (bottom right of pop-up).

  3. Click runtime > change runtime settings. Select Python 3, and hardware accelerator “TPU”. This will give you a TPU with 8 cores.

  4. Next, insert this code into the first cell and execute. This will install the xla library that interfaces between PyTorch and the TPU.

    import collections
    from datetime import datetime, timedelta
    import os
    import requests
    import threading
    
    _VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
    VERSION = "xrt==1.15.0"  #@param ["xrt==1.15.0", "torch_xla==nightly"]
    CONFIG = {
        'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
        'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
            (datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
    }[VERSION]
    DIST_BUCKET = 'gs://tpu-pytorch/wheels'
    TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
    TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
    TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
    
    # Update TPU XRT version
    def update_server_xrt():
      print('Updating server-side XRT to {} ...'.format(CONFIG.server))
      url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
          TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
          XRT_VERSION=CONFIG.server,
      )
      print('Done updating server-side XRT: {}'.format(requests.post(url)))
    
    update = threading.Thread(target=update_server_xrt)
    update.start()
    
    # Install Colab TPU compat PyTorch/TPU wheels and dependencies
    !pip uninstall -y torch torchvision
    !gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
    !gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
    !gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
    !pip install "$TORCH_WHEEL"
    !pip install "$TORCH_XLA_WHEEL"
    !pip install "$TORCHVISION_WHEEL"
    !sudo apt-get install libomp5
    update.join()
    
  5. Once the above is done, install PyTorch Lightning (v 0.7.0+).

    !pip install pytorch-lightning
    
  6. Then set up your LightningModule as normal.


DistributedSamplers

Lightning automatically inserts the correct samplers - no need to do this yourself!

Usually, with TPUs (and DDP), you would need to define a DistributedSampler to move the right chunk of data to the appropriate TPU. As mentioned, this is not needed in Lightning

Note

Don’t add distributedSamplers. Lightning does this automatically

If for some reason you still need to, this is how to construct the sampler for TPU use

import torch_xla.core.xla_model as xm

def train_dataloader(self):
    dataset = MNIST(
        os.getcwd(),
        train=True,
        download=True,
        transform=transforms.ToTensor()
    )

    # required for TPU support
    sampler = None
    if use_tpu:
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True
        )

    loader = DataLoader(
        dataset,
        sampler=sampler,
        batch_size=32
    )

    return loader

Configure the number of TPU cores in the trainer. You can only choose 1 or 8. To use a full TPU pod skip to the TPU pod section.

import pytorch_lightning as pl

my_model = MyLightningModule()
trainer = pl.Trainer(num_tpu_cores=8)
trainer.fit(my_model)

That’s it! Your model will train on all 8 TPU cores.


Distributed Backend with TPU

The `distributed_backend` option used for GPUs does not apply to TPUs. TPUs work in DDP mode by default (distributing over each core)


TPU Pod

To train on more than 8 cores, your code actually doesn’t change! All you need to do is submit the following command:

$ python -m torch_xla.distributed.xla_dist
--tpu=$TPU_POD_NAME
--conda-env=torch-xla-nightly
-- python /usr/share/torch-xla-0.5/pytorch/xla/test/test_train_imagenet.py --fake_data

16 bit precision

Lightning also supports training in 16-bit precision with TPUs. By default, TPU training will use 32-bit precision. To enable 16-bit, also set the 16-bit flag.

import pytorch_lightning as pl

my_model = MyLightningModule()
trainer = pl.Trainer(num_tpu_cores=8, precision=16)
trainer.fit(my_model)

Under the hood the xla library will use the bfloat16 type.


About XLA

XLA is the library that interfaces PyTorch with the TPUs. For more information check out XLA.

Guide for troubleshooting XLA