Shortcuts

Source code for pytorch_lightning.strategies.single_tpu

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict, Optional

import pytorch_lightning as pl
from lightning_lite.accelerators.tpu import _XLA_AVAILABLE
from lightning_lite.plugins import CheckpointIO, XLACheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
from pytorch_lightning.utilities import find_shared_parameters, set_shared_parameters


[docs]class SingleTPUStrategy(SingleDeviceStrategy): """Strategy for training on a single TPU device.""" strategy_name = "single_tpu" def __init__( self, device: int, accelerator: Optional["pl.accelerators.Accelerator"] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, ): if not _XLA_AVAILABLE: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) import torch_xla.core.xla_model as xm super().__init__( accelerator=accelerator, device=xm.xla_device(device), checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) self.debug = debug @property def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: self._checkpoint_io = XLACheckpointIO() elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): self._checkpoint_io.checkpoint_io = XLACheckpointIO() return self._checkpoint_io @checkpoint_io.setter def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: self._checkpoint_io = io @property def is_distributed(self) -> bool: return False
[docs] def setup(self, trainer: "pl.Trainer") -> None: assert self.model, "self.model must be set before find_shared_parameters(self.model)" shared_params = find_shared_parameters(self.model) self.model_to_device() set_shared_parameters(self.model, shared_params) super().setup(trainer) if self.debug: os.environ["PT_XLA_DEBUG"] = str(1)
@classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( cls.strategy_name, cls, description=f"{cls.__class__.__name__}", )
[docs] def teardown(self) -> None: super().teardown() os.environ.pop("PT_XLA_DEBUG", None)

© Copyright Copyright (c) 2018-2022, Lightning AI et al... Revision f4fcad36.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.8.3post1
1.8.3.post0
1.8.3
1.8.2
1.8.1
1.8.0.post1
1.8.0
1.7.7
1.7.6
1.7.5
1.7.4
1.7.3
1.7.2
1.7.1
1.7.0
1.6.5
1.6.4
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.