Shortcuts

batch_size_scaling

Functions

scale_batch_size

Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error.

pytorch_lightning.tuner.batch_size_scaling.scale_batch_size(trainer, model, mode='power', steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name='batch_size', **fit_kwargs)[source]

Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error.

Parameters
  • trainer – The Trainer

  • model (LightningModule) – Model to fit.

  • mode (str) – string setting the search mode. Either power or binsearch. If mode is power we keep multiplying the batch size by 2, until we get an OOM error. If mode is ‘binsearch’, we will initially also keep multiplying by 2 and after encountering an OOM error do a binary search between the last successful batch size and the batch size that failed.

  • steps_per_trial (int) – number of steps to run with a given batch size. Idealy 1 should be enough to test if a OOM error occurs, however in practise a few are needed

  • init_val (int) – initial batch size to start the search with

  • max_trials (int) – max number of increase in batch size done before algorithm is terminated

  • batch_arg_name (str) –

    name of the attribute that stores the batch size. It is expected that the user has provided a model or datamodule that has a hyperparameter with that name. We will look for this attribute name in the following places

    • model

    • model.hparams

    • model.datamodule

    • trainer.datamodule (the datamodule passed to the tune method)

  • **fit_kwargs – remaining arguments to be passed to .fit(), e.g., dataloader or datamodule.