BatchSizeFinder

class lightning.pytorch.callbacks.BatchSizeFinder(mode='power', steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name='batch_size')[source]

Bases: Callback

Finds the largest batch size supported by a given model before encountering an out of memory (OOM) error.

All you need to do is add it as a callback inside Trainer and call trainer.{fit,validate,test,predict}. Internally, it calls the respective step function steps_per_trial times for each batch size until one of the batch sizes generates an OOM error.

Warning

This is an experimental feature.

Parameters:
  • mode (str) –

    search strategy to update the batch size:

    • 'power': Keep multiplying the batch size by 2, until we get an OOM error.

    • 'binsearch': Initially keep multiplying by 2 and after encountering an OOM error do a binary search between the last successful batch size and the batch size that failed.

  • steps_per_trial (int) – number of steps to run with a given batch size. Ideally 1 should be enough to test if an OOM error occurs, however in practice a few are needed.

  • init_val (int) – initial batch size to start the search with.

  • max_trials (int) – max number of increases in batch size done before algorithm is terminated

  • batch_arg_name (str) –

    name of the attribute that stores the batch size. It is expected that the user has provided a model or datamodule that has a hyperparameter with that name. We will look for this attribute name in the following places

    • model

    • model.hparams

    • trainer.datamodule (the datamodule passed to the tune method)

Example:

# 1. Customize the BatchSizeFinder callback to run at different epochs. This feature is
# useful while fine-tuning models since you can't always use the same batch size after
# unfreezing the backbone.
from lightning.pytorch.callbacks import BatchSizeFinder


class FineTuneBatchSizeFinder(BatchSizeFinder):
    def __init__(self, milestones, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.milestones = milestones

    def on_fit_start(self, *args, **kwargs):
        return

    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
            self.scale_batch_size(trainer, pl_module)


trainer = Trainer(callbacks=[FineTuneBatchSizeFinder(milestones=(5, 10))])
trainer.fit(...)

Example:

# 2. Run batch size finder for validate/test/predict.
from lightning.pytorch.callbacks import BatchSizeFinder


class EvalBatchSizeFinder(BatchSizeFinder):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def on_fit_start(self, *args, **kwargs):
        return

    def on_test_start(self, trainer, pl_module):
        self.scale_batch_size(trainer, pl_module)


trainer = Trainer(callbacks=[EvalBatchSizeFinder()])
trainer.test(...)
on_fit_start(trainer, pl_module)[source]

Called when fit begins.

Return type:

None

on_predict_start(trainer, pl_module)[source]

Called when the predict begins.

Return type:

None

on_test_start(trainer, pl_module)[source]

Called when the test begins.

Return type:

None

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

Return type:

None

setup(trainer, pl_module, stage=None)[source]

Called when fit, validate, test, predict, or tune begins.

Return type:

None