Shortcuts

BaseFinetuning

class pytorch_lightning.callbacks.BaseFinetuning[source]

Bases: pytorch_lightning.callbacks.base.Callback

This class implements the base logic for writing your own Finetuning Callback.

Override freeze_before_training and finetune_function methods with your own logic.

freeze_before_training: This method is called before configure_optimizers

and should be used to freeze any modules parameters.

finetune_function: This method is called on every train epoch start and should be used to

unfreeze any parameters. Those parameters needs to be added in a new param_group within the optimizer.

Note

Make sure to filter the parameters based on requires_grad.

Example:

>>> from torch.optim import Adam
>>> class MyModel(pl.LightningModule):
...     def configure_optimizer(self):
...         # Make sure to filter the parameters based on `requires_grad`
...         return Adam(filter(lambda p: p.requires_grad, self.parameters))
...
>>> class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
...     def __init__(self, unfreeze_at_epoch=10):
...         self._unfreeze_at_epoch = unfreeze_at_epoch
...
...     def freeze_before_training(self, pl_module):
...         # freeze any module you want
...         # Here, we are freezing `feature_extractor`
...         self.freeze(pl_module.feature_extractor)
...
...     def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
...         # When `current_epoch` is 10, feature_extractor will start training.
...         if current_epoch == self._unfreeze_at_epoch:
...             self.unfreeze_and_add_param_group(
...                 modules=pl_module.feature_extractor,
...                 optimizer=optimizer,
...                 train_bn=True,
...             )
static filter_on_optimizer(optimizer, params)[source]

This function is used to exclude any parameter which already exists in this optimizer.

Parameters
  • optimizer (Optimizer) – Optimizer used for parameter exclusion

  • params (Iterable) – Iterable of parameters used to check against the provided optimizer

Return type

List

Returns

List of parameters not contained in this optimizer param groups

static filter_params(modules, train_bn=True, requires_grad=True)[source]

Yields the requires_grad parameters of a given module or list of modules.

Parameters
  • modules (Union[Module, Iterable[Union[Module, Iterable]]]) – A given module or an iterable of modules

  • train_bn (bool) – Whether to train BatchNorm module

  • requires_grad (bool) – Whether to create a generator for trainable or non-trainable parameters.

Return type

Generator

Returns

Generator

finetune_function(pl_module, epoch, optimizer, opt_idx)[source]

Override to add your unfreeze logic.

Return type

None

static flatten_modules(modules)[source]

This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves.

Parameters

modules (Union[Module, Iterable[Union[Module, Iterable]]]) – A given module or an iterable of modules

Return type

List[Module]

Returns

List of modules

static freeze(modules, train_bn=True)[source]

Freezes the parameters of the provided modules.

Parameters
Return type

None

Returns

None

freeze_before_training(pl_module)[source]

Override to add your freeze logic.

Return type

None

static make_trainable(modules)[source]

Unfreezes the parameters of the provided modules.

Parameters

modules (Union[Module, Iterable[Union[Module, Iterable]]]) – A given module or an iterable of modules

Return type

None

on_before_accelerator_backend_setup(trainer, pl_module)[source]

Called before accelerator is being setup.

on_fit_start(trainer, pl_module)[source]

Called when fit begins.

Return type

None

on_load_checkpoint(trainer, pl_module, callback_state)[source]

Called when loading a model checkpoint, use to reload state.

Parameters

Note

The on_load_checkpoint won’t be called with an undefined state. If your on_load_checkpoint hook behavior doesn’t rely on a state, you will still need to override on_save_checkpoint to return a dummy state.

Return type

None

on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Called when saving a model checkpoint, use to persist state.

Parameters
Return type

Dict[int, List[Dict[str, Any]]]

Returns

The callback state.

on_train_epoch_start(trainer, pl_module)[source]

Called when the epoch begins.

Return type

None

static unfreeze_and_add_param_group(modules, optimizer, lr=None, initial_denom_lr=10.0, train_bn=True)[source]

Unfreezes a module and adds its parameters to an optimizer.

Parameters
  • modules (Union[Module, Iterable[Union[Module, Iterable]]]) – A module or iterable of modules to unfreeze. Their parameters will be added to an optimizer as a new param group.

  • optimizer (Optimizer) – The provided optimizer will receive new parameters and will add them to add_param_group

  • lr (Optional[float]) – Learning rate for the new param group.

  • initial_denom_lr (float) – If no lr is provided, the learning from the first param group will be used and divided by initial_denom_lr.

  • train_bn (bool) – Whether to train the BatchNormalization layers.

Return type

None

Read the Docs v: stable
Versions
latest
stable
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
ipynb-update
docs-search
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.