Shortcuts

pytorch_lightning.trainer.trainer module

class pytorch_lightning.trainer.trainer.Trainer(logger=True, checkpoint_callback=True, early_stop_callback=False, callbacks=None, default_root_dir=None, gradient_clip_val=0, process_position=0, num_nodes=1, num_processes=1, gpus=None, auto_select_gpus=False, num_tpu_cores=None, log_gpu_memory=None, progress_bar_refresh_rate=1, overfit_pct=0.0, track_grad_norm=-1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, max_epochs=1000, min_epochs=1, max_steps=None, min_steps=None, train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=1.0, log_save_interval=100, row_log_interval=10, add_row_log_interval=None, distributed_backend=None, precision=32, print_nan_grads=False, weights_summary='full', weights_save_path=None, num_sanity_val_steps=2, truncated_bptt_steps=None, resume_from_checkpoint=None, profiler=None, benchmark=False, deterministic=False, reload_dataloaders_every_epoch=False, auto_lr_find=False, replace_sampler_ddp=True, progress_bar_callback=True, terminate_on_nan=False, auto_scale_batch_size=False, amp_level='O1', default_save_path=None, gradient_clip=None, nb_gpu_nodes=None, max_nb_epochs=None, min_nb_epochs=None, use_amp=None, show_progress_bar=None, nb_sanity_val_steps=None, **kwargs)[source]

Bases: pytorch_lightning.trainer.training_io.TrainerIOMixin, pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin, pytorch_lightning.trainer.auto_mix_precision.TrainerAMPMixin, pytorch_lightning.trainer.distrib_parts.TrainerDPMixin, pytorch_lightning.trainer.distrib_data_parallel.TrainerDDPMixin, pytorch_lightning.trainer.logging.TrainerLoggingMixin, pytorch_lightning.trainer.model_hooks.TrainerModelHooksMixin, pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin, pytorch_lightning.trainer.data_loading.TrainerDataLoadingMixin, pytorch_lightning.trainer.evaluation_loop.TrainerEvaluationLoopMixin, pytorch_lightning.trainer.training_loop.TrainerTrainLoopMixin, pytorch_lightning.trainer.callback_config.TrainerCallbackConfigMixin, pytorch_lightning.trainer.callback_hook.TrainerCallbackHookMixin, pytorch_lightning.trainer.lr_finder.TrainerLRFinderMixin, pytorch_lightning.trainer.deprecated_api.TrainerDeprecatedAPITillVer0_8, pytorch_lightning.trainer.deprecated_api.TrainerDeprecatedAPITillVer0_9

Customize every aspect of training via flags

Parameters
  • logger (Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]) – Logger (or iterable collection of loggers) for experiment tracking.

  • checkpoint_callback (Union[ModelCheckpoint, bool]) – Callback for checkpointing.

  • early_stop_callback (pytorch_lightning.callbacks.EarlyStopping) –

  • callbacks (Optional[List[Callback]]) – Add a list of callbacks.

  • default_root_dir (Optional[str]) – Default path for logs and weights when no logger/ckpt_callback passed

  • default_save_path

    Warning

    Deprecated since version 0.7.3.

    Use default_root_dir instead. Will remove 0.9.0.

  • gradient_clip_val (float) – 0 means don’t clip.

  • gradient_clip

    Warning

    Deprecated since version 0.7.0.

    Use gradient_clip_val instead. Will remove 0.9.0.

  • process_position (int) – orders the progress bar when running multiple models on same machine.

  • num_nodes (int) – number of GPU nodes for distributed training.

  • nb_gpu_nodes

    Warning

    Deprecated since version 0.7.0.

    Use num_nodes instead. Will remove 0.9.0.

  • gpus (Union[List[int], str, int, None]) – Which GPUs to train on.

  • auto_select_gpus (bool) – If enabled and gpus is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in “exclusive mode”, such that only one process at a time can access them.

  • num_tpu_cores (Optional[int]) – How many TPU cores to train on (1 or 8).

  • log_gpu_memory (Optional[str]) – None, ‘min_max’, ‘all’. Might slow performance

  • show_progress_bar

    Warning

    Deprecated since version 0.7.2.

    Set progress_bar_refresh_rate to positive integer to enable. Will remove 0.9.0.

  • progress_bar_refresh_rate (int) – How often to refresh progress bar (in steps). Value 0 disables progress bar. Ignored when a custom callback is passed to callbacks.

  • overfit_pct (float) – How much of training-, validation-, and test dataset to check.

  • track_grad_norm (int) – -1 no tracking. Otherwise tracks that norm

  • check_val_every_n_epoch (int) – Check val every n train epochs.

  • fast_dev_run (bool) – runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).

  • accumulate_grad_batches (Union[int, Dict[int, int], List[list]]) – Accumulates grads every k batches or as set up in the dict.

  • max_epochs (int) – Stop training once this number of epochs is reached.

  • max_nb_epochs

    Warning

    Deprecated since version 0.7.0.

    Use max_epochs instead. Will remove 0.9.0.

  • min_epochs (int) – Force training for at least these many epochs

  • min_nb_epochs

    Warning

    Deprecated since version 0.7.0.

    Use min_epochs instead. Will remove 0.9.0.

  • max_steps (Optional[int]) – Stop training after this number of steps. Disabled by default (None).

  • min_steps (Optional[int]) – Force training for at least these number of steps. Disabled by default (None).

  • train_percent_check (float) – How much of training dataset to check.

  • val_percent_check (float) – How much of validation dataset to check.

  • test_percent_check (float) – How much of test dataset to check.

  • val_check_interval (float) – How often within one training epoch to check the validation set

  • log_save_interval (int) – Writes logs to disk this often

  • row_log_interval (int) – How often to add logging rows (does not write to disk)

  • add_row_log_interval

    Warning

    Deprecated since version 0.7.0.

    Use row_log_interval instead. Will remove 0.9.0.

  • distributed_backend (Optional[str]) – The distributed backend to use.

  • use_amp

    Warning

    Deprecated since version 0.7.0.

    Use precision instead. Will remove 0.9.0.

  • precision (int) – Full precision (32), half precision (16).

  • print_nan_grads (bool) –

    Warning

    Deprecated since version 0.7.2.

    Has no effect. When detected, NaN grads will be printed automatically. Will remove 0.9.0.

  • weights_summary (Optional[str]) – Prints a summary of the weights when training begins.

  • weights_save_path (Optional[str]) – Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in default_root_dir.

  • amp_level (str) – The optimization level to use (O1, O2, etc…).

  • num_sanity_val_steps (int) – Sanity check runs n batches of val before starting the training routine.

  • nb_sanity_val_steps

    Warning

    Deprecated since version 0.7.0.

    Use num_sanity_val_steps instead. Will remove 0.8.0.

  • truncated_bptt_steps (Optional[int]) – Truncated back prop breaks performs backprop every k steps of

  • resume_from_checkpoint (Optional[str]) – To resume training from a specific checkpoint pass in the path here.

  • profiler (Union[BaseProfiler, bool, None]) – To profile individual steps during training and assist in

  • reload_dataloaders_every_epoch (bool) – Set to True to reload dataloaders every epoch

  • auto_lr_find (Union[bool, str]) – If set to True, will initially run a learning rate finder, trying to optimize initial learning for faster convergence. Sets learning rate in self.hparams.lr | self.hparams.learning_rate in the lightning module. To use a different key, set a string instead of True with the key name.

  • replace_sampler_ddp (bool) – Explicitly enables or disables sampler replacement. If not specified this will toggled automatically ddp is used

  • benchmark (bool) – If true enables cudnn.benchmark.

  • deterministic (bool) – If true enables cudnn.deterministic

  • terminate_on_nan (bool) – If set to True, will terminate training (by raising a ValueError) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

  • auto_scale_batch_size (Union[str, bool]) – If set to True, will initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.hparams.batch_size in the LightningModule. Additionally, can be set to either power that estimates the batch size through a power search or binsearch that estimates the batch size through a binary search.

_Trainer__attach_dataloaders(model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None)[source]
_Trainer__set_random_port()[source]

When running DDP NOT managed by SLURM, the ports might collide :return:

_allowed_type()[source]
Return type

Union[int, str]

_arg_default()[source]
Return type

Union[int, str]

classmethod add_argparse_args(parent_parser)[source]

Extends existing argparse by default Trainer attributes.

Parameters

parent_parser (ArgumentParser) – The custom cli arguments parser, which will be extended by the Trainer default arguments.

Only arguments of the allowed types (str, float, int, bool) will extend the parent_parser.

Examples

>>> import argparse
>>> import pprint
>>> parser = argparse.ArgumentParser()
>>> parser = Trainer.add_argparse_args(parser)
>>> args = parser.parse_args([])
>>> pprint.pprint(vars(args))  
{...
 'check_val_every_n_epoch': 1,
 'checkpoint_callback': True,
 'default_root_dir': None,
 'deterministic': False,
 'distributed_backend': None,
 'early_stop_callback': False,
 ...
 'logger': True,
 'max_epochs': 1000,
 'max_steps': None,
 'min_epochs': 1,
 'min_steps': None,
 ...
 'profiler': None,
 'progress_bar_callback': True,
 'progress_bar_refresh_rate': 1,
 ...}
Return type

ArgumentParser

check_model_configuration(model)[source]

Checks that the model is configured correctly before training is started.

Parameters

model (LightningModule) – The model to test.

check_testing_model_configuration(model)[source]
classmethod default_attributes()[source]
fit(model, train_dataloader=None, val_dataloaders=None)[source]

Runs the full optimization routine.

Parameters
  • model (LightningModule) – Model to fit.

  • train_dataloader (Optional[DataLoader]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.

  • val_dataloaders (Union[DataLoader, List[DataLoader], None]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped

Example:

# Option 1,
# Define the train_dataloader() and val_dataloader() fxs
# in the lightningModule
# RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
trainer = Trainer()
model = LightningModule()
trainer.fit(model)

# Option 2
# in production cases we might want to pass different datasets to the same model
# Recommended for PRODUCTION SYSTEMS
train, val = DataLoader(...), DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model, train_dataloader=train, val_dataloader=val)

# Option 1 & 2 can be mixed, for example the training set can be
# defined as part of the model, and validation can then be feed to .fit()
classmethod from_argparse_args(args, **kwargs)[source]

create an instance from CLI arguments

Example

>>> parser = ArgumentParser(add_help=False)
>>> parser = Trainer.add_argparse_args(parser)
>>> args = Trainer.parse_argparser(parser.parse_args(""))
>>> trainer = Trainer.from_argparse_args(args)
Return type

Trainer

classmethod get_deprecated_arg_names()[source]

Returns a list with deprecated Trainer arguments.

Return type

List

classmethod get_init_arguments_and_types()[source]

Scans the Trainer signature and returns argument names, types and default values.

Returns

(argument name, set with argument types, argument default value).

Return type

List with tuples of 3 values

Examples

>>> args = Trainer.get_init_arguments_and_types()
>>> import pprint
>>> pprint.pprint(sorted(args))  
[('accumulate_grad_batches',
  (<class 'int'>, typing.Dict[int, int], typing.List[list]),
  1),
 ...
 ('callbacks',
  (typing.List[pytorch_lightning.callbacks.base.Callback],
   <class 'NoneType'>),
   None),
 ('check_val_every_n_epoch', (<class 'int'>,), 1),
 ...
 ('max_epochs', (<class 'int'>,), 1000),
 ...
 ('precision', (<class 'int'>,), 32),
 ('print_nan_grads', (<class 'bool'>,), False),
 ('process_position', (<class 'int'>,), 0),
 ('profiler',
  (<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>,
   <class 'bool'>,
   <class 'NoneType'>),
  None),
 ...
static parse_argparser(arg_parser)[source]

Parse CLI arguments, required for custom bool types.

Return type

Namespace

run_pretrain_routine(model)[source]

Sanity check a few things before starting actual training.

Parameters

model (LightningModule) – The model to run sanity test on.

test(model=None, test_dataloaders=None)[source]

Separates from fit to make sure you never run on your test set until you want to.

Parameters

Example:

# Option 1
# run test after fitting
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()

trainer.fit(model)
trainer.test(test_dataloaders=test)

# Option 2
# run test from a loaded model
test = DataLoader(...)
model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
trainer = Trainer()
trainer.test(model, test_dataloaders=test)
DEPRECATED_IN_0_8 = ('gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs', 'min_nb_epochs', 'add_row_log_interval', 'nb_sanity_val_steps', 'tng_tqdm_dic')[source]
DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict')[source]
accumulate_grad_batches = None[source]
checkpoint_callback = None[source]
property data_parallel[source]
Return type

bool

early_stop_callback = None[source]
logger = None[source]
lr_schedulers = None[source]
model = None[source]
property num_gpus[source]

this is just empty shell for code implemented in other class.

Type

Warning

Return type

int

num_training_batches = None[source]
on_gpu = None[source]
on_tpu = None[source]
optimizers = None[source]
proc_rank = None[source]
property progress_bar_dict[source]

Read-only for progress bar metrics.

Return type

dict

resume_from_checkpoint = None[source]
root_gpu = None[source]
property slurm_job_id[source]

this is just empty shell for code implemented in other class.

Type

Warning

Return type

int

use_ddp = None[source]
use_ddp2 = None[source]
use_horovod = None[source]
weights_save_path = None[source]
class pytorch_lightning.trainer.trainer._PatchDataLoader(dataloader)[source]

Bases: object

Callable object for patching dataloaders passed into trainer.fit(). Use this class to override model.*_dataloader() and be pickle-compatible.

Parameters

dataloader (Union[List[DataLoader], DataLoader]) – Dataloader object to return when called.

__call__()[source]

Call self as a function.

Return type

Union[List[DataLoader], DataLoader]