Shortcuts

pytorch_lightning.trainer.trainer module

class pytorch_lightning.trainer.trainer.Trainer(logger=True, checkpoint_callback=True, early_stop_callback=False, callbacks=None, default_root_dir=None, gradient_clip_val=0, process_position=0, num_nodes=1, num_processes=1, gpus=None, auto_select_gpus=False, tpu_cores=None, log_gpu_memory=None, progress_bar_refresh_rate=1, overfit_batches=0.0, track_grad_norm=-1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, max_epochs=1000, min_epochs=1, max_steps=None, min_steps=None, limit_train_batches=1.0, limit_val_batches=1.0, limit_test_batches=1.0, val_check_interval=1.0, log_save_interval=100, row_log_interval=50, distributed_backend=None, sync_batchnorm=False, precision=32, weights_summary='top', weights_save_path=None, num_sanity_val_steps=2, truncated_bptt_steps=None, resume_from_checkpoint=None, profiler=None, benchmark=False, deterministic=False, reload_dataloaders_every_epoch=False, auto_lr_find=False, replace_sampler_ddp=True, terminate_on_nan=False, auto_scale_batch_size=False, prepare_data_per_node=True, amp_backend='native', amp_level='O2', val_percent_check=None, test_percent_check=None, train_percent_check=None, overfit_pct=None)[source]

Bases: pytorch_lightning.trainer.training_io.TrainerIOMixin, pytorch_lightning.trainer.callback_hook.TrainerCallbackHookMixin, pytorch_lightning.trainer.model_hooks.TrainerModelHooksMixin, pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin, pytorch_lightning.trainer.auto_mix_precision.TrainerAMPMixin, pytorch_lightning.trainer.distrib_parts.TrainerDPMixin, pytorch_lightning.trainer.distrib_data_parallel.TrainerDDPMixin, pytorch_lightning.trainer.logging.TrainerLoggingMixin, pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin, pytorch_lightning.trainer.data_loading.TrainerDataLoadingMixin, pytorch_lightning.trainer.evaluation_loop.TrainerEvaluationLoopMixin, pytorch_lightning.trainer.training_loop.TrainerTrainLoopMixin, pytorch_lightning.trainer.callback_config.TrainerCallbackConfigMixin, pytorch_lightning.trainer.lr_finder.TrainerLRFinderMixin, pytorch_lightning.trainer.deprecated_api.TrainerDeprecatedAPITillVer0_10

Example

>>> import torch
>>> from torch.nn import functional as F
>>> from torch.utils.data import Dataset, DataLoader
>>> # Define model
>>> class SimpleModel(LightningModule):
...     def __init__(self):
...         super().__init__()
...         self.l1 = torch.nn.Linear(in_features=64, out_features=4)
...
...     def forward(self, x):
...         return torch.relu(self.l1(x.view(x.size(0), -1)))
...
...     def training_step(self, batch, batch_nb):
...         x, y = batch
...         loss = F.cross_entropy(self(x), y)
...         return {'loss': loss, 'log': {'train_loss': loss}}
...
...     def test_step(self, batch, batch_nb):
...         x, y = batch
...         loss = F.cross_entropy(self(x), y)
...         return {'loss': loss, 'log': {'test_loss': loss}}
...
...     def configure_optimizers(self):
...         return torch.optim.Adam(self.parameters(), lr=0.02)
...
>>> # Define dataset
>>> class SimpleDataset(Dataset):
...     def __init__(self, num_samples=200):
...         self.input_seq = torch.randn(num_samples, 64)
...         self.output_seq = torch.randint(0, 4, (num_samples,))
...
...     def __len__(self):
...         return len(self.input_seq)
...
...     def __getitem__(self, item):
...         return self.input_seq[item], self.output_seq[item]
...
>>> train_loader = DataLoader(SimpleDataset(), batch_size=8)
>>> model = SimpleModel()
>>> # Define Trainer and fit model
>>> trainer = Trainer(max_epochs=1, progress_bar_refresh_rate=0)
>>> trainer.fit(model, train_loader)
1
>>> test_outputs = trainer.test(model, train_loader, verbose=False)
>>> len(test_outputs)
25

Customize every aspect of training via flags

Parameters
  • logger (Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]) – Logger (or iterable collection of loggers) for experiment tracking.

  • checkpoint_callback (Union[ModelCheckpoint, bool]) – Callback for checkpointing.

  • early_stop_callback (pytorch_lightning.callbacks.EarlyStopping) –

  • callbacks (Optional[List[Callback]]) – Add a list of callbacks.

  • default_root_dir (Optional[str]) – Default path for logs and weights when no logger/ckpt_callback passed. Default: os.getcwd(). Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’

  • gradient_clip_val (float) – 0 means don’t clip.

  • process_position (int) – orders the progress bar when running multiple models on same machine.

  • num_nodes (int) – number of GPU nodes for distributed training.

  • gpus (Union[int, str, List[int], None]) – Which GPUs to train on.

  • auto_select_gpus (bool) – If enabled and gpus is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in “exclusive mode”, such that only one process at a time can access them.

  • tpu_cores (Union[int, str, List[int], None]) – How many TPU cores to train on (1 or 8) / Single TPU to train on [1]

  • log_gpu_memory (Optional[str]) – None, ‘min_max’, ‘all’. Might slow performance

  • progress_bar_refresh_rate (int) – How often to refresh progress bar (in steps). Value 0 disables progress bar. Ignored when a custom callback is passed to callbacks.

  • overfit_batches (Union[int, float]) – Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0

  • overfit_pct (Optional[float]) –

    Warning

    Deprecated since version 0.8.0.

    Use overfit_batches instead. Will be removed in 0.10.0.

  • track_grad_norm (Union[int, float, str]) – -1 no tracking. Otherwise tracks that p-norm. May be set to ‘inf’ infinity-norm.

  • check_val_every_n_epoch (int) – Check val every n train epochs.

  • fast_dev_run (bool) – runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).

  • accumulate_grad_batches (Union[int, Dict[int, int], List[list]]) – Accumulates grads every k batches or as set up in the dict.

  • max_epochs (int) – Stop training once this number of epochs is reached.

  • min_epochs (int) – Force training for at least these many epochs

  • max_steps (Optional[int]) – Stop training after this number of steps. Disabled by default (None).

  • min_steps (Optional[int]) – Force training for at least these number of steps. Disabled by default (None).

  • limit_train_batches (Union[int, float]) – How much of training dataset to check (floats = percent, int = num_batches)

  • limit_val_batches (Union[int, float]) – How much of validation dataset to check (floats = percent, int = num_batches)

  • limit_test_batches (Union[int, float]) – How much of test dataset to check (floats = percent, int = num_batches)

  • train_percent_check (Optional[float]) –

    Warning

    Deprecated since version 0.8.0.

    Use limit_train_batches instead. Will remove v0.10.0.

  • val_percent_check (Optional[float]) –

    Warning

    Deprecated since version 0.8.0.

    Use limit_val_batches instead. Will remove v0.10.0.

  • test_percent_check (Optional[float]) –

    Warning

    Deprecated since version 0.8.0.

    Use limit_test_batches instead. Will remove v0.10.0.

  • val_check_interval (Union[int, float]) – How often to check the validation set. Use float to check within a training epoch, use int to check every n steps (batches).

  • log_save_interval (int) – Writes logs to disk this often

  • row_log_interval (int) – How often to add logging rows (does not write to disk)

  • distributed_backend (Optional[str]) – The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu)

  • sync_batchnorm (bool) – Synchronize batch norm layers between process groups/whole world.

  • precision (int) – Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.

  • weights_summary (Optional[str]) – Prints a summary of the weights when training begins.

  • weights_save_path (Optional[str]) – Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in default_root_dir. Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’ Defaults to default_root_dir.

  • amp_backend (str) – The mixed precision backend to use (“native” or “apex”)

  • amp_level (str) – The optimization level to use (O1, O2, etc…).

  • num_sanity_val_steps (int) – Sanity check runs n validation batches before starting the training routine. Set it to -1 to run all batches in all validation dataloaders. Default: 2

  • truncated_bptt_steps (Optional[int]) – Truncated back prop breaks performs backprop every k steps of much longer sequence.

  • resume_from_checkpoint (Optional[str]) – To resume training from a specific checkpoint pass in the path here. This can be a URL.

  • profiler (Union[BaseProfiler, bool, None]) – To profile individual steps during training and assist in identifying bottlenecks.

  • reload_dataloaders_every_epoch (bool) – Set to True to reload dataloaders every epoch.

  • auto_lr_find (Union[bool, str]) – If set to True, will initially run a learning rate finder, trying to optimize initial learning for faster convergence. Sets learning rate in self.lr or self.learning_rate in the LightningModule. To use a different key, set a string instead of True with the key name.

  • replace_sampler_ddp (bool) – Explicitly enables or disables sampler replacement. If not specified this will toggled automatically when DDP is used. By default it will add shuffle=True for train sampler and shuffle=False for val/test sampler. If you want to customize it, you can set replace_sampler_ddp=False and add your own distributed sampler.

  • benchmark (bool) – If true enables cudnn.benchmark.

  • deterministic (bool) – If true enables cudnn.deterministic.

  • terminate_on_nan (bool) – If set to True, will terminate training (by raising a ValueError) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

  • auto_scale_batch_size (Union[str, bool]) – If set to True, will initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either power that estimates the batch size through a power search or binsearch that estimates the batch size through a binary search.

  • prepare_data_per_node (bool) – If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data

_Trainer__attach_dataloaders(model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None)[source]
_Trainer__attach_datamodule(model, datamodule, stage)[source]
_Trainer__test_given_model(model, test_dataloaders)[source]
_Trainer__test_using_best_weights(ckpt_path, test_dataloaders)[source]
_gpus_allowed_type()[source]
Return type

Union[int, str]

_gpus_arg_default()[source]
Return type

Union[int, str]

_int_or_float_type()[source]
Return type

Union[int, float]

_run_sanity_check(ref_model, model)[source]
classmethod add_argparse_args(parent_parser)[source]

Extends existing argparse by default Trainer attributes.

Parameters

parent_parser (ArgumentParser) – The custom cli arguments parser, which will be extended by the Trainer default arguments.

Only arguments of the allowed types (str, float, int, bool) will extend the parent_parser.

Examples

>>> import argparse
>>> import pprint
>>> parser = argparse.ArgumentParser()
>>> parser = Trainer.add_argparse_args(parser)
>>> args = parser.parse_args([])
>>> pprint.pprint(vars(args))  
{...
 'check_val_every_n_epoch': 1,
 'checkpoint_callback': True,
 'default_root_dir': None,
 'deterministic': False,
 'distributed_backend': None,
 'early_stop_callback': False,
 ...
 'logger': True,
 'max_epochs': 1000,
 'max_steps': None,
 'min_epochs': 1,
 'min_steps': None,
 ...
 'profiler': None,
 'progress_bar_refresh_rate': 1,
 ...}
Return type

ArgumentParser

barrier(name)[source]
call_setup_hook(model)[source]

Warning: this is just empty shell for code implemented in other class.

can_prepare_data()[source]
classmethod default_attributes()[source]
fit(model, train_dataloader=None, val_dataloaders=None, datamodule=None)[source]

Runs the full optimization routine.

Parameters
  • model (LightningModule) – Model to fit.

  • train_dataloader (Optional[DataLoader]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.

  • val_dataloaders (Union[DataLoader, List[DataLoader], None]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped

Example:

# Option 1,
# Define the train_dataloader() and val_dataloader() fxs
# in the lightningModule
# RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
trainer = Trainer()
model = LightningModule()
trainer.fit(model)

# Option 2
# in production cases we might want to pass different datasets to the same model
# Recommended for PRODUCTION SYSTEMS
train, val = DataLoader(...), DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model, train_dataloader=train, val_dataloaders=val)

# Option 1 & 2 can be mixed, for example the training set can be
# defined as part of the model, and validation can then be feed to .fit()
classmethod from_argparse_args(args, **kwargs)[source]

Create an instance from CLI arguments.

Parameters
  • args (Union[Namespace, ArgumentParser]) – The parser or namespace to take arguments from. Only known arguments will be parsed and passed to the Trainer.

  • **kwargs – Additional keyword arguments that may override ones in the parser or namespace. These must be valid Trainer arguments.

Example

>>> parser = ArgumentParser(add_help=False)
>>> parser = Trainer.add_argparse_args(parser)
>>> parser.add_argument('--my_custom_arg', default='something')  
>>> args = Trainer.parse_argparser(parser.parse_args(""))
>>> trainer = Trainer.from_argparse_args(args, logger=False)
Return type

Trainer

classmethod get_deprecated_arg_names()[source]

Returns a list with deprecated Trainer arguments.

Return type

List

classmethod get_init_arguments_and_types()[source]

Scans the Trainer signature and returns argument names, types and default values.

Returns

(argument name, set with argument types, argument default value).

Return type

List with tuples of 3 values

Examples

>>> args = Trainer.get_init_arguments_and_types()
>>> import pprint
>>> pprint.pprint(sorted(args))  
[('accumulate_grad_batches',
  (<class 'int'>, typing.Dict[int, int], typing.List[list]),
  1),
 ...
 ('callbacks',
  (typing.List[pytorch_lightning.callbacks.base.Callback],
   <class 'NoneType'>),
   None),
 ('check_val_every_n_epoch', (<class 'int'>,), 1),
 ...
 ('max_epochs', (<class 'int'>,), 1000),
 ...
 ('precision', (<class 'int'>,), 32),
 ('prepare_data_per_node', (<class 'bool'>,), True),
 ('process_position', (<class 'int'>,), 0),
 ('profiler',
  (<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>,
   <class 'bool'>,
   <class 'NoneType'>),
  None),
 ...
init_amp(amp_type)[source]
classmethod parse_argparser(arg_parser)[source]

Parse CLI arguments, required for custom bool types.

Return type

Namespace

run_pretrain_routine(model)[source]

Sanity check a few things before starting actual training.

Parameters

model (LightningModule) – The model to run sanity test on.

test(model=None, test_dataloaders=None, ckpt_path='best', verbose=True, datamodule=None)[source]

Separates from fit to make sure you never run on your test set until you want to.

Parameters
  • model (Optional[LightningModule]) – The model to test.

  • test_dataloaders (Union[DataLoader, List[DataLoader], None]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples.

  • ckpt_path (Optional[str]) – Either best or path to the checkpoint you wish to test. If None, use the weights from the last epoch to test. Default to best.

  • verbose (bool) – If True, prints the test results

Returns

The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries

Example:

# Option 1
# run test with the best checkpoint from ``ModelCheckpoint`` after fitting.
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()

trainer.fit(model)
trainer.test(test_dataloaders=test)

# Option 2
# run test with the specified checkpoint after fitting
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()

trainer.fit(model)
trainer.test(test_dataloaders=test, ckpt_path='path/to/checkpoint.ckpt')

# Option 3
# run test with the weights from the end of training after fitting
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()

trainer.fit(model)
trainer.test(test_dataloaders=test, ckpt_path=None)

# Option 4
# run test from a loaded model. ``ckpt_path`` is ignored in this case.
test = DataLoader(...)
model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
trainer = Trainer()
trainer.test(model, test_dataloaders=test)
accumulate_grad_batches = None[source]
amp_backend = None[source]
checkpoint_callback = None[source]
property data_parallel[source]
Return type

bool

property default_root_dir[source]

The default location to save artifacts of loggers, checkpoints etc. It is used as a fallback if logger or checkpoint callback do not define specific save paths.

Return type

str

property disable_validation[source]

Check if validation is disabled during training.

Return type

bool

early_stop_callback = None[source]
property enable_validation[source]

Check if we should run validation during training.

Return type

bool

global_rank = None[source]
property is_global_zero[source]

this is just empty shell for code implemented in other class.

Type

Warning

Return type

bool

logger = None[source]
lr_schedulers = None[source]
model = None[source]
property num_gpus[source]

this is just empty shell for code implemented in other class.

Type

Warning

Return type

int

num_training_batches = None[source]
on_gpu = None[source]
on_tpu = None[source]
optimizers = None[source]
property progress_bar_callback[source]
property progress_bar_dict[source]

Read-only for progress bar metrics.

Return type

dict

resume_from_checkpoint = None[source]
root_gpu = None[source]
scaler = None[source]
property slurm_job_id[source]

this is just empty shell for code implemented in other class.

Type

Warning

Return type

Optional[int]

use_ddp = None[source]
use_ddp2 = None[source]
use_horovod = None[source]
use_tpu = None[source]
property weights_save_path[source]

The default root location to save weights (checkpoints), e.g., when the ModelCheckpoint does not define a file path.

Return type

str

class pytorch_lightning.trainer.trainer._PatchDataLoader(dataloader)[source]

Bases: object

Callable object for patching dataloaders passed into trainer.fit(). Use this class to override model.*_dataloader() and be pickle-compatible.

Parameters

dataloader (Union[List[DataLoader], DataLoader]) – Dataloader object to return when called.

__call__()[source]

Call self as a function.

Return type

Union[List[DataLoader], DataLoader]

pytorch_lightning.trainer.trainer._determine_batch_limits(batches, name)[source]
Return type

Union[int, float]