Fabric Arguments

accelerator

Choose one of "cpu", "gpu", "tpu", "auto".

# CPU accelerator
fabric = Fabric(accelerator="cpu")

# Running with GPU Accelerator using 2 GPUs
fabric = Fabric(devices=2, accelerator="gpu")

# Running with TPU Accelerator using 8 TPU cores
fabric = Fabric(devices=8, accelerator="tpu")

# Running with GPU Accelerator using the DistributedDataParallel strategy
fabric = Fabric(devices=4, accelerator="gpu", strategy="ddp")

The "auto" option recognizes the machine you are on and selects the available accelerator.

# If your machine has GPUs, it will use the GPU Accelerator
fabric = Fabric(devices=2, accelerator="auto")

See also: Accelerate your code with Fabric

strategy

Choose a training strategy: "dp", "ddp", "ddp_spawn", "xla", "deepspeed", "fsdp"``.

# Running with the DistributedDataParallel strategy on 4 GPUs
fabric = Fabric(strategy="ddp", accelerator="gpu", devices=4)

# Running with the DDP Spawn strategy using 4 CPU processes
fabric = Fabric(strategy="ddp_spawn", accelerator="cpu", devices=4)

Additionally, you can pass in your custom strategy by configuring additional parameters.

from lightning.fabric.strategies import DeepSpeedStrategy

fabric = Fabric(strategy=DeepSpeedStrategy(stage=2), accelerator="gpu", devices=2)

See also: Launch distributed training

devices

Configure the devices to run on. Can be of type:

  • int: the number of devices (e.g., GPUs) to train on

  • list of int: which device index (e.g., GPU ID) to train on (0-indexed)

  • str: a string representation of one of the above

# default used by Fabric, i.e., use the CPU
fabric = Fabric(devices=None)

# equivalent
fabric = Fabric(devices=0)

# int: run on two GPUs
fabric = Fabric(devices=2, accelerator="gpu")

# list: run on the 2nd (idx 1) and 5th (idx 4) GPUs (by bus ordering)
fabric = Fabric(devices=[1, 4], accelerator="gpu")
fabric = Fabric(devices="1, 4", accelerator="gpu")  # equivalent

# -1: run on all GPUs
fabric = Fabric(devices=-1, accelerator="gpu")
fabric = Fabric(devices="-1", accelerator="gpu")  # equivalent

See also: Launch distributed training

num_nodes

The number of cluster nodes for distributed operation.

# Default used by Fabric
fabric = Fabric(num_nodes=1)

# Run on 8 nodes
fabric = Fabric(num_nodes=8)

Learn more about distributed multi-node training on clusters.

precision

There are two different techniques to set the mixed precision. “True” precision and “Mixed” precision. For an extensive guide into their differences, please see: Save memory with mixed precision

Fabric supports doing floating point operations in 64-bit precision (“double”), 32-bit precision (“full”), or 16-bit (“half”) with both regular and bfloat16). This selected precision will have a direct impact in the performance and memory usage based on your hardware. Automatic mixed precision settings are denoted by a "-mixed" suffix, while “true” precision settings have a "-true" suffix:

# Default used by the Fabric
fabric = Fabric(precision="32-true", devices=1)

# the same as:
fabric = Fabric(precision="32", devices=1)

# 16-bit mixed precision (model weights remain in torch.float32)
fabric = Fabric(precision="16-mixed", devices=1)

# 16-bit bfloat mixed precision (model weights remain in torch.float32)
fabric = Fabric(precision="bf16-mixed", devices=1)

# 8-bit mixed precision via TransformerEngine (model weights get cast to torch.bfloat16)
fabric = Fabric(precision="transformer-engine", devices=1)

# 16-bit precision (model weights get cast to torch.float16)
fabric = Fabric(precision="16-true", devices=1)

# 16-bit bfloat precision (model weights get cast to torch.bfloat16)
fabric = Fabric(precision="bf16-true", devices=1)

# 64-bit (double) precision (model weights get cast to torch.float64)
fabric = Fabric(precision="64-true", devices=1)

Precision settings can also be enabled via the plugins argument (see section below on plugins). An example is the weights quantization plugin Bitsandbytes for 4-bit and 8-bit:

from lightning.fabric.plugins import BitsandbytesPrecision

precision = BitsandbytesPrecision(mode="nf4-dq", dtype=torch.bfloat16)
fabric = Fabric(plugins=precision)

plugins

Plugins allow you to connect arbitrary backends, precision libraries, clusters, etc. For example: To define your own behavior, subclass the relevant class and pass it in. Here’s an example linking up your own ClusterEnvironment.

from lightning.fabric.plugins.environments import ClusterEnvironment


class MyCluster(ClusterEnvironment):
    @property
    def main_address(self):
        return your_main_address

    @property
    def main_port(self):
        return your_main_port

    def world_size(self):
        return the_world_size


fabric = Fabric(plugins=[MyCluster()], ...)

callbacks

A callback class is a collection of methods that the training loop can call at a specific time, for example, at the end of an epoch. Add callbacks to Fabric to inject logic into your training loop from an external callback class.

class MyCallback:
    def on_train_epoch_end(self, results):
        ...

You can then register this callback or multiple ones directly in Fabric:

fabric = Fabric(callbacks=[MyCallback()])

Then, in your training loop, you can call a hook by its name. Any callback objects that have this hook will execute it:

# Call any hook by name
fabric.call("on_train_epoch_end", results={...})

See also: Callbacks

loggers

Attach one or several loggers/experiment trackers to Fabric for convenient metrics logging.

# Default used by Fabric; no loggers are active
fabric = Fabric(loggers=[])

# Log to a single logger
fabric = Fabric(loggers=TensorBoardLogger(...))

# Or multiple instances
fabric = Fabric(loggers=[logger1, logger2, ...])

Anywhere in your training loop, you can log metrics to all loggers at once:

fabric.log("loss", loss)
fabric.log_dict({"loss": loss, "accuracy": acc})

See also: Track and Visualize Experiments