pytorch_lightning.callbacks.gradient_accumulation_scheduler module¶
Gradient Accumulator¶
Change gradient accumulation factor according to scheduling.
-
class
pytorch_lightning.callbacks.gradient_accumulation_scheduler.
GradientAccumulationScheduler
(scheduling)[source]¶ Bases:
pytorch_lightning.callbacks.base.Callback
Change gradient accumulation factor according to scheduling.
Example:
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import GradientAccumulationScheduler # at epoch 5 start accumulating every 2 batches >>> accumulator = GradientAccumulationScheduler(scheduling={5: 2}) >>> trainer = Trainer(callbacks=[accumulator]) # alternatively, pass the scheduling dict directly to the Trainer >>> trainer = Trainer(accumulate_grad_batches={5: 2})