Shortcuts

Source code for pytorch_lightning.plugins.training_type.ddp2

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.types import _METRIC_COLLECTION


[docs]class DDP2Plugin(DDPPlugin): """DDP2 behaves like DP in one node, but synchronization across nodes behaves like in DDP.""" distributed_backend = _StrategyType.DDP2 @property def global_rank(self) -> int: return self.node_rank @property def world_size(self) -> int: return self.num_nodes
[docs] def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: """Reduces a collection of tensors from all processes. It can be applied to just a single tensor. In DDP2, the reduction here is only across local devices within the node. Args: collection: The collection of tensors to sync and reduce. *args: ignored for DDP2 **kwargs: ignored for DDP2 Return: Reduced tensor values or the same value if it was not or did not contain a tensor. """ def mean(t: torch.Tensor) -> torch.Tensor: original_dtype = t.dtype return t.float().mean().to(original_dtype) return apply_to_collection(collection, torch.Tensor, mean)
@property def root_device(self): return self.parallel_devices[0]
[docs] def model_to_device(self): # no need to do anything when model is wrapped in torch.nn.DataParallel pass
@property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.num_nodes, rank=self.global_rank) return distributed_sampler_kwargs @property def _is_single_process_single_device(self) -> bool: return False def set_world_ranks(self) -> None: if self.cluster_environment is None: return self.cluster_environment.set_global_rank(self.node_rank) self.cluster_environment.set_world_size(self.num_nodes)

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision 46f718d2.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
ipynb-update
docs-search
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.