Plugins¶
Plugins allow custom integrations to the internals of the Trainer such as a custom amp or ddp implementation.
For example, to customize your own DistributedDataParallel you could do something like this:
class MyDDP(DDPPlugin):
...
# use your own ddp algorithm
my_ddp = MyDDP()
trainer = Trainer(plugins=[my_ddp])
ApexPlugin¶
-
class
pytorch_lightning.plugins.apex.
ApexPlugin
(trainer=None)[source]¶ Bases:
pytorch_lightning.plugins.precision_plugin.PrecisionPlugin
-
clip_gradients
(grad_clip_val, optimizer, norm_type)[source]¶ This code is a modification of
torch.nn.utils.clip_grad_norm_()
using a higher epsilon for fp16 weights. This is important when setting amp_level to O2, and the master weights are in fp16. :type _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.grad_clip_val:Union
[int
,float
] :param _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.grad_clip_val: Maximum norm of gradients. :type _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.optimizer:Optimizer
:param _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.optimizer: Optimizer with gradients that will be clipped. :type _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.norm_type:float
:param _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.norm_type: (float or int): type of the used p-norm. Can be'inf'
for :param infinity _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.norm.:
-
configure_apex
(amp, model, optimizers, amp_level)[source]¶ Override to init AMP your own way. Must return a model and list of optimizers.
- Parameters
- Return type
Tuple
[LightningModule
,List
[Optimizer
]]- Returns
Apex wrapped model and optimizers
Examples
# Default implementation used by Trainer. def configure_apex(self, amp, model, optimizers, amp_level): model, optimizers = amp.initialize( model, optimizers, opt_level=amp_level, ) return model, optimizers
-
NativeAMPPlugin¶
DDPPlugin¶
-
class
pytorch_lightning.plugins.ddp_plugin.
DDPPlugin
(**kwargs)[source]¶ Bases:
object
Plugin to link a custom ddp implementation to any arbitrary accelerator.
This plugin forwards all constructor arguments to LightningDistributedDataParallel, which in turn forwards all args to DistributedDataParallel.
Example:
class MyDDP(DDPPlugin): def configure_ddp(self, model, device_ids): model = MyDDPWrapper(model, device_ids) return model my_ddp = MyDDP() trainer = Trainer(accelerator='ddp_x', plugins=[my_ddp])
-
configure_ddp
(model, device_ids)[source]¶ Pass through all customizations from constructor to LightningDistributedDataParallel. Override to define a custom DDP implementation.
Note
Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel
The default implementation is:
def configure_ddp(self, model, device_ids): model = LightningDistributedDataParallel( model, device_ids=device_ids, find_unused_parameters=True ) return model
- Parameters
model¶ (
LightningModule
) – the lightningModule
- Return type
LightningDistributedDataParallel
- Returns
the model wrapped in LightningDistributedDataParallel
-