Shortcuts

Source code for pytorch_lightning.plugins.training_type.dp

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional

import torch
from torch.nn import DataParallel, Module

import pytorch_lightning as pl
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, STEP_OUTPUT


[docs]class DataParallelPlugin(ParallelPlugin): """Implements data-parallel training in a single process, i.e., the model gets replicated to each device and each gets a split of the data.""" distributed_backend = _StrategyType.DP def __init__( self, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ): super().__init__( parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) @property def global_rank(self) -> int: return 0 @property def local_rank(self) -> int: return 0 @property def node_rank(self) -> int: return 0 @property def world_size(self) -> int: return 1
[docs] def setup(self, trainer: "pl.Trainer") -> None: # model needs to be moved to the device before it is wrapped self.model_to_device() self._model = self._setup_model(LightningParallelModule(self._model)) super().setup(trainer)
[docs] def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: """Moves the batch to the correct device. The input and the output is the same type. Args: batch: The batch of samples to move to the correct device device: The target device dataloader_idx: The index of the dataloader to which the batch belongs. """ return move_data_to_device(batch, device=device or self.root_device)
def _setup_model(self, model: Module) -> DataParallel: """Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module.""" return DataParallel(module=model, device_ids=self.parallel_devices)
[docs] def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: """Reduces a collection of tensors from all processes. It can be applied to just a single tensor. Args: collection: The collection of tensors to sync and reduce. *args: ignored for DP **kwargs: ignored for DP Return: Reduced tensor values or the same value if it was not or did not contain a tensor. """ def mean(t: torch.Tensor) -> torch.Tensor: original_dtype = t.dtype return t.float().mean().to(original_dtype) return apply_to_collection(collection, torch.Tensor, mean)
@property def root_device(self): return self.parallel_devices[0]
[docs] def model_to_device(self) -> None: self._model.to(self.root_device)
[docs] def barrier(self, *args, **kwargs): pass
[docs] def broadcast(self, obj: object, src: int = 0) -> object: return obj
[docs] def reduce_boolean_decision(self, decision: bool) -> bool: return decision
[docs] def training_step(self, *args, **kwargs) -> STEP_OUTPUT: with self.precision_plugin.train_step_context(): return self.model(*args, **kwargs)
[docs] def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): return self.model(*args, **kwargs)
[docs] def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: with self.precision_plugin.test_step_context(): return self.model(*args, **kwargs)
[docs] def predict_step(self, *args, **kwargs) -> STEP_OUTPUT: with self.precision_plugin.predict_step_context(): return self.model(*args, **kwargs)
def training_step_end(self, output): if not is_overridden("training_step_end", self.lightning_module): return self.reduce(output) return output def validation_step_end(self, output): if not is_overridden("validation_step_end", self.lightning_module): return self.reduce(output) return output def test_step_end(self, output): if not is_overridden("test_step_end", self.lightning_module): return self.reduce(output) return output
[docs] def teardown(self) -> None: super().teardown() if self.on_gpu: # GPU teardown self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache()

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision 46f718d2.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
ipynb-update
docs-search
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.