Shortcuts

GradientAccumulationScheduler

class pytorch_lightning.callbacks.GradientAccumulationScheduler(scheduling)[source]

Bases: pytorch_lightning.callbacks.base.Callback

Change gradient accumulation factor according to scheduling.

Parameters

scheduling (Dict[int, int]) – scheduling in format {epoch: accumulation_factor}

Raises
  • TypeError – If scheduling is an empty dict, or not all keys and values of scheduling are integers.

  • IndexError – If minimal_epoch is less than 0.

Example:

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import GradientAccumulationScheduler

# at epoch 5 start accumulating every 2 batches
>>> accumulator = GradientAccumulationScheduler(scheduling={5: 2})
>>> trainer = Trainer(callbacks=[accumulator])

# alternatively, pass the scheduling dict directly to the Trainer
>>> trainer = Trainer(accumulate_grad_batches={5: 2})
on_train_epoch_start(trainer, *_)[source]

Called when the train epoch begins.