Shortcuts

Source code for pytorch_lightning.tuner.lr_finder

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
import os
import uuid
from functools import wraps
from typing import Optional, Sequence

import numpy as np
import torch
from torch.optim.lr_scheduler import _LRScheduler

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr

# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
if importlib.util.find_spec("ipywidgets") is not None:
    from tqdm.auto import tqdm
else:
    from tqdm import tqdm

log = logging.getLogger(__name__)


def _determine_lr_attr_name(trainer: "pl.Trainer", model: "pl.LightningModule") -> str:
    if isinstance(trainer.auto_lr_find, str):
        if not lightning_hasattr(model, trainer.auto_lr_find):
            raise MisconfigurationException(
                f"`auto_lr_find` was set to {trainer.auto_lr_find}, however"
                " could not find this as a field in `model` or `model.hparams`."
            )
        return trainer.auto_lr_find

    attr_options = ("lr", "learning_rate")
    for attr in attr_options:
        if lightning_hasattr(model, attr):
            return attr

    raise MisconfigurationException(
        "When `auto_lr_find=True`, either `model` or `model.hparams` should"
        f" have one of these fields: {attr_options} overridden."
    )


class _LRFinder:
    """LR finder object. This object stores the results of lr_find().

    Args:
        mode: either `linear` or `exponential`, how to increase lr after each step

        lr_min: lr to start search from

        lr_max: lr to stop search

        num_training: number of steps to take between lr_min and lr_max

    Example::
        # Run lr finder
        lr_finder = trainer.lr_find(model)

        # Results stored in
        lr_finder.results

        # Plot using
        lr_finder.plot()

        # Get suggestion
        lr = lr_finder.suggestion()
    """

    def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
        assert mode in ("linear", "exponential"), "mode should be either `linear` or `exponential`"

        self.mode = mode
        self.lr_min = lr_min
        self.lr_max = lr_max
        self.num_training = num_training

        self.results = {}
        self._total_batch_idx = 0  # for debug purpose

    def _exchange_scheduler(self, trainer: "pl.Trainer"):
        """Decorate `trainer.init_optimizers` method such that it returns the users originally specified optimizer
        together with a new scheduler that that takes care of the learning rate search."""
        init_optimizers = trainer.init_optimizers

        @wraps(init_optimizers)
        def func(model):
            # Decide the structure of the output from init_optimizers
            optimizers, _, _ = init_optimizers(model)

            if len(optimizers) != 1:
                raise MisconfigurationException(
                    f"`model.configure_optimizers()` returned {len(optimizers)}, but"
                    " learning rate finder only works with single optimizer"
                )

            optimizer = optimizers[0]

            new_lrs = [self.lr_min] * len(optimizer.param_groups)
            for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
                param_group["lr"] = new_lr
                param_group["initial_lr"] = new_lr

            args = (optimizer, self.lr_max, self.num_training)
            scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
            sched_config = _get_default_scheduler_config()
            sched_config.update({"scheduler": scheduler, "interval": "step"})

            return [optimizer], [sched_config], []

        return func

    def plot(self, suggest: bool = False, show: bool = False):
        """Plot results from lr_find run
        Args:
            suggest: if True, will mark suggested lr to use with a red point

            show: if True, will show figure
        """
        import matplotlib.pyplot as plt

        lrs = self.results["lr"]
        losses = self.results["loss"]

        fig, ax = plt.subplots()

        # Plot loss as a function of the learning rate
        ax.plot(lrs, losses)
        if self.mode == "exponential":
            ax.set_xscale("log")
        ax.set_xlabel("Learning rate")
        ax.set_ylabel("Loss")

        if suggest:
            _ = self.suggestion()
            if self._optimal_idx:
                ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx], markersize=10, marker="o", color="red")

        if show:
            plt.show()

        return fig

    def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
        """This will propose a suggestion for choice of initial learning rate as the point with the steepest
        negative gradient.

        Returns:
            lr: suggested initial learning rate to use
            skip_begin: how many samples to skip in the beginning. Prevent too naive estimates
            skip_end: how many samples to skip in the end. Prevent too optimistic estimates
        """
        try:
            loss = np.array(self.results["loss"][skip_begin:-skip_end])
            loss = loss[np.isfinite(loss)]
            min_grad = np.gradient(loss).argmin()
            self._optimal_idx = min_grad + skip_begin
            return self.results["lr"][self._optimal_idx]
        # todo: specify the possible exception
        except Exception:
            log.exception("Failed to compute suggesting for `lr`. There might not be enough points.")
            self._optimal_idx = None


[docs]def lr_find( trainer: "pl.Trainer", model: "pl.LightningModule", min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = "exponential", early_stop_threshold: float = 4.0, update_attr: bool = False, ) -> Optional[_LRFinder]: """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`""" if trainer.fast_dev_run: rank_zero_warn("Skipping learning rate finder since fast_dev_run is enabled.", UserWarning) return # Determine lr attr if update_attr: lr_attr_name = _determine_lr_attr_name(trainer, model) save_path = os.path.join(trainer.default_root_dir, f"lr_find_temp_model_{uuid.uuid4()}.ckpt") __lr_finder_dump_params(trainer, model) # Prevent going into infinite loop trainer.auto_lr_find = False # Initialize lr finder object (stores results) lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) # Use special lr logger callback trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] # No logging trainer.logger = DummyLogger() if trainer.logger is not None else None # Max step set to number of iterations trainer.fit_loop.max_steps = num_training # Disable standard progress bar for fit if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() # Required for saving the model trainer.optimizers, trainer.lr_schedulers = [], [] trainer.model = model # Dump model checkpoint trainer.save_checkpoint(str(save_path)) # Configure optimizer and scheduler trainer.init_optimizers = lr_finder._exchange_scheduler(trainer) # Fit, lr & loss logged in callback trainer.tuner._run(model) # Prompt if we stopped early if trainer.global_step != num_training: log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.") # Transfer results from callback to lr finder object lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses}) lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose # Reset model state if trainer.is_global_zero: trainer.checkpoint_connector.restore(str(save_path)) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) # Finish by resetting variables so trainer is ready to fit model __lr_finder_restore_params(trainer, model) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() # Update lr attr if required if update_attr: lr = lr_finder.suggestion() # TODO: log lr.results to self.logger lightning_setattr(model, lr_attr_name, lr) log.info(f"Learning rate set to {lr}") return lr_finder
def __lr_finder_dump_params(trainer, model): # Prevent going into infinite loop trainer.__dumped_params = { "auto_lr_find": trainer.auto_lr_find, "callbacks": trainer.callbacks, "logger": trainer.logger, "global_step": trainer.global_step, "max_steps": trainer.max_steps, "checkpoint_callback": trainer.checkpoint_callback, "current_epoch": trainer.current_epoch, "init_optimizers": trainer.init_optimizers, } def __lr_finder_restore_params(trainer, model): trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"] trainer.logger = trainer.__dumped_params["logger"] trainer.callbacks = trainer.__dumped_params["callbacks"] trainer.fit_loop.global_step = trainer.__dumped_params["global_step"] trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"] trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"] trainer.init_optimizers = trainer.__dumped_params["init_optimizers"] del trainer.__dumped_params class _LRCallback(Callback): """Special callback used by the learning rate finder. This callbacks log the learning rate before each batch and log the corresponding loss after each batch. Args: num_training: number of iterations done by the learning rate finder early_stop_threshold: threshold for stopping the search. If the loss at any point is larger than ``early_stop_threshold*best_loss`` then the search is stopped. To disable, set to ``None``. progress_bar_refresh_rate: rate to refresh the progress bar for the learning rate finder beta: smoothing value, the loss being logged is a running average of loss values logged until now. ``beta`` controls the forget rate i.e. if ``beta=0`` all past information is ignored. """ def __init__( self, num_training: int, early_stop_threshold: float = 4.0, progress_bar_refresh_rate: int = 0, beta: float = 0.98, ): self.num_training = num_training self.early_stop_threshold = early_stop_threshold self.beta = beta self.losses = [] self.lrs = [] self.avg_loss = 0.0 self.best_loss = 0.0 self.progress_bar_refresh_rate = progress_bar_refresh_rate self.progress_bar = None def on_batch_start(self, trainer, pl_module): """Called before each training batch, logs the lr that will be used.""" if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return if self.progress_bar_refresh_rate and self.progress_bar is None: self.progress_bar = tqdm(desc="Finding best initial lr", total=self.num_training) self.lrs.append(trainer.lr_schedulers[0]["scheduler"].lr[0]) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): """Called when the training batch ends, logs the calculated loss.""" if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return if self.progress_bar: self.progress_bar.update() current_loss = trainer.fit_loop.running_loss.last().item() current_step = trainer.global_step # Avg loss (loss with momentum) + smoothing self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss smoothed_loss = self.avg_loss / (1 - self.beta ** (current_step + 1)) # Check if we diverging if self.early_stop_threshold is not None: if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss: trainer.fit_loop.max_steps = current_step # stop signal if self.progress_bar: self.progress_bar.close() # Save best loss for diverging checking if smoothed_loss < self.best_loss or current_step == 1: self.best_loss = smoothed_loss self.losses.append(smoothed_loss) class _LinearLR(_LRScheduler): """Linearly increases the learning rate between two boundaries over a number of iterations. Args: optimizer: wrapped optimizer. end_lr: the final learning rate. num_iter: the number of iterations over which the test occurs. last_epoch: the index of last epoch. Default: -1. """ last_epoch: int base_lrs: Sequence def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1): self.end_lr = end_lr self.num_iter = num_iter super().__init__(optimizer, last_epoch) def get_lr(self): curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter if self.last_epoch > 0: val = [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] else: val = [base_lr for base_lr in self.base_lrs] self._lr = val return val @property def lr(self): return self._lr class _ExponentialLR(_LRScheduler): """Exponentially increases the learning rate between two boundaries over a number of iterations. Arguments: optimizer: wrapped optimizer. end_lr: the final learning rate. num_iter: the number of iterations over which the test occurs. last_epoch: the index of last epoch. Default: -1. """ last_epoch: int base_lrs: Sequence def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1): self.end_lr = end_lr self.num_iter = num_iter super().__init__(optimizer, last_epoch) def get_lr(self): curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter if self.last_epoch > 0: val = [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] else: val = [base_lr for base_lr in self.base_lrs] self._lr = val return val @property def lr(self): return self._lr

© Copyright Copyright (c) 2018-2022, William Falcon et al... Revision 9ebdc52e.

Built with Sphinx using a theme provided by Read the Docs.