Shortcuts

gradient_accumulation_scheduler

Classes

GradientAccumulationScheduler

Change gradient accumulation factor according to scheduling.

Gradient Accumulator

Change gradient accumulation factor according to scheduling. Trainer also calls optimizer.step() for the last indivisible step number.

class pytorch_lightning.callbacks.gradient_accumulation_scheduler.GradientAccumulationScheduler(scheduling)[source]

Bases: pytorch_lightning.callbacks.base.Callback

Change gradient accumulation factor according to scheduling.

Parameters

scheduling (Dict[int, int]) – scheduling in format {epoch: accumulation_factor}

Note

The argument scheduling is a dictionary. Each key represent an epoch and its associated accumulation factor value. Warning: Epoch are zero-indexed c.f it means if you want to change the accumulation factor after 4 epochs, set Trainer(accumulate_grad_batches={4: factor}) or GradientAccumulationScheduler(scheduling={4: factor}). For more info check the example below.

Raises
  • TypeError – If scheduling is an empty dict, or not all keys and values of scheduling are integers.

  • IndexError – If minimal_epoch is less than 0.

Example:

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import GradientAccumulationScheduler

# from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5
because epoch (key) should be zero-indexed.
>>> accumulator = GradientAccumulationScheduler(scheduling={4: 2})
>>> trainer = Trainer(callbacks=[accumulator])

# alternatively, pass the scheduling dict directly to the Trainer
>>> trainer = Trainer(accumulate_grad_batches={4: 2})
on_train_epoch_start(trainer, *_)[source]

Called when the train epoch begins.

Return type

None