pytorch_lightning.trainer.training_tricks module¶
-
class
pytorch_lightning.trainer.training_tricks.
TrainerTrainingTricksMixin
[source]¶ Bases:
abc.ABC
-
abstract
get_model
()[source]¶ Warning: this is just empty shell for code implemented in other class.
- Return type
-
abstract
restore
(*args)[source]¶ Warning: this is just empty shell for code implemented in other class.
-
abstract
save_checkpoint
(*args)[source]¶ Warning: this is just empty shell for code implemented in other class.
-
scale_batch_size
(model, mode='power', steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name='batch_size')[source]¶ Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error.
- Parameters
model¶ (
LightningModule
) – Model to fit.mode¶ (
str
) – string setting the search mode. Either power or binsearch. If mode is power we keep multiplying the batch size by 2, until we get an OOM error. If mode is ‘binsearch’, we will initially also keep multiplying by 2 and after encountering an OOM error do a binary search between the last successful batch size and the batch size that failed.steps_per_trial¶ (
int
) – number of steps to run with a given batch size. Idealy 1 should be enough to test if a OOM error occurs, however in practise a few are neededinit_val¶ (
int
) – initial batch size to start the search withmax_trials¶ (
int
) – max number of increase in batch size done before algorithm is terminated
-
abstract
-
pytorch_lightning.trainer.training_tricks.
_adjust_batch_size
(trainer, batch_arg_name='batch_size', factor=1.0, value=None, desc=None)[source]¶ - Function for adjusting the batch size. It is expected that the user
has provided a model that has a hparam field called batch_size i.e. model.hparams.batch_size should exist.
- Parameters
trainer¶ – instance of pytorch_lightning.Trainer
batch_arg_name¶ (
str
) – field where batch_size is stored in model.hparamsfactor¶ (
float
) – value which the old batch size is multiplied by to get the new batch sizevalue¶ (
Optional
[int
]) – if a value is given, will override the batch size with this value. Note that the value of factor will not have an effect in this casedesc¶ (
Optional
[str
]) – either succeeded or failed. Used purely for logging
-
pytorch_lightning.trainer.training_tricks.
_run_binsearch_scaling
(trainer, model, new_size, batch_arg_name, max_trials)[source]¶ Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further refined using a binary search