Shortcuts

TrainingTypePlugin

class pytorch_lightning.plugins.training_type.TrainingTypePlugin[source]

Bases: pytorch_lightning.plugins.base_plugin.Plugin, abc.ABC

Base class for all training type plugins that change the behaviour of the training, validation and test-loop.

abstract all_gather(tensor, group=None, sync_grads=False)[source]

Perform a all_gather on all processes

Return type

Tensor

abstract barrier(name=None)[source]

Forces all possibly joined processes to wait for each other

Return type

None

abstract broadcast(obj, src=0)[source]

Broadcasts an object to all processes

Return type

~T

connect(model)[source]

Called by the accelerator to connect the accelerator and the model with this plugin

Return type

None

model_sharded_context()[source]

Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.

Returns: Model parallel context.

Return type

Generator

abstract model_to_device()[source]

Moves the model to the correct device

Return type

None

post_backward(closure_loss, should_accumulate, optimizer, opt_idx)[source]

Run after precision plugin executes backward

post_optimizer_step(optimizer, optimizer_idx, **kwargs)[source]

Hook to do something after each optimizer step.

Return type

None

pre_backward(closure_loss, should_accumulate, optimizer, opt_idx)[source]

Run before precision plugin executes backward

process_dataloader(dataloader)[source]

Wraps the dataloader if necessary

Parameters

dataloader (Union[Iterable, DataLoader]) – iterable. Ideally of type: torch.utils.data.DataLoader

Return type

Union[Iterable, DataLoader]

abstract reduce(tensor, *args, **kwargs)[source]

Reduces the given tensor (e.g. across GPUs/processes).

Parameters
  • tensor (Union[Tensor, Any]) – the tensor to sync and reduce

  • *args – plugin-specific positional arguments

  • **kwargs – plugin-specific keyword arguments

Return type

Union[Tensor, Any]

reduce_boolean_decision(decision)[source]

Reduce the early stopping decision across all processes

Return type

bool

restore_model_state_from_ckpt_path(ckpt_path, map_location=<function TrainingTypePlugin.<lambda>>)[source]

This function is used to load and restore the model state.

Parameters
  • ckpt_path (str) – Path to a checkpoint

  • map_location (Callable) – lambda function to map checkpoint location

Return

checkpoint: Return loaded checkpoint bool: Wether to load optimizer / lr_schedulers states from checkpoint

Return type

Tuple[Dict, bool]

save_checkpoint(checkpoint, filepath)[source]

Save model/training states as a checkpoint file through state-dump and file-write.

Parameters
  • checkpoint (Dict[str, Any]) – dict containing model and trainer state

  • filepath (str) – write-target file’s path

Return type

None

setup(model)[source]

Called by the accelerator to finish setup.

Return type

None

setup_environment()[source]

Setup any processes or distributed connections. This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete.

Return type

None

update_global_step(total_batch_idx, current_global_step)[source]

Provide a hook to count optimizer step calls.

Parameters
  • total_batch_idx (int) – Total number of batches seen for training

  • current_global_step (int) – Current number of optimizer step calls

Returns: New optimizer step calls

Return type

int

property call_configure_sharded_model_hook

Allow model parallel hook to be called in suitable environments determined by the training type plugin. This is useful for when we want to shard the model once within fit. Returns: True if we want to call the model parallel setup hook.

Return type

bool

abstract property is_global_zero

Whether the current process is the rank zero process not only on the local node, but for all nodes.

Return type

bool

property lightning_module

Returns the pure LightningModule without potential wrappers

Return type

LightningModule

property model

Returns the potentially wrapped LightningModule

Return type

Module

abstract property on_gpu

Returns whether the current process is done on GPU

Return type

bool

property results

Enables plugin-agnostic access to the result returned by the training/evaluation/prediction run. The result is cached instead of returned directly, because some plugins require transmitting the results from one multiprocessing context to another in a separate step. For example, the plugins that use the “spawn” start-method send the result to the master process through a multiprocessing queue (shared memory).

Return type

Union[List[Dict[str, float]], List[Any], List[List[Any]], None]

abstract property root_device

Returns the root device

Return type

device

property setup_optimizers_in_pre_dispatch

Override to delay setting optimizers and schedulers till after dispatch. This is useful when the TrainingTypePlugin requires operating on the wrapped accelerator model. However this may break certain precision plugins such as APEX which require optimizers to be set. Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.

Return type

bool