pytorch_lightning.core.decorators module¶
-
pytorch_lightning.core.decorators.
auto_move_data
(fn)[source]¶ Decorator for
LightningModule
methods for which input arguments should be moved automatically to the correct device. It as no effect if applied to a method of an object that is not an instance ofLightningModule
and is typically applied to__call__
orforward
.- Parameters
fn¶ (
Callable
) – A LightningModule method for which the arguments should be moved to the device the parameters are on.
Example
# directly in the source code class LitModel(LightningModule): @auto_move_data def forward(self, x): return x # or outside LitModel.forward = auto_move_data(LitModel.forward) model = LitModel() model = model.to('cuda') model(torch.zeros(1, 3)) # input gets moved to device # tensor([[0., 0., 0.]], device='cuda:0')
- Return type