Plugins¶
Plugins allow custom integrations to the internals of the Trainer such as a custom amp or ddp implementation.
For example, to customize your own DistributedDataParallel you could do something like this:
class MyDDP(DDPPlugin):
...
# use your own ddp algorithm
my_ddp = MyDDP()
trainer = Trainer(plugins=[my_ddp])
ApexPlugin¶
-
class
pytorch_lightning.plugins.apex.ApexPlugin(trainer=None)[source]¶ Bases:
pytorch_lightning.plugins.precision_plugin.PrecisionPlugin-
clip_gradients(grad_clip_val, optimizer, norm_type)[source]¶ This code is a modification of
torch.nn.utils.clip_grad_norm_()using a higher epsilon for fp16 weights. This is important when setting amp_level to O2, and the master weights are in fp16. :type _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.grad_clip_val:Union[int,float] :param _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.grad_clip_val: Maximum norm of gradients. :type _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.optimizer:Optimizer:param _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.optimizer: Optimizer with gradients that will be clipped. :type _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.norm_type:float:param _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.norm_type: (float or int): type of the used p-norm. Can be'inf'for :param infinity _sphinx_paramlinks_pytorch_lightning.plugins.apex.ApexPlugin.clip_gradients.norm.:
-
configure_apex(amp, model, optimizers, amp_level)[source]¶ Override to init AMP your own way. Must return a model and list of optimizers.
- Parameters
- Return type
Tuple[LightningModule,List[Optimizer]]- Returns
Apex wrapped model and optimizers
Examples
# Default implementation used by Trainer. def configure_apex(self, amp, model, optimizers, amp_level): model, optimizers = amp.initialize( model, optimizers, opt_level=amp_level, ) return model, optimizers
-
NativeAMPPlugin¶
DDPPlugin¶
-
class
pytorch_lightning.plugins.ddp_plugin.DDPPlugin(**kwargs)[source]¶ Bases:
pytorch_lightning.plugins.plugin.LightningPluginPlugin to link a custom ddp implementation to any arbitrary accelerator.
This plugin forwards all constructor arguments to LightningDistributedDataParallel, which in turn forwards all args to DistributedDataParallel.
Example:
class MyDDP(DDPPlugin): def configure_ddp(self, model, device_ids): model = MyDDPWrapper(model, device_ids) return model my_ddp = MyDDP() trainer = Trainer(accelerator='ddp_x', plugins=[my_ddp])
-
block_backward_sync(model)[source]¶ Blocks ddp sync gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off
-
configure_ddp(model, device_ids)[source]¶ Pass through all customizations from constructor to LightningDistributedDataParallel. Override to define a custom DDP implementation.
Note
Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel
The default implementation is:
def configure_ddp(self, model, device_ids): model = LightningDistributedDataParallel( model, device_ids=device_ids, find_unused_parameters=False ) return model
- Parameters
model¶ (
LightningModule) – the lightningModule
- Return type
LightningDistributedDataParallel- Returns
the model wrapped in LightningDistributedDataParallel
-
get_model_from_plugin(model)[source]¶ Override to modify returning base
LightningModulewhen accessing variable and functions outside of the parallel wrapper.- Example::
ref_model = ddp_plugin.get_model_from_plugin(model) ref_model.training_step(…)
- Parameters
model¶ (
Union[LightningDistributedDataParallel,LightningModule]) – Model with parallel wrapper.
Returns: Reference
LightningModulewithin parallel wrapper.- Return type
-
on_after_setup_optimizers(trainer)[source]¶ Called after optimizers have been set-up. This is useful for doing any configuration options in RPC, or state sharding.
-
on_before_forward(model, *args)[source]¶ Override to handle custom edge case.
- Parameters
args¶ – Inputs to the model.
model¶ (
LightningModule) – Model to train.
Returns: args moved to correct device if needed.
-
property
data_parallel_group¶ Return the group that this process exists in. By default, this is the world size. Useful for when additional parallel groups have been created, to select certain processes. Returns: The ProcessGroup this process exists in.
-