DoublePrecisionPlugin¶
-
class
pytorch_lightning.plugins.precision.DoublePrecisionPlugin[source]¶ Bases:
pytorch_lightning.plugins.precision.precision_plugin.PrecisionPluginPlugin for training with double (
torch.float64) precision.-
connect(model, optimizers, lr_schedulers)[source]¶ Converts the model to double precision and wraps the training_step, validation_step, test_step, predict_step, and forward methods to convert incoming floating point data to double. Does not alter optimizers or lr_schedulers.
-
post_dispatch()[source]¶ Hook to do something after the training/evaluation/prediction finishes.
- Return type
-
predict_step_context()[source]¶ A context manager to change the default tensor type. See:
torch.set_default_tensor_type()
-
test_step_context()[source]¶ A context manager to change the default tensor type. See:
torch.set_default_tensor_type()
-