TransformerEnginePrecision

class lightning.pytorch.plugins.precision.TransformerEnginePrecision(*, weights_dtype, recipe=None, replace_layers=None, fallback_compute_dtype=None)[source]

Bases: Precision, TransformerEnginePrecision

Plugin for training with fp8 precision via nvidia’s Transformer Engine.

Warning

This is an experimental feature.

Parameters:
  • dtype – The weights dtype to use.

  • recipe (Union[Mapping[str, Any], DelayedScaling, None]) – Recipe for the DelayedScaling configuration. In dict format or the dataclass format.

  • replace_layers (Optional[bool]) – Whether to replace Linear and LayerNorm layers automatically with their Transformer Engine alternatives. Note that they don’t subclass the torch equivalents so checks like isinstance(l, torch.nn.Linear) will not pass.

Note

Support for FP8 in the linear layers with this plugin is currently limited to tensors with shapes where the dimensions are divisible by 8 and 16 respectively. You might want to add padding to your inputs to conform to this restriction.