Shortcuts

Strategy

Strategy controls the model distribution across training, evaluation, and prediction to be used by the Trainer. It can be controlled by passing different strategy with aliases ("ddp", "ddp_spawn", "deepspeed" and so on) as well as a custom strategy to the strategy parameter for Trainer.

The Strategy in PyTorch Lightning handles the following responsibilities:

  • Launch and teardown of training processes (if applicable).

  • Setup communication between processes (NCCL, GLOO, MPI, and so on).

  • Provide a unified communication interface for reduction, broadcast, and so on.

  • Owns the LightningModule

  • Handles/owns optimizers and schedulers.

Strategy also manages the accelerator, precision, and checkpointing plugins.

Training Strategies with Various Configs

# Training with the DistributedDataParallel strategy on 4 GPUs
trainer = Trainer(strategy="ddp", accelerator="gpu", devices=4)

# Training with the custom DistributedDataParallel strategy on 4 GPUs
trainer = Trainer(strategy=DDPStrategy(...), accelerator="gpu", devices=4)

# Training with the DDP Spawn strategy using auto accelerator selection
trainer = Trainer(strategy="ddp_spawn", accelerator="auto", devices=4)

# Training with the DeepSpeed strategy on available GPUs
trainer = Trainer(strategy="deepspeed", accelerator="gpu", devices="auto")

# Training with the DDP strategy using 3 CPU processes
trainer = Trainer(strategy="ddp", accelerator="cpu", devices=3)

# Training with the DDP Spawn strategy on 8 TPU cores
trainer = Trainer(strategy="ddp_spawn", accelerator="tpu", devices=8)

# Training with the default IPU strategy on 8 IPUs
trainer = Trainer(accelerator="ipu", devices=8)

Create a Custom Strategy

Expert users may choose to extend an existing strategy by overriding its methods.

from pytorch_lightning.strategies import DDPStrategy


class CustomDDPStrategy(DDPStrategy):
    def configure_ddp(self):
        self.model = MyCustomDistributedDataParallel(
            self.model,
            device_ids=...,
        )

or by subclassing the base class Strategy to create new ones. These custom strategies can then be passed into the Trainer directly via the strategy parameter.

# custom plugins
trainer = Trainer(strategy=CustomDDPStrategy())

# fully custom accelerator and plugins
accelerator = MyAccelerator()
precision_plugin = MyPrecisionPlugin()
training_strategy = CustomDDPStrategy(accelerator=accelerator, precision_plugin=precision_plugin)
trainer = Trainer(strategy=training_strategy)

The complete list of built-in strategies is listed below.


Built-In Training Strategies

BaguaStrategy

Strategy for training using the Bagua library, with advanced distributed training algorithms and system optimizations.

DDP2Strategy

DDP2 behaves like DP in one node, but synchronization across nodes behaves like in DDP.

DDPFullyShardedStrategy

Plugin for Fully Sharded Data Parallel provided by FairScale.

DDPShardedStrategy

Optimizer and gradient sharded training provided by FairScale.

DDPSpawnShardedStrategy

Optimizer sharded training provided by FairScale.

DDPSpawnStrategy

Spawns processes using the torch.multiprocessing.spawn() method and joins processes after training finishes.

DDPStrategy

Strategy for multi-process single-device training on one or multiple nodes.

DataParallelStrategy

Implements data-parallel training in a single process, i.e., the model gets replicated to each device and each gets a split of the data.

DeepSpeedStrategy

Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models.

HorovodStrategy

Plugin for Horovod distributed training integration.

HPUParallelStrategy

Strategy for distributed training on multiple HPU devices.

IPUStrategy

Plugin for training on IPU devices.

ParallelStrategy

Plugin for training with multiple processes in parallel.

SingleDeviceStrategy

Strategy that handles communication on a single device.

SingleHPUStrategy

Strategy for training on single HPU device.

SingleTPUStrategy

Strategy for training on a single TPU device.

Strategy

Base class for all strategies that change the behaviour of the training, validation and test- loop.

TPUSpawnStrategy

Strategy for training multiple TPU devices using the torch_xla.distributed.xla_multiprocessing.spawn() method.