Source code for lightning.fabric.utilities.distributed
import contextlib
import logging
import os
import time
from contextlib import nullcontext
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union
import torch
import torch.nn.functional as F
from lightning_utilities.core.imports import package_available
from torch import Tensor
from torch.utils.data import Dataset, DistributedSampler, Sampler
from typing_extensions import override
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.data import _num_cpus_available
from lightning.fabric.utilities.rank_zero import rank_zero_info
from lightning.fabric.utilities.types import _PATH, ReduceOp
if torch.distributed.is_available():
from torch.distributed import group
else:
class group: # type: ignore
WORLD = None
if TYPE_CHECKING:
from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.strategies import Strategy
log = logging.getLogger(__name__)
[docs]def is_shared_filesystem(strategy: "Strategy", path: Optional[_PATH] = None, timeout: int = 3) -> bool:
"""Checks whether the filesystem under the given path is shared across all processes.
This function should only be used in a context where distributed is initialized.
Args:
strategy: The strategy being used, either from Fabric (``fabric.strategy``) or from Trainer
(``trainer.strategy``).
path: The path to check. Defaults to the current working directory. The user must have permissions to write
to this path or the parent folder, and the filesystem must be writable.
timeout: If any of the processes can't list the file created by rank 0 within this many seconds, the
filesystem is determined to be not shared.
"""
# Fast path: Any non-local filesystem is considered shared (e.g., S3)
if path is not None and not _is_local_file_protocol(path):
return True
path = Path(Path.cwd() if path is None else path).resolve()
# Fast path: Only distributed strategies can detect shared filesystems
if not hasattr(strategy, "world_size") or strategy.world_size == 1:
return True
# Fast path: If the path is not the same on all ranks we know it's not a shared filesystem
rank_zero_path = strategy.broadcast(path)
if not strategy.reduce_boolean_decision(rank_zero_path == path, all=True):
return False
if not strategy.reduce_boolean_decision(path.exists(), all=True):
raise FileNotFoundError(
f"Unable to determine if the path belongs to a shared filesystem. The path does not exist: {path}"
)
path = path.parent if path.is_file() else path
check_file = path / ".lightning_shared_fs_check"
check_file.unlink(missing_ok=True)
strategy.barrier()
if strategy.is_global_zero:
# Rank 0 creates the file
check_file.touch()
found = True
else:
# All other ranks will wait until they find the file or timeout
start = time.perf_counter()
found = False
while not found and (time.perf_counter() - start) < timeout:
found = check_file.exists()
strategy.barrier()
all_found = strategy.reduce_boolean_decision(found, all=True)
with contextlib.suppress(OSError): # handle race condition on deletion
check_file.unlink()
return all_found
def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]:
"""Function to gather all tensors from several DDP processes onto a list that is broadcasted to all processes.
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
tensors are padded, gathered and then trimmed to secure equal workload for all processes.
Args:
result: The value to sync
group: The process group to gather results from. Defaults to all processes (world)
Return:
gathered_result: List with size equal to the process group where
gathered_result[i] corresponds to result tensor from process i
"""
if group is None:
group = torch.distributed.group.WORLD
# Convert tensors to contiguous format
result = result.contiguous()
world_size = torch.distributed.get_world_size(group)
torch.distributed.barrier(group=group)
# If the tensor is scalar, things are easy
if result.ndim == 0:
return _simple_gather_all_tensors(result, group, world_size)
# 1. Gather sizes of all tensors
local_size = torch.tensor(result.shape, device=result.device)
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
torch.distributed.all_gather(local_sizes, local_size, group=group)
max_size = torch.stack(local_sizes).max(dim=0).values
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
# 2. If shapes are all the same, then do a simple gather:
if all_sizes_equal:
return _simple_gather_all_tensors(result, group, world_size)
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = F.pad(result, pad_dims)
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
return gathered_result
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
return gathered_result
def _sync_ddp_if_available(
result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> Tensor:
"""Function to reduce a tensor across worker processes during distributed training.
Args:
result: The value to sync and reduce (typically tensor or number)
group: The process group to gather results from. Defaults to all processes (world)
reduce_op: The reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
Return:
reduced value
"""
if _distributed_is_initialized():
return _sync_ddp(result, group=group, reduce_op=reduce_op)
return result
def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor:
"""Reduces a tensor across several distributed processes.
This operation is performed in-place, meaning the result will be placed back into the input tensor on all processes.
Args:
result: The value to sync and reduce (typically tensor or number)
group: The process group to gather results from. Defaults to all processes (world)
reduce_op: The reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
Return:
The reduced value.
"""
divide_by_world_size = False
group = torch.distributed.group.WORLD if group is None else group
op: Optional[ReduceOp]
if isinstance(reduce_op, str):
reduce_op = "avg" if reduce_op == "mean" else reduce_op
if reduce_op.lower() == "avg" and torch.distributed.get_backend(group) == "gloo":
# The GLOO backend does not support the `ReduceOp.AVG` operation
op = ReduceOp.SUM # type: ignore[assignment]
divide_by_world_size = True
else:
op = getattr(ReduceOp, reduce_op.upper())
else:
op = reduce_op
# HPU doesn't support Long types, forcefully set it to float
# TODO: move this to the `lightning_habana` package
if (
package_available("habana_frameworks")
and os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
and result.type()
in (
"torch.LongTensor",
"torch.hpu.LongTensor",
)
):
rank_zero_info("Long tensor unsupported on HPU, casting to float")
result = result.float()
# Sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
world_size = torch.distributed.get_world_size(group)
if not divide_by_world_size:
return result
# `torch.distributed.all_reduce` is in-place, so we should do the division in-place to leave the modified tensors
# with the expected value
if not torch.is_floating_point(result):
return result.copy_(result / world_size)
return result.div_(world_size)
def _all_gather_ddp_if_available(
tensor: Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False
) -> Tensor:
"""Function to gather a tensor from several distributed processes.
Args:
tensor: Tensor of shape (batch, ...)
group: The process group to gather results from. Defaults to all processes (world)
sync_grads: Flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
if not _distributed_is_initialized():
return tensor
from torch.distributed.nn.functional import all_gather
tensor = tensor.contiguous() # https://github.com/pytorch/pytorch/issues/73515
with nullcontext() if sync_grads else torch.no_grad():
gathered_tensors = all_gather(tensor, group)
return torch.stack(gathered_tensors)
def _init_dist_connection(
cluster_environment: "ClusterEnvironment",
torch_distributed_backend: str,
global_rank: Optional[int] = None,
world_size: Optional[int] = None,
**kwargs: Any,
) -> None:
"""Utility function to initialize distributed connection by setting env variables and initializing the distributed
process group.
Args:
cluster_environment: ``ClusterEnvironment`` instance
torch_distributed_backend: Backend to use (includes `nccl` and `gloo`)
global_rank: Rank of the current process
world_size: Number of processes in the group
kwargs: Kwargs for ``init_process_group``
Raises:
RuntimeError:
If ``torch.distributed`` is not available
"""
if not torch.distributed.is_available():
raise RuntimeError("torch.distributed is not available. Cannot initialize distributed process group")
if torch.distributed.is_initialized():
log.debug("torch.distributed is already initialized. Exiting early")
return
global_rank = global_rank if global_rank is not None else cluster_environment.global_rank()
world_size = world_size if world_size is not None else cluster_environment.world_size()
os.environ["MASTER_ADDR"] = cluster_environment.main_address
os.environ["MASTER_PORT"] = str(cluster_environment.main_port)
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
# On rank=0 let everyone know training is starting
rank_zero_info(
f"{'-' * 100}\n"
f"distributed_backend={torch_distributed_backend}\n"
f"All distributed processes registered. Starting with {world_size} processes\n"
f"{'-' * 100}\n"
)
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"
class _DatasetSamplerWrapper(Dataset):
"""Dataset to create indexes from `Sampler` or `Iterable`"""
def __init__(self, sampler: Union[Sampler, Iterable]) -> None:
if not isinstance(sampler, Sized):
raise TypeError(
"You seem to have configured a sampler in your DataLoader which"
" does not provide `__len__` method. The sampler was about to be"
" replaced by `DistributedSamplerWrapper` since `use_distributed_sampler`"
" is True and you are using distributed training. Either provide `__len__`"
" method in your sampler, remove it from DataLoader or set `use_distributed_sampler=False`"
" if you want to handle distributed sampling yourself."
)
if len(sampler) == float("inf"):
raise TypeError(
"You seem to have configured a sampler in your DataLoader which"
" does not provide finite `__len__` method. The sampler was about to be"
" replaced by `DistributedSamplerWrapper` since `use_distributed_sampler`"
" is True and you are using distributed training. Either provide `__len__`"
" method in your sampler which returns a finite number, remove it from DataLoader"
" or set `use_distributed_sampler=False` if you want to handle distributed sampling yourself."
)
self._sampler = sampler
# defer materializing an iterator until it is necessary
self._sampler_list: Optional[List[Any]] = None
@override
def __getitem__(self, index: int) -> Any:
if self._sampler_list is None:
self._sampler_list = list(self._sampler)
return self._sampler_list[index]
def __len__(self) -> int:
return len(self._sampler)
def reset(self) -> None:
"""Reset the sampler list in order to get new sampling."""
self._sampler_list = list(self._sampler)
class DistributedSamplerWrapper(DistributedSampler):
"""Wrapper over ``Sampler`` for distributed training.
Allows you to use any sampler in distributed mode. It will be automatically used by Lightning in distributed mode if
sampler replacement is enabled.
Note:
The purpose of this wrapper is to take care of sharding the sampler indices. It is up to the underlying
sampler to handle randomness and shuffling. The ``shuffle`` and ``seed`` arguments on this wrapper won't
have any effect.
"""
def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None:
super().__init__(_DatasetSamplerWrapper(sampler), *args, **kwargs)
@override
def __iter__(self) -> Iterator:
self.dataset.reset()
return (self.dataset[index] for index in super().__iter__())
def _suggested_max_num_threads(num_processes: int = 1) -> int:
if num_processes < 1:
raise ValueError(f"`num_processes` should be >= 1, got {num_processes}.")
return max(1, _num_cpus_available() // num_processes)
def _set_num_threads_if_needed(num_processes: int = 1) -> None:
if "OMP_NUM_THREADS" not in os.environ:
num_threads = _suggested_max_num_threads(num_processes)
torch.set_num_threads(num_threads)
os.environ["OMP_NUM_THREADS"] = str(num_threads)
def _distributed_is_initialized() -> bool:
# `is_initialized` is only defined conditionally
# https://github.com/pytorch/pytorch/blob/v2.1.0/torch/distributed/__init__.py#L25
# this might happen to MacOS builds from source (default) or any build from source that sets `USE_DISTRIBUTED=0`
return torch.distributed.is_available() and torch.distributed.is_initialized()