DDPSpawnPlugin(parallel_devices=None, num_nodes=None, cluster_environment=None, sync_batchnorm=None, ddp_comm_state=None, ddp_comm_hook=None, ddp_comm_wrapper=None, **kwargs)¶
Spawns processes using the
torch.multiprocessing.spawn()method and joins processes after training finishes.
Forces all possibly joined processes to wait for each other
Moves the model to the correct device
Hook to do something after the training/evaluation/prediction finishes.
pre_backward(closure_loss, should_accumulate, optimizer, opt_idx)¶
Run before precision plugin executes backward
reduce(tensor, group=None, reduce_op='mean')¶
Reduces a tensor from several distributed processes to one aggregated tensor.
reduced value, except when the input was not a tensor the output remains is unchanged
Called by the accelerator to finish setup.
Returns the root device