Shortcuts

QuantizationAwareTraining

class pytorch_lightning.callbacks.QuantizationAwareTraining(qconfig='fbgemm', observer_type='average', collect_quantization=None, modules_to_fuse=None, input_compatible=True, quantize_on_fit_end=True)[source]

Bases: pytorch_lightning.callbacks.base.Callback

Quantization allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision. We use native PyTorch API so for more information see Quantization.

Warning

QuantizationAwareTraining is in beta and subject to change.

Parameters
  • qconfig (Union[str, QConfig]) –

    quantization configuration:

  • observer_type (str) – allows switching between MovingAverageMinMaxObserver as “average” (default) and HistogramObserver as “histogram” which is more computationally expensive.

  • collect_quantization (Union[Callable, int, None]) –

    count or custom function to collect quantization statistics:

    • None (deafult). The quantization observer is called in each module forward

      (useful for collecting extended statistic when useing image/data augmentation).

    • int. Use to set a fixed number of calls, starting from the beginning.

    • Callable. Custom function with single trainer argument.

      See this example to trigger only the last epoch:

      def custom_trigger_last(trainer):
          return trainer.current_epoch == (trainer.max_epochs - 1)
      
      
      QuantizationAwareTraining(collect_quantization=custom_trigger_last)
      

  • modules_to_fuse (Optional[Sequence]) – allows you fuse a few layers together as shown in diagram to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.

  • input_compatible (bool) – preserve quant/dequant layers. This allows to feat any input as to the original model, but break compatibility to torchscript and export with torch.save.

  • quantize_on_fit_end (bool) – perform the quantization in on_fit_end. Note that once converted, the model cannot be put in training mode again.

on_fit_end(trainer, pl_module)[source]

Called when fit ends

on_fit_start(trainer, pl_module)[source]

Called when fit begins