Shortcuts

pytorch_lightning.overrides.data_parallel module

class pytorch_lightning.overrides.data_parallel.LightningDataParallel(*args, **kwargs)[source]

Bases: torch.nn.DataParallel

Override the forward call in lightning so it goes to training and validation step respectively

_LightningDataParallel__gather_structured_result(outputs)[source]
forward(*inputs, **kwargs)[source]
gather(outputs)[source]

Override the gather method to support python scalars as well.

parallel_apply(replicas, inputs, kwargs)[source]
class pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel(*args, **kwargs)[source]

Bases: torch.nn.parallel.DistributedDataParallel

Override the forward call in lightning so it goes to training and validation step respectively

forward(*inputs, **kwargs)[source]
parallel_apply(replicas, inputs, kwargs)[source]
pytorch_lightning.overrides.data_parallel._find_tensors(obj)[source]

Recursively find all tensors contained in the specified object.

pytorch_lightning.overrides.data_parallel.auto_squeeze_dim_zeros(output)[source]

In DP or DDP2 we need to unsqueeze dim 0 :param _sphinx_paramlinks_pytorch_lightning.overrides.data_parallel.auto_squeeze_dim_zeros.output: :return:

pytorch_lightning.overrides.data_parallel.get_a_var(obj)[source]
pytorch_lightning.overrides.data_parallel.parallel_apply(modules, inputs, kwargs_tup=None, devices=None)[source]

Applies each module in modules in parallel on arguments contained in inputs (positional) and kwargs_tup (keyword) on each of devices.

Parameters
  • modules (Module) – modules to be parallelized

  • inputs (tensor) – inputs to the modules

  • devices (list of int or torch.device) – CUDA devices

modules, inputs, kwargs_tup (if given), and devices (if given) should all have same length. Moreover, each element of inputs can either be a single object as the only argument to a module, or a collection of positional arguments.