Shortcuts

pytorch_lightning.utilities.apply_func module

class pytorch_lightning.utilities.apply_func.TransferableDataType[source]

Bases: abc.ABC

A custom type for data that can be moved to a torch device via .to(…).

Example

>>> isinstance(dict, TransferableDataType)
False
>>> isinstance(torch.rand(2, 3), TransferableDataType)
True
>>> class CustomObject:
...     def __init__(self):
...         self.x = torch.rand(2, 2)
...     def to(self, device):
...         self.x = self.x.to(device)
...         return self
>>> isinstance(CustomObject(), TransferableDataType)
True
pytorch_lightning.utilities.apply_func.apply_to_collection(data, dtype, function, *args, **kwargs)[source]

Recursively applies a function to all elements of a certain dtype.

Parameters
  • data (Any) – the collection to apply the function to

  • dtype (Union[type, tuple]) – the given function will be applied to all elements of this dtype

  • function (Callable) – the function to apply

  • *args – positional arguments (will be forwarded to calls of function)

  • **kwargs – keyword arguments (will be forwarded to calls of function)

Return type

Any

Returns

the resulting collection

pytorch_lightning.utilities.apply_func.move_data_to_device(batch, device)[source]

Transfers a collection of data to the given device. Any object that defines a method to(device) will be moved and all other objects in the collection will be left untouched.

Parameters
  • batch (Any) – A tensor or collection of tensors or anything that has a method .to(…). See apply_to_collection() for a list of supported collection types.

  • device (device) – The device to which the data should be moved

Returns

the same collection but with all contained tensors residing on the new device.

See also