ModelPruning¶
-
class
pytorch_lightning.callbacks.
ModelPruning
(pruning_fn, parameters_to_prune=None, parameter_names=None, use_global_unstructured=True, amount=0.5, apply_pruning=True, make_pruning_permanent=True, use_lottery_ticket_hypothesis=True, resample_parameters=False, pruning_dim=None, pruning_norm=None, verbose=0)[source]¶ Bases:
pytorch_lightning.callbacks.base.Callback
Model pruning Callback, using PyTorch’s prune utilities. This callback is responsible of pruning networks parameters during training.
To learn more about pruning with PyTorch, please take a look at this tutorial.
Warning
ModelPruning
is in beta and subject to change.parameters_to_prune = [ (model.mlp_1, "weight"), (model.mlp_2, "weight") ] trainer = Trainer(callbacks=[ ModelPruning( pruning_fn='l1_unstructured', parameters_to_prune=parameters_to_prune, amount=0.01, use_global_unstructured=True, ) ])
When
parameters_to_prune
isNone
,parameters_to_prune
will contain all parameters from the model. The user can overridefilter_parameters_to_prune
to filter anynn.Module
to be pruned.- Parameters
pruning_fn¶ (
Union
[Callable
,str
]) – Function from torch.nn.utils.prune module or your own PyTorchBasePruningMethod
subclass. Can also be string e.g. “l1_unstructured”. See pytorch docs for more details.parameters_to_prune¶ (
Union
[List
[Tuple
[Module
,str
]],Tuple
[Tuple
[Module
,str
]],None
]) – List of tuples(nn.Module, "parameter_name_string")
.parameter_names¶ (
Optional
[List
[str
]]) – List of parameter names to be pruned from the nn.Module. Can either be"weight"
or"bias"
.use_global_unstructured¶ (
bool
) – Whether to apply pruning globally on the model. Ifparameters_to_prune
is provided, global unstructured will be restricted on them.amount¶ (
Union
[int
,float
,Callable
[[int
],Union
[int
,float
]]]) –Quantity of parameters to prune:
float
. Between 0.0 and 1.0. Represents the fraction of parameters to prune.int
. Represents the absolute number of parameters to prune.Callable
. For dynamic values. Will be called every epoch. Should return a value.
apply_pruning¶ (
Union
[bool
,Callable
[[int
],bool
]]) –Whether to apply pruning.
bool
. Always apply it or not.Callable[[epoch], bool]
. For dynamic values. Will be called every epoch.
make_pruning_permanent¶ (
bool
) – Whether to remove all reparametrization pre-hooks and apply masks when training ends or the model is saved.use_lottery_ticket_hypothesis¶ (
Union
[bool
,Callable
[[int
],bool
]]) –See The lottery ticket hypothesis:
bool
. Whether to apply it or not.Callable[[epoch], bool]
. For dynamic values. Will be called every epoch.
resample_parameters¶ (
bool
) – Used withuse_lottery_ticket_hypothesis
. If True, the model parameters will be resampled, otherwise, the exact original parameters will be used.pruning_dim¶ (
Optional
[int
]) – If you are using a structured pruning method you need to specify the dimension.pruning_norm¶ (
Optional
[int
]) – If you are usingln_structured
you need to specify the norm.verbose¶ (
int
) – Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity
- Raises
MisconfigurationException – If
parameter_names
is neither"weight"
nor"bias"
, if the providedpruning_fn
is not supported, ifpruning_dim
is not provided when"unstructured"
, ifpruning_norm
is not provided when"ln_structured"
, ifpruning_fn
is neitherstr
nortorch.nn.utils.prune.BasePruningMethod
, or ifamount
is none ofint
,float
andCallable
.
-
apply_lottery_ticket_hypothesis
()[source]¶ Lottery ticket hypothesis algorithm (see page 2 of the paper):
Randomly initialize a neural network (where ).
Train the network for iterations, arriving at parameters .
Prune of the parameters in , creating a mask .
Reset the remaining parameters to their values in , creating the winning ticket .
This function implements the step 4.
The
resample_parameters
argument can be used to reset the parameters with a new
-
filter_parameters_to_prune
(parameters_to_prune=None)[source]¶ This function can be overridden to control which module to prune.
-
make_pruning_permanent
(pl_module)[source]¶ Removes pruning buffers from any pruned modules
Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180
-
on_before_accelerator_backend_setup
(trainer, pl_module)[source]¶ Called before accelerator is being setup
-
on_save_checkpoint
(trainer, pl_module, checkpoint)[source]¶ Called when saving a model checkpoint, use to persist state.