Shortcuts

Source code for pytorch_lightning.strategies.single_hpu

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, Optional, Union

from torch.nn import Module
from torch.optim.optimizer import Optimizer

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO
from lightning_lite.utilities.types import _DEVICE
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
from pytorch_lightning.utilities import _HPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _HPU_AVAILABLE:
    import habana_frameworks.torch.core as htcore


[docs]class SingleHPUStrategy(SingleDeviceStrategy): """Strategy for training on single HPU device.""" strategy_name = "hpu_single" def __init__( self, device: _DEVICE = "hpu", accelerator: Optional["pl.accelerators.Accelerator"] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ): if not _HPU_AVAILABLE: raise MisconfigurationException("`SingleHPUStrategy` requires HPU devices to run") super().__init__( accelerator=accelerator, device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) @property def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: self._checkpoint_io = HPUCheckpointIO() elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): self._checkpoint_io.checkpoint_io = HPUCheckpointIO() return self._checkpoint_io @checkpoint_io.setter def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: self._checkpoint_io = io @property def is_distributed(self) -> bool: return False
[docs] def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() super().setup(trainer)
[docs] def setup_optimizers(self, trainer: "pl.Trainer") -> None: super().setup_optimizers(trainer)
[docs] def model_to_device(self) -> None: self.model.to(self.root_device) # type: ignore
def on_after_backward(self) -> None: # Break lazy accumulation of graph after fwd+bwd htcore.mark_step()
[docs] def optimizer_step( self, optimizer: Optimizer, opt_idx: int, closure: Callable[[], Any], model: Optional[Union["pl.LightningModule", Module]] = None, **kwargs: Any, ) -> Any: optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs) # Break lazy accumulation of graph after optimizer htcore.mark_step() return optimizer_output
def validation_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT: # Break lazy accumulation of graph after every step htcore.mark_step() return step_output def test_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT: # Break lazy accumulation of graph after every step htcore.mark_step() return step_output @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( cls.strategy_name, cls, description=f"{cls.__class__.__name__}", )

© Copyright Copyright (c) 2018-2022, Lightning AI et al... Revision f4fcad36.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.8.3post1
1.8.3.post0
1.8.3
1.8.2
1.8.1
1.8.0.post1
1.8.0
1.7.7
1.7.6
1.7.5
1.7.4
1.7.3
1.7.2
1.7.1
1.7.0
1.6.5
1.6.4
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.