Shortcuts

Source code for pytorch_lightning.callbacks.progress

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Progress Bars
=============

Use or override one of the progress bar callbacks.

"""
import importlib
import io
import math
import os
import sys

# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
from typing import Optional, Union

if importlib.util.find_spec('ipywidgets') is not None:
    from tqdm.auto import tqdm as _tqdm
else:
    from tqdm import tqdm as _tqdm

from pytorch_lightning.callbacks import Callback

_PAD_SIZE = 5


[docs]class tqdm(_tqdm): """ Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering """
[docs] @staticmethod def format_num(n) -> str: """ Add additional padding to the formatted numbers """ should_be_padded = isinstance(n, (float, str)) if not isinstance(n, str): n = _tqdm.format_num(n) if should_be_padded and 'e' not in n: if '.' not in n and len(n) < _PAD_SIZE: try: _ = float(n) except ValueError: return n n += '.' n += "0" * (_PAD_SIZE - len(n)) return n
[docs]class ProgressBarBase(Callback): r""" The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback` that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. You should implement your highly custom progress bars with this as the base class. Example:: class LitProgressBar(ProgressBarBase): def __init__(self): super().__init__() # don't forget this :) self.enable = True def disable(self): self.enable = False def on_train_batch_end(self, trainer, pl_module, outputs): super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :) percent = (self.train_batch_idx / self.total_train_batches) * 100 sys.stdout.flush() sys.stdout.write(f'{percent:.01f} percent complete \r') bar = LitProgressBar() trainer = Trainer(callbacks=[bar]) """ def __init__(self): self._trainer = None self._train_batch_idx = 0 self._val_batch_idx = 0 self._test_batch_idx = 0 self._predict_batch_idx = 0 @property def trainer(self): return self._trainer @property def train_batch_idx(self) -> int: """ The current batch index being processed during training. Use this to update your progress bar. """ return self._train_batch_idx @property def val_batch_idx(self) -> int: """ The current batch index being processed during validation. Use this to update your progress bar. """ return self._val_batch_idx @property def test_batch_idx(self) -> int: """ The current batch index being processed during testing. Use this to update your progress bar. """ return self._test_batch_idx @property def predict_batch_idx(self) -> int: """ The current batch index being processed during predicting. Use this to update your progress bar. """ return self._predict_batch_idx @property def total_train_batches(self) -> int: """ The total number of training batches during training, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training dataloader is of infinite size. """ return self.trainer.num_training_batches @property def total_val_batches(self) -> int: """ The total number of validation batches during validation, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation dataloader is of infinite size. """ total_val_batches = 0 if self.trainer.enable_validation: is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0 return total_val_batches @property def total_test_batches(self) -> int: """ The total number of testing batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is of infinite size. """ return sum(self.trainer.num_test_batches) @property def total_predict_batches(self) -> int: """ The total number of predicting batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader is of infinite size. """ return sum(self.trainer.num_predict_batches)
[docs] def disable(self): """ You should provide a way to disable the progress bar. The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the output on processes that have a rank different from 0, e.g., in multi-node training. """ raise NotImplementedError
[docs] def enable(self): """ You should provide a way to enable the progress bar. The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training routines like the :ref:`learning rate finder <advanced/lr_finder:Learning Rate Finder>` to temporarily enable and disable the main progress bar. """ raise NotImplementedError
[docs] def print(self, *args, **kwargs): """ You should provide a way to print without breaking the progress bar. """ print(*args, **kwargs)
[docs] def on_init_end(self, trainer): self._trainer = trainer
[docs] def on_train_start(self, trainer, pl_module): self._train_batch_idx = trainer.train_loop.batch_idx
[docs] def on_train_epoch_start(self, trainer, pl_module): self._train_batch_idx = 0
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._train_batch_idx += 1
[docs] def on_validation_start(self, trainer, pl_module): self._val_batch_idx = 0
[docs] def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._val_batch_idx += 1
[docs] def on_test_start(self, trainer, pl_module): self._test_batch_idx = 0
[docs] def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._test_batch_idx += 1
[docs] def on_predict_epoch_start(self, trainer, pl_module): self._predict_batch_idx = 0
[docs] def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._predict_batch_idx += 1
[docs]class ProgressBar(ProgressBarBase): r""" This is the default progress bar used by Lightning. It prints to `stdout` using the :mod:`tqdm` package and shows up to four different bars: - **sanity check progress:** the progress during the sanity check run - **main progress:** shows training + validation progress combined. It also accounts for multiple validation runs during training when :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used. - **validation progress:** only visible during validation; shows total progress over all validation datasets. - **test progress:** only active when testing; shows total progress over all test datasets. For infinite datasets, the progress bar never ends. If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override specific methods of the callback class and pass your custom implementation to the :class:`~pytorch_lightning.trainer.trainer.Trainer`: Example:: class LitProgressBar(ProgressBar): def init_validation_tqdm(self): bar = super().init_validation_tqdm() bar.set_description('running validation ...') return bar bar = LitProgressBar() trainer = Trainer(callbacks=[bar]) Args: refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. Set it to ``0`` to disable the display. By default, the :class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress bar and sets the refresh rate to the value provided to the :paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. process_position: Set this to a value greater than ``0`` to offset the progress bars by this many lines. This is useful when you have progress bars defined elsewhere and want to show all of them together. This corresponds to :paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. """ def __init__(self, refresh_rate: int = 1, process_position: int = 0): super().__init__() self._refresh_rate = refresh_rate self._process_position = process_position self._enabled = True self.main_progress_bar = None self.val_progress_bar = None self.test_progress_bar = None def __getstate__(self): # can't pickle the tqdm objects state = self.__dict__.copy() state['main_progress_bar'] = None state['val_progress_bar'] = None state['test_progress_bar'] = None return state @property def refresh_rate(self) -> int: return self._refresh_rate @property def process_position(self) -> int: return self._process_position @property def is_enabled(self) -> bool: return self._enabled and self.refresh_rate > 0 @property def is_disabled(self) -> bool: return not self.is_enabled
[docs] def disable(self) -> None: self._enabled = False
[docs] def enable(self) -> None: self._enabled = True
[docs] def init_sanity_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for the validation sanity run. """ bar = tqdm( desc='Validation sanity check', position=(2 * self.process_position), disable=self.is_disabled, leave=False, dynamic_ncols=True, file=sys.stdout, ) return bar
[docs] def init_train_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for training. """ bar = tqdm( desc='Training', initial=self.train_batch_idx, position=(2 * self.process_position), disable=self.is_disabled, leave=True, dynamic_ncols=True, file=sys.stdout, smoothing=0, ) return bar
[docs] def init_predict_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for predicting. """ bar = tqdm( desc='Predicting', initial=self.train_batch_idx, position=(2 * self.process_position), disable=self.is_disabled, leave=True, dynamic_ncols=True, file=sys.stdout, smoothing=0, ) return bar
[docs] def init_validation_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for validation. """ # The main progress bar doesn't exist in `trainer.validate()` has_main_bar = self.main_progress_bar is not None bar = tqdm( desc='Validating', position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, leave=False, dynamic_ncols=True, file=sys.stdout ) return bar
[docs] def init_test_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for testing. """ bar = tqdm( desc="Testing", position=(2 * self.process_position), disable=self.is_disabled, leave=True, dynamic_ncols=True, file=sys.stdout ) return bar
[docs] def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() self.main_progress_bar = tqdm(disable=True) # dummy progress bar
[docs] def on_sanity_check_end(self, trainer, pl_module): super().on_sanity_check_end(trainer, pl_module) self.main_progress_bar.close() self.val_progress_bar.close()
[docs] def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) self.main_progress_bar = self.init_train_tqdm()
[docs] def on_train_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float('inf') and total_val_batches != float('inf'): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches reset(self.main_progress_bar, total_batches) self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}')
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) total_batches = self.total_train_batches + self.total_val_batches total_batches = convert_inf(total_batches) if self._should_update(self.train_batch_idx, total_batches): self._update_bar(self.main_progress_bar) self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
[docs] def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) if trainer.sanity_checking: reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches)) else: self._update_bar(self.main_progress_bar) # fill up remaining self.val_progress_bar = self.init_validation_tqdm() reset(self.val_progress_bar, self.total_val_batches)
[docs] def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)): self._update_bar(self.val_progress_bar) self._update_bar(self.main_progress_bar)
[docs] def on_validation_end(self, trainer, pl_module): super().on_validation_end(trainer, pl_module) if self.main_progress_bar is not None: self.main_progress_bar.set_postfix(trainer.progress_bar_dict) self.val_progress_bar.close()
[docs] def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) self.main_progress_bar.close()
[docs] def on_test_start(self, trainer, pl_module): super().on_test_start(trainer, pl_module) self.test_progress_bar = self.init_test_tqdm() self.test_progress_bar.total = convert_inf(self.total_test_batches)
[docs] def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.test_batch_idx, self.total_test_batches): self._update_bar(self.test_progress_bar)
[docs] def on_test_end(self, trainer, pl_module): super().on_test_end(trainer, pl_module) self.test_progress_bar.close()
[docs] def on_predict_epoch_start(self, trainer, pl_module): super().on_predict_epoch_start(trainer, pl_module) self.predict_progress_bar = self.init_predict_tqdm() self.predict_progress_bar.total = convert_inf(self.total_predict_batches)
[docs] def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.predict_batch_idx, self.total_predict_batches): self._update_bar(self.predict_progress_bar)
[docs] def on_predict_end(self, trainer, pl_module): self.predict_progress_bar.close()
[docs] def print( self, *args, sep: str = ' ', end: str = os.linesep, file: Optional[io.TextIOBase] = None, nolock: bool = False ): active_progress_bar = None if not self.main_progress_bar.disable: active_progress_bar = self.main_progress_bar elif not self.val_progress_bar.disable: active_progress_bar = self.val_progress_bar elif not self.test_progress_bar.disable: active_progress_bar = self.test_progress_bar if active_progress_bar is not None: s = sep.join(map(str, args)) active_progress_bar.write(s, end=end, file=file, nolock=nolock)
def _should_update(self, current, total) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def _update_bar(self, bar: Optional[tqdm]) -> None: """ Updates the bar by the refresh rate without overshooting. """ if bar is None: return if bar.total is not None: delta = min(self.refresh_rate, bar.total - bar.n) else: # infinite / unknown size delta = self.refresh_rate if delta > 0: bar.update(delta)
[docs]def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: """ The tqdm doesn't support inf/nan values. We have to convert it to None. """ if x is None or math.isinf(x) or math.isnan(x): return None return x
[docs]def reset(bar: tqdm, total: Optional[int] = None) -> None: """ Resets the tqdm bar to 0 progress with a new total, unless it is disabled. """ if not bar.disable: bar.reset(total=convert_inf(total))

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision 233f252b.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.3.1
1.3.0
1.2.10
1.2.9_a
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
docs-robots
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.