Shortcuts

pytorch_lightning.trainer.trainer module

class pytorch_lightning.trainer.trainer.Trainer(logger=True, checkpoint_callback=True, early_stop_callback=False, callbacks=None, default_root_dir=None, gradient_clip_val=0, process_position=0, num_nodes=1, num_processes=1, gpus=None, auto_select_gpus=False, tpu_cores=None, log_gpu_memory=None, progress_bar_refresh_rate=1, overfit_batches=0.0, track_grad_norm=-1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, max_epochs=1000, min_epochs=1, max_steps=None, min_steps=None, limit_train_batches=1.0, limit_val_batches=1.0, limit_test_batches=1.0, val_check_interval=1.0, log_save_interval=100, row_log_interval=50, distributed_backend=None, precision=32, print_nan_grads=False, weights_summary='top', weights_save_path=None, num_sanity_val_steps=2, truncated_bptt_steps=None, resume_from_checkpoint=None, profiler=None, benchmark=False, deterministic=False, reload_dataloaders_every_epoch=False, auto_lr_find=False, replace_sampler_ddp=True, terminate_on_nan=False, auto_scale_batch_size=False, prepare_data_per_node=True, amp_level='O2', num_tpu_cores=None, use_amp=None, show_progress_bar=None, val_percent_check=None, test_percent_check=None, train_percent_check=None, overfit_pct=None)[source]

Bases: pytorch_lightning.trainer.training_io.TrainerIOMixin, pytorch_lightning.trainer.callback_hook.TrainerCallbackHookMixin, pytorch_lightning.trainer.model_hooks.TrainerModelHooksMixin, pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin, pytorch_lightning.trainer.auto_mix_precision.TrainerAMPMixin, pytorch_lightning.trainer.distrib_parts.TrainerDPMixin, pytorch_lightning.trainer.distrib_data_parallel.TrainerDDPMixin, pytorch_lightning.trainer.logging.TrainerLoggingMixin, pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin, pytorch_lightning.trainer.data_loading.TrainerDataLoadingMixin, pytorch_lightning.trainer.evaluation_loop.TrainerEvaluationLoopMixin, pytorch_lightning.trainer.training_loop.TrainerTrainLoopMixin, pytorch_lightning.trainer.callback_config.TrainerCallbackConfigMixin, pytorch_lightning.trainer.lr_finder.TrainerLRFinderMixin, pytorch_lightning.trainer.deprecated_api.TrainerDeprecatedAPITillVer0_9, pytorch_lightning.trainer.deprecated_api.TrainerDeprecatedAPITillVer0_10

Example

>>> import torch
>>> from torch.nn import functional as F
>>> from torch.utils.data import Dataset, DataLoader
>>> # Define model
>>> class SimpleModel(LightningModule):
...     def __init__(self):
...         super().__init__()
...         self.l1 = torch.nn.Linear(in_features=64, out_features=4)
...
...     def forward(self, x):
...         return torch.relu(self.l1(x.view(x.size(0), -1)))
...
...     def training_step(self, batch, batch_nb):
...         x, y = batch
...         loss = F.cross_entropy(self(x), y)
...         return {'loss': loss, 'log': {'train_loss': loss}}
...
...     def test_step(self, batch, batch_nb):
...         x, y = batch
...         loss = F.cross_entropy(self(x), y)
...         return {'loss': loss, 'log': {'train_loss': loss}}
...
...     def configure_optimizers(self):
...         return torch.optim.Adam(self.parameters(), lr=0.02)
...
>>> # Define dataset
>>> class SimpleDataset(Dataset):
...     def __init__(self, num_samples=200):
...         self.input_seq = torch.randn(num_samples, 64)
...         self.output_seq = torch.randint(0, 4, (num_samples,))
...
...     def __len__(self):
...         return len(self.input_seq)
...
...     def __getitem__(self, item):
...         return self.input_seq[item], self.output_seq[item]
...
>>> train_loader = DataLoader(SimpleDataset(), batch_size=8)
>>> model = SimpleModel()
>>> # Define Trainer and fit model
>>> trainer = Trainer(max_epochs=1, progress_bar_refresh_rate=0)
>>> trainer.fit(model, train_loader)
1
>>> trainer.test(model, train_loader)  
1

Customize every aspect of training via flags

Parameters
  • logger (Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]) – Logger (or iterable collection of loggers) for experiment tracking.

  • checkpoint_callback (Union[ModelCheckpoint, bool]) – Callback for checkpointing.

  • early_stop_callback (pytorch_lightning.callbacks.EarlyStopping) –

  • callbacks (Optional[List[Callback]]) – Add a list of callbacks.

  • default_root_dir (Optional[str]) – Default path for logs and weights when no logger/ckpt_callback passed

  • gradient_clip_val (float) – 0 means don’t clip.

  • gradient_clip

    Warning

    Deprecated since version 0.7.0.

    Use gradient_clip_val instead. Will remove 0.9.0.

  • process_position (int) – orders the progress bar when running multiple models on same machine.

  • num_nodes (int) – number of GPU nodes for distributed training.

  • nb_gpu_nodes

    Warning

    Deprecated since version 0.7.0.

    Use num_nodes instead. Will remove 0.9.0.

  • gpus (Union[int, str, List[int], None]) – Which GPUs to train on.

  • auto_select_gpus (bool) – If enabled and gpus is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in “exclusive mode”, such that only one process at a time can access them.

  • tpu_cores (Union[int, str, List[int], None]) – How many TPU cores to train on (1 or 8) / Single TPU to train on [1]

  • num_tpu_cores (Optional[int]) – How many TPU cores to train on (1 or 8) .. warning:: .. deprecated:: 0.7.6. Will remove 0.9.0.

  • log_gpu_memory (Optional[str]) – None, ‘min_max’, ‘all’. Might slow performance

  • show_progress_bar

    Warning

    Deprecated since version 0.7.2.

    Set progress_bar_refresh_rate to positive integer to enable. Will remove 0.9.0.

  • progress_bar_refresh_rate (int) – How often to refresh progress bar (in steps). Value 0 disables progress bar. Ignored when a custom callback is passed to callbacks.

  • overfit_batches (Union[int, float]) – Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0

  • overfit_pct (Optional[float]) –

    Warning

    Deprecated since version 0.8.0.

    Use overfit_batches instead. Will be removed in 0.10.0.

  • track_grad_norm (Union[int, float, str]) – -1 no tracking. Otherwise tracks that p-norm. May be set to ‘inf’ infinity-norm.

  • check_val_every_n_epoch (int) – Check val every n train epochs.

  • fast_dev_run (bool) – runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).

  • accumulate_grad_batches (Union[int, Dict[int, int], List[list]]) – Accumulates grads every k batches or as set up in the dict.

  • max_epochs (int) – Stop training once this number of epochs is reached.

  • max_nb_epochs

    Warning

    Deprecated since version 0.7.0.

    Use max_epochs instead. Will remove 0.9.0.

  • min_epochs (int) – Force training for at least these many epochs

  • min_nb_epochs

    Warning

    Deprecated since version 0.7.0.

    Use min_epochs instead. Will remove 0.9.0.

  • max_steps (Optional[int]) – Stop training after this number of steps. Disabled by default (None).

  • min_steps (Optional[int]) – Force training for at least these number of steps. Disabled by default (None).

  • limit_train_batches (Union[int, float]) – How much of training dataset to check (floats = percent, int = num_batches)

  • limit_val_batches (Union[int, float]) – How much of validation dataset to check (floats = percent, int = num_batches)

  • limit_test_batches (Union[int, float]) – How much of test dataset to check (floats = percent, int = num_batches)

  • train_percent_check (Optional[float]) –

    Warning

    Deprecated since version 0.8.0.

    Use limit_train_batches instead. Will remove v0.10.0.

  • val_percent_check (Optional[float]) –

    Warning

    Deprecated since version 0.8.0.

    Use limit_val_batches instead. Will remove v0.10.0.

  • test_percent_check (Optional[float]) –

    Warning

    Deprecated since version 0.8.0.

    Use limit_test_batches instead. Will remove v0.10.0.

  • val_check_interval (Union[int, float]) – How often within one training epoch to check the validation set

  • log_save_interval (int) – Writes logs to disk this often

  • row_log_interval (int) – How often to add logging rows (does not write to disk)

  • add_row_log_interval

    Warning

    Deprecated since version 0.7.0.

    Use row_log_interval instead. Will remove 0.9.0.

  • distributed_backend (Optional[str]) – The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu)

  • use_amp

    Warning

    Deprecated since version 0.7.0.

    Use precision instead. Will remove 0.9.0.

  • precision (int) – Full precision (32), half precision (16).

  • print_nan_grads (bool) –

    Warning

    Deprecated since version 0.7.2.

    Has no effect. When detected, NaN grads will be printed automatically. Will remove 0.9.0.

  • weights_summary (Optional[str]) – Prints a summary of the weights when training begins.

  • weights_save_path (Optional[str]) – Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in default_root_dir.

  • amp_level (str) – The optimization level to use (O1, O2, etc…).

  • num_sanity_val_steps (int) – Sanity check runs n batches of val before starting the training routine.

  • truncated_bptt_steps (Optional[int]) – Truncated back prop breaks performs backprop every k steps of

  • resume_from_checkpoint (Optional[str]) – To resume training from a specific checkpoint pass in the path here. This can be a URL.

  • profiler (Union[BaseProfiler, bool, None]) – To profile individual steps during training and assist in

  • reload_dataloaders_every_epoch (bool) – Set to True to reload dataloaders every epoch

  • auto_lr_find (Union[bool, str]) – If set to True, will initially run a learning rate finder, trying to optimize initial learning for faster convergence. Sets learning rate in self.lr or self.learning_rate in the LightningModule. To use a different key, set a string instead of True with the key name.

  • replace_sampler_ddp (bool) – Explicitly enables or disables sampler replacement. If not specified this will toggled automatically ddp is used

  • benchmark (bool) – If true enables cudnn.benchmark.

  • deterministic (bool) – If true enables cudnn.deterministic

  • terminate_on_nan (bool) – If set to True, will terminate training (by raising a ValueError) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

  • auto_scale_batch_size (Union[str, bool]) – If set to True, will initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either power that estimates the batch size through a power search or binsearch that estimates the batch size through a binary search.

  • prepare_data_per_node (bool) – If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data

_Trainer__attach_dataloaders(model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None)[source]
_Trainer__run_ddp_spawn(model, nprocs)[source]
_Trainer__test_given_model(model, test_dataloaders)[source]
_Trainer__test_using_best_weights(ckpt_path, test_dataloaders)[source]
_allowed_type()[source]
Return type

Union[int, str]

_arg_default()[source]
Return type

Union[int, str]

classmethod add_argparse_args(parent_parser)[source]

Extends existing argparse by default Trainer attributes.

Parameters

parent_parser (ArgumentParser) – The custom cli arguments parser, which will be extended by the Trainer default arguments.

Only arguments of the allowed types (str, float, int, bool) will extend the parent_parser.

Examples

>>> import argparse
>>> import pprint
>>> parser = argparse.ArgumentParser()
>>> parser = Trainer.add_argparse_args(parser)
>>> args = parser.parse_args([])
>>> pprint.pprint(vars(args))  
{...
 'check_val_every_n_epoch': 1,
 'checkpoint_callback': True,
 'default_root_dir': None,
 'deterministic': False,
 'distributed_backend': None,
 'early_stop_callback': False,
 ...
 'logger': True,
 'max_epochs': 1000,
 'max_steps': None,
 'min_epochs': 1,
 'min_steps': None,
 ...
 'profiler': None,
 'progress_bar_refresh_rate': 1,
 ...}
Return type

ArgumentParser

barrier(name)[source]
can_prepare_data()[source]
check_model_configuration(model)[source]

Checks that the model is configured correctly before training or testing is started.

Parameters

model (LightningModule) – The model to check the configuration.

classmethod default_attributes()[source]
fit(model, train_dataloader=None, val_dataloaders=None)[source]

Runs the full optimization routine.

Parameters
  • model (LightningModule) – Model to fit.

  • train_dataloader (Optional[DataLoader]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.

  • val_dataloaders (Union[DataLoader, List[DataLoader], None]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped

Example:

# Option 1,
# Define the train_dataloader() and val_dataloader() fxs
# in the lightningModule
# RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
trainer = Trainer()
model = LightningModule()
trainer.fit(model)

# Option 2
# in production cases we might want to pass different datasets to the same model
# Recommended for PRODUCTION SYSTEMS
train, val = DataLoader(...), DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model, train_dataloader=train, val_dataloaders=val)

# Option 1 & 2 can be mixed, for example the training set can be
# defined as part of the model, and validation can then be feed to .fit()
classmethod from_argparse_args(args, **kwargs)[source]

Create an instance from CLI arguments.

Parameters
  • args (Union[Namespace, ArgumentParser]) – The parser or namespace to take arguments from. Only known arguments will be parsed and passed to the Trainer.

  • **kwargs – Additional keyword arguments that may override ones in the parser or namespace. These must be valid Trainer arguments.

Example

>>> parser = ArgumentParser(add_help=False)
>>> parser = Trainer.add_argparse_args(parser)
>>> parser.add_argument('--my_custom_arg', default='something')  
>>> args = Trainer.parse_argparser(parser.parse_args(""))
>>> trainer = Trainer.from_argparse_args(args, logger=False)
Return type

Trainer

classmethod get_deprecated_arg_names()[source]

Returns a list with deprecated Trainer arguments.

Return type

List

classmethod get_init_arguments_and_types()[source]

Scans the Trainer signature and returns argument names, types and default values.

Returns

(argument name, set with argument types, argument default value).

Return type

List with tuples of 3 values

Examples

>>> args = Trainer.get_init_arguments_and_types()
>>> import pprint
>>> pprint.pprint(sorted(args))  
[('accumulate_grad_batches',
  (<class 'int'>, typing.Dict[int, int], typing.List[list]),
  1),
 ...
 ('callbacks',
  (typing.List[pytorch_lightning.callbacks.base.Callback],
   <class 'NoneType'>),
   None),
 ('check_val_every_n_epoch', (<class 'int'>,), 1),
 ...
 ('max_epochs', (<class 'int'>,), 1000),
 ...
 ('precision', (<class 'int'>,), 32),
 ('prepare_data_per_node', (<class 'bool'>,), True),
 ('print_nan_grads', (<class 'bool'>,), False),
 ('process_position', (<class 'int'>,), 0),
 ('profiler',
  (<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>,
   <class 'bool'>,
   <class 'NoneType'>),
  None),
 ...
classmethod parse_argparser(arg_parser)[source]

Parse CLI arguments, required for custom bool types.

Return type

Namespace

run_pretrain_routine(model)[source]

Sanity check a few things before starting actual training.

Parameters

model (LightningModule) – The model to run sanity test on.

test(model=None, test_dataloaders=None, ckpt_path='best')[source]

Separates from fit to make sure you never run on your test set until you want to.

Parameters
  • model (Optional[LightningModule]) – The model to test.

  • test_dataloaders (Union[DataLoader, List[DataLoader], None]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples.

  • ckpt_path (Optional[str]) – Either best or path to the checkpoint you wish to test. If None, use the weights from the last epoch to test. Default to best.

Example:

# Option 1
# run test with the best checkpoint from ``ModelCheckpoint`` after fitting.
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()

trainer.fit(model)
trainer.test(test_dataloaders=test)

# Option 2
# run test with the specified checkpoint after fitting
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()

trainer.fit(model)
trainer.test(test_dataloaders=test, ckpt_path='path/to/checkpoint.ckpt')

# Option 3
# run test with the weights from the end of training after fitting
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()

trainer.fit(model)
trainer.test(test_dataloaders=test, ckpt_path=None)

# Option 4
# run test from a loaded model. ``ckpt_path`` is ignored in this case.
test = DataLoader(...)
model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
trainer = Trainer()
trainer.test(model, test_dataloaders=test)
DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict', 'num_tpu_cores')[source]
accumulate_grad_batches = None[source]
checkpoint_callback = None[source]
property data_parallel[source]
Return type

bool

early_stop_callback = None[source]
global_rank = None[source]
property is_global_zero[source]

this is just empty shell for code implemented in other class.

Type

Warning

Return type

bool

logger = None[source]
lr_schedulers = None[source]
model = None[source]
property num_gpus[source]

this is just empty shell for code implemented in other class.

Type

Warning

Return type

int

num_training_batches = None[source]
on_gpu = None[source]
on_tpu = None[source]
optimizers = None[source]
property progress_bar_callback[source]
property progress_bar_dict[source]

Read-only for progress bar metrics.

Return type

dict

resume_from_checkpoint = None[source]
root_gpu = None[source]
scaler = None[source]
property slurm_job_id[source]

this is just empty shell for code implemented in other class.

Type

Warning

Return type

Optional[int]

use_ddp = None[source]
use_ddp2 = None[source]
use_horovod = None[source]
weights_save_path = None[source]
class pytorch_lightning.trainer.trainer._PatchDataLoader(dataloader)[source]

Bases: object

Callable object for patching dataloaders passed into trainer.fit(). Use this class to override model.*_dataloader() and be pickle-compatible.

Parameters

dataloader (Union[List[DataLoader], DataLoader]) – Dataloader object to return when called.

__call__()[source]

Call self as a function.

Return type

Union[List[DataLoader], DataLoader]

pytorch_lightning.trainer.trainer._determine_limit_batches(batches)[source]
Return type

Union[int, float]