Optimization¶
Lightning offers two modes for managing the optimization process:
automatic optimization (AutoOpt)
manual optimization
For the majority of research cases, automatic optimization will do the right thing for you and it is what most users should use.
For advanced/expert users who want to do esoteric optimization schedules or techniques, use manual optimization.
Manual optimization¶
For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable to manually manage the optimization process. To do so, do the following:
Disable automatic optimization in Trainer: Trainer(automatic_optimization=False)
Drop or ignore the optimizer_idx argument
Use self.manual_backward(loss) instead of loss.backward() to automatically scale your loss
def training_step(self, batch, batch_idx, optimizer_idx):
# 1. ignore optimizer_idx
# 2. `use_pl_optimizer=True` means `opt_g` and `opt_d` will be of type `LightningOptimizer`
# `LightningOptimizer` simply wrapped your optimizer and behave the same way !
# When calling `optimizer.step`, `LightningOptimizer` will just handle TPU, AMP, accumulate_grad_batches, etc ... for you.
# access your optimizers with `use_pl_optimizer=False` or `optimizer.optimizer` when using use_pl_optimizer=True
# use_pl_optimizer=True is the default
(opt_g, opt_d) = self.optimizers(use_pl_optimizer=True)
# do anything you want
loss_a = ...
# use self.backward which will also handle scaling the loss when using amp
self.manual_backward(loss_a, opt_g)
opt_g.step()
# do anything you want
loss_b = ...
# pass in any args that loss.backward() normally takes
self.manual_backward(loss_b, opt_d, retain_graph=True)
self.manual_backward(loss_b, opt_d)
opt_d.step()
# log losses
self.log('loss_a', loss_a)
self.log('loss_b', loss_b)
Note
This is only recommended for experts who need ultimate flexibility
Manual optimization does not yet support accumulated gradients but will be live in 1.1.0
Automatic optimization¶
With Lightning most users don’t have to think about when to call .backward(), .step(), .zero_grad(), since Lightning automates that for you.
Under the hood Lightning does the following:
for epoch in epochs:
for batch in data:
loss = model.training_step(batch, batch_idx, ...)
loss.backward()
optimizer.step()
optimizer.zero_grad()
for scheduler in schedulers:
scheduler.step()
In the case of multiple optimizers, Lightning does the following:
for epoch in epochs:
for batch in data:
for opt in optimizers:
disable_grads_for_other_optimizers()
train_step(opt)
opt.step()
for scheduler in schedulers:
scheduler.step()
Learning rate scheduling¶
Every optimizer you use can be paired with any LearningRateScheduler.
# no LR scheduler
def configure_optimizers(self):
return Adam(...)
# Adam + LR scheduler
def configure_optimizers(self):
optimizer = Adam(...)
scheduler = LambdaLR(optimizer, ...)
return [optimizer], [scheduler]
# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
return {
'optimizer': Adam(...),
'lr_scheduler': ReduceLROnPlateau(optimizer, ...),
'monitor': 'metric_to_track'
}
# Two optimizers each with a scheduler
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = LambdaLR(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return [optimizer1, optimizer2], [scheduler1, scheduler2]
# Alternatively
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return (
{'optimizer': optimizer1, 'lr_scheduler': scheduler1, 'monitor': 'metric_to_track'},
{'optimizer': optimizer2, 'lr_scheduler': scheduler2},
)
# Same as above with additional params passed to the first scheduler
def configure_optimizers(self):
optimizers = [Adam(...), SGD(...)]
schedulers = [
{
'scheduler': ReduceLROnPlateau(optimizers[0], ...),
'monitor': 'metric_to_track',
'interval': 'epoch',
'frequency': 1,
'strict': True,
},
LambdaLR(optimizers[1], ...)
]
return optimizers, schedulers
Use multiple optimizers (like GANs)¶
To use multiple optimizers return > 1 optimizers from pytorch_lightning.core.LightningModule.configure_optimizers()
# one optimizer
def configure_optimizers(self):
return Adam(...)
# two optimizers, no schedulers
def configure_optimizers(self):
return Adam(...), SGD(...)
# Two optimizers, one scheduler for adam only
def configure_optimizers(self):
return [Adam(...), SGD(...)], {'scheduler': ReduceLROnPlateau(), 'monitor': 'metric_to_track'}
Lightning will call each optimizer sequentially:
for epoch in epochs:
for batch in data:
for opt in optimizers:
train_step(opt)
opt.step()
for scheduler in schedulers:
scheduler.step()
Step optimizers at arbitrary intervals¶
To do more interesting things with your optimizers such as learning rate warm-up or odd scheduling,
override the optimizer_step()
function.
For example, here step optimizer A every 2 batches and optimizer B every 4 batches
Note
When using Trainer(enable_pl_optimizer=True), there is no need to call .zero_grad().
def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx):
optimizer.zero_grad()
# Alternating schedule for optimizer steps (ie: GANs)
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
# update generator opt every 2 steps
if optimizer_idx == 0:
if batch_nb % 2 == 0 :
optimizer.step(closure=closure)
# update discriminator opt every 4 steps
if optimizer_idx == 1:
if batch_nb % 4 == 0 :
optimizer.step(closure=closure)
Note
When using Trainer(enable_pl_optimizer=True)
, .step
accepts a boolean make_optimizer_step
which can be used as follow.
def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx):
optimizer.zero_grad()
# Alternating schedule for optimizer steps (ie: GANs)
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
# update generator opt every 2 steps
if optimizer_idx == 0:
optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 2) == 0)
# update discriminator opt every 4 steps
if optimizer_idx == 1:
optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 4) == 0)
Here we add a learning-rate warm up
# learning rate warm-up
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
for pg in optimizer.param_groups:
pg['lr'] = lr_scale * self.hparams.learning_rate
# update params
optimizer.step(closure=closure)
Note
The default optimizer_step
is relying on the internal LightningOptimizer
to properly perform a step. It handles TPUs, AMP, accumulate_grad_batches, zero_grad, and much more …
# function hook in LightningModule
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
optimizer.step(closure=closure)
Note
To access your wrapped Optimizer from LightningOptimizer
, do as follow.
# function hook in LightningModule
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
# `optimizer is a ``LightningOptimizer`` wrapping the optimizer.
# To access it, do as follow:
optimizer = optimizer.optimizer
# run step. However, it won't work on TPU, AMP, etc...
optimizer.step(closure=closure)
Using the closure functions for optimization¶
When using optimization schemes such as LBFGS, the second_order_closure needs to be enabled. By default, this function is defined by wrapping the training_step and the backward steps as follows
def second_order_closure(pl_module, split_batch, batch_idx, opt_idx, optimizer, hidden):
# Model training step on a given batch
result = pl_module.training_step(split_batch, batch_idx, opt_idx, hidden)
# Model backward pass
pl_module.backward(result, optimizer, opt_idx)
# on_after_backward callback
pl_module.on_after_backward(result.training_step_output, batch_idx, result.loss)
return result
# This default `second_order_closure` function can be enabled by passing it directly into the `optimizer.step`
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
# update params
optimizer.step(second_order_closure)