Shortcuts

pytorch_lightning.trainer.lr_finder module

Trainer Learning Rate Finder

class pytorch_lightning.trainer.lr_finder.TrainerLRFinderMixin[source]

Bases: abc.ABC

_TrainerLRFinderMixin__lr_finder_dump_params(model)[source]
_TrainerLRFinderMixin__lr_finder_restore_params(model)[source]
_run_lr_finder_internally(model)[source]

Call lr finder internally during Trainer.fit()

abstract fit(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract init_optimizers(*args)[source]

Warning: this is just empty shell for code implemented in other class.

Return type

Tuple[List, List, List]

lr_find(model, train_dataloader=None, val_dataloaders=None, min_lr=1e-08, max_lr=1, num_training=100, mode='exponential', early_stop_threshold=4.0, num_accumulation_steps=None)[source]

lr_find enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.

Parameters
  • model (LightningModule) – Model to do range testing for

  • train_dataloader (Optional[DataLoader]) – A PyTorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.

  • min_lr (float) – minimum learning rate to investigate

  • max_lr (float) – maximum learning rate to investigate

  • num_training (int) – number of learning rates to test

  • mode (str) – search strategy, either ‘linear’ or ‘exponential’. If set to ‘linear’ the learning rate will be searched by linearly increasing after each batch. If set to ‘exponential’, will increase learning rate exponentially.

  • early_stop_threshold (float) – threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None.

  • num_accumulation_steps – deprepecated, number of batches to calculate loss over. Set trainer argument accumulate_grad_batches instead.

Example:

# Setup model and trainer
model = MyModelClass(hparams)
trainer = pl.Trainer()

# Run lr finder
lr_finder = trainer.lr_find(model, ...)

# Inspect results
fig = lr_finder.plot(); fig.show()
suggested_lr = lr_finder.suggestion()

# Overwrite lr and create new model
hparams.lr = suggested_lr
model = MyModelClass(hparams)

# Ready to train with new learning rate
trainer.fit(model)
abstract restore(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract save_checkpoint(*args)[source]

Warning: this is just empty shell for code implemented in other class.

default_root_dir: str = None[source]
global_step: int = None[source]
on_gpu: bool = None[source]
progress_bar_callback: ... = None[source]
total_batch_idx: int = None[source]
class pytorch_lightning.trainer.lr_finder._ExponentialLR(optimizer, end_lr, num_iter, last_epoch=-1)[source]

Bases: torch.optim.lr_scheduler._LRScheduler

Exponentially increases the learning rate between two boundaries over a number of iterations.

Parameters
  • optimizer (Optimizer) – wrapped optimizer.

  • end_lr (float) – the final learning rate.

  • num_iter (int) – the number of iterations over which the test occurs.

  • last_epoch (int) – the index of last epoch. Default: -1.

get_lr()[source]
base_lrs: Sequence = None[source]
last_epoch: int = None[source]
property lr[source]
class pytorch_lightning.trainer.lr_finder._LRCallback(num_training, early_stop_threshold=4.0, progress_bar_refresh_rate=0, beta=0.98)[source]

Bases: pytorch_lightning.callbacks.base.Callback

Special callback used by the learning rate finder. This callbacks log the learning rate before each batch and log the corresponding loss after each batch.

Parameters
  • num_training (int) – number of iterations done by the learning rate finder

  • early_stop_threshold (float) – threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None.

  • progress_bar_refresh_rate (int) – rate to refresh the progress bar for the learning rate finder

  • beta (float) – smoothing value, the loss being logged is a running average of loss values logged until now. beta controls the forget rate i.e. if beta=0 all past information is ignored.

on_batch_end(trainer, pl_module)[source]

Called when the training batch ends, logs the calculated loss

on_batch_start(trainer, pl_module)[source]

Called before each training batch, logs the lr that will be used

class pytorch_lightning.trainer.lr_finder._LRFinder(mode, lr_min, lr_max, num_training)[source]

Bases: object

LR finder object. This object stores the results of Trainer.lr_find().

Parameters
  • mode (str) – either linear or exponential, how to increase lr after each step

  • lr_min (float) – lr to start search from

  • lr_max (float) – lr to stop search

  • num_training (int) – number of steps to take between lr_min and lr_max

Example::

# Run lr finder lr_finder = trainer.lr_find(model)

# Results stored in lr_finder.results

# Plot using lr_finder.plot()

# Get suggestion lr = lr_finder.suggestion()

_get_new_optimizer(optimizer)[source]
Construct a new configure_optimizers() method, that has a optimizer

with initial lr set to lr_min and a scheduler that will either linearly or exponentially increase the lr to lr_max in num_training steps.

Parameters

optimizer (Optimizer) – instance of torch.optim.Optimizer

plot(suggest=False, show=False)[source]

Plot results from lr_find run :type _sphinx_paramlinks_pytorch_lightning.trainer.lr_finder._LRFinder.plot.suggest: bool :param _sphinx_paramlinks_pytorch_lightning.trainer.lr_finder._LRFinder.plot.suggest: if True, will mark suggested lr to use with a red point :type _sphinx_paramlinks_pytorch_lightning.trainer.lr_finder._LRFinder.plot.show: bool :param _sphinx_paramlinks_pytorch_lightning.trainer.lr_finder._LRFinder.plot.show: if True, will show figure

suggestion(skip_begin=10, skip_end=1)[source]

This will propose a suggestion for choice of initial learning rate as the point with the steepest negative gradient.

Returns

suggested initial learning rate to use skip_begin: how many samples to skip in the beginning. Prevent too naive estimates skip_end: how many samples to skip in the end. Prevent too optimistic estimates

Return type

lr

class pytorch_lightning.trainer.lr_finder._LinearLR(optimizer, end_lr, num_iter, last_epoch=-1)[source]

Bases: torch.optim.lr_scheduler._LRScheduler

Linearly increases the learning rate between two boundaries over a number of iterations. :type _sphinx_paramlinks_pytorch_lightning.trainer.lr_finder._LinearLR.optimizer: Optimizer :param _sphinx_paramlinks_pytorch_lightning.trainer.lr_finder._LinearLR.optimizer: wrapped optimizer. :type _sphinx_paramlinks_pytorch_lightning.trainer.lr_finder._LinearLR.end_lr: float :param _sphinx_paramlinks_pytorch_lightning.trainer.lr_finder._LinearLR.end_lr: the final learning rate. :type _sphinx_paramlinks_pytorch_lightning.trainer.lr_finder._LinearLR.num_iter: int :param _sphinx_paramlinks_pytorch_lightning.trainer.lr_finder._LinearLR.num_iter: the number of iterations over which the test occurs. :type _sphinx_paramlinks_pytorch_lightning.trainer.lr_finder._LinearLR.last_epoch: int :param _sphinx_paramlinks_pytorch_lightning.trainer.lr_finder._LinearLR.last_epoch: the index of last epoch. Default: -1.

get_lr()[source]
base_lrs: Sequence = None[source]
last_epoch: int = None[source]
property lr[source]
pytorch_lightning.trainer.lr_finder._nested_hasattr(obj, path)[source]
pytorch_lightning.trainer.lr_finder._nested_setattr(obj, path, val)[source]