pytorch_lightning.core.decorators module¶
-
pytorch_lightning.core.decorators.auto_move_data(fn)[source]¶ Decorator for
LightningModulemethods for which input arguments should be moved automatically to the correct device. It as no effect if applied to a method of an object that is not an instance ofLightningModuleand is typically applied to__call__orforward.- Parameters
fn¶ (
Callable) – A LightningModule method for which the arguments should be moved to the device the parameters are on.
Example
# directly in the source code class LitModel(LightningModule): @auto_move_data def forward(self, x): return x # or outside LitModel.forward = auto_move_data(LitModel.forward) model = LitModel() model = model.to('cuda') model(torch.zeros(1, 3)) # input gets moved to device # tensor([[0., 0., 0.]], device='cuda:0')
- Return type