Shortcuts

Source code for pytorch_lightning.core.mixins.hparams_mixin

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import types
from argparse import Namespace
from typing import Any, MutableMapping, Optional, Sequence, Union

from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.parsing import save_hyperparameters


[docs]class HyperparametersMixin: __jit_unused_properties__ = ["hparams", "hparams_initial"] def __init__(self) -> None: super().__init__() self._log_hyperparams = False
[docs] def save_hyperparameters( self, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None, logger: bool = True, ) -> None: """Save arguments to ``hparams`` attribute. Args: args: single object of `dict`, `NameSpace` or `OmegaConf` or string names or arguments from class ``__init__`` ignore: an argument name or a list of argument names from class ``__init__`` to be ignored frame: a frame object. Default is None logger: Whether to send the hyperparameters to the logger. Default: True Example:: >>> class ManuallyArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # manually assign arguments ... self.save_hyperparameters('arg1', 'arg3') ... def forward(self, *args, **kwargs): ... ... >>> model = ManuallyArgsModel(1, 'abc', 3.14) >>> model.hparams "arg1": 1 "arg3": 3.14 >>> class AutomaticArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # equivalent automatic ... self.save_hyperparameters() ... def forward(self, *args, **kwargs): ... ... >>> model = AutomaticArgsModel(1, 'abc', 3.14) >>> model.hparams "arg1": 1 "arg2": abc "arg3": 3.14 >>> class SingleArgModel(HyperparametersMixin): ... def __init__(self, params): ... super().__init__() ... # manually assign single argument ... self.save_hyperparameters(params) ... def forward(self, *args, **kwargs): ... ... >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) >>> model.hparams "p1": 1 "p2": abc "p3": 3.14 >>> class ManuallyArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # pass argument(s) to ignore as a string or in a list ... self.save_hyperparameters(ignore='arg2') ... def forward(self, *args, **kwargs): ... ... >>> model = ManuallyArgsModel(1, 'abc', 3.14) >>> model.hparams "arg1": 1 "arg3": 3.14 """ self._log_hyperparams = logger # the frame needs to be created in this file. if not frame: current_frame = inspect.currentframe() if current_frame: frame = current_frame.f_back save_hyperparameters(self, *args, ignore=ignore, frame=frame)
def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None: hp = self._to_hparams_dict(hp) if isinstance(hp, dict) and isinstance(self.hparams, dict): self.hparams.update(hp) else: self._hparams = hp @staticmethod def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[MutableMapping, AttributeDict]: if isinstance(hp, Namespace): hp = vars(hp) if isinstance(hp, dict): hp = AttributeDict(hp) elif isinstance(hp, PRIMITIVE_TYPES): raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.") elif not isinstance(hp, ALLOWED_CONFIG_TYPES): raise ValueError(f"Unsupported config type of {type(hp)}.") return hp @property def hparams(self) -> Union[AttributeDict, MutableMapping]: """The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. For the frozen set of initial hyperparameters, use :attr:`hparams_initial`. Returns: Mutable hyperparameters dictionary """ if not hasattr(self, "_hparams"): self._hparams = AttributeDict() return self._hparams @property def hparams_initial(self) -> AttributeDict: """The collection of hyperparameters saved with :meth:`save_hyperparameters`. These contents are read-only. Manual updates to the saved hyperparameters can instead be performed through :attr:`hparams`. Returns: AttributeDict: immutable initial hyperparameters """ if not hasattr(self, "_hparams_initial"): return AttributeDict() # prevent any change return copy.deepcopy(self._hparams_initial)

© Copyright Copyright (c) 2018-2022, William Falcon et al... Revision 74b1317e.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
1.6.4
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
docs_2
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.