Shortcuts

lr_finder

Functions

lr_find

lr_find enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.

pytorch_lightning.tuner.lr_finder.lr_find(trainer, model, train_dataloader=None, val_dataloaders=None, min_lr=1e-08, max_lr=1, num_training=100, mode='exponential', early_stop_threshold=4.0, datamodule=None, update_attr=False)[source]

lr_find enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.

Parameters
  • model (LightningModule) – Model to do range testing for

  • train_dataloader (Optional[DataLoader]) – A PyTorch DataLoader with training samples. If the model has a predefined train_dataloader method, this will be skipped.

  • min_lr (float) – minimum learning rate to investigate

  • max_lr (float) – maximum learning rate to investigate

  • num_training (int) – number of learning rates to test

  • mode (str) –

    Search strategy to update learning rate after each batch:

    • 'exponential' (default): Will increase the learning rate exponentially.

    • 'linear': Will increase the learning rate linearly.

  • early_stop_threshold (float) – threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None.

  • datamodule (Optional[LightningDataModule]) – An optional LightningDataModule which holds the training and validation dataloader(s). Note that the train_dataloader and val_dataloaders parameters cannot be used at the same time as this parameter, or a MisconfigurationException will be raised.

  • update_attr (bool) – Whether to update the learning rate attribute or not.

Example:

# Setup model and trainer
model = MyModelClass(hparams)
trainer = pl.Trainer()

# Run lr finder
lr_finder = trainer.tuner.lr_find(model, ...)

# Inspect results
fig = lr_finder.plot(); fig.show()
suggested_lr = lr_finder.suggestion()

# Overwrite lr and create new model
hparams.lr = suggested_lr
model = MyModelClass(hparams)

# Ready to train with new learning rate
trainer.fit(model)