The trainer de-couples the engineering code (16-bit, early stopping, GPU distribution, etc…) from the science code (GAN, BERT, your project, etc…). It uses many assumptions which are best practices in AI research today.

The trainer automates all parts of training except:

  • what happens in training , test, val loop
  • where the data come from
  • which optimizers to use
  • how to do the computations

The Trainer delegates those calls to your LightningModule which defines how to do those parts.

This is the basic use of the trainer:

from pytorch_lightning import Trainer

model = MyLightningModule()

trainer = Trainer()
class pytorch_lightning.trainer.Trainer(logger=True, checkpoint_callback=True, early_stop_callback=True, default_save_path=None, gradient_clip_val=0, gradient_clip=None, process_position=0, nb_gpu_nodes=None, num_nodes=1, gpus=None, log_gpu_memory=None, show_progress_bar=True, overfit_pct=0.0, track_grad_norm=-1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, max_nb_epochs=None, min_nb_epochs=None, max_epochs=1000, min_epochs=1, train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=1.0, log_save_interval=100, row_log_interval=10, add_row_log_interval=None, distributed_backend=None, use_amp=False, print_nan_grads=False, weights_summary='full', weights_save_path=None, amp_level='O1', nb_sanity_val_steps=None, num_sanity_val_steps=5, truncated_bptt_steps=None, resume_from_checkpoint=None)[source]

Bases: pytorch_lightning.trainer.training_io.TrainerIOMixin, pytorch_lightning.trainer.distrib_parts.TrainerDPMixin, pytorch_lightning.trainer.distrib_data_parallel.TrainerDDPMixin, pytorch_lightning.trainer.logging.TrainerLoggingMixin, pytorch_lightning.trainer.model_hooks.TrainerModelHooksMixin, pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin, pytorch_lightning.trainer.data_loading.TrainerDataLoadingMixin, pytorch_lightning.trainer.auto_mix_precision.TrainerAMPMixin, pytorch_lightning.trainer.evaluation_loop.TrainerEvaluationLoopMixin, pytorch_lightning.trainer.training_loop.TrainerTrainLoopMixin, pytorch_lightning.trainer.callback_config.TrainerCallbackConfigMixin

Customize every aspect of training via flags

  • logger (Logger) –

    Logger for experiment tracking. Example:

    from pytorch_lightning.logging import TensorBoardLogger
    # default logger used by trainer
    logger = TensorBoardLogger(
  • checkpoint_callback (CheckpointCallback) –

    Callback for checkpointing. Example:

    from pytorch_lightning.callbacks import ModelCheckpoint
    # default used by the Trainer
    checkpoint_callback = ModelCheckpoint(
    trainer = Trainer(checkpoint_callback=checkpoint_callback)
  • early_stop_callback (EarlyStopping) –

    Callback for early stopping Example:

    from pytorch_lightning.callbacks import EarlyStopping
    # default used by the Trainer
    early_stop_callback = EarlyStopping(
    trainer = Trainer(early_stop_callback=early_stop_callback)
  • default_save_path (str) –

    Default path for logs and weights when no logger/ckpt_callback passed Example:

    # default used by the Trainer
    trainer = Trainer(default_save_path=os.getcwd())
  • gradient_clip_val (float) –

    0 means don’t clip. Example:

    # default used by the Trainer
    trainer = Trainer(gradient_clip_val=0.0)
  • gradient_clip (int) –

    Deprecated since version 0.5.0: Use gradient_clip_val instead. Will remove 0.8.0.

  • process_position (int) –

    orders the tqdm bar when running multiple models on same machine. Example:

    # default used by the Trainer
    trainer = Trainer(process_position=0)
  • num_nodes (int) –

    number of GPU nodes for distributed training. Example:

    # default used by the Trainer
    trainer = Trainer(num_nodes=1)
    # to train on 8 nodes
    trainer = Trainer(num_nodes=8)
  • nb_gpu_nodes (int) –

    Deprecated since version 0.5.0: Use num_nodes instead. Will remove 0.8.0.

  • gpus (list|str|int) –

    Which GPUs to train on. Example:

    # default used by the Trainer (ie: train on CPU)
    trainer = Trainer(gpus=None)
    # int: train on 2 gpus
    trainer = Trainer(gpus=2)
    # list: train on GPUs 1, 4 (by bus ordering)
    trainer = Trainer(gpus=[1, 4])
    trainer = Trainer(gpus='1, 4') # equivalent
    # -1: train on all gpus
    trainer = Trainer(gpus=-1)
    trainer = Trainer(gpus='-1') # equivalent
    # combine with num_nodes to train on multiple GPUs across nodes
    trainer = Trainer(gpus=2, num_nodes=4) # uses 8 gpus in total
  • log_gpu_memory (str) –

    None, ‘min_max’, ‘all’. Might slow performance because it uses the output of nvidia-smi. Example:

    # default used by the Trainer
    trainer = Trainer(log_gpu_memory=None)
    # log all the GPUs (on master node only)
    trainer = Trainer(log_gpu_memory='all')
    # log only the min and max memory on the master node
    trainer = Trainer(log_gpu_memory='min_max')
  • show_progress_bar (bool) –

    If true shows tqdm progress bar Example:

    # default used by the Trainer
    trainer = Trainer(show_progress_bar=True)
  • overfit_pct (float) –

    uses this much data of all datasets. Example:

    # default used by the Trainer
    trainer = Trainer(overfit_pct=0.0)
    # use only 1% of the train, test, val datasets
    trainer = Trainer(overfit_pct=0.01)
  • track_grad_norm (int) –

    -1 no tracking. Otherwise tracks that norm Example:

    # default used by the Trainer
    trainer = Trainer(track_grad_norm=-1)
    # track the 2-norm
    trainer = Trainer(track_grad_norm=2)
  • check_val_every_n_epoch (int) –

    check val every n train epochs Example:

    # default used by the Trainer
    trainer = Trainer(check_val_every_n_epoch=1)
    # run val loop every 10 training epochs
    trainer = Trainer(check_val_every_n_epoch=10)
  • fast_dev_run (bool) –

    runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). Example:

    # default used by the Trainer
    trainer = Trainer(fast_dev_run=False)
    # runs 1 train, val, test  batch and program ends
    trainer = Trainer(fast_dev_run=True)
  • accumulate_grad_batches (int|dict) –

    Accumulates grads every k batches or as set up in the dict. Example:

    # default used by the Trainer (no accumulation)
    trainer = Trainer(accumulate_grad_batches=1)
    # accumulate every 4 batches (effective batch size is batch*4)
    trainer = Trainer(accumulate_grad_batches=4)
    # no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
    trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})
  • max_epochs (int) –

    Stop training once this number of epochs is reached Example:

    # default used by the Trainer
    trainer = Trainer(max_epochs=1000)
  • max_nb_epochs (int) –

    Deprecated since version 0.5.0: Use max_epochs instead. Will remove 0.8.0.

  • min_epochs (int) –

    Force training for at least these many epochs Example:

    # default used by the Trainer
    trainer = Trainer(min_epochs=1)
  • min_nb_epochs (int) –

    Deprecated since version 0.5.0: Use min_nb_epochs instead. Will remove 0.8.0.

  • train_percent_check (int) –

    How much of training dataset to check. Useful when debugging or testing something that happens at the end of an epoch. Example:

    # default used by the Trainer
    trainer = Trainer(train_percent_check=1.0)
    # run through only 25% of the training set each epoch
    trainer = Trainer(train_percent_check=0.25)
  • val_percent_check (int) –

    How much of validation dataset to check. Useful when debugging or testing something that happens at the end of an epoch. Example:

    # default used by the Trainer
    trainer = Trainer(val_percent_check=1.0)
    # run through only 25% of the validation set each epoch
    trainer = Trainer(val_percent_check=0.25)
  • test_percent_check (int) –

    How much of test dataset to check. Useful when debugging or testing something that happens at the end of an epoch. Example:

    # default used by the Trainer
    trainer = Trainer(test_percent_check=1.0)
    # run through only 25% of the test set each epoch
    trainer = Trainer(test_percent_check=0.25)
  • val_check_interval (float|int) –

    How often within one training epoch to check the validation set If float, % of tng epoch. If int, check every n batch Example:

    # default used by the Trainer
    trainer = Trainer(val_check_interval=1.0)
    # check validation set 4 times during a training epoch
    trainer = Trainer(val_check_interval=0.25)
    # check validation set every 1000 training batches
    # use this when using iterableDataset and your dataset has no length
    # (ie: production cases with streaming data)
    trainer = Trainer(val_check_interval=1000)
  • log_save_interval (int) –

    Writes logs to disk this often Example:

    # default used by the Trainer
    trainer = Trainer(log_save_interval=100)
  • row_log_interval (int) –

    How often to add logging rows (does not write to disk) Example:

    # default used by the Trainer
    trainer = Trainer(row_log_interval=10)
  • add_row_log_interval (int) –

    Deprecated since version 0.5.0: Use row_log_interval instead. Will remove 0.8.0.

  • distributed_backend (str) –

    The distributed backend to use. Options: ‘dp’, ‘ddp’, ‘ddp2’. Example:

    # default used by the Trainer
    trainer = Trainer(distributed_backend=None)
    # dp = DataParallel (split a batch onto k gpus on same machine).
    trainer = Trainer(gpus=2, distributed_backend='dp')
    # ddp = DistributedDataParallel
    # Each gpu trains by itself on a subset of the data.
    # Gradients sync across all gpus and all machines.
    trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')
    # ddp2 = DistributedDataParallel + dp
    # behaves like dp on every node
    # syncs gradients across nodes like ddp
    # useful for things like increasing the number of negative samples
    trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')
  • use_amp (bool) –

    If true uses apex for 16bit precision Example:

    # default used by the Trainer
    trainer = Trainer(use_amp=False)
  • print_nan_grads (bool) –

    Prints gradients with nan values Example:

    # default used by the Trainer
    trainer = Trainer(print_nan_grads=False)
  • weights_summary (str) –

    Prints a summary of the weights when training begins. Options: ‘full’, ‘top’, None. Example:

    # default used by the Trainer (ie: print all weights)
    trainer = Trainer(weights_summary='full')
    # print only the top level modules
    trainer = Trainer(weights_summary='top')
    # don't print a summary
    trainer = Trainer(weights_summary=None)
  • weights_save_path (str) –

    Where to save weights if specified. Example:

    # default used by the Trainer
    trainer = Trainer(weights_save_path=os.getcwd())
    # save to your custom path
    trainer = Trainer(weights_save_path='my/path')
    # if checkpoint callback used, then overrides the weights path
    # **NOTE: this saves weights to some/path NOT my/path
    checkpoint_callback = ModelCheckpoint(filepath='some/path')
    trainer = Trainer(
  • amp_level (str) –

    The optimization level to use (O1, O2, etc…). Check nvidia docs for level ( Example:

    # default used by the Trainer
    trainer = Trainer(amp_level='O1')
  • num_sanity_val_steps (int) –

    Sanity check runs n batches of val before starting the training routine. This catches any bugs in your validation without having to wait for the first validation check. The Trainer uses 5 steps by default. Turn it off or modify it here. Example:

    # default used by the Trainer
    trainer = Trainer(num_sanity_val_steps=5)
    # turn it off
    trainer = Trainer(num_sanity_val_steps=0)
  • nb_sanity_val_steps (int) –

    Deprecated since version 0.5.0: Use num_sanity_val_steps instead. Will remove 0.8.0.

  • truncated_bptt_steps (int) –

    Truncated back prop breaks performs backprop every k steps of a much longer sequence If this is enabled, your batches will automatically get truncated and the trainer will apply Truncated Backprop to it. Make sure your batches have a sequence dimension. (Williams et al. “An efficient gradient-based algorithm for on-line training of recurrent network trajectories.”) Example:

    # default used by the Trainer (ie: disabled)
    trainer = Trainer(truncated_bptt_steps=None)
    # backprop every 5 steps in a batch
    trainer = Trainer(truncated_bptt_steps=5)
  • resume_from_checkpoint (str) –

    To resume training from a specific checkpoint pass in the path here.k Example:

    # default used by the Trainer
    trainer = Trainer(resume_from_checkpoint=None)
    # resume from a specific checkpoint
    trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')

Runs the full optimization routine.


trainer = Trainer()
model = LightningModule()

Separates from fit to make sure you never run on your test set until you want to.

Parameters:model (LightningModule) – The model to test.


# Option 1
# run test after fitting
trainer = Trainer()
model = LightningModule()

# Option 2
# run test from a loaded model
model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
trainer = Trainer()
Read the Docs v: 0.6.0
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources