DoublePrecisionPlugin¶
-
class
pytorch_lightning.plugins.precision.
DoublePrecisionPlugin
[source]¶ Bases:
pytorch_lightning.plugins.precision.precision_plugin.PrecisionPlugin
Plugin for training with double (
torch.float64
) precision.-
connect
(model, optimizers, lr_schedulers)[source]¶ Converts the model to double precision and wraps the training_step, validation_step, test_step, predict_step, and forward methods to convert incoming floating point data to double. Does not alter optimizers or lr_schedulers.
-
post_dispatch
()[source]¶ Hook to do something after the training/evaluation/prediction finishes.
- Return type
-
predict_step_context
()[source]¶ A context manager to change the default tensor type. See:
torch.set_default_tensor_type()
-
test_step_context
()[source]¶ A context manager to change the default tensor type. See:
torch.set_default_tensor_type()
-