Shortcuts

BackboneFinetuning

class pytorch_lightning.callbacks.BackboneFinetuning(unfreeze_backbone_at_epoch=10, lambda_func=<function multiplicative>, backbone_initial_ratio_lr=0.1, backbone_initial_lr=None, should_align=True, initial_denom_lr=10.0, train_bn=True, verbose=False, rounding=12)[source]

Bases: pytorch_lightning.callbacks.finetuning.BaseFinetuning

Finetune a backbone model based on a learning rate user-defined scheduling.

When the backbone learning rate reaches the current model learning rate and should_align is set to True, it will align with it for the rest of the training.

Parameters
  • unfreeze_backbone_at_epoch (int) – Epoch at which the backbone will be unfreezed.

  • lambda_func (Callable) – Scheduling function for increasing backbone learning rate.

  • backbone_initial_ratio_lr (float) – Used to scale down the backbone learning rate compared to rest of model

  • backbone_initial_lr (Optional[float]) – Optional, Initial learning rate for the backbone. By default, we will use current_learning /  backbone_initial_ratio_lr

  • should_align (bool) – Whether to align with current learning rate when backbone learning reaches it.

  • initial_denom_lr (float) – When unfreezing the backbone, the initial learning rate will current_learning_rate /  initial_denom_lr.

  • train_bn (bool) – Whether to make Batch Normalization trainable.

  • verbose (bool) – Display current learning rate for model and backbone

  • rounding (int) – Precision for displaying learning rate

Example:

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import BackboneFinetuning
>>> multiplicative = lambda epoch: 1.5
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
>>> trainer = Trainer(callbacks=[backbone_finetuning])
finetune_function(pl_module, epoch, optimizer, opt_idx)[source]

Called when the epoch begins.

Return type

None

freeze_before_training(pl_module)[source]

Override to add your freeze logic.

Return type

None

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters

state_dict (Dict[str, Any]) – the callback state returned by state_dict.

Return type

None

on_fit_start(trainer, pl_module)[source]
Raises

MisconfigurationException – If LightningModule has no nn.Module backbone attribute.

Return type

None

state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type

Dict[str, Any]

Returns

A dictionary containing callback state.

Read the Docs v: latest
Versions
latest
stable
1.6.4
1.6.3
1.6.2
1.6.1
1.6.0
1.5.10
1.5.9
1.5.8
1.5.7
1.5.6
1.5.5
1.5.4
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
future-structure
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.