lr_finder¶
Functions
|
-
pytorch_lightning.tuner.lr_finder.
lr_find
(trainer, model, train_dataloader=None, val_dataloaders=None, min_lr=1e-08, max_lr=1, num_training=100, mode='exponential', early_stop_threshold=4.0, datamodule=None, update_attr=False)[source]¶ lr_find
enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.- Parameters
model¶ (
LightningModule
) – Model to do range testing fortrain_dataloader¶ (
Optional
[DataLoader
]) – A PyTorchDataLoader
with training samples. If the model has a predefined train_dataloader method, this will be skipped.Search strategy to update learning rate after each batch:
'exponential'
(default): Will increase the learning rate exponentially.'linear'
: Will increase the learning rate linearly.
early_stop_threshold¶ (
float
) – threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None.datamodule¶ (
Optional
[LightningDataModule
]) – An optionalLightningDataModule
which holds the training and validation dataloader(s). Note that thetrain_dataloader
andval_dataloaders
parameters cannot be used at the same time as this parameter, or aMisconfigurationException
will be raised.update_attr¶ (
bool
) – Whether to update the learning rate attribute or not.
Example:
# Setup model and trainer model = MyModelClass(hparams) trainer = pl.Trainer() # Run lr finder lr_finder = trainer.tuner.lr_find(model, ...) # Inspect results fig = lr_finder.plot(); fig.show() suggested_lr = lr_finder.suggestion() # Overwrite lr and create new model hparams.lr = suggested_lr model = MyModelClass(hparams) # Ready to train with new learning rate trainer.fit(model)