Shortcuts

Tuner

class pytorch_lightning.tuner.tuning.Tuner(trainer)[source]

Bases: object

Tuner class to tune your model

lr_find(model, train_dataloader=None, val_dataloaders=None, datamodule=None, min_lr=1e-08, max_lr=1, num_training=100, mode='exponential', early_stop_threshold=4.0, update_attr=False)[source]

Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.

Parameters
  • model (LightningModule) – Model to tune.

  • train_dataloader (Optional[DataLoader]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.

  • val_dataloaders (Union[DataLoader, List[DataLoader], None]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped

  • datamodule (Optional[LightningDataModule]) – An instance of LightningDataModule.

  • min_lr (float) – minimum learning rate to investigate

  • max_lr (float) – maximum learning rate to investigate

  • num_training (int) – number of learning rates to test

  • mode (str) –

    Search strategy to update learning rate after each batch:

    • 'exponential' (default): Will increase the learning rate exponentially.

    • 'linear': Will increase the learning rate linearly.

  • early_stop_threshold (float) – threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None.

  • update_attr (bool) – Whether to update the learning rate attribute or not.

Raises

MisconfigurationException – If learning rate/lr in model or model.hparams isn’t overridden when auto_lr_find=True, or if you are using more than one optimizer.

Return type

Optional[_LRFinder]

scale_batch_size(model, train_dataloader=None, val_dataloaders=None, datamodule=None, mode='power', steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name='batch_size')[source]

Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error.

Parameters
  • model (LightningModule) – Model to tune.

  • train_dataloader (Optional[DataLoader]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.

  • val_dataloaders (Union[DataLoader, List[DataLoader], None]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped

  • datamodule (Optional[LightningDataModule]) – An instance of LightningDataModule.

  • mode (str) –

    Search strategy to update the batch size:

    • 'power' (default): Keep multiplying the batch size by 2, until we get an OOM error.

    • 'binsearch': Initially keep multiplying by 2 and after encountering an OOM error

      do a binary search between the last successful batch size and the batch size that failed.

  • steps_per_trial (int) – number of steps to run with a given batch size. Ideally 1 should be enough to test if a OOM error occurs, however in practise a few are needed

  • init_val (int) – initial batch size to start the search with

  • max_trials (int) – max number of increase in batch size done before algorithm is terminated

  • batch_arg_name (str) –

    name of the attribute that stores the batch size. It is expected that the user has provided a model or datamodule that has a hyperparameter with that name. We will look for this attribute name in the following places

    • model

    • model.hparams

    • model.datamodule

    • trainer.datamodule (the datamodule passed to the tune method)

Return type

Optional[int]