Shortcuts

pytorch_lightning.lite.LightningLite

class pytorch_lightning.lite.LightningLite(accelerator=None, strategy=None, devices=None, num_nodes=1, precision=32, plugins=None, gpus=None, tpu_cores=None)[source]

Bases: abc.ABC

Lite accelerates your PyTorch training or inference code with minimal changes required.

  • Automatic placement of models and data onto the device.

  • Automatic support for mixed and double precision (smaller memory footprint).

  • Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies (data-parallel training, sharded training, etc.).

  • Automated spawning of processes, no launch utilities required.

  • Multi-node support.

Parameters
__init__(accelerator=None, strategy=None, devices=None, num_nodes=1, precision=32, plugins=None, gpus=None, tpu_cores=None)[source]

Initialize self. See help(type(self)) for accurate signature.

Methods

__init__([accelerator, strategy, devices, …])

Initialize self.

all_gather(data[, group, sync_grads])

Gather tensors or collections of tensors from multiple processes.

autocast()

A context manager to automatically convert operations for the chosen precision.

backward(tensor, *args[, model])

Replaces loss.backward() in your training loop.

barrier([name])

Wait for all processes to enter this call.

broadcast(obj[, src])

rtype

object

load(filepath)

Load a checkpoint from a file.

print(*args, **kwargs)

Print something only on the first process.

run()

All the code inside this run method gets accelerated by Lite.

save(content, filepath)

Save checkpoint contents to a file.

seed_everything([seed, workers])

Helper function to seed everything without explicitly importing Lightning.

setup(model, *optimizers[, move_to_device])

Setup a model and its optimizers for accelerated training.

setup_dataloaders(*dataloaders[, …])

Setup one or multiple dataloaders for accelerated training.

to_device(obj)

Move a torch.nn.Module or a collection of tensors to the current device, if it is not already on that device.

Attributes

device

The current device this process runs on.

global_rank

The global index of the current process across all devices and nodes.

is_global_zero

Wether this rank is rank zero.

local_rank

The index of the current process among the processes running on the local node.

node_rank

The index of the current node.

world_size

The total number of processes running across all devices and nodes.

all_gather(data, group=None, sync_grads=False)[source]

Gather tensors or collections of tensors from multiple processes.

Parameters
  • data (Union[Tensor, Dict, List, Tuple]) – int, float, tensor of shape (batch, …), or a (possibly nested) collection thereof.

  • group (Optional[Any]) – the process group to gather results from. Defaults to all processes (world)

  • sync_grads (bool) – flag that allows users to synchronize gradients for the all_gather operation

Return type

Union[Tensor, Dict, List, Tuple]

Returns

A tensor of shape (world_size, batch, …), or if the input was a collection the output will also be a collection with tensors of this shape.

autocast()[source]

A context manager to automatically convert operations for the chosen precision.

Use this only if the forward method of your model does not cover all operations you wish to run with the chosen precision setting.

Return type

Generator[None, None, None]

backward(tensor, *args, model=None, **kwargs)[source]

Replaces loss.backward() in your training loop. Handles precision and automatically for you.

Parameters
  • tensor (Tensor) – The tensor (loss) to back-propagate gradients from.

  • *args – Optional positional arguments passed to the underlying backward function.

  • model (Optional[_LiteModule]) – Optional model instance for plugins that require the model for backward().

  • **kwargs – Optional named keyword arguments passed to the underlying backward function.

Note

When using strategy="deepspeed" and multiple models were setup, it is required to pass in the model as argument here.

Return type

None

barrier(name=None)[source]

Wait for all processes to enter this call. Use this to synchronize all parallel processes, but only if necessary, otherwise the overhead of synchronization will cause your program to slow down.

Example:

if self.global_rank == 0:
    # let process 0 download the dataset
    dataset.download_files()

# let all processes wait before reading the dataset
self.barrier()

# now all processes can read the files and start training
Return type

None

load(filepath)[source]

Load a checkpoint from a file.

How and which processes load gets determined by the strategy

Parameters

filepath (Union[str, Path]) – A path to where the file is located

Return type

Any

print(*args, **kwargs)[source]

Print something only on the first process.

Arguments passed to this method are forwarded to the Python built-in print() function.

Return type

None

abstract run()[source]

All the code inside this run method gets accelerated by Lite.

You can pass arbitrary arguments to this function when overriding it.

Return type

Any

save(content, filepath)[source]

Save checkpoint contents to a file.

How and which processes save gets determined by the strategy. For example, the ddp strategy saves checkpoints only on process 0.

Parameters
  • content (Dict[str, Any]) – A dictionary with contents, i.e., the state dict of your model

  • filepath (Union[str, Path]) – A path to where the file should be saved

Return type

None

static seed_everything(seed=None, workers=None)[source]

Helper function to seed everything without explicitly importing Lightning.

See pytorch_lightning.seed_everything() for more details.

Return type

int

setup(model, *optimizers, move_to_device=True)[source]

Setup a model and its optimizers for accelerated training.

Parameters
  • model (Module) – A model to setup

  • *optimizers – The optimizer(s) to setup (no optimizers is also possible)

  • move_to_device (bool) – If set True (default), moves the model to the correct device. Set this to False and alternatively use to_device() manually.

Return type

Any

Returns

The tuple of the wrapped model and list of optimizers, in the same order they were passed in.

setup_dataloaders(*dataloaders, replace_sampler=True, move_to_device=True)[source]

Setup one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one.

Parameters
  • *dataloaders – A single dataloader or a sequence of dataloaders.

  • replace_sampler (bool) – If set True (default), automatically wraps or replaces the sampler on the dataloader(s) for distributed training. If you have a custom sampler defined, set this to this argument to False.

  • move_to_device (bool) – If set True (default), moves the data returned by the dataloader(s) automatially to the correct device. Set this to False and alternatively use to_device() manually on the returned data.

Return type

Union[Iterable, List[Iterable]]

Returns

The wrapped dataloaders, in the same order they were passed in.

to_device(obj)[source]

Move a torch.nn.Module or a collection of tensors to the current device, if it is not already on that device.

Parameters

obj (Union[Module, Tensor, Any]) – An object to move to the device. Can be an instance of torch.nn.Module, a tensor, or a (nested) collection of tensors (e.g., a dictionary).

Return type

Union[Module, Tensor, Any]

Returns

A reference to the object that was moved to the new device.

property device: torch.device

The current device this process runs on.

Use this to create tensors directly on the device if needed.

Return type

device

property global_rank: int

The global index of the current process across all devices and nodes.

Return type

int

property is_global_zero: bool

Wether this rank is rank zero.

Return type

bool

property local_rank: int

The index of the current process among the processes running on the local node.

Return type

int

property node_rank: int

The index of the current node.

Return type

int

property world_size: int

The total number of processes running across all devices and nodes.

Return type

int