Shortcuts

Source code for pytorch_lightning.plugins.precision.fully_sharded_native_amp

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional

import torch

from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12

if _TORCH_GREATER_EQUAL_1_12:
    from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
else:
    MixedPrecision = None


[docs]class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): """Native AMP for Fully Sharded Training."""
[docs] def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: # see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html # section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect # for FSDP module. To overcome this, needs to call sharded_module.clip_grad_norm(clip_val) # however we rely on LightningModule's configure_sharded_model to wrap FSDP, it would be hard to # trace back the root FSDP. Now we only support clip by value. raise MisconfigurationException( f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`" )
@property def mixed_precision_config(self) -> Optional[MixedPrecision]: assert MixedPrecision is not None if self.precision == PrecisionType.HALF: dtype = torch.float16 elif self.precision == PrecisionType.BFLOAT: dtype = torch.bfloat16 else: raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.") return MixedPrecision( param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype, )

© Copyright Copyright (c) 2018-2022, Lightning AI et al... Revision be581598.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
1.7.1
1.7.0
1.6.5
1.6.4
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
future-structure
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.