Training Tricks

Lightning implements various tricks to help during training

Accumulate gradients

Accumulated gradients runs K small batches of size N before doing a backwards pass. The effect is a large effective batch size of size KxN.

See also


# DEFAULT (ie: no accumulated grads)
trainer = Trainer(accumulate_grad_batches=1)

Gradient Clipping

Gradient clipping may be enabled to avoid exploding gradients. Specifically, this will clip the gradient norm computed over all model parameters together.

See also


# DEFAULT (ie: don't clip)
trainer = Trainer(gradient_clip_val=0)

# clip gradients with norm above 0.5
trainer = Trainer(gradient_clip_val=0.5)

Auto scaling of batch size

Auto scaling of batch size may be enabled to find the largest batch size that fits into memory. Larger batch size often yields better estimates of gradients, but may also result in longer training time. Inspired by

See also


# DEFAULT (ie: don't scale batch size automatically)
trainer = Trainer(auto_scale_batch_size=None)

# Autoscale batch size
trainer = Trainer(auto_scale_batch_size=None|'power'|'binsearch')

Currently, this feature supports two modes ‘power’ scaling and ‘binsearch’ scaling. In ‘power’ scaling, starting from a batch size of 1 keeps doubling the batch size until an out-of-memory (OOM) error is encountered. Setting the argument to ‘binsearch’ continues to finetune the batch size by performing a binary search.


This feature expects that a batch_size field in the hparams of your model, i.e., model.hparams.batch_size should exist and will be overridden by the results of this algorithm. Additionally, your train_dataloader() method should depend on this field for this feature to work i.e.

def train_dataloader(self):
    return DataLoader(train_dataset, batch_size=self.batch_size)


Due to these constraints, this features does NOT work when passing dataloaders directly to .fit().

The scaling algorithm has a number of parameters that the user can control by invoking the trainer method .scale_batch_size themself (see description below).

# Use default in trainer construction
trainer = Trainer()

# Invoke method
new_batch_size = trainer.scale_batch_size(model, ...)

# Override old batch size
model.hparams.batch_size = new_batch_size

# Fit as normal
The algorithm in short works by:
  1. Dumping the current state of the model and trainer

  2. Iteratively until convergence or maximum number of tries max_trials (default 25) has been reached:
    • Call fit() method of trainer. This evaluates steps_per_trial (default 3) number of training steps. Each training step can trigger an OOM error if the tensors (training batch, weights, gradients ect.) allocated during the steps have a too large memory footprint.

    • If an OOM error is encountered, decrease batch size else increase it. How much the batch size is increased/decreased is determined by the choosen stratrgy.

  3. The found batch size is saved to model.hparams.batch_size

  4. Restore the initial state of model and trainer

class pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin[source]

Bases: abc.ABC

abstract fit(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract get_model()[source]

Warning: this is just empty shell for code implemented in other class.

Return type


abstract restore(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract save_checkpoint(*args)[source]

Warning: this is just empty shell for code implemented in other class.

scale_batch_size(model, mode='power', steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name='batch_size')[source]

Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error.

  • model (LightningModule) – Model to fit.

  • mode (str) – string setting the search mode. Either power or binsearch. If mode is power we keep multiplying the batch size by 2, until we get an OOM error. If mode is ‘binsearch’, we will initially also keep multiplying by 2 and after encountering an OOM error do a binary search between the last successful batch size and the batch size that failed.

  • steps_per_trial (int) – number of steps to run with a given batch size. Idealy 1 should be enough to test if a OOM error occurs, however in practise a few are needed

  • init_val (int) – initial batch size to start the search with

  • max_trials (int) – max number of increase in batch size done before algorithm is terminated


Batch size finder is not supported for DDP yet, it is coming soon.

Read the Docs v: 0.8.1
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.