Shortcuts

BackboneFinetuning

class pytorch_lightning.callbacks.BackboneFinetuning(unfreeze_backbone_at_epoch=10, lambda_func=<function multiplicative>, backbone_initial_ratio_lr=0.1, backbone_initial_lr=None, should_align=True, initial_denom_lr=10.0, train_bn=True, verbose=False, round=12)[source]

Bases: pytorch_lightning.callbacks.finetuning.BaseFinetuning

Finetune a backbone model based on a learning rate user-defined scheduling. When the backbone learning rate reaches the current model learning rate and should_align is set to True, it will align with it for the rest of the training.

Parameters
  • unfreeze_backbone_at_epoch (int) – Epoch at which the backbone will be unfreezed.

  • lambda_func (Callable) – Scheduling function for increasing backbone learning rate.

  • backbone_initial_ratio_lr (float) – Used to scale down the backbone learning rate compared to rest of model

  • backbone_initial_lr (Optional[float]) – Optional, Inital learning rate for the backbone. By default, we will use current_learning / backbone_initial_ratio_lr

  • should_align (bool) – Wheter to align with current learning rate when backbone learning reaches it.

  • initial_denom_lr (float) – When unfreezing the backbone, the intial learning rate will current_learning_rate / initial_denom_lr.

  • train_bn (bool) – Wheter to make Batch Normalization trainable.

  • verbose (bool) – Display current learning rate for model and backbone

  • round (int) – Precision for displaying learning rate

Example:

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import BackboneFinetuning
>>> multiplicative = lambda epoch: 1.5
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
>>> trainer = Trainer(callbacks=[backbone_finetuning])
finetune_function(pl_module, epoch, optimizer, opt_idx)[source]

Called when the epoch begins.

freeze_before_training(pl_module)[source]

Override to add your freeze logic

on_fit_start(trainer, pl_module)[source]
Raises

MisconfigurationException – If LightningModule has no nn.Module backbone attribute.

on_load_checkpoint(trainer, pl_module, callback_state)[source]

Called when loading a model checkpoint, use to reload state.

Parameters

Note

The on_load_checkpoint won’t be called with an undefined state. If your on_load_checkpoint hook behavior doesn’t rely on a state, you will still need to override on_save_checkpoint to return a dummy state.

Return type

None

on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Called when saving a model checkpoint, use to persist state.

Parameters
Return type

Dict[int, Any]

Returns

The callback state.