Shortcuts

TrainingTypePlugin

class pytorch_lightning.plugins.training_type.TrainingTypePlugin(checkpoint_io=None)[source]

Bases: abc.ABC

Base class for all training type plugins that change the behaviour of the training, validation and test- loop.

abstract all_gather(tensor, group=None, sync_grads=False)[source]

Perform an all_gather on all processes.

Parameters
  • tensor (Tensor) – the tensor to all_gather

  • group (Optional[Any]) – the process group to gather results from

  • sync_grads (bool) – flag that allows users to synchronize gradients for all_gather op

Return type

Tensor

abstract barrier(name=None)[source]

Synchronizes all processes which blocks processes until the whole group enters this function.

Parameters

name (Optional[str]) – an optional name to pass into barrier.

Return type

None

abstract broadcast(obj, src=0)[source]

Broadcasts an object to all processes.

Parameters
  • obj (object) – the object to broadcast

  • src (int) – source rank

Return type

object

connect(model)[source]

Called by the accelerator to connect the accelerator and the model with this plugin.

Return type

None

dispatch(trainer)[source]

Hook to do something at trainer run_stage starts.

Return type

None

lightning_module_state_dict()[source]

Returns model state.

Return type

Dict[str, Union[Any, Tensor]]

model_sharded_context()[source]

Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.

Returns: Model parallel context.

Return type

Generator

abstract model_to_device()[source]

Moves the model to the correct device.

Return type

None

on_predict_end()[source]

Called when predict ends.

on_predict_start()[source]

Called when predict begins.

Return type

None

on_test_end()[source]

Called when test end.

Return type

None

on_test_start()[source]

Called when test begins.

Return type

None

on_train_batch_start(batch, batch_idx, dataloader_idx=0)[source]

Called in the training loop before anything happens for that batch.

Return type

None

on_train_end()[source]

Called when train ends.

Return type

None

on_train_start()[source]

Called when train begins.

Return type

None

on_validation_end()[source]

Called when validation ends.

Return type

None

on_validation_start()[source]

Called when validation begins.

Return type

None

post_backward(closure_loss)[source]

Run after precision plugin executes backward.

Return type

None

post_dispatch(trainer)[source]

Hook to do something after the training/evaluation/prediction finishes.

Return type

None

pre_backward(closure_loss)[source]

Run before precision plugin executes backward.

Return type

None

pre_dispatch()[source]

Hook to do something before the training/evaluation/prediction starts.

Return type

None

process_dataloader(dataloader)[source]

Wraps the dataloader if necessary.

Parameters

dataloader (Union[Iterable, DataLoader]) – iterable. Ideally of type: torch.utils.data.DataLoader

Return type

Union[Iterable, DataLoader]

abstract reduce(tensor, group=None, reduce_op='mean')[source]

Reduces the given tensor (e.g. across GPUs/processes).

Parameters
  • tensor (Union[Tensor, Any]) – the tensor to sync and reduce

  • group (Optional[Any]) – the process group to reduce

  • reduce_op (Union[ReduceOp, str, None]) – the reduction operation. Defaults to ‘mean’. Can also be a string ‘sum’ or ReduceOp.

Return type

Union[Tensor, Any]

reduce_boolean_decision(decision)[source]

Reduce the early stopping decision across all processes.

Return type

bool

remove_checkpoint(filepath)[source]

Remove checkpoint filepath from the filesystem.

Parameters

filepath (Union[str, Path]) – Path to checkpoint

Return type

None

save_checkpoint(checkpoint, filepath)[source]

Save model/training states as a checkpoint file through state-dump and file-write.

Parameters
  • checkpoint (Dict[str, Any]) – dict containing model and trainer state

  • filepath (Union[str, Path]) – write-target file’s path

Return type

None

setup()[source]

Called by the accelerator to finish setup.

Return type

None

setup_environment()[source]

Setup any processes or distributed connections.

This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete.

Return type

None

abstract teardown()[source]

This method is called to teardown the training process.

It is the right place to release memory and free other resources.

Return type

None

property handles_gradient_accumulation: bool

Whether the plugin handles gradient accumulation internally.

Return type

bool

abstract property is_global_zero: bool

Whether the current process is the rank zero process not only on the local node, but for all nodes.

Return type

bool

property lightning_module: pytorch_lightning.core.lightning.LightningModule

Returns the pure LightningModule without potential wrappers.

Return type

LightningModule

property lightning_restore_optimizer_and_schedulers: bool

Override to disable Lightning restoring optimizers/schedulers.

This is useful for plugins which manage restoring optimizers/schedulers.

Return type

bool

property model: Optional[torch.nn.Module]

Returns the potentially wrapped LightningModule.

Return type

Optional[Module]

abstract property on_gpu: bool

Returns whether the current process is done on GPU.

Return type

bool

abstract property on_tpu: bool

Returns whether the current process is done on TPU.

Return type

bool

property restore_checkpoint_after_pre_dispatch: bool

Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin requires all the setup hooks to run before loading checkpoint.

Return type

bool

Returns

If true, restore checkpoint after pre_dispatch.

property results: Optional[Union[List[Dict[str, float]], List[Any], List[List[Any]]]]

Enables plugin-agnostic access to the result returned by the training/evaluation/prediction run.

The result is cached instead of returned directly, because some plugins require transmitting the results from one multiprocessing context to another in a separate step. For example, the plugins that use the “spawn” start-method send the result to the master process through a multiprocessing queue (shared memory).

Return type

Union[List[Dict[str, float]], List[Any], List[List[Any]], None]

abstract property root_device: torch.device

Returns the root device.

Return type

device

property setup_optimizers_in_pre_dispatch: bool

Override to delay setting optimizers and schedulers till after dispatch. This is useful when the TrainingTypePlugin requires operating on the wrapped accelerator model. However this may break certain precision plugins such as APEX which require optimizers to be set.

Return type

bool

Returns

If True, delay setup optimizers till pre_dispatch, else call within setup.

property should_rank_save_checkpoint: bool

Returns whether the checkpoint should be saved (rank based)

Return type

bool