Shortcuts

Source code for pytorch_lightning.plugins.ddp_plugin

from typing import List, Dict, Any

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel


[docs]class DDPPlugin(object): """ Plugin to link a custom ddp implementation to any arbitrary accelerator. This plugin forwards all constructor arguments to `LightningDistributedDataParallel`, which in turn forwards all args to `DistributedDataParallel`. Example:: class MyDDP(DDPPlugin): def configure_ddp(self, model, device_ids): model = MyDDPWrapper(model, device_ids) return model my_ddp = MyDDP() trainer = Trainer(accelerator='ddp_x', plugins=[my_ddp]) """ def __init__(self, **kwargs): self._ddp_kwargs: Dict[str, Any] = kwargs
[docs] def configure_ddp( self, model: LightningModule, device_ids: List[int] ) -> LightningDistributedDataParallel: """ Pass through all customizations from constructor to `LightningDistributedDataParallel`. Override to define a custom DDP implementation. .. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel The default implementation is:: def configure_ddp(self, model, device_ids): model = LightningDistributedDataParallel( model, device_ids=device_ids, find_unused_parameters=True ) return model Args: model: the lightningModule device_ids: the list of devices available Returns: the model wrapped in LightningDistributedDataParallel """ # if unset, default `find_unused_parameters` `True` self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( "find_unused_parameters", True ) model = LightningDistributedDataParallel( model, device_ids=device_ids, **self._ddp_kwargs, ) return model

© Copyright Copyright (c) 2018-2020, William Falcon et al... Revision 0979e2ce.

Built with Sphinx using a theme provided by Read the Docs.