Shortcuts

Source code for pytorch_lightning.strategies.dp

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import DataParallel, Module

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO
from lightning_lite.utilities.distributed import ReduceOp
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast, TReduce
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import STEP_OUTPUT


[docs]class DataParallelStrategy(ParallelStrategy): """Implements data-parallel training in a single process, i.e., the model gets replicated to each device and each gets a split of the data.""" strategy_name = "dp" def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ): super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) @property def global_rank(self) -> int: return 0 @property def local_rank(self) -> int: return 0 @property def node_rank(self) -> int: return 0 @property def world_size(self) -> int: return 1
[docs] def setup(self, trainer: "pl.Trainer") -> None: # model needs to be moved to the device before it is wrapped self.model_to_device() assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) self.model = self._setup_model(LightningParallelModule(self.model)) super().setup(trainer)
[docs] def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: """Moves the batch to the correct device. The input and the output is the same type. Args: batch: The batch of samples to move to the correct device device: The target device dataloader_idx: The index of the dataloader to which the batch belongs. """ # DataParallel handles the transfer of batch to the device return batch
def _setup_model(self, model: Module) -> DataParallel: """Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module.""" return DataParallel(module=model, device_ids=self.parallel_devices)
[docs] def reduce( self, collection: TReduce, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" ) -> TReduce: """Reduces a collection of tensors from all processes. It can be applied to just a single tensor. Args: collection: The collection of tensors to sync and reduce. group: ignored for DP reduce_op: ignored for DP Return: Reduced tensor values or the same value if it was not or did not contain a tensor. """ def mean(t: Tensor) -> Tensor: original_dtype = t.dtype return t.float().mean().to(original_dtype) return apply_to_collection(collection, Tensor, mean)
@property def root_device(self) -> torch.device: assert self.parallel_devices is not None return self.parallel_devices[0]
[docs] def model_to_device(self) -> None: assert self.model is not None self.model.to(self.root_device)
[docs] def barrier(self, *args: Any, **kwargs: Any) -> None: pass
[docs] def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj
[docs] def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: return decision
[docs] def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: with self.precision_plugin.train_step_context(): assert self.model is not None return self.model(*args, **kwargs)
[docs] def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): assert self.model is not None return self.model(*args, **kwargs)
[docs] def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: with self.precision_plugin.test_step_context(): assert self.model is not None return self.model(*args, **kwargs)
[docs] def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: with self.precision_plugin.predict_step_context(): assert self.model is not None return self.model(*args, **kwargs)
def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: if is_overridden("training_step_end", self.lightning_module): return output if isinstance(output, dict) and "loss" in output: output["loss"] = self.reduce(output["loss"]) elif isinstance(output, Tensor): output = self.reduce(output) return output @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( cls.strategy_name, cls, description=f"{cls.__class__.__name__}", )

© Copyright Copyright (c) 2018-2022, Lightning AI et al... Revision f4fcad36.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.8.3post1
1.8.3.post0
1.8.3
1.8.2
1.8.1
1.8.0.post1
1.8.0
1.7.7
1.7.6
1.7.5
1.7.4
1.7.3
1.7.2
1.7.1
1.7.0
1.6.5
1.6.4
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.