- class pytorch_lightning.callbacks.QuantizationAwareTraining(qconfig='fbgemm', observer_type='average', collect_quantization=None, modules_to_fuse=None, input_compatible=True, quantize_on_fit_end=True, observer_enabled_stages=('train'))¶
Quantization allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision. We use native PyTorch API so for more information see PyTorch Quantization.
QuantizationAwareTrainingis in beta and subject to change.
’fbgemm’ for server inference.
’qnnpack’ for mobile inference.
a custom torch.quantization.QConfig.
count or custom function to collect quantization statistics:
None(deafult). The quantization observer is called in each module forward
(useful for collecting extended statistic when useing image/data augmentation).
int. Use to set a fixed number of calls, starting from the beginning.
Callable. Custom function with single trainer argument.
See this example to trigger only the last epoch:
def custom_trigger_last(trainer): return trainer.current_epoch == (trainer.max_epochs - 1) QuantizationAwareTraining(collect_quantization=custom_trigger_last)
allow fake-quantization modules’ observers to do calibration during provided stages:
'train': the observers can do calibration during training.
'validate': the observers can do calibration during validating. Note that we don’t disable observers during the sanity check as the model hasn’t been calibrated with training data yet. After the sanity check, the fake-quantization modules are restored to initial states.
'test': the observers can do calibration during testing.
'predict': the observers can do calibration during predicting.
Note that we only handle observers belonging to fake-quantization modules. When
'histogram', the observers won’t belong to any fake-quantization modules and will not be controlled by the callback.