Shortcuts

Source code for pytorch_lightning.callbacks.progress.base

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_warn


[docs]class ProgressBarBase(Callback): r""" The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback` that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. You should implement your highly custom progress bars with this as the base class. Example:: class LitProgressBar(ProgressBarBase): def __init__(self): super().__init__() # don't forget this :) self.enable = True def disable(self): self.enable = False def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch_idx) # don't forget this :) percent = (self.train_batch_idx / self.total_train_batches) * 100 sys.stdout.flush() sys.stdout.write(f'{percent:.01f} percent complete \r') bar = LitProgressBar() trainer = Trainer(callbacks=[bar]) """ def __init__(self): self._trainer = None self._train_batch_idx = 0 self._val_batch_idx = 0 self._test_batch_idx = 0 self._predict_batch_idx = 0 @property def trainer(self): return self._trainer @property def train_batch_idx(self) -> int: """The current batch index being processed during training. Use this to update your progress bar. """ return self._train_batch_idx @property def val_batch_idx(self) -> int: """The current batch index being processed during validation. Use this to update your progress bar. """ return self._val_batch_idx @property def test_batch_idx(self) -> int: """The current batch index being processed during testing. Use this to update your progress bar. """ return self._test_batch_idx @property def predict_batch_idx(self) -> int: """The current batch index being processed during predicting. Use this to update your progress bar. """ return self._predict_batch_idx @property def total_train_batches(self) -> int: """The total number of training batches during training, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training dataloader is of infinite size. """ return self.trainer.num_training_batches @property def total_val_batches(self) -> int: """The total number of validation batches during validation, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation dataloader is of infinite size. """ total_val_batches = 0 if self.trainer.enable_validation: is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0 return total_val_batches @property def total_test_batches(self) -> int: """The total number of testing batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is of infinite size. """ return sum(self.trainer.num_test_batches) @property def total_predict_batches(self) -> int: """The total number of predicting batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader is of infinite size. """ return sum(self.trainer.num_predict_batches)
[docs] def disable(self): """You should provide a way to disable the progress bar. The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the output on processes that have a rank different from 0, e.g., in multi-node training. """ raise NotImplementedError
[docs] def enable(self): """You should provide a way to enable the progress bar. The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training routines like the :ref:`learning rate finder <advanced/lr_finder:Learning Rate Finder>` to temporarily enable and disable the main progress bar. """ raise NotImplementedError
[docs] def print(self, *args, **kwargs): """You should provide a way to print without breaking the progress bar.""" print(*args, **kwargs)
[docs] def on_init_end(self, trainer): self._trainer = trainer
[docs] def on_train_start(self, trainer, pl_module): self._train_batch_idx = 0
[docs] def on_train_epoch_start(self, trainer, pl_module): self._train_batch_idx = trainer.fit_loop.epoch_loop.batch_progress.current.completed
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._train_batch_idx += 1
[docs] def on_validation_start(self, trainer, pl_module): self._val_batch_idx = 0
[docs] def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._val_batch_idx += 1
[docs] def on_test_start(self, trainer, pl_module): self._test_batch_idx = 0
[docs] def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._test_batch_idx += 1
[docs] def on_predict_epoch_start(self, trainer, pl_module): self._predict_batch_idx = 0
[docs] def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._predict_batch_idx += 1
[docs] def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]: r""" Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar. Here is an example of how to override the defaults: .. code-block:: python def get_metrics(self, trainer, model): # don't show the version number items = super().get_metrics(trainer, model) items.pop("v_num", None) return items Return: Dictionary with the items to be displayed in the progress bar. """ standard_metrics = pl_module.get_progress_bar_dict() pbar_metrics = trainer.progress_bar_metrics duplicates = list(standard_metrics.keys() & pbar_metrics.keys()) if duplicates: rank_zero_warn( f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and" f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. " " If this is undesired, change the name or override `get_metrics()` in the progress bar callback.", UserWarning, ) return {**standard_metrics, **pbar_metrics}
def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]: r""" Returns several standard metrics displayed in the progress bar, including the average loss value, split index of BPTT (if used) and the version of the experiment when using a logger. .. code-block:: Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10] Return: Dictionary with the standard metrics to be displayed in the progress bar. """ # call .item() only once but store elements without graphs running_train_loss = trainer.fit_loop.running_loss.mean() avg_training_loss = None if running_train_loss is not None: avg_training_loss = running_train_loss.cpu().item() elif pl_module.automatic_optimization: avg_training_loss = float("NaN") items_dict = {} if avg_training_loss is not None: items_dict["loss"] = f"{avg_training_loss:.3g}" if pl_module.truncated_bptt_steps > 0: items_dict["split_idx"] = trainer.fit_loop.split_idx if trainer.logger is not None and trainer.logger.version is not None: version = trainer.logger.version # show last 4 places of long version strings version = version[-4:] if isinstance(version, str) else version items_dict["v_num"] = version return items_dict

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision 46f718d2.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
ipynb-update
docs-search
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.