Shortcuts

ModelPruning

class pytorch_lightning.callbacks.ModelPruning(pruning_fn, parameters_to_prune=(), parameter_names=None, use_global_unstructured=True, amount=0.5, apply_pruning=True, make_pruning_permanent=True, use_lottery_ticket_hypothesis=True, resample_parameters=False, pruning_dim=None, pruning_norm=None, verbose=0, prune_on_train_epoch_end=True)[source]

Bases: pytorch_lightning.callbacks.base.Callback

Model pruning Callback, using PyTorch’s prune utilities. This callback is responsible of pruning networks parameters during training.

To learn more about pruning with PyTorch, please take a look at this tutorial.

Warning

ModelPruning is in beta and subject to change.

parameters_to_prune = [(model.mlp_1, "weight"), (model.mlp_2, "weight")]

trainer = Trainer(
    callbacks=[
        ModelPruning(
            pruning_fn="l1_unstructured",
            parameters_to_prune=parameters_to_prune,
            amount=0.01,
            use_global_unstructured=True,
        )
    ]
)

When parameters_to_prune is None, parameters_to_prune will contain all parameters from the model. The user can override filter_parameters_to_prune to filter any nn.Module to be pruned.

Parameters
  • pruning_fn (Union[Callable, str]) – Function from torch.nn.utils.prune module or your own PyTorch BasePruningMethod subclass. Can also be string e.g. “l1_unstructured”. See pytorch docs for more details.

  • parameters_to_prune (Sequence[Tuple[Module, str]]) – List of tuples (nn.Module, "parameter_name_string").

  • parameter_names (Optional[List[str]]) – List of parameter names to be pruned from the nn.Module. Can either be "weight" or "bias".

  • use_global_unstructured (bool) – Whether to apply pruning globally on the model. If parameters_to_prune is provided, global unstructured will be restricted on them.

  • amount (Union[int, float, Callable[[int], Union[int, float]]]) –

    Quantity of parameters to prune:

    • float. Between 0.0 and 1.0. Represents the fraction of parameters to prune.

    • int. Represents the absolute number of parameters to prune.

    • Callable. For dynamic values. Will be called every epoch. Should return a value.

  • apply_pruning (Union[bool, Callable[[int], bool]]) –

    Whether to apply pruning.

    • bool. Always apply it or not.

    • Callable[[epoch], bool]. For dynamic values. Will be called every epoch.

  • make_pruning_permanent (bool) – Whether to remove all reparametrization pre-hooks and apply masks when training ends or the model is saved.

  • use_lottery_ticket_hypothesis (Union[bool, Callable[[int], bool]]) –

    See The lottery ticket hypothesis:

    • bool. Whether to apply it or not.

    • Callable[[epoch], bool]. For dynamic values. Will be called every epoch.

  • resample_parameters (bool) – Used with use_lottery_ticket_hypothesis. If True, the model parameters will be resampled, otherwise, the exact original parameters will be used.

  • pruning_dim (Optional[int]) – If you are using a structured pruning method you need to specify the dimension.

  • pruning_norm (Optional[int]) – If you are using ln_structured you need to specify the norm.

  • verbose (int) – Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity

  • prune_on_train_epoch_end (bool) – whether to apply pruning at the end of the training epoch. If this is False, then the check runs at the end of the validation epoch.

Raises

MisconfigurationException – If parameter_names is neither "weight" nor "bias", if the provided pruning_fn is not supported, if pruning_dim is not provided when "unstructured", if pruning_norm is not provided when "ln_structured", if pruning_fn is neither str nor torch.nn.utils.prune.BasePruningMethod, or if amount is none of int, float and Callable.

apply_lottery_ticket_hypothesis()[source]

Lottery ticket hypothesis algorithm (see page 2 of the paper):

  1. Randomly initialize a neural network f(x; \theta_0) (where \theta_0 \sim \mathcal{D}_\theta).

  2. Train the network for j iterations, arriving at parameters \theta_j.

  3. Prune p\% of the parameters in \theta_j, creating a mask m.

  4. Reset the remaining parameters to their values in \theta_0, creating the winning ticket f(x; m \odot \theta_0).

This function implements the step 4.

The resample_parameters argument can be used to reset the parameters with a new \theta_z \sim \mathcal{D}_\theta

Return type

None

apply_pruning(amount)[source]

Applies pruning to parameters_to_prune.

Return type

None

filter_parameters_to_prune(parameters_to_prune=())[source]

This function can be overridden to control which module to prune.

Return type

Sequence[Tuple[Module, str]]

make_pruning_permanent(module)[source]

Removes pruning buffers from any pruned modules.

Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180

Return type

None

on_before_accelerator_backend_setup(trainer, pl_module)[source]

Called before accelerator is being setup.

Return type

None

on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Called when saving a model checkpoint, use to persist state.

Parameters
Return type

Dict[str, Any]

Returns

The callback state.

on_train_end(trainer, pl_module)[source]

Called when the train ends.

Return type

None

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, either:

  1. Implement training_epoch_end in the LightningModule and access outputs via the module OR

  2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.

Return type

None

on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

Return type

None

static sanitize_parameters_to_prune(pl_module, parameters_to_prune=(), parameter_names=())[source]

This function is responsible of sanitizing parameters_to_prune and parameter_names. If parameters_to_prune is None, it will be generated with all parameters of the model.

Raises

MisconfigurationException – If parameters_to_prune doesn’t exist in the model, or if parameters_to_prune is neither a list nor a tuple.

Return type

Sequence[Tuple[Module, str]]

Read the Docs v: stable
Versions
latest
stable
1.5.3
1.5.2
1.5.1
1.5.0
1.4.9
1.4.8
1.4.7
1.4.6
1.4.5
1.4.4
1.4.3
1.4.2
1.4.1
1.4.0
1.3.8
1.3.7
1.3.6
1.3.5
1.3.4
1.3.3
1.3.2
1.3.1
1.3.0
1.2.10
1.2.8
1.2.7
1.2.6
1.2.5
1.2.4
1.2.3
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3
0.4.9
ipynb-update
docs-search
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.