Shortcuts

lr_finder

Functions

lr_find

lr_find enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.

pytorch_lightning.tuner.lr_finder.lr_find(trainer, model, train_dataloader=None, val_dataloaders=None, min_lr=1e-08, max_lr=1, num_training=100, mode='exponential', early_stop_threshold=4.0, datamodule=None, update_attr=False)[source]

lr_find enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.

Parameters
  • model (LightningModule) – Model to do range testing for

  • train_dataloader (Optional[DataLoader]) – A PyTorch DataLoader with training samples. If the model has a predefined train_dataloader method, this will be skipped.

  • min_lr (float) – minimum learning rate to investigate

  • max_lr (float) – maximum learning rate to investigate

  • num_training (int) – number of learning rates to test

  • mode (str) –

    Search strategy to update learning rate after each batch:

    • 'exponential' (default): Will increase the learning rate exponentially.

    • 'linear': Will increase the learning rate linearly.

  • early_stop_threshold (float) – threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None.

  • datamodule (Optional[LightningDataModule]) – An optional LightningDataModule which holds the training and validation dataloader(s). Note that the train_dataloader and val_dataloaders parameters cannot be used at the same time as this parameter, or a MisconfigurationException will be raised.

  • update_attr (bool) – Whether to update the learning rate attribute or not.

Raises

MisconfigurationException – If learning rate/lr in model or model.hparams isn’t overriden when auto_lr_find=True, or if you are using more than one optimizer with learning rate finder.

Example:

# Setup model and trainer
model = MyModelClass(hparams)
trainer = pl.Trainer()

# Run lr finder
lr_finder = trainer.tuner.lr_find(model, ...)

# Inspect results
fig = lr_finder.plot(); fig.show()
suggested_lr = lr_finder.suggestion()

# Overwrite lr and create new model
hparams.lr = suggested_lr
model = MyModelClass(hparams)

# Ready to train with new learning rate
trainer.fit(model)
Read the Docs v: latest
Versions
latest
stable
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
docs-robots
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.