lr_finder¶
Functions
lr_find enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. |
-
pytorch_lightning.tuner.lr_finder.
lr_find
(trainer, model, train_dataloader=None, val_dataloaders=None, min_lr=1e-08, max_lr=1, num_training=100, mode='exponential', early_stop_threshold=4.0, datamodule=None)[source]¶ lr_find enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.
- Parameters
model¶ (
LightningModule
) – Model to do range testing fortrain_dataloader¶ (
Optional
[DataLoader
]) – A PyTorch DataLoader with training samples. If the model has a predefined train_dataloader method, this will be skipped.mode¶ (
str
) – search strategy, either ‘linear’ or ‘exponential’. If set to ‘linear’ the learning rate will be searched by linearly increasing after each batch. If set to ‘exponential’, will increase learning rate exponentially.early_stop_threshold¶ (
float
) – threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None.datamodule¶ (
Optional
[LightningDataModule
]) – An optional LightningDataModule which holds the training and validation dataloader(s). Note that the train_dataloader and val_dataloaders parameters cannot be used at the same time as this parameter, or a MisconfigurationException will be raised.
Example:
# Setup model and trainer model = MyModelClass(hparams) trainer = pl.Trainer() # Run lr finder lr_finder = trainer.tuner.lr_find(model, ...) # Inspect results fig = lr_finder.plot(); fig.show() suggested_lr = lr_finder.suggestion() # Overwrite lr and create new model hparams.lr = suggested_lr model = MyModelClass(hparams) # Ready to train with new learning rate trainer.fit(model)