Shortcuts

Source code for pytorch_lightning.plugins.io.async_plugin

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional

from lightning_lite.plugins import CheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO


[docs]class AsyncCheckpointIO(_WrappingCheckpointIO): """``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread. .. warning:: This is currently an experimental plugin/feature and API changes are to be expected. Args: checkpoint_io: A checkpoint IO plugin that is used as the basis for async checkpointing. """ def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None: super().__init__(checkpoint_io) self._executor = ThreadPoolExecutor(max_workers=1) self._error: Optional[BaseException] = None
[docs] def save_checkpoint(self, *args: Any, **kwargs: Any) -> None: """Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``.""" def _save_checkpoint(*args: Any, **kwargs: Any) -> None: try: assert self.checkpoint_io is not None self.checkpoint_io.save_checkpoint(*args, **kwargs) except BaseException as e: self._error = e self._executor.submit(_save_checkpoint, *args, **kwargs) # if an error was raised between the previous time `save_checkpoint`` was called and now, # because `executor.submit` is not blocking if self._error: raise self._error
[docs] def teardown(self) -> None: """This method is called to close the threads.""" self._executor.shutdown(wait=True) # if an error was raised anytime in any of the `executor.submit` calls if self._error: raise self._error

© Copyright Copyright (c) 2018-2022, Lightning AI et al... Revision f4fcad36.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
1.8.3post1
1.8.3.post0
1.8.3
1.8.2
1.8.1
1.8.0.post1
1.8.0
1.7.7
1.7.6
1.7.5
1.7.4
1.7.3
1.7.2
1.7.1
1.7.0
1.6.5
1.6.4
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.