- class pytorch_lightning.strategies.DDPFullyShardedStrategy(accelerator=None, cpu_offload=False, flatten_parameters=True, reshard_after_forward=True, move_grads_to_cpu=None, fp32_reduce_scatter=None, compute_dtype=None, bucket_cap_mb=25, min_num_params=100000000, state_dict_to_cpu=True, parallel_devices=None, cluster_environment=None, checkpoint_io=None, precision_plugin=None, process_group_backend=None)¶
Plugin for Fully Sharded Data Parallel provided by FairScale.
Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar to ZeRO-Stage 3 but has been built for upstreaming to PyTorch. For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html. .. warning::
FullyShardedPluginis in beta and subject to change.
Defaults have been set and options have been exposed, but may require configuration based on your level of memory/speed efficiency. We suggest having a look at this PR for more information. https://github.com/facebookresearch/fairscale/pull/413
Many of the helpful doc strings below came from the original FairScale documentation: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html
int) – bucket parameters so that gradient reduction can potentially overlap with backward computation. bucket_cap_mb controls the bucket size in MegaBytes (MB). Buckets are sub-divided based on world_size, so the max shard size is roughly bucket_cap_mb / world_size. Values <= 0 disable bucketing. (Default: 25).
Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.
Returns: Model parallel context.
- Return type
- predict_step(*args, **kwargs)¶
The actual predict step.
predict_step()for more details
Setup plugins for the trainer fit and creates optimizers.
- test_step(*args, **kwargs)¶
The actual test step.
test_step()for more details
- training_step(*args, **kwargs)¶
The actual training step.
training_step()for more details