Shortcuts

pytorch_lightning.core.decorators module

pytorch_lightning.core.decorators.auto_move_data(fn)[source]

Decorator for LightningModule methods for which input arguments should be moved automatically to the correct device. It as no effect if applied to a method of an object that is not an instance of LightningModule and is typically applied to __call__ or forward.

Parameters

fn (Callable) – A LightningModule method for which the arguments should be moved to the device the parameters are on.

Example

# directly in the source code
class LitModel(LightningModule):

    @auto_move_data
    def forward(self, x):
        return x

# or outside
LitModel.forward = auto_move_data(LitModel.forward)

model = LitModel()
model = model.to('cuda')
model(torch.zeros(1, 3))

# input gets moved to device
# tensor([[0., 0., 0.]], device='cuda:0')
Return type

Callable