pytorch_lightning.utilities.apply_func module¶
-
class
pytorch_lightning.utilities.apply_func.
TransferableDataType
[source]¶ Bases:
abc.ABC
A custom type for data that can be moved to a torch device via .to(…).
Example
>>> isinstance(dict, TransferableDataType) False >>> isinstance(torch.rand(2, 3), TransferableDataType) True >>> class CustomObject: ... def __init__(self): ... self.x = torch.rand(2, 2) ... def to(self, device): ... self.x = self.x.to(device) ... return self >>> isinstance(CustomObject(), TransferableDataType) True
-
pytorch_lightning.utilities.apply_func.
apply_to_collection
(data, dtype, function, *args, **kwargs)[source]¶ Recursively applies a function to all elements of a certain dtype.
-
pytorch_lightning.utilities.apply_func.
move_data_to_device
(batch, device)[source]¶ Transfers a collection of data to the given device. Any object that defines a method
to(device)
will be moved and all other objects in the collection will be left untouched.- Parameters
batch¶ (
Any
) – A tensor or collection of tensors or anything that has a method .to(…). Seeapply_to_collection()
for a list of supported collection types.device¶ (
device
) – The device to which the data should be moved
- Returns
the same collection but with all contained tensors residing on the new device.
See also
torch.device