Shortcuts

Source code for pytorch_lightning.callbacks.progress.rich_progress

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Optional, Union

from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE

Task, Style = None, None
if _RICH_AVAILABLE:
    from rich.console import Console, RenderableType
    from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn
    from rich.progress_bar import ProgressBar
    from rich.style import Style
    from rich.text import Text

    class CustomBarColumn(BarColumn):
        """Overrides ``BarColumn`` to provide support for dataloaders that do not define a size (infinite size)
        such as ``IterableDataset``."""

        def render(self, task: "Task") -> ProgressBar:
            """Gets a progress bar widget for a task."""
            return ProgressBar(
                total=max(0, task.total),
                completed=max(0, task.completed),
                width=None if self.bar_width is None else max(1, self.bar_width),
                pulse=not task.started or not math.isfinite(task.remaining),
                animation_time=task.get_time(),
                style=self.style,
                complete_style=self.complete_style,
                finished_style=self.finished_style,
                pulse_style=self.pulse_style,
            )

    @dataclass
    class CustomInfiniteTask(Task):
        """Overrides ``Task`` to define an infinite task.

        This is useful for datasets that do not define a size (infinite size) such as ``IterableDataset``.
        """

        @property
        def time_remaining(self) -> Optional[float]:
            return None

    class CustomProgress(Progress):
        """Overrides ``Progress`` to support adding tasks that have an infinite total size."""

        def add_task(
            self,
            description: str,
            start: bool = True,
            total: float = 100.0,
            completed: int = 0,
            visible: bool = True,
            **fields: Any,
        ) -> TaskID:
            if not math.isfinite(total):
                task = CustomInfiniteTask(
                    self._task_index,
                    description,
                    total,
                    completed,
                    visible=visible,
                    fields=fields,
                    _get_time=self.get_time,
                    _lock=self._lock,
                )
                return self.add_custom_task(task)
            return super().add_task(description, start, total, completed, visible, **fields)

        def add_custom_task(self, task: CustomInfiniteTask, start: bool = True):
            with self._lock:
                self._tasks[self._task_index] = task
                if start:
                    self.start_task(self._task_index)
                new_task_index = self._task_index
                self._task_index = TaskID(int(self._task_index) + 1)
            self.refresh()
            return new_task_index

    class CustomTimeColumn(ProgressColumn):

        # Only refresh twice a second to prevent jitter
        max_refresh = 0.5

        def __init__(self, style: Union[str, Style]) -> None:
            self.style = style
            super().__init__()

        def render(self, task) -> Text:
            elapsed = task.finished_time if task.finished else task.elapsed
            remaining = task.time_remaining
            elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed)))
            remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining)))
            return Text(f"{elapsed_delta}{remaining_delta}", style=self.style)

    class BatchesProcessedColumn(ProgressColumn):
        def __init__(self, style: Union[str, Style]):
            self.style = style
            super().__init__()

        def render(self, task) -> RenderableType:
            total = task.total if task.total != float("inf") else "--"
            return Text(f"{int(task.completed)}/{total}", style=self.style)

    class ProcessingSpeedColumn(ProgressColumn):
        def __init__(self, style: Union[str, Style]):
            self.style = style
            super().__init__()

        def render(self, task) -> RenderableType:
            task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00"
            return Text(f"{task_speed}it/s", style=self.style)

    class MetricsTextColumn(ProgressColumn):
        """A column containing text."""

        def __init__(self, trainer, style):
            self._trainer = trainer
            self._tasks = {}
            self._current_task_id = 0
            self._metrics = {}
            self._style = style
            super().__init__()

        def update(self, metrics):
            # Called when metrics are ready to be rendered.
            # This is to prevent render from causing deadlock issues by requesting metrics
            # in separate threads.
            self._metrics = metrics

        def render(self, task) -> Text:
            from pytorch_lightning.trainer.states import TrainerFn

            if self._trainer.state.fn != TrainerFn.FITTING or self._trainer.sanity_checking:
                return Text("")
            if self._trainer.training and task.id not in self._tasks:
                self._tasks[task.id] = "None"
                if self._renderable_cache:
                    self._tasks[self._current_task_id] = self._renderable_cache[self._current_task_id][1]
                self._current_task_id = task.id
            if self._trainer.training and task.id != self._current_task_id:
                return self._tasks[task.id]
            _text = ""

            for k, v in self._metrics.items():
                _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
            return Text(_text, justify="left", style=self._style)


@dataclass
class RichProgressBarTheme:
    """Styles to associate to different base components.

    Args:
        description: Style for the progress bar description. For eg., Epoch x, Testing, etc.
        progress_bar: Style for the bar in progress.
        progress_bar_finished: Style for the finished progress bar.
        progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed.
        batch_progress: Style for the progress tracker (i.e 10/50 batches completed).
        time: Style for the processed time and estimate time remaining.
        processing_speed: Style for the speed of the batches being processed.
        metrics: Style for the metrics

    https://rich.readthedocs.io/en/stable/style.html
    """

    description: Union[str, Style] = "white"
    progress_bar: Union[str, Style] = "#6206E0"
    progress_bar_finished: Union[str, Style] = "#6206E0"
    progress_bar_pulse: Union[str, Style] = "#6206E0"
    batch_progress: Union[str, Style] = "white"
    time: Union[str, Style] = "grey54"
    processing_speed: Union[str, Style] = "grey70"
    metrics: Union[str, Style] = "white"


[docs]class RichProgressBar(ProgressBarBase): """Create a progress bar with `rich text formatting <https://github.com/willmcgugan/rich>`_. Install it with pip: .. code-block:: bash pip install rich .. code-block:: python from pytorch_lightning import Trainer from pytorch_lightning.callbacks import RichProgressBar trainer = Trainer(callbacks=RichProgressBar()) Args: refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. Set it to ``0`` to disable the display. leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False theme: Contains styles used to stylize the progress bar. Raises: ModuleNotFoundError: If required `rich` package is not installed on the device. Note: PyCharm users will need to enable “emulate terminal” in output console option in run/debug configuration to see styled output. Reference: https://rich.readthedocs.io/en/latest/introduction.html#requirements """ def __init__( self, refresh_rate: int = 1, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), ) -> None: if not _RICH_AVAILABLE: raise MisconfigurationException( "`RichProgressBar` requires `rich` >= 10.2.2. Install it by running `pip install -U rich`." ) super().__init__() self._refresh_rate: int = refresh_rate self._leave: bool = leave self._enabled: bool = True self.progress: Optional[Progress] = None self.val_sanity_progress_bar_id: Optional[int] = None self._reset_progress_bar_ids() self._metric_component = None self._progress_stopped: bool = False self.theme = theme @property def refresh_rate(self) -> float: return self._refresh_rate @property def is_enabled(self) -> bool: return self._enabled and self.refresh_rate > 0 @property def is_disabled(self) -> bool: return not self.is_enabled
[docs] def disable(self) -> None: self._enabled = False
[docs] def enable(self) -> None: self._enabled = True
@property def sanity_check_description(self) -> str: return "Validation Sanity Check" @property def validation_description(self) -> str: return "Validation" @property def test_description(self) -> str: return "Testing" @property def predict_description(self) -> str: return "Predicting" def _init_progress(self, trainer): if self.is_enabled and (self.progress is None or self._progress_stopped): self._reset_progress_bar_ids() self._console: Console = Console() self._console.clear_live() self._metric_component = MetricsTextColumn(trainer, self.theme.metrics) self.progress = CustomProgress( *self.configure_columns(trainer), self._metric_component, auto_refresh=False, disable=self.is_disabled, console=self._console, ) self.progress.start() # progress has started self._progress_stopped = False def refresh(self) -> None: if self.progress: self.progress.refresh()
[docs] def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) self._init_progress(trainer)
[docs] def on_predict_start(self, trainer, pl_module): super().on_predict_start(trainer, pl_module) self._init_progress(trainer)
[docs] def on_test_start(self, trainer, pl_module): super().on_test_start(trainer, pl_module) self._init_progress(trainer)
[docs] def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) self._init_progress(trainer)
def __getstate__(self): # can't pickle the rich progress objects state = self.__dict__.copy() state["progress"] = None state["_console"] = None return state def __setstate__(self, state): self.__dict__ = state state["_console"] = Console()
[docs] def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self._init_progress(trainer) self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description) self.refresh()
[docs] def on_sanity_check_end(self, trainer, pl_module): super().on_sanity_check_end(trainer, pl_module) self._update(self.val_sanity_progress_bar_id, visible=False) self.refresh()
[docs] def on_train_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches train_description = self._get_train_description(trainer.current_epoch) if self.main_progress_bar_id is not None and self._leave: self._stop_progress() self._init_progress(trainer) if self.main_progress_bar_id is None: self.main_progress_bar_id = self._add_task(total_batches, train_description) elif self.progress is not None: self.progress.reset( self.main_progress_bar_id, total=total_batches, description=train_description, visible=True ) self.refresh()
[docs] def on_validation_epoch_start(self, trainer, pl_module): super().on_validation_epoch_start(trainer, pl_module) if self.total_val_batches > 0: total_val_batches = self.total_val_batches if self.total_train_batches != float("inf") and hasattr(trainer, "val_check_batch"): # val can be checked multiple times per epoch val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch total_val_batches = self.total_val_batches * val_checks_per_epoch self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False) self.refresh()
def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: if self.progress is not None: return self.progress.add_task( f"[{self.theme.description}]{description}", total=total_batches, visible=visible ) def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None: if self.progress is not None and self._should_update(current, total): self.progress.update(progress_bar_id, advance=self.refresh_rate, visible=visible) self.refresh() def _should_update(self, current: int, total: int) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
[docs] def on_validation_epoch_end(self, trainer, pl_module): super().on_validation_epoch_end(trainer, pl_module) if self.val_progress_bar_id is not None: self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False)
[docs] def on_test_epoch_start(self, trainer, pl_module): self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) self.refresh()
[docs] def on_predict_epoch_start(self, trainer, pl_module): super().on_predict_epoch_start(trainer, pl_module) self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) self.refresh()
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) self._update_metrics(trainer, pl_module) self.refresh()
[docs] def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if trainer.sanity_checking: self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches) elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar if self.main_progress_bar_id is not None: self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches) self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches) self.refresh()
[docs] def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches) self.refresh()
[docs] def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches) self.refresh()
def _get_train_description(self, current_epoch: int) -> str: train_description = f"Epoch {current_epoch}" if len(self.validation_description) > len(train_description): # Padding is required to avoid flickering due of uneven lengths of "Epoch X" # and "Validation" Bar description num_digits = len(str(current_epoch)) required_padding = (len(self.validation_description) - len(train_description) + 1) - num_digits for _ in range(required_padding): train_description += " " return train_description def _stop_progress(self) -> None: if self.progress is not None: self.progress.stop() # # signals for progress to be re-initialized for next stages self._progress_stopped = True def _reset_progress_bar_ids(self): self.main_progress_bar_id: Optional[int] = None self.val_progress_bar_id: Optional[int] = None self.test_progress_bar_id: Optional[int] = None self.predict_progress_bar_id: Optional[int] = None def _update_metrics(self, trainer, pl_module) -> None: metrics = self.get_metrics(trainer, pl_module) if self._metric_component: self._metric_component.update(metrics)
[docs] def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None: self._stop_progress()
[docs] def on_exception(self, trainer, pl_module, exception: BaseException) -> None: self._stop_progress()
@property def val_progress_bar(self) -> Task: return self.progress.tasks[self.val_progress_bar_id] @property def val_sanity_check_bar(self) -> Task: return self.progress.tasks[self.val_sanity_progress_bar_id] @property def main_progress_bar(self) -> Task: return self.progress.tasks[self.main_progress_bar_id] @property def test_progress_bar(self) -> Task: return self.progress.tasks[self.test_progress_bar_id] def configure_columns(self, trainer) -> list: return [ TextColumn("[progress.description]{task.description}"), CustomBarColumn( complete_style=self.theme.progress_bar, finished_style=self.theme.progress_bar_finished, pulse_style=self.theme.progress_bar_pulse, ), BatchesProcessedColumn(style=self.theme.batch_progress), CustomTimeColumn(style=self.theme.time), ProcessingSpeedColumn(style=self.theme.processing_speed), ]

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision 46f718d2.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
ipynb-update
docs-search
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.