Shortcuts

pytorch_lightning.trainer.training_loop module

The lightning training loop handles everything except the actual computations of your model.

To decide what will happen in your training loop, define the training_step function.

Below are all the things lightning automates for you in the training loop.

Accumulated gradients

Accumulated gradients runs K small batches of size N before doing a backwards pass.

The effect is a large effective batch size of size KxN.

# DEFAULT (ie: no accumulated grads)
trainer = Trainer(accumulate_grad_batches=1)

Force training for min or max epochs

It can be useful to force training for a minimum number of epochs or limit to a max number

# DEFAULT
trainer = Trainer(min_epochs=1, max_epochs=1000)

Force disable early stop

To disable early stopping pass None to the early_stop_callback

# DEFAULT
trainer = Trainer(early_stop_callback=None)

Gradient Clipping

Gradient clipping may be enabled to avoid exploding gradients.

Specifically, this will clip the gradient norm computed over all model parameters `together.

# DEFAULT (ie: don't clip)
trainer = Trainer(gradient_clip_val=0)

# clip gradients with norm above 0.5
trainer = Trainer(gradient_clip_val=0.5)

Inspect gradient norms

Looking at grad norms can help you figure out where training might be going wrong.

# DEFAULT (-1 doesn't track norms)
trainer = Trainer(track_grad_norm=-1)

# track the LP norm (P=2 here)
trainer = Trainer(track_grad_norm=2)

Set how much of the training set to check

If you don’t want to check 100% of the training set (for debugging or if it’s huge), set this flag.

limit_train_batches will be overwritten by overfit_batches if overfit_batches > 0

# DEFAULT
trainer = Trainer(limit_train_batches=1.0)

# check 10% only
trainer = Trainer(limit_train_batches=0.1)

# check 10 batches only
trainer = Trainer(limit_train_batches=10)

Packed sequences as inputs

When using PackedSequence, do 2 things: 1. return either a padded tensor in dataset or a list of variable length tensors in the dataloader collate_fn (example above shows the list implementation). 2. Pack the sequence in forward or training and validation steps depending on use case.

# For use in dataloader
def collate_fn(batch):
    x = [item[0] for item in batch]
    y = [item[1] for item in batch]
    return x, y

# In module
def training_step(self, batch, batch_idx):
    x = rnn.pack_sequence(batch[0], enforce_sorted=False)
    y = rnn.pack_sequence(batch[1], enforce_sorted=False)

Truncated Backpropagation Through Time

There are times when multiple backwards passes are needed for each batch.

For example, it may save memory to use Truncated Backpropagation Through Time when training RNNs.

When this flag is enabled each batch is split into sequences of size truncated_bptt_steps

and passed to training_step(…) separately. A default splitting function is provided, however, you can override it for more flexibility. See tbptt_split_batch.

# DEFAULT (single backwards pass per batch)
trainer = Trainer(truncated_bptt_steps=None)

# (split batch into sequences of size 2)
trainer = Trainer(truncated_bptt_steps=2)

NaN detection and intervention

When the terminate_on_nan flag is enabled, after every forward pass during training, Lightning will check that

  1. the loss you return in training_step is finite (not NaN and not +/-inf)

  2. the model parameters have finite values.

Lightning will terminate the training loop with an error message if NaN or infinite values are detected. If this happens, you should investigate numerically unstable operations in your model.

# DEFAULT (won't perform the NaN check)
trainer = Trainer(terminate_on_nan=False)

# (NaN check each batch and terminate on NaN or infinite values)
trainer = Trainer(terminate_on_nan=True)
class pytorch_lightning.trainer.training_loop.TrainerTrainLoopMixin[source]

Bases: abc.ABC

_get_optimizers_iterable()[source]
abstract add_progress_bar_metrics(*args)[source]

Warning: this is just empty shell for code implemented in other class.

call_optimizer_step(optimizer, opt_idx, batch_idx, split_batch)[source]
check_checkpoint_callback(should_check_val)[source]
abstract clip_gradients()[source]

Warning: this is just empty shell for code implemented in other class.

abstract detect_nan_tensors(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract get_model()[source]

Warning: this is just empty shell for code implemented in other class.

Return type

LightningModule

abstract has_arg(*args)[source]

Warning: this is just empty shell for code implemented in other class.

increment_accumulated_grad_global_step()[source]
abstract is_function_implemented(*args, **kwargs)[source]

Warning: this is just empty shell for code implemented in other class.

abstract is_overridden(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract log_metrics(*args)[source]

Warning: this is just empty shell for code implemented in other class.

optimizer_closure(split_batch, batch_idx, opt_idx, optimizer, hiddens)[source]

wrap the forward step in a closure so second order methods work

prepare_train_loop_dataloader(train_dataloader)[source]
abstract process_output(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract reset_train_dataloader(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract reset_val_dataloader(model)[source]

Warning: this is just empty shell for code implemented in other class.

run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer)[source]
abstract run_evaluation(*args, **kwargs)[source]

Warning: this is just empty shell for code implemented in other class.

run_on_epoch_end_hook(model)[source]
run_on_epoch_start_hook(model)[source]
run_training_batch(batch, batch_idx)[source]
run_training_epoch()[source]
run_training_epoch_end(epoch_output)[source]
run_training_teardown()[source]
save_loggers_in_training_loop(batch_idx)[source]
save_train_loop_metrics_to_loggers(batch_idx, batch_output)[source]
should_check_val(batch_idx, is_last_batch)[source]
sync_horovod()[source]
train()[source]
training_forward(batch, batch_idx, opt_idx, hiddens)[source]

Handle forward for each training case (distributed, single gpu, etc…) :param _sphinx_paramlinks_pytorch_lightning.trainer.training_loop.TrainerTrainLoopMixin.training_forward.batch: :param _sphinx_paramlinks_pytorch_lightning.trainer.training_loop.TrainerTrainLoopMixin.training_forward.batch_idx: :return:

abstract transfer_batch_to_gpu(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract transfer_batch_to_tpu(*args)[source]

Warning: this is just empty shell for code implemented in other class.

update_learning_rates(interval)[source]

Update learning rates.

Parameters

interval (str) – either ‘epoch’ or ‘step’.

update_train_loop_lr_schedulers()[source]
accumulate_grad_batches: int = None[source]
accumulation_scheduler: ... = None[source]
batch_idx: int = None[source]
callback_metrics: ... = None[source]
callbacks: List[Callback] = None[source]
check_val_every_n_epoch: ... = None[source]
data_parallel_device_ids: ... = None[source]
disable_validation: bool = None[source]
early_stop_callback: ... = None[source]
fast_dev_run: ... = None[source]
global_rank: int = None[source]
global_step: int = None[source]
interactive_ddp_procs: ... = None[source]
interrupted: bool = None[source]
log_save_interval: float = None[source]
logger: Union[LightningLoggerBase, bool] = None[source]
lr_schedulers: ... = None[source]
max_epochs: int = None[source]
max_steps: int = None[source]
min_epochs: int = None[source]
min_steps: int = None[source]
model: LightningModule = None[source]
num_training_batches: int = None[source]
on_batch_end: Callable = None[source]
on_batch_start: Callable = None[source]
on_epoch_end: Callable = None[source]
on_epoch_start: Callable = None[source]
on_gpu: bool = None[source]
on_keyboard_interrupt: Callable = None[source]
on_train_end: Callable = None[source]
on_train_start: Callable = None[source]
on_validation_end: Callable = None[source]
optimizer_frequencies: ... = None[source]
optimizers: ... = None[source]
precision: ... = None[source]
profiler: ... = None[source]
progress_bar_dict: ... = None[source]
reduce_lr_on_plateau_scheduler: ... = None[source]
reload_dataloaders_every_epoch: bool = None[source]
row_log_interval: float = None[source]
running_loss: ... = None[source]
single_gpu: bool = None[source]
terminate_on_nan: bool = None[source]
testing: bool = None[source]
total_batch_idx: int = None[source]
tpu_id: int = None[source]
track_grad_norm: ... = None[source]
train_dataloader: DataLoader = None[source]
truncated_bptt_steps: ... = None[source]
use_ddp: bool = None[source]
use_ddp2: bool = None[source]
use_dp: bool = None[source]
use_horovod: bool = None[source]
use_tpu: bool = None[source]
val_check_batch: ... = None[source]
pytorch_lightning.trainer.training_loop._with_is_last(iterable)[source]

Pass through values from the given iterable with an added boolean indicating if this is the last item. See https://stackoverflow.com/a/1630350