Shortcuts

Source code for pytorch_lightning.strategies.tpu_spawn

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.environments import XLAEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.strategies.launchers.xla import _XLALauncher
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.utilities.types import _PATH, EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS

if _TPU_AVAILABLE:
    import torch_xla.core.xla_env_vars as xenv
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.xla_multiprocessing as xmp
    from torch_xla.core.xla_model import rendezvous
    from torch_xla.distributed.parallel_loader import MpDeviceLoader
else:
    xm, xmp, MpDeviceLoader, rendezvous = [None] * 4


[docs]class TPUSpawnStrategy(DDPSpawnStrategy): """Strategy for training multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn` method.""" strategy_name = "tpu_spawn" def __init__( self, accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, **_: Any, ) -> None: super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, cluster_environment=XLAEnvironment(), checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, start_method="fork", ) self._checkpoint_io: Optional[CheckpointIO] self.debug = debug self._launched = False @property def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: self._checkpoint_io = XLACheckpointIO() elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): self._checkpoint_io.checkpoint_io = XLACheckpointIO() return self._checkpoint_io @checkpoint_io.setter def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: self._checkpoint_io = io @property def root_device(self) -> torch.device: if not self._launched: raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.") return xm.xla_device() @staticmethod def _validate_dataloader(dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None: def check_has_len(dataloader: DataLoader) -> None: if not has_len(dataloader): raise MisconfigurationException( "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." ) apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len) @staticmethod def _validate_patched_dataloaders(model: "pl.LightningModule") -> None: """Validate and fail fast if the dataloaders were passed directly to fit.""" connector: DataConnector = model.trainer._data_connector sources = ( connector._train_dataloader_source, connector._val_dataloader_source, connector._test_dataloader_source, connector._predict_dataloader_source, ) for source in sources: if not source.is_module(): assert source.instance is not None assert not isinstance(source.instance, (pl.LightningModule, pl.LightningDataModule)) TPUSpawnStrategy._validate_dataloader(source.instance)
[docs] def connect(self, model: "pl.LightningModule") -> None: # type: ignore TPUSpawnStrategy._validate_patched_dataloaders(model) self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model)) return super().connect(model)
def _configure_launcher(self) -> None: self._launcher = _XLALauncher(self)
[docs] def setup(self, trainer: "pl.Trainer") -> None: assert self.accelerator self.accelerator.setup(trainer) if self.debug: os.environ["PT_XLA_DEBUG"] = "1" assert self.model shared_params = find_shared_parameters(self.model) self.model_to_device() assert isinstance(self.model.module, Module) set_shared_parameters(self.model.module, shared_params) self.setup_precision_plugin() if trainer.state.fn == TrainerFn.FITTING: self.setup_optimizers(trainer) optimizers_to_device(self.optimizers, self.root_device)
def _setup_model(self, model: Module) -> Module: # type: ignore return model @property def distributed_sampler_kwargs(self) -> Dict[str, int]: return dict(num_replicas=self.world_size, rank=self.global_rank) @property def is_distributed(self) -> bool: # HOST_WORLD_SIZE is not set outside the xmp.spawn process return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1
[docs] def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader: TPUSpawnStrategy._validate_dataloader(dataloader) dataloader = MpDeviceLoader(dataloader, self.root_device) # Mimic interface to torch.utils.data.DataLoader dataloader.dataset = dataloader._loader.dataset return dataloader
def configure_ddp(self) -> None: pass
[docs] def model_to_device(self) -> None: self.model = self.wrapped_model.to(self.root_device)
[docs] def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: if self.is_distributed: rendezvous(name)
[docs] def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: if not self.is_distributed: return obj buffer = io.BytesIO() torch.save(obj, buffer) data = bytearray(buffer.getbuffer()) data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) data = xm.all_gather(data_tensor) buffer = io.BytesIO(data.cpu().byte().numpy()) obj = torch.load(buffer) return obj
[docs] def reduce( self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> Tensor: if not isinstance(output, Tensor): output = torch.tensor(output, device=self.root_device) invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") if invalid_reduce_op or invalid_reduce_op_str: raise ValueError( "Currently, the TPUSpawnStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:" f" {reduce_op}" ) output = xm.mesh_reduce("reduce", output, sum) if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): output = output / self.world_size return output
def _worker_setup(self, process_idx: int) -> None: self._launched = True self.set_world_ranks(process_idx) rank_zero_only.rank = self.global_rank
[docs] def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: assert self.model is not None with self.precision_plugin.val_step_context(): return self.model(*args, **kwargs)
[docs] def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: assert self.model is not None with self.precision_plugin.test_step_context(): return self.model(*args, **kwargs)
[docs] def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: assert self.model is not None with self.precision_plugin.predict_step_context(): return self.model(*args, **kwargs)
def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: self._pod_progress_bar_force_stdout() return output def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: self._pod_progress_bar_force_stdout() return output def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: self._pod_progress_bar_force_stdout() return output def _pod_progress_bar_force_stdout(self) -> None: # Why is it required? The way `pytorch_xla.distributed` streams logs # from different vms to the main worker doesn't work well with tqdm # Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140 # The print statement seems to force tqdm to flush stdout. if self.global_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: print()
[docs] def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state filepath: write-target file's path storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin """ # `xla_model.save` needs to be called on all ranks. It internally checks if the local rank is 0 self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
[docs] def remove_checkpoint(self, filepath: _PATH) -> None: """Remove checkpoint filepath from the filesystem. Args: filepath: Path to checkpoint """ if self.local_rank == 0: self.checkpoint_io.remove_checkpoint(filepath)
[docs] def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: """ Function to gather a tensor from several distributed processes Args: tensor: tensor of shape (batch, ...) group: not available with TPUs sync_grads: not available with TPUs Return: A tensor of shape (world_size, batch, ...) """ if isinstance(tensor, Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) return xm.all_gather(tensor)
[docs] def teardown(self) -> None: super().teardown() os.environ.pop("PT_XLA_DEBUG", None)
@classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( "tpu_spawn_debug", cls, description="TPUSpawn Strategy with `debug` as True", debug=True ) strategy_registry.register( cls.strategy_name, cls, description=f"{cls.__class__.__name__}", )

© Copyright Copyright (c) 2018-2022, Lightning AI et al... Revision dbb5ca8d.

Built with Sphinx using a theme provided by Read the Docs.