Shortcuts

BatchSizeFinder

class pytorch_lightning.callbacks.BatchSizeFinder(mode='power', steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name='batch_size')[source]

Bases: pytorch_lightning.callbacks.callback.Callback

The BatchSizeFinder callback tries to find the largest batch size for a given model that does not give an out of memory (OOM) error. All you need to do is add it as a callback inside Trainer and call trainer.{fit,validate,test,predict}. Internally it calls the respective step function steps_per_trial times for each batch size until one of the batch sizes generates an OOM error.

Parameters
  • mode (str) –

    search strategy to update the batch size:

    • 'power': Keep multiplying the batch size by 2, until we get an OOM error.

    • 'binsearch': Initially keep multiplying by 2 and after encountering an OOM error

      do a binary search between the last successful batch size and the batch size that failed.

  • steps_per_trial (int) – number of steps to run with a given batch size. Ideally 1 should be enough to test if an OOM error occurs, however in practice a few are needed.

  • init_val (int) – initial batch size to start the search with.

  • max_trials (int) – max number of increases in batch size done before algorithm is terminated

  • batch_arg_name (str) –

    name of the attribute that stores the batch size. It is expected that the user has provided a model or datamodule that has a hyperparameter with that name. We will look for this attribute name in the following places

    • model

    • model.hparams

    • trainer.datamodule (the datamodule passed to the tune method)

Example:

# 1. Customize the BatchSizeFinder callback to run at different epochs. This feature is
# useful while fine-tuning models since you can't always use the same batch size after
# unfreezing the backbone.
from pytorch_lightning.callbacks import BatchSizeFinder


class FineTuneBatchSizeFinder(BatchSizeFinder):
    def __init__(self, milestones, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.milestones = milestones

    def on_fit_start(self, *args, **kwargs):
        return

    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
            self.scale_batch_size(trainer, pl_module)


trainer = Trainer(callbacks=[FineTuneBatchSizeFinder(milestones=(5, 10))])
trainer.fit(...)

Example:

# 2. Run batch size finder for validate/test/predict.
from pytorch_lightning.callbacks import BatchSizeFinder


class EvalBatchSizeFinder(BatchSizeFinder):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def on_fit_start(self, *args, **kwargs):
        return

    def on_test_start(self, trainer, pl_module):
        self.scale_batch_size(trainer, pl_module)


trainer = Trainer(callbacks=[EvalBatchSizeFinder()])
trainer.test(...)
on_fit_start(trainer, pl_module)[source]

Called when fit begins.

Return type

None

on_predict_start(trainer, pl_module)[source]

Called when the predict begins.

Return type

None

on_test_start(trainer, pl_module)[source]

Called when the test begins.

Return type

None

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

Return type

None

setup(trainer, pl_module, stage=None)[source]

Called when fit, validate, test, predict, or tune begins.

Return type

None

Read the Docs v: stable
Versions
latest
stable
1.8.3post1
1.8.3.post0
1.8.3
1.8.2
1.8.1
1.8.0.post1
1.8.0
1.7.7
1.7.6
1.7.5
1.7.4
1.7.3
1.7.2
1.7.1
1.7.0
1.6.5
1.6.4
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.