Shortcuts

Source code for pytorch_lightning.core.decorators

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decorator for LightningModule methods."""

from functools import wraps
from typing import Callable

from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn


[docs]def auto_move_data(fn: Callable) -> Callable: """ Decorator for :class:`~pytorch_lightning.core.lightning.LightningModule` methods for which input arguments should be moved automatically to the correct device. It as no effect if applied to a method of an object that is not an instance of :class:`~pytorch_lightning.core.lightning.LightningModule` and is typically applied to ``__call__`` or ``forward``. Args: fn: A LightningModule method for which the arguments should be moved to the device the parameters are on. Example:: # directly in the source code class LitModel(LightningModule): @auto_move_data def forward(self, x): return x # or outside LitModel.forward = auto_move_data(LitModel.forward) model = LitModel() model = model.to('cuda') model(torch.zeros(1, 3)) # input gets moved to device # tensor([[0., 0., 0.]], device='cuda:0') """ @wraps(fn) def auto_transfer_args(self, *args, **kwargs): from pytorch_lightning.core.lightning import LightningModule if not isinstance(self, LightningModule): return fn(self, *args, **kwargs) args, kwargs = self.transfer_batch_to_device((args, kwargs), device=self.device, dataloader_idx=None) return fn(self, *args, **kwargs) rank_zero_deprecation( "The `@auto_move_data` decorator is deprecated in v1.3 and will be removed in v1.5." f" Please use `trainer.predict` instead for inference. The decorator was applied to `{fn.__name__}`" ) return auto_transfer_args
[docs]def parameter_validation(fn: Callable) -> Callable: """ Validates that the module parameter lengths match after moving to the device. It is useful when tying weights on TPU's. Args: fn: ``model_to_device`` method Note: TPU's require weights to be tied/shared after moving the module to the device. Failure to do this results in the initialization of new weights which are not tied. To overcome this issue, weights should be tied using the ``on_post_move_to_device`` model hook which is called after the module has been moved to the device. See Also: - `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_ """ @wraps(fn) def inner_fn(self, *args, **kwargs): pre_layer_count = len(list(self.model.parameters())) module = fn(self, *args, **kwargs) self.model.on_post_move_to_device() post_layer_count = len(list(self.model.parameters())) if not pre_layer_count == post_layer_count: rank_zero_warn( f"The model layers do not match after moving to the target device." " If your model employs weight sharing on TPU," " please tie your weights using the `on_post_move_to_device` model hook.\n" f"Layer count: [Before: {pre_layer_count} After: {post_layer_count}]" ) return module return inner_fn

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision 495aa44f.

Built with Sphinx using a theme provided by Read the Docs.