Shortcuts

pytorch_lightning.trainer.training_tricks module

class pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin[source]

Bases: abc.ABC

_TrainerTrainingTricksMixin__scale_batch_dump_params()[source]
_TrainerTrainingTricksMixin__scale_batch_reset_params(model, steps_per_trial)[source]
_TrainerTrainingTricksMixin__scale_batch_restore_params()[source]
clip_gradients()[source]
configure_accumulated_gradients(accumulate_grad_batches)[source]
detect_nan_tensors(loss)[source]
Return type

None

abstract fit(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract get_model()[source]

Warning: this is just empty shell for code implemented in other class.

Return type

LightningModule

print_nan_gradients()[source]
Return type

None

abstract restore(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract save_checkpoint(*args)[source]

Warning: this is just empty shell for code implemented in other class.

scale_batch_size(model, mode='power', steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name='batch_size')[source]

Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error.

Parameters
  • model (LightningModule) – Model to fit.

  • mode (str) – string setting the search mode. Either power or binsearch. If mode is power we keep multiplying the batch size by 2, until we get an OOM error. If mode is ‘binsearch’, we will initially also keep multiplying by 2 and after encountering an OOM error do a binary search between the last successful batch size and the batch size that failed.

  • steps_per_trial (int) – number of steps to run with a given batch size. Idealy 1 should be enough to test if a OOM error occurs, however in practise a few are needed

  • init_val (int) – initial batch size to start the search with

  • max_trials (int) – max number of increase in batch size done before algorithm is terminated

default_root_dir: str = None[source]
gradient_clip_val: ... = None[source]
on_gpu: bool = None[source]
precision: int = None[source]
progress_bar_callback: ... = None[source]
pytorch_lightning.trainer.training_tricks._adjust_batch_size(trainer, batch_arg_name='batch_size', factor=1.0, value=None, desc=None)[source]
Function for adjusting the batch size. It is expected that the user

has provided a model that has a hparam field called batch_size i.e. model.hparams.batch_size should exist.

Parameters
  • trainer – instance of pytorch_lightning.Trainer

  • batch_arg_name (str) – field where batch_size is stored in model.hparams

  • factor (float) – value which the old batch size is multiplied by to get the new batch size

  • value (Optional[int]) – if a value is given, will override the batch size with this value. Note that the value of factor will not have an effect in this case

  • desc (Optional[str]) – either succeeded or failed. Used purely for logging

pytorch_lightning.trainer.training_tricks._run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials)[source]

Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further refined using a binary search

pytorch_lightning.trainer.training_tricks._run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials)[source]

Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered.