Shortcuts

Source code for pytorch_lightning.utilities.seed

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper functions to help with reproducibility of models. """

import os
import random
from typing import Optional

import numpy as np
import torch

from pytorch_lightning import _logger as log


[docs]def seed_everything(seed: Optional[int] = None) -> int: """ Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, sets the env variable `PL_GLOBAL_SEED` which will be passed to spawned subprocesses (e.g. ddp_spawn backend). Args: seed: the integer value seed for global random state in Lightning. If `None`, will read seed from `PL_GLOBAL_SEED` env variable or select it randomly. """ max_seed_value = np.iinfo(np.uint32).max min_seed_value = np.iinfo(np.uint32).min try: if seed is None: seed = os.environ.get("PL_GLOBAL_SEED", _select_seed_randomly(min_seed_value, max_seed_value)) seed = int(seed) except (TypeError, ValueError): seed = _select_seed_randomly(min_seed_value, max_seed_value) if (seed > max_seed_value) or (seed < min_seed_value): log.warning( f"{seed} is not in bounds, \ numpy accepts from {min_seed_value} to {max_seed_value}" ) seed = _select_seed_randomly(min_seed_value, max_seed_value) os.environ["PL_GLOBAL_SEED"] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) return seed
def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: seed = random.randint(min_seed_value, max_seed_value) log.warning(f"No correct seed found, seed set to {seed}") return seed

© Copyright Copyright (c) 2018-2020, William Falcon et al... Revision 0979e2ce.

Built with Sphinx using a theme provided by Read the Docs.