Shortcuts

Source code for pytorch_lightning.callbacks.gradient_accumulation_scheduler

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Gradient Accumulator
====================

Change gradient accumulation factor according to scheduling.
Trainer also calls ``optimizer.step()`` for the last indivisible step number.

"""

from typing import Dict

from pytorch_lightning.callbacks.base import Callback


[docs]class GradientAccumulationScheduler(Callback): r""" Change gradient accumulation factor according to scheduling. Args: scheduling: scheduling in format {epoch: accumulation_factor} Raises: TypeError: If ``scheduling`` is an empty ``dict``, or not all keys and values of ``scheduling`` are integers. IndexError: If ``minimal_epoch`` is less than 0. Example:: >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import GradientAccumulationScheduler # at epoch 5 start accumulating every 2 batches >>> accumulator = GradientAccumulationScheduler(scheduling={5: 2}) >>> trainer = Trainer(callbacks=[accumulator]) # alternatively, pass the scheduling dict directly to the Trainer >>> trainer = Trainer(accumulate_grad_batches={5: 2}) """ def __init__(self, scheduling: Dict[int, int]): super().__init__() if not scheduling: # empty dict error raise TypeError("Empty dict cannot be interpreted correct") for key in scheduling: if not isinstance(key, int) or not isinstance(scheduling[key], int): raise TypeError("All epoches and accumulation factor must be integers") minimal_epoch = min(scheduling.keys()) if minimal_epoch < 0: raise IndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct") if minimal_epoch != 0: # if user didnt define first epoch accumulation factor scheduling.update({0: 1}) self.scheduling = scheduling self.epochs = sorted(scheduling.keys()) def going_to_accumulate_grad_batches(self): return any([v > 1 for v in self.scheduling.values()])
[docs] def on_train_epoch_start(self, trainer, pl_module): epoch = trainer.current_epoch for i in reversed(range(len(self.epochs))): if epoch >= self.epochs[i]: trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) break

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision f5f4f03a.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
docs-robots
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.