pytorch_lightning.overrides.data_parallel module¶
-
class
pytorch_lightning.overrides.data_parallel.
LightningDataParallel
(*args, **kwargs)[source]¶ Bases:
torch.nn.DataParallel
Override the forward call in lightning so it goes to training and validation step respectively
-
class
pytorch_lightning.overrides.data_parallel.
LightningDistributedDataParallel
(*args, **kwargs)[source]¶ Bases:
torch.nn.parallel.DistributedDataParallel
Override the forward call in lightning so it goes to training and validation step respectively
-
pytorch_lightning.overrides.data_parallel.
_find_tensors
(obj)[source]¶ Recursively find all tensors contained in the specified object.
-
pytorch_lightning.overrides.data_parallel.
auto_squeeze_dim_zeros
(output)[source]¶ In DP or DDP2 we need to unsqueeze dim 0 :param _sphinx_paramlinks_pytorch_lightning.overrides.data_parallel.auto_squeeze_dim_zeros.output: :return:
-
pytorch_lightning.overrides.data_parallel.
parallel_apply
(modules, inputs, kwargs_tup=None, devices=None)[source]¶ Applies each module in
modules
in parallel on arguments contained ininputs
(positional) andkwargs_tup
(keyword) on each ofdevices
.- Parameters
modules
,inputs
,kwargs_tup
(if given), anddevices
(if given) should all have same length. Moreover, each element ofinputs
can either be a single object as the only argument to a module, or a collection of positional arguments.