Shortcuts

Source code for pytorch_lightning.strategies.sharded_spawn

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, List, Optional, Tuple

import torch
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _FAIRSCALE_AVAILABLE:
    from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
    from fairscale.optim import OSS

    from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded


[docs]class DDPSpawnShardedStrategy(DDPSpawnStrategy): """Optimizer sharded training provided by FairScale.""" strategy_name = "ddp_sharded_spawn" def configure_ddp(self) -> None: self.model, self.optimizers = self._setup_model_and_optimizers( model=LightningShardedDataParallel(self.model), optimizers=self.optimizers ) def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Wraps the model and optimizers with fairscale components. Return: The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. """ optimizers = self._wrap_optimizers(optimizers) model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs) return model, optimizers def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) optimizers[x] = zero_optimizer del optimizer return optimizers def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: return optimizers return self._reinit_optimizers_with_oss(optimizers)
[docs] def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: if isinstance(optimizer, OSS): optimizer.consolidate_state_dict() return self._optim_state_dict(optimizer)
[docs] @contextmanager def block_backward_sync(self) -> Generator: """Blocks syncing gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off """ if isinstance(self.model, ShardedDataParallel): with self.model.no_sync(): yield None else: yield None
@rank_zero_only def _optim_state_dict(self, optimizer): """ Retrieves state dict only on rank 0, which contains the entire optimizer state after calling :meth:`consolidate_state_dict`. """ return optimizer.state_dict() @property def lightning_module(self) -> Optional["pl.LightningModule"]: if not _FAIRSCALE_AVAILABLE: # pragma: no cover raise MisconfigurationException( "`DDPSpawnShardedStrategy` requires `fairscale` to be installed." " Install it by running `pip install fairscale`." ) return unwrap_lightning_module_sharded(self.model) if self.model is not None else None
[docs] def pre_backward(self, closure_loss: torch.Tensor) -> None: pass
def post_training_step(self): pass @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( "ddp_sharded_spawn_find_unused_parameters_false", cls, description="DDP Spawn Sharded Strategy with `find_unused_parameters` as False", find_unused_parameters=False, ) strategy_registry.register( cls.strategy_name, cls, description=f"{cls.__class__.__name__}", )

© Copyright Copyright (c) 2018-2022, William Falcon et al... Revision eb21135b.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
docs_2
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.