Shortcuts

Source code for pytorch_lightning.accelerators.tpu

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union

import torch

from lightning_lite.accelerators.tpu import _parse_tpu_cores, _XLA_AVAILABLE
from lightning_lite.accelerators.tpu import TPUAccelerator as LiteTPUAccelerator
from lightning_lite.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator


[docs]class TPUAccelerator(Accelerator): """Accelerator for TPU devices.""" def __init__(self, *args: Any, **kwargs: Any) -> None: if not _XLA_AVAILABLE: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(*args, **kwargs)
[docs] def setup_device(self, device: torch.device) -> None: pass
[docs] def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """Gets stats for the given TPU device. Args: device: TPU device for which to get stats Returns: A dictionary mapping the metrics (free memory and peak memory) to their values. """ import torch_xla.core.xla_model as xm memory_info = xm.get_memory_info(device) free_memory = memory_info["kb_free"] peak_memory = memory_info["kb_total"] - free_memory device_stats = { "avg. free memory (MB)": free_memory, "avg. peak memory (MB)": peak_memory, } return device_stats
[docs] def teardown(self) -> None: pass
[docs] @staticmethod def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]: """Accelerator device parsing logic.""" return _parse_tpu_cores(devices)
[docs] @staticmethod def get_parallel_devices(devices: Union[int, List[int]]) -> List[int]: """Gets parallel devices for the Accelerator.""" if isinstance(devices, int): return list(range(devices)) return devices
[docs] @staticmethod def auto_device_count() -> int: """Get the devices when set to auto.""" return 8
[docs] @staticmethod def is_available() -> bool: return LiteTPUAccelerator.is_available()
@classmethod def register_accelerators(cls, accelerator_registry: Dict) -> None: accelerator_registry.register( "tpu", cls, description=cls.__class__.__name__, )

© Copyright Copyright (c) 2018-2022, Lightning AI et al... Revision f4fcad36.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.8.3post1
1.8.3.post0
1.8.3
1.8.2
1.8.1
1.8.0.post1
1.8.0
1.7.7
1.7.6
1.7.5
1.7.4
1.7.3
1.7.2
1.7.1
1.7.0
1.6.5
1.6.4
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.