Shortcuts

Source code for pytorch_lightning.callbacks.early_stopping

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Early Stopping
^^^^^^^^^^^^^^

Monitor a metric and stop training when it stops improving.

"""
from typing import Any, Dict

import numpy as np
import torch

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException


[docs]class EarlyStopping(Callback): r""" Monitor a metric and stop training when it stops improving. Args: monitor: quantity to be monitored. Default: ``'early_stop_on'``. min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than `min_delta`, will count as no improvement. Default: ``0.0``. patience: number of validation epochs with no improvement after which training will be stopped. Default: ``3``. verbose: verbosity mode. Default: ``False``. mode: one of {auto, min, max}. In `min` mode, training will stop when the quantity monitored has stopped decreasing; in `max` mode it will stop when the quantity monitored has stopped increasing; in `auto` mode, the direction is automatically inferred from the name of the monitored quantity. .. warning:: Setting ``mode='auto'`` has been deprecated in v1.1 and will be removed in v1.3. strict: whether to crash the training if `monitor` is not found in the validation metrics. Default: ``True``. Raises: MisconfigurationException: If ``mode`` is none of ``"min"``, ``"max"``, and ``"auto"``. RuntimeError: If the metric ``monitor`` is not available. Example:: >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import EarlyStopping >>> early_stopping = EarlyStopping('val_loss') >>> trainer = Trainer(callbacks=[early_stopping]) """ mode_dict = { 'min': torch.lt, 'max': torch.gt, } def __init__( self, monitor: str = 'early_stop_on', min_delta: float = 0.0, patience: int = 3, verbose: bool = False, mode: str = 'auto', strict: bool = True, ): super().__init__() self.monitor = monitor self.patience = patience self.verbose = verbose self.strict = strict self.min_delta = min_delta self.wait_count = 0 self.stopped_epoch = 0 self.mode = mode self.warned_result_obj = False self.__init_monitor_mode() self.min_delta *= 1 if self.monitor_op == torch.gt else -1 torch_inf = torch.tensor(np.Inf) self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf def __init_monitor_mode(self): if self.mode not in self.mode_dict and self.mode != 'auto': raise MisconfigurationException(f"`mode` can be auto, {', '.join(self.mode_dict.keys())}, got {self.mode}") # TODO: Update with MisconfigurationException when auto mode is removed in v1.3 if self.mode == 'auto': rank_zero_warn( "mode='auto' is deprecated in v1.1 and will be removed in v1.3." " Default value for mode with be 'min' in v1.3.", DeprecationWarning ) if "acc" in self.monitor or self.monitor.startswith("fmeasure"): self.mode = 'max' else: self.mode = 'min' if self.verbose > 0: rank_zero_info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.') def _validate_condition_metric(self, logs): monitor_val = logs.get(self.monitor) error_msg = ( f'Early stopping conditioned on metric `{self.monitor}` which is not available.' ' Pass in or modify your `EarlyStopping` callback to use any of the following:' f' `{"`, `".join(list(logs.keys()))}`' ) if monitor_val is None: if self.strict: raise RuntimeError(error_msg) if self.verbose > 0: rank_zero_warn(error_msg, RuntimeWarning) return False return True @property def monitor_op(self): return self.mode_dict[self.mode]
[docs] def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { 'wait_count': self.wait_count, 'stopped_epoch': self.stopped_epoch, 'best_score': self.best_score, 'patience': self.patience }
[docs] def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.wait_count = callback_state['wait_count'] self.stopped_epoch = callback_state['stopped_epoch'] self.best_score = callback_state['best_score'] self.patience = callback_state['patience']
[docs] def on_validation_end(self, trainer, pl_module): if trainer.running_sanity_check: return self._run_early_stopping_check(trainer, pl_module)
def _run_early_stopping_check(self, trainer, pl_module): """ Checks whether the early stopping condition is met and if so tells the trainer to stop the training. """ logs = trainer.callback_metrics if ( trainer.fast_dev_run # disable early_stopping with fast_dev_run or not self._validate_condition_metric(logs) # short circuit if metric not present ): return # short circuit if metric not present current = logs.get(self.monitor) # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) if self.monitor_op(current - self.min_delta, self.best_score): self.best_score = current self.wait_count = 0 else: self.wait_count += 1 if self.wait_count >= self.patience: self.stopped_epoch = trainer.current_epoch trainer.should_stop = True # stop every ddp process if any world process decides to stop trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop)

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision b3b8f95e.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3.2
0.5.3
0.4.9
release-1.2-dev
release-1.0.x
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.