Plugins¶
Plugins allow custom integrations to the internals of the Trainer such as a custom amp or ddp implementation.
For example, to customize your own DistributedDataParallel you could do something like this:
class MyDDP(DDPPlugin):
...
# use your own ddp algorithm
my_ddp = MyDDP()
trainer = Trainer(plugins=[my_ddp])
ApexPlugin¶
-
class
pytorch_lightning.plugins.apex.
ApexPlugin
(trainer=None)[source]¶ Bases:
pytorch_lightning.plugins.precision_plugin.PrecisionPlugin
-
clip_gradients
(grad_clip_val, optimizer, norm_type)[source]¶ This code is a modification of
torch.nn.utils.clip_grad_norm_()
using a higher epsilon for fp16 weights. This is important when setting amp_level to O2, and the master weights are in fp16. :type _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.grad_clip_val:Union
[int
,float
] :param _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.grad_clip_val: Maximum norm of gradients. :type _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.optimizer:Optimizer
:param _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.optimizer: Optimizer with gradients that will be clipped. :type _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.norm_type:float
:param _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.norm_type: (float or int): type of the used p-norm. Can be'inf'
for :param infinity _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.norm.:
-
configure_apex
(amp, model, optimizers, amp_level)[source]¶ Override to init AMP your own way. Must return a model and list of optimizers.
- Parameters
- Return type
Tuple
[LightningModule
,List
[Optimizer
]]- Returns
Apex wrapped model and optimizers
Examples
# Default implementation used by Trainer. def configure_apex(self, amp, model, optimizers, amp_level): model, optimizers = amp.initialize( model, optimizers, opt_level=amp_level, ) return model, optimizers
-
NativeAMPPlugin¶
DDPPlugin¶
-
class
pytorch_lightning.plugins.ddp_plugin.
DDPPlugin
(**kwargs)[source]¶ Bases:
pytorch_lightning.plugins.plugin.LightningPlugin
Plugin to link a custom ddp implementation to any arbitrary accelerator.
This plugin forwards all constructor arguments to LightningDistributedDataParallel, which in turn forwards all args to DistributedDataParallel.
Example:
class MyDDP(DDPPlugin): def configure_ddp(self, model, device_ids): model = MyDDPWrapper(model, device_ids) return model my_ddp = MyDDP() trainer = Trainer(accelerator='ddp_x', plugins=[my_ddp])
-
block_backward_sync
(model)[source]¶ Blocks ddp sync gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off
-
configure_ddp
(model, device_ids)[source]¶ Pass through all customizations from constructor to LightningDistributedDataParallel. Override to define a custom DDP implementation.
Note
Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel
The default implementation is:
def configure_ddp(self, model, device_ids): model = LightningDistributedDataParallel( model, device_ids=device_ids, find_unused_parameters=False ) return model
- Parameters
model¶ (
LightningModule
) – the lightningModule
- Return type
LightningDistributedDataParallel
- Returns
the model wrapped in LightningDistributedDataParallel
-
get_model_from_plugin
(model)[source]¶ Override to modify returning base
LightningModule
when accessing variable and functions outside of the parallel wrapper.- Example::
ref_model = ddp_plugin.get_model_from_plugin(model) ref_model.training_step(…)
- Parameters
model¶ (
Union
[LightningDistributedDataParallel
,LightningModule
]) – Model with parallel wrapper.
Returns: Reference
LightningModule
within parallel wrapper.- Return type
-
on_after_setup_optimizers
(trainer)[source]¶ Called after optimizers have been set-up. This is useful for doing any configuration options in RPC, or state sharding.
-
on_before_forward
(model, *args)[source]¶ Override to handle custom edge case.
- Parameters
args¶ – Inputs to the model.
model¶ (
LightningModule
) – Model to train.
Returns: args moved to correct device if needed.
-
property
data_parallel_group
¶ Return the group that this process exists in. By default, this is the world size. Useful for when additional parallel groups have been created, to select certain processes. Returns: The ProcessGroup this process exists in.
-