Shortcuts

Source code for pytorch_lightning.plugins.training_type.single_device

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union

import torch

import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.utilities import _XLA_AVAILABLE


[docs]class SingleDevicePlugin(TrainingTypePlugin): """Plugin that handles communication on a single device.""" def __init__( self, device: torch.device, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ): super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.device: torch.device = device self.global_rank = 0 self.local_rank = 0 self.world_size = 1 @property def on_tpu(self) -> bool: return self.root_device.type == "xla" and _XLA_AVAILABLE @property def on_gpu(self) -> bool: return self.root_device.type == "cuda" and torch.cuda.is_available()
[docs] def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]: """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only operates with a single device, the reduction is simply the identity. Args: tensor: the tensor to sync and reduce *args: ignored **kwargs: ignored Return: the unmodified input as reduction is not needed for single process operation """ return tensor
[docs] def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: """Perform a all_gather on all processes.""" return tensor
@property def root_device(self) -> torch.device: return self.device
[docs] def model_to_device(self) -> None: self._model.to(self.root_device)
[docs] def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() super().setup(trainer)
@property def is_global_zero(self) -> bool: return True
[docs] def barrier(self, *args, **kwargs) -> None: pass
[docs] def broadcast(self, obj: object, src: int = 0) -> object: return obj
[docs] def teardown(self) -> None: if self.on_gpu: # GPU teardown self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache()

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision 46f718d2.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
ipynb-update
docs-search
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.