Shortcuts

Source code for pytorch_lightning.plugins.training_type.parallel

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, List, Optional

import torch
from torch.nn.parallel import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.utilities import _XLA_AVAILABLE
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp


[docs]class ParallelPlugin(TrainingTypePlugin, ABC): """Plugin for training with multiple processes in parallel.""" def __init__( self, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ): super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.parallel_devices = parallel_devices self.cluster_environment = cluster_environment @property @abstractmethod def root_device(self) -> torch.device: """Return the root device.""" @property def on_gpu(self) -> bool: return self.root_device.type == "cuda" and torch.cuda.is_available() @property def on_tpu(self) -> bool: return self.root_device.type == "xla" and _XLA_AVAILABLE @property def lightning_module(self) -> Optional["pl.LightningModule"]: return unwrap_lightning_module(self._model) if self._model is not None else None @property def global_rank(self) -> int: return self.cluster_environment.global_rank() if self.cluster_environment is not None else 0 @property def local_rank(self) -> int: return self.cluster_environment.local_rank() if self.cluster_environment is not None else 0 @property def node_rank(self) -> int: return self.cluster_environment.node_rank() if self.cluster_environment is not None else 0 @property def world_size(self) -> int: return self.cluster_environment.world_size() if self.cluster_environment is not None else 1 @property def is_global_zero(self) -> bool: return self.global_rank == 0 @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank) return distributed_sampler_kwargs
[docs] def reconciliate_processes(self, trace: str): """Function to re-conciliate processes on failure."""
[docs] def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: """Perform a all_gather on all processes.""" return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
[docs] def reduce_boolean_decision(self, decision: bool) -> bool: decision = torch.tensor(int(decision), device=self.lightning_module.device) decision = self.reduce(decision, reduce_op=ReduceOp.SUM) decision = bool(decision == self.world_size) return decision
@property def torch_distributed_backend(self): torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND") if torch_backend is None: torch_backend = "nccl" if self.on_gpu else "gloo" return torch_backend
[docs] @staticmethod def configure_sync_batchnorm(model: "pl.LightningModule") -> "pl.LightningModule": """Add global batchnorm for a model spread across multiple GPUs and nodes. Override to synchronize batchnorm between specific process groups instead of the whole world or use a different sync_bn like `apex`'s version. Args: model: pointer to current :class:`LightningModule`. Return: LightningModule with batchnorm layers synchronized between process groups """ return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
[docs] @contextmanager def block_backward_sync(self): """Blocks ddp sync gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off """ if isinstance(self.model, DistributedDataParallel): with self.model.no_sync(): yield None else: yield None
[docs] def teardown(self) -> None: self.cluster_environment.teardown() super().teardown()

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision 46f718d2.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
ipynb-update
docs-search
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.