Shortcuts

Source code for pytorch_lightning.plugins.training_type.single_tpu

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, Optional

import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _PATH

if _TPU_AVAILABLE:
    import torch_xla.core.xla_model as xm


[docs]class SingleTPUPlugin(SingleDevicePlugin): """Plugin for training on a single TPU device.""" def __init__( self, device: int, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, ): device = xm.xla_device(device) checkpoint_io = checkpoint_io or XLACheckpointIO() super().__init__(device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.debug = debug self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 @property def is_distributed(self) -> bool: return False
[docs] def setup(self, trainer: "pl.Trainer") -> None: shared_params = find_shared_parameters(self.model) self.model_to_device() if is_overridden("on_post_move_to_device", self.lightning_module): self.model.on_post_move_to_device() else: set_shared_parameters(self.model, shared_params) super().setup(trainer)
[docs] def model_to_device(self) -> None: self.model.to(self.root_device)
[docs] def pre_dispatch(self, trainer: "pl.Trainer") -> None: super().pre_dispatch(trainer) if isinstance(self.device, int): self.device = xm.xla_device(self.device) if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal()
def save(self, state_dict: Dict, path: _PATH) -> None: xm.save(state_dict, path)
[docs] def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state filepath: write-target file's path """ return self.checkpoint_io.save_checkpoint(checkpoint, filepath)
[docs] def teardown(self) -> None: # TPU teardown os.environ.pop("PT_XLA_DEBUG", None)
@property def checkpoint_io(self) -> CheckpointIO: return self._checkpoint_io @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Plugin currently does not support custom checkpoint plugins.")

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision 46f718d2.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
ipynb-update
docs-search
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.