Shortcuts

Source code for pytorch_lightning.loops.dataloader.prediction_loop

from typing import Any, List, Optional, Sequence

from deprecate.utils import void
from torch.utils.data import DataLoader

from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop
from pytorch_lightning.loops.utilities import _set_sampler_epoch
from pytorch_lightning.strategies import DDPSpawnStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT


[docs]class PredictionLoop(DataLoaderLoop): """Loop to run over dataloaders for prediction.""" def __init__(self) -> None: super().__init__() self.predictions: List[List[Any]] = [] self.epoch_batch_indices: List[List[int]] = [] self.epoch_loop = PredictionEpochLoop() self._results = None # for `trainer._results` access self._return_predictions: bool = False @property def return_predictions(self) -> bool: """Whether to return the predictions or not.""" return self._return_predictions @return_predictions.setter def return_predictions(self, return_predictions: Optional[bool] = None) -> None: # `DDPSpawnStrategy` plugins and derivatives don't support return predictions. is_ddp_spawn = isinstance(self.trainer.strategy, DDPSpawnStrategy) if return_predictions and is_ddp_spawn: raise MisconfigurationException( "`return_predictions` should be set to `False` when using the `DDPSpawnStrategy` or children class. " f"Found {return_predictions} with strategy {type(self.trainer.strategy)}." ) # For non `DDPSpawnStrategy` plugin, the `return_predictions` is True by default unless user decide otherwise. self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions self.epoch_loop.return_predictions = self._return_predictions @property def num_dataloaders(self) -> int: """Returns the number of prediction dataloaders.""" # case where user does: # return dl1, dl2 dataloaders = self.dataloaders length = len(dataloaders) if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): length = len(dataloaders[0]) return length @property def max_batches(self) -> List[int]: """The max number of batches this loop will run for each dataloader.""" return self.trainer.num_predict_batches @property def dataloaders(self) -> Sequence[DataLoader]: """Returns all prediction dataloaders.""" return self.trainer.predict_dataloaders @property def skip(self) -> bool: return sum(self.max_batches) == 0
[docs] def connect(self, epoch_loop: PredictionEpochLoop) -> None: # type: ignore[override] """Connect the prediction epoch loop with this loop.""" self.epoch_loop = epoch_loop
[docs] def reset(self) -> None: """Resets the internal state of the loop for a new run.""" self.predictions = [] self.epoch_batch_indices = [] super().reset() # when restarting, if we are running twice, since there's no concept of `max_epochs` we need to reset the # current state when the loop has finished running if self.done: self.dataloader_progress.reset_on_run()
[docs] def on_run_start(self) -> None: # type: ignore[override] """Calls ``_on_predict_model_eval``, ``_on_predict_start`` and ``_on_predict_epoch_start`` hooks.""" self.trainer._call_lightning_module_hook("on_predict_model_eval") self.trainer.lightning_module.zero_grad() self._on_predict_start() self._on_predict_epoch_start()
[docs] def advance(self, *args: Any, **kwargs: Any) -> None: """Predicts one entire dataloader.""" void(*args, **kwargs) dataloader = self.current_dataloader if dataloader is not None: _set_sampler_epoch(dataloader, self.trainer.fit_loop.epoch_progress.current.processed) dataloader = self.trainer.strategy.process_dataloader(dataloader) dataloader_iter = enumerate(dataloader) dl_max_batches = self.max_batches[self.current_dataloader_idx] dl_predictions, dl_batch_indices = self.epoch_loop.run( dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders ) self.predictions.append(dl_predictions) self.epoch_batch_indices.append(dl_batch_indices)
[docs] def on_run_end(self) -> Optional[_PREDICT_OUTPUT]: """Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders.""" results = self._on_predict_epoch_end() self._on_predict_end() return results
def _on_predict_start(self) -> None: """Calls ``on_predict_start`` hooks.""" self.trainer._call_callback_hooks("on_predict_start") self.trainer._call_lightning_module_hook("on_predict_start") self.trainer._call_strategy_hook("on_predict_start") def _on_predict_epoch_start(self) -> None: """Calls ``on_predict_epoch_start`` hooks.""" self.trainer._call_callback_hooks("on_predict_epoch_start") self.trainer._call_lightning_module_hook("on_predict_epoch_start") def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: """Calls ``on_predict_epoch_end`` hook. Returns: the results for all dataloaders """ results = self.predictions self.trainer._call_callback_hooks("on_predict_epoch_end", results) self.trainer._call_lightning_module_hook("on_predict_epoch_end", results) if self.return_predictions: return results[0] if self.num_dataloaders == 1 else results def _on_predict_end(self) -> None: """Resets previous gradient status and calls ``on_predict_end`` hook.""" # clear memory. the predictions are extracted in `on_predict_epoch_end`. self.predictions = [] self.epoch_batch_indices = [] # hook self.trainer._call_callback_hooks("on_predict_end") self.trainer._call_lightning_module_hook("on_predict_end") self.trainer._call_strategy_hook("on_predict_end")

© Copyright Copyright (c) 2018-2022, Lightning AI et al... Revision 48c23e57.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.7.1
1.7.0
1.6.5
1.6.4
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
future-structure
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.