Shortcuts

Source code for pytorch_lightning.strategies.single_device

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import Any

import torch
from torch import Tensor

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO
from lightning_lite.utilities.types import _DEVICE
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast


[docs]class SingleDeviceStrategy(Strategy): """Strategy that handles communication on a single device.""" strategy_name = "single_device" def __init__( self, device: _DEVICE = "cpu", accelerator: pl.accelerators.accelerator.Accelerator | None = None, checkpoint_io: CheckpointIO | None = None, precision_plugin: PrecisionPlugin | None = None, ): super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self._root_device = torch.device(device) self.global_rank = 0 self.local_rank = 0 self.world_size = 1
[docs] def reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only operates with a single device, the reduction is simply the identity. Args: tensor: the tensor to sync and reduce *args: ignored **kwargs: ignored Return: the unmodified input as reduction is not needed for single process operation """ return tensor
[docs] def all_gather(self, tensor: Tensor, group: Any | None = None, sync_grads: bool = False) -> Tensor: """Perform a all_gather on all processes.""" return tensor
@property def root_device(self) -> torch.device: return self._root_device
[docs] def model_to_device(self) -> None: assert self.model is not None, "self.model must be set before self.model.to()" self.model.to(self.root_device)
[docs] def setup(self, trainer: pl.Trainer) -> None: self.model_to_device() super().setup(trainer)
@property def is_global_zero(self) -> bool: return True
[docs] def barrier(self, *args: Any, **kwargs: Any) -> None: pass
[docs] def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj
@classmethod def register_strategies(cls, strategy_registry: dict) -> None: strategy_registry.register( cls.strategy_name, cls, description=f"{cls.__class__.__name__}", )

© Copyright Copyright (c) 2018-2022, Lightning AI et al... Revision 92fe1887.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
1.8.3post1
1.8.3.post0
1.8.3
1.8.2
1.8.1
1.8.0.post1
1.8.0
1.7.7
1.7.6
1.7.5
1.7.4
1.7.3
1.7.2
1.7.1
1.7.0
1.6.5
1.6.4
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.