TrainingTypePlugin¶
-
class
pytorch_lightning.plugins.training_type.
TrainingTypePlugin
[source]¶ Bases:
pytorch_lightning.plugins.base_plugin.Plugin
,abc.ABC
Base class for all training type plugins that change the behaviour of the training, validation and test-loop.
-
abstract
all_gather
(tensor, group=None, sync_grads=False)[source]¶ Perform a all_gather on all processes
- Return type
-
abstract
barrier
(name=None)[source]¶ Forces all possibly joined processes to wait for each other
- Return type
-
connect
(model)[source]¶ Called by the accelerator to connect the accelerator and the model with this plugin
- Return type
-
model_sharded_context
()[source]¶ Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.
Returns: Model parallel context.
- Return type
-
post_backward
(closure_loss, should_accumulate, optimizer, opt_idx)[source]¶ Run after precision plugin executes backward
-
post_optimizer_step
(optimizer, optimizer_idx, **kwargs)[source]¶ Hook to do something after each optimizer step.
- Return type
-
pre_backward
(closure_loss, should_accumulate, optimizer, opt_idx)[source]¶ Run before precision plugin executes backward
-
process_dataloader
(dataloader)[source]¶ Wraps the dataloader if necessary
- Parameters
dataloader¶ (
Union
[Iterable
,DataLoader
]) – iterable. Ideally of type:torch.utils.data.DataLoader
- Return type
-
abstract
reduce
(tensor, *args, **kwargs)[source]¶ Reduces the given tensor (e.g. across GPUs/processes).
-
reduce_boolean_decision
(decision)[source]¶ Reduce the early stopping decision across all processes
- Return type
-
restore_model_state_from_ckpt_path
(ckpt_path, map_location=<function TrainingTypePlugin.<lambda>>)[source]¶ This function is used to load and restore the model state.
- Parameters
- Return
checkpoint: Return loaded checkpoint bool: Wether to load optimizer / lr_schedulers states from checkpoint
-
save_checkpoint
(checkpoint, filepath)[source]¶ Save model/training states as a checkpoint file through state-dump and file-write.
-
setup_environment
()[source]¶ Setup any processes or distributed connections. This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete.
- Return type
-
update_global_step
(total_batch_idx, current_global_step)[source]¶ Provide a hook to count optimizer step calls.
- Parameters
Returns: New optimizer step calls
- Return type
-
property
call_configure_sharded_model_hook
¶ Allow model parallel hook to be called in suitable environments determined by the training type plugin. This is useful for when we want to shard the model once within fit. Returns: True if we want to call the model parallel setup hook.
- Return type
-
abstract property
is_global_zero
¶ Whether the current process is the rank zero process not only on the local node, but for all nodes.
- Return type
-
property
lightning_module
¶ Returns the pure LightningModule without potential wrappers
- Return type
-
property
results
¶ Enables plugin-agnostic access to the result returned by the training/evaluation/prediction run. The result is cached instead of returned directly, because some plugins require transmitting the results from one multiprocessing context to another in a separate step. For example, the plugins that use the “spawn” start-method send the result to the master process through a multiprocessing queue (shared memory).
-
abstract property
root_device
¶ Returns the root device
- Return type
device
-
property
setup_optimizers_in_pre_dispatch
¶ Override to delay setting optimizers and schedulers till after dispatch. This is useful when the TrainingTypePlugin requires operating on the wrapped accelerator model. However this may break certain precision plugins such as APEX which require optimizers to be set. Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
- Return type
-
abstract