
Source code for pytorch_lightning.accelerators.tpu_accelerator

# Copyright The PyTorch Lightning team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import os
import re
from typing import Any, Callable, Optional, Union

import torch
import torch.multiprocessing as mp
from torch.optim import Optimizer

from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import (
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException

log = logging.getLogger(__name__)
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as xla_pl
    import torch_xla.distributed.xla_multiprocessing as xmp

[docs]class TPUAccelerator(Accelerator): def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None): """ Runs training using TPUs (colab, single machine or pod) Example:: # default trainer = Trainer(accelerator=TPUAccelerator()) """ super().__init__(trainer, cluster_environment) self.start_method = None self.mp_queue = None self.nickname = None def setup(self, model): rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores') # TODO: Move this check to Trainer __init__ or device parser if not TPU_AVAILABLE: raise MisconfigurationException('PyTorch XLA not installed.') # see: self.start_method = 'fork' # pass in a state q smp = mp.get_context(self.start_method) self.mp_queue = smp.SimpleQueue() self.trainer.model = model def teardown(self): model = self.trainer.model # restore main state with best weights best_path = self.mp_queue.get() results = self.mp_queue.get() last_path = self.mp_queue.get() # transfer back the best path to the trainer if self.trainer.checkpoint_callback is not None: self.trainer.checkpoint_callback.best_model_path = best_path # todo, pass also bets score # load last weights if last_path and not self.trainer.testing: ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt) self.trainer.model = model # when training completes, load the weights back in main process self.__load_weights_on_main_process() return results def train(self): model = self.trainer.model # train if self.trainer.tpu_id is not None: self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue) else: xmp.spawn( self.tpu_train_in_process, args=(model, self.trainer, self.mp_queue), nprocs=self.trainer.tpu_cores, start_method=self.start_method ) def __load_weights_on_main_process(self): model = self.trainer.model # load weights if not interrupted if self.trainer.on_colab_kaggle and not self.trainer.testing: self.load_spawn_weights(model) self.trainer.model = model
[docs] def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, trainer=None, mp_queue=None): """ Here we are inside each individual process """ if not trainer: trainer = self.trainer trainer.call_setup_hook(model) # setup TPU training self.__setup_tpu_training(model, trainer) self.trainer.setup_trainer(model) # train or test results = self.train_or_test() # save weights at the end of training self.__save_end_of_training_weights(model, trainer) # persist info in spawn self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
def _step(self, model_step: Callable, args): args[0] = self.to_device(args[0]) return model_step(*args) def training_step(self, args): return self._step(self.trainer.model.training_step, args) def validation_step(self, args): return self._step(self.trainer.model.validation_step, args) def test_step(self, args): return self._step(self.trainer.model.test_step, args) def process_dataloader(self, dataloader): device = xm.xla_device(self.trainer.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) return dataloader
[docs] def to_device(self, batch): """ Transfers the data to the TPU. Args: batch: A tensor or collection of tensors. tpu_id: The id of the TPU core. If omitted, the first available core is chosen. Return: the tensor on the TPU device. See Also: - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` """ if not TPU_AVAILABLE: raise MisconfigurationException( 'Requested to transfer batch to TPU but XLA is not available.' ' Are you sure this machine has TPUs?' ) device = xm.xla_device(self.trainer.tpu_id) return self.batch_to_device(batch, device)
def __save_end_of_training_weights(self, model: LightningModule, trainer): # when training ends on these platforms dump weights to get out of the main process if trainer.on_colab_kaggle: rank_zero_warn('cleaning up... please do not interrupt') self.save_spawn_weights(model) def __setup_tpu_training(self, model: LightningModule, trainer): # use the default device from the process # tpu_device = xm.xla_device() # if given an ordinal device, use this as the device if trainer.tpu_id is not None: tpu_device = xm.xla_device(trainer.tpu_id) else: tpu_device = xm.xla_device() # track the device and move model to it trainer._device = tpu_device # get the appropriate tpu ranks trainer.tpu_local_core_rank = xm.get_local_ordinal() trainer.tpu_global_core_rank = xm.get_ordinal() # avoid duplicating progress bar if trainer.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: trainer.progress_bar_callback.disable() trainer.global_rank = trainer.tpu_local_core_rank rank_zero_only.rank = trainer.global_rank # CHOOSE OPTIMIZER # allow for lr schedulers as well self.setup_optimizers(model) # init 16 bit for TPU if trainer.precision == 16: os.environ['XLA_USE_BF16'] = str(1)'INIT TPU local core: {trainer.tpu_local_core_rank},' f' global rank: {trainer.tpu_global_core_rank}' f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}') def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): # do backward pass if self.trainer.train_loop.automatic_optimization: model = self.trainer.get_model() model.backward(closure_loss, optimizer, opt_idx) else: closure_loss.backward(*args, **kwargs) # detach after backward closure_loss = closure_loss.detach() return closure_loss def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0): # this code is a modification of torch.nn.utils.clip_grad_norm_ # with TPU support based on model = self.trainer.get_model() parameters = model.parameters() max_norm = grad_clip_val if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) device = parameters[0].device out = torch.empty(len(parameters), device=device) for i, p in enumerate(parameters): torch.norm(, norm_type, out=out[i]) total_norm = torch.norm(out, norm_type) clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon) clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) for p in parameters: def barrier(self, name: Optional[str] = None): torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}") def early_stopping_should_stop(self, pl_module): stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device, dtype=torch.int32) stop = xm.mesh_reduce("stop_signal", stop, sum) torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") should_stop = int(stop.item()) == self.trainer.world_size return should_stop
[docs] def save_spawn_weights(self, model): """ Dump a temporary checkpoint after ddp ends to get weights out of the process """ if self.trainer.is_global_zero: path = os.path.join(self.trainer.default_root_dir, '__temp_weight_distributed_end.ckpt') self.trainer.save_checkpoint(path) return path
[docs] def load_spawn_weights(self, original_model): """ Load the temp weights saved in the process To recover the trained model from the ddp process we load the saved weights """ loaded_model = original_model if self.trainer.is_global_zero: # load weights saved in ddp path = os.path.join(self.trainer.default_root_dir, '__temp_weight_distributed_end.ckpt') loaded_model = original_model.__class__.load_from_checkpoint(path) # copy loaded weights to old model original_model.load_state_dict(loaded_model.state_dict()) # remove ddp weights os.remove(path) return loaded_model
def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): if self.trainer.distributed_backend not in ("ddp_spawn", "ddp_cpu", "tpu"): return # track the best model path best_model_path = None if self.trainer.checkpoint_callback is not None: best_model_path = self.trainer.checkpoint_callback.best_model_path if self.trainer.global_rank == 0 and mp_queue is not None: rank_zero_warn('cleaning up ddp environment...') # todo, pass complete checkpoint as state dictionary mp_queue.put(best_model_path) mp_queue.put(results) # save the last weights last_path = None if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) state_dict = move_data_to_device(model.state_dict(), torch.device("cpu")) atomic_save(state_dict, last_path) mp_queue.put(last_path) def broadcast(self, obj, src=0): if self.trainer.tpu_id is not None: # running on a single core return obj buffer = io.BytesIO(), buffer) data = bytearray(buffer.getbuffer()) data_tensor = torch.tensor(data).to(xm.xla_device(), dtype=torch.float) data = xm.all_gather(data_tensor) buffer = io.BytesIO(data.cpu().byte().numpy()) obj = torch.load(buffer) return obj
[docs] def sync_tensor(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return tensor
@property def norm_clipping_epsilon(self): return 1e-6
[docs] def on_save(self, checkpoint): """ Move XLA tensors to CPU before saving Recommended on XLA Guide: """ return move_data_to_device(checkpoint, torch.device("cpu"))
@property def distributed_sampler_kwargs(self): return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) @property def require_distributed_sampler(self): return True

