Shortcuts

pytorch_lightning.trainer.evaluation_loop module

Validation loop

The lightning validation loop handles everything except the actual computations of your model. To decide what will happen in your validation loop, define the validation_step function. Below are all the things lightning automates for you in the validation loop.

Note

Lightning will run 5 steps of validation in the beginning of training as a sanity check so you don’t have to wait until a full epoch to catch possible validation issues.

Check validation every n epochs

If you have a small dataset you might want to check validation every n epochs

# DEFAULT
trainer = Trainer(check_val_every_n_epoch=1)

Set how much of the validation set to check

If you don’t want to check 100% of the validation set (for debugging or if it’s huge), set this flag.

limit_val_batches will be overwritten by overfit_batches if overfit_batches > 0

# DEFAULT
trainer = Trainer(limit_val_batches=1.0)

# check 10% only
trainer = Trainer(limit_val_batches=0.1)

Set how much of the test set to check

If you don’t want to check 100% of the test set (for debugging or if it’s huge), set this flag.

limit_test_batches will be overwritten by overfit_batches if overfit_batches > 0

# DEFAULT
trainer = Trainer(limit_test_batches=1.0)

# check 10% only
trainer = Trainer(limit_test_batches=0.1)

Set validation check frequency within 1 training epoch

For large datasets it’s often desirable to check validation multiple times within a training loop.

Pass in a float to check that often within 1 training epoch. Pass in an int k to check every k training batches. Must use an int if using an IterableDataset.

# DEFAULT
trainer = Trainer(val_check_interval=0.95)

# check every .25 of an epoch
trainer = Trainer(val_check_interval=0.25)

# check every 100 train batches (ie: for IterableDatasets or fixed frequency)
trainer = Trainer(val_check_interval=100)

Set the number of validation sanity steps

Lightning runs a few steps of validation in the beginning of training.

This avoids crashing in the validation loop sometime deep into a lengthy training loop.

# DEFAULT
trainer = Trainer(num_sanity_val_steps=5)

You can use Trainer(num_sanity_val_steps=0) to skip the sanity check.

# Testing loop

To ensure you don’t accidentally use test data to guide training decisions Lightning

makes running the test set deliberate.

test

You have two options to run the test set. First case is where you test right after a full training routine.

# run full training
trainer.fit(model)

# run test set
trainer.test()

Second case is where you load a model and run the test set

model = MyLightningModule.load_from_checkpoint(
    checkpoint_path='/path/to/pytorch_checkpoint.ckpt',
    hparams_file='/path/to/test_tube/experiment/version/hparams.yaml',
    map_location=None
)

# init trainer with whatever options
trainer = Trainer(...)

# test (pass in the model)
trainer.test(model)
In this second case, the options you pass to trainer will be used when running

the test set (ie: 16-bit, dp, ddp, etc…)

class pytorch_lightning.trainer.evaluation_loop.TrainerEvaluationLoopMixin[source]

Bases: abc.ABC

_evaluate(model, dataloaders, max_batches, test_mode=False)[source]

Run evaluation code.

Parameters
  • model (LightningModule) – The model to evaluate.

  • dataloaders (List[DataLoader]) – A list of PyTorch dataloaders.

  • max_batches (Union[int, List[int]]) – An integer or list of integers with length of the number of dataloaders. Each entry is the number of batches to process in the corresponding dataloader.

  • test_mode (bool) –

abstract add_progress_bar_metrics(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract copy_trainer_model_properties(*args)[source]

Warning: this is just empty shell for code implemented in other class.

evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode=False)[source]
abstract get_model()[source]

Warning: this is just empty shell for code implemented in other class.

Return type

LightningModule

abstract is_overridden(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract log_metrics(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract reset_test_dataloader(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract reset_val_dataloader(*args)[source]

Warning: this is just empty shell for code implemented in other class.

run_evaluation(test_mode=False)[source]
abstract transfer_batch_to_gpu(*args)[source]

Warning: this is just empty shell for code implemented in other class.

abstract transfer_batch_to_tpu(*args)[source]

Warning: this is just empty shell for code implemented in other class.

callback_metrics: ... = None[source]
current_epoch: int = None[source]
data_parallel_device_ids: ... = None[source]
fast_dev_run: ... = None[source]
global_rank: int = None[source]
model: LightningModule = None[source]
num_test_batches: List[int] = None[source]
num_val_batches: int = None[source]
on_gpu: bool = None[source]
on_test_batch_end: Callable = None[source]
on_test_batch_start: Callable = None[source]
on_test_end: Callable = None[source]
on_test_start: Callable = None[source]
on_validation_batch_end: Callable = None[source]
on_validation_batch_start: Callable = None[source]
on_validation_end: Callable = None[source]
on_validation_start: Callable = None[source]
process_output: ... = None[source]
progress_bar_dict: ... = None[source]
reload_dataloaders_every_epoch: ... = None[source]
single_gpu: bool = None[source]
test_dataloaders: DataLoader = None[source]
tpu_id: int = None[source]
use_ddp: bool = None[source]
use_ddp2: bool = None[source]
use_dp: bool = None[source]
use_horovod: bool = None[source]
use_tpu: bool = None[source]
val_dataloaders: DataLoader = None[source]
world_size: int = None[source]