LightningModule¶
A LightningModule
organizes your PyTorch code into the following sections:
Notice a few things.
It’s the SAME code.
The PyTorch code IS NOT abstracted - just organized.
All the other code that’s not in the
LightningModule
has been automated for you by the trainer.net = Net() trainer = Trainer() trainer.fit(net)
There are no .cuda() or .to() calls… Lightning does these for you.
# don't do in lightning x = torch.Tensor(2, 3) x = x.cuda() x = x.to(device) # do this instead x = x # leave it alone! # or to init a new tensor new_x = torch.Tensor(2, 3) new_x = new_x.type_as(x.type())
There are no samplers for distributed, Lightning also does this for you.
# Don't do in Lightning... data = MNIST(...) sampler = DistributedSampler(data) DataLoader(data, sampler=sampler) # do this instead data = MNIST(...) DataLoader(data)
A
LightningModule
is atorch.nn.Module
but with added functionality. Use it as such!net = Net.load_from_checkpoint(PATH) net.freeze() out = net(x)
Thus, to use Lightning, you just need to organize your code which takes about 30 minutes, (and let’s be real, you probably should do anyhow).
Minimal Example¶
Here are the only required methods.
>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
...
... def __init__(self):
... super().__init__()
... self.l1 = torch.nn.Linear(28 * 28, 10)
...
... def forward(self, x):
... return torch.relu(self.l1(x.view(x.size(0), -1)))
...
... def training_step(self, batch, batch_idx):
... x, y = batch
... y_hat = self(x)
... return {'loss': F.cross_entropy(y_hat, y)}
...
... def train_dataloader(self):
... return DataLoader(MNIST(os.getcwd(), train=True, download=True,
... transform=transforms.ToTensor()), batch_size=32)
...
... def configure_optimizers(self):
... return torch.optim.Adam(self.parameters(), lr=0.02)
Which you can train by doing:
trainer = pl.Trainer()
model = LitModel()
trainer.fit(model)
Training loop structure¶
The general pattern is that each loop (training, validation, test loop) has 3 methods:
___step
___step_end
___epoch_end
To show how Lightning calls these, let’s use the validation loop as an example:
val_outs = []
for val_batch in val_data:
# do something with each batch
out = validation_step(val_batch)
val_outs.append(out)
# do something with the outputs for all batches
# like calculate validation set accuracy or loss
validation_epoch_end(val_outs)
If we use dp or ddp2 mode, we can also define the XXX_step_end
method to operate
on all parts of the batch:
val_outs = []
for val_batch in val_data:
batches = split_batch(val_batch)
dp_outs = []
for sub_batch in batches:
dp_out = validation_step(sub_batch)
dp_outs.append(dp_out)
out = validation_step_end(dp_outs)
val_outs.append(out)
# do something with the outputs for all batches
# like calculate validation set accuracy or loss
validation_epoch_end(val_outs)
Add validation loop¶
Thus, if we wanted to add a validation loop you would add this to your
LightningModule
:
>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
... def validation_step(self, batch, batch_idx):
... x, y = batch
... y_hat = self(x)
... return {'val_loss': F.cross_entropy(y_hat, y)}
...
... def validation_epoch_end(self, outputs):
... val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
... return {'val_loss': val_loss_mean}
...
... def val_dataloader(self):
... # can also return a list of val dataloaders
... return DataLoader(...)
Add test loop¶
>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
... def test_step(self, batch, batch_idx):
... x, y = batch
... y_hat = self(x)
... return {'test_loss': F.cross_entropy(y_hat, y)}
...
... def test_epoch_end(self, outputs):
... test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
... return {'test_loss': test_loss_mean}
...
... def test_dataloader(self):
... # can also return a list of test dataloaders
... return DataLoader(...)
However, the test loop won’t ever be called automatically to make sure you don’t run your test data by accident. Instead you have to explicitly call:
# call after training
trainer = Trainer()
trainer.fit(model)
trainer.test()
# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
trainer.test(model)
Training_step_end method¶
When using LightningDataParallel
or
LightningDistributedDataParallel
, the
training_step()
will be operating on a portion of the batch. This is normally ok but in special
cases like calculating NCE loss using negative samples, we might want to
perform a softmax across all samples in the batch.
For these types of situations, each loop has an additional __step_end
method
which allows you to operate on the pieces of the batch:
training_outs = []
for train_batch in train_data:
# dp, ddp2 splits the batch
sub_batches = split_batches_for_dp(batch)
# run training_step on each piece of the batch
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
# do softmax with all pieces
out = training_step_end(batch_parts_outputs)
training_outs.append(out)
# do something with the outputs for all batches
# like calculate validation set accuracy or loss
training_epoch_end(val_outs)
Remove cuda calls¶
In a LightningModule
, all calls to .cuda()
and .to(device)
should be removed. Lightning will do these
automatically. This will allow your code to work on CPUs, TPUs and GPUs.
When you init a new tensor in your code, just use type_as()
:
def training_step(self, batch, batch_idx):
x, y = batch
# put the z on the appropriate gpu or tpu core
z = sample_noise()
z = z.type_as(x)
Data preparation¶
Data preparation in PyTorch follows 5 steps:
Download
Clean and (maybe) save to disk
Load inside
Dataset
Apply transforms (rotate, tokenize, etc…)
Wrap inside a
DataLoader
When working in distributed settings, steps 1 and 2 have to be done
from a single GPU, otherwise you will overwrite these files from
every GPU. The LightningModule
has the
prepare_data
method to
allow for this:
>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
... def prepare_data(self):
... # download
... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
...
... def setup(self, stage):
... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor())
... # train/val split
... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
...
... # assign to use in dataloaders
... self.train_dataset = mnist_train
... self.val_dataset = mnist_val
... self.test_dataset = mnist_test
...
... def train_dataloader(self):
... return DataLoader(self.train_dataset, batch_size=64)
...
... def val_dataloader(self):
... return DataLoader(self.mnist_val, batch_size=64)
...
... def test_dataloader(self):
... return DataLoader(self.mnist_test, batch_size=64)
Note
prepare_data()
is called once.
Note
Do anything with data that needs to happen ONLY once here, like download, tokenize, etc…
Lifecycle¶
The methods in the LightningModule
are called in this order:
If you define a validation loop then
And if you define a test loop:
Note
test_dataloader()
is only called with .test()
In every epoch, the loop methods are called in this frequency:
validation_step()
called every batchvalidation_epoch_end()
called every epoch
LightningModule Class¶
-
class
pytorch_lightning.core.
LightningModule
(*args, **kwargs)[source] Bases:
abc.ABC
,pytorch_lightning.utilities.device_dtype_mixin.DeviceDtypeModuleMixin
,pytorch_lightning.core.grads.GradInformation
,pytorch_lightning.core.saving.ModelIO
,pytorch_lightning.core.hooks.ModelHooks
,torch.nn.Module
-
_LightningModule__get_hparams_assignment_variable
()[source] looks at the code of the class to figure out what the user named self.hparams this only happens when the user explicitly sets self.hparams
-
classmethod
_auto_collect_arguments
(frame=None)[source] Collect all module arguments in the current constructor and all child constructors. The child constructors are all the
__init__
methods that reach the current class through (chained)super().__init__()
calls.- Parameters
frame¶ – instance frame
- Returns
arguments dictionary of the first instance parents_arguments: arguments dictionary of the parent’s instances
- Return type
self_arguments
-
_init_slurm_connection
()[source] Sets up environment variables necessary for pytorch distributed communications based on slurm environment.
- Return type
None
-
configure_apex
(amp, model, optimizers, amp_level)[source] Override to init AMP your own way. Must return a model and list of optimizers.
- Parameters
model¶ (
LightningModule
) – pointer to currentLightningModule
.optimizers¶ (
List
[Optimizer
]) – list of optimizers passed inconfigure_optimizers()
.
- Return type
Tuple
[LightningModule
,List
[Optimizer
]]- Returns
Apex wrapped model and optimizers
Examples
# Default implementation used by Trainer. def configure_apex(self, amp, model, optimizers, amp_level): model, optimizers = amp.initialize( model, optimizers, opt_level=amp_level, ) return model, optimizers
-
configure_ddp
(model, device_ids)[source] Override to init DDP in your own way or with your own wrapper. The only requirements are that:
On a validation batch the call goes to
model.validation_step
.On a training batch the call goes to
model.training_step
.On a testing batch, the call goes to
model.test_step
.+
- Parameters
model¶ (
LightningModule
) – theLightningModule
currently being optimized.
- Return type
- Returns
DDP wrapped model
Examples
# default implementation used in Trainer def configure_ddp(self, model, device_ids): # Lightning DDP simply routes to test_step, val_step, etc... model = LightningDistributedDataParallel( model, device_ids=device_ids, find_unused_parameters=True ) return model
-
configure_optimizers
()[source] Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
- Return type
Union
[Optimizer
,Sequence
[Optimizer
],Dict
,Sequence
[Dict
],Tuple
[List
,List
],None
]- Returns
Any of these 6 options.
Single optimizer.
List or Tuple - List of optimizers.
Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict).
Dictionary, with an ‘optimizer’ key, and (optionally) a ‘lr_scheduler’ key which value is a single LR scheduler or lr_dict.
Tuple of dictionaries as described, with an optional ‘frequency’ key.
None - Fit will run without any optimizer.
Note
The ‘frequency’ value is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1: In the former case, all optimizers will operate on the given batch in each optimization step. In the latter, only one optimizer will operate on the given batch at every step.
The lr_dict is a dictionary which contains scheduler and its associated configuration. It has five keys. The default configuration is shown below.
{ 'scheduler': lr_scheduler, # The LR schduler 'interval': 'epoch', # The unit of the scheduler's step size 'frequency': 1, # The frequency of the scheduler 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler 'monitor': 'val_loss' # Metric to monitor }
If user only provides LR schedulers, then their configuration will set to default as shown above.
Examples
# most cases def configure_optimizers(self): opt = Adam(self.parameters(), lr=1e-3) return opt # multiple optimizer case (e.g.: GAN) def configure_optimizers(self): generator_opt = Adam(self.model_gen.parameters(), lr=0.01) disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) return generator_opt, disriminator_opt # example with learning rate schedulers def configure_optimizers(self): generator_opt = Adam(self.model_gen.parameters(), lr=0.01) disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10) return [generator_opt, disriminator_opt], [discriminator_sched] # example with step-based learning rate schedulers def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_disc.parameters(), lr=0.02) gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99), 'interval': 'step'} # called after each training step dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch return [gen_opt, dis_opt], [gen_sched, dis_sched] # example with optimizer frequencies # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 # https://arxiv.org/abs/1704.00028 def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_disc.parameters(), lr=0.02) n_critic = 5 return ( {'optimizer': dis_opt, 'frequency': n_critic}, {'optimizer': gen_opt, 'frequency': 1} )
Note
Some things to know:
Lightning calls
.backward()
and.step()
on each optimizer and learning rate scheduler as needed.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizers for you.If you use multiple optimizers,
training_step()
will have an additionaloptimizer_idx
parameter.If you use LBFGS Lightning handles the closure function automatically for you.
If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.
If you need to control how often those optimizers step or override the default
.step()
schedule, override theoptimizer_step()
hook.If you only want to call a learning rate scheduler every
x
step or epoch, or want to monitor a custom metric, you can specify these in a lr_dict:{ 'scheduler': lr_scheduler, 'interval': 'step', # or 'epoch' 'monitor': 'val_f1', 'frequency': x, }
-
abstract
forward
(*args, **kwargs)[source] Same as
torch.nn.Module.forward()
, however in Lightning you want this to define the operations you want to use for prediction (i.e.: on a server or as a feature extractor).Normally you’d call
self()
from yourtraining_step()
method. This makes it easy to write a complex system for training with the outputs you’d want in a prediction setting.You may also find the
auto_move_data()
decorator useful when using the module outside Lightning in a production setting.- Parameters
- Returns
Predicted output
Examples
# example if we were using this model as a feature extractor def forward(self, x): feature_maps = self.convnet(x) return feature_maps def training_step(self, batch, batch_idx): x, y = batch feature_maps = self(x) logits = self.classifier(feature_maps) # ... return loss # splitting it this way allows model to be used a feature extractor model = MyModelAbove() inputs = server.get_request() results = model(inputs) server.write_results(results) # ------------- # This is in stark contrast to torch.nn.Module where normally you would have this: def forward(self, batch): x, y = batch feature_maps = self.convnet(x) logits = self.classifier(feature_maps) return logits
-
freeze
()[source] Freeze all params for inference.
Example
model = MyLightningModule(...) model.freeze()
- Return type
None
-
get_progress_bar_dict
()[source] Additional items to be displayed in the progress bar.
-
get_tqdm_dict
()[source] Additional items to be displayed in the progress bar.
- Return type
- Returns
Dictionary with the items to be displayed in the progress bar.
Warning
Deprecated since v0.7.3. Use
get_progress_bar_dict()
instead.
-
init_ddp_connection
(global_rank, world_size, is_slurm_managing_tasks=True)[source] Override to define your custom way of setting up a distributed environment.
Lightning’s implementation uses env:// init by default and sets the first node as root for SLURM managed cluster.
-
on_load_checkpoint
(checkpoint)[source] Called by Lightning to restore your model. If you saved something with
on_save_checkpoint()
this is your chance to restore this.Example
def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note
Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.
- Return type
None
-
on_save_checkpoint
(checkpoint)[source] Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
Example
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- Return type
None
-
optimizer_step
(epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False)[source] Override this method to adjust the default way the
Trainer
calls each optimizer. By default, Lightning callsstep()
andzero_grad()
as shown in the example once per optimizer.- Parameters
Examples
# DEFAULT def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, second_order_closure, on_tpu, using_native_amp, using_lbfgs): optimizer.step() # Alternating schedule for optimizer steps (i.e.: GANs) def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, second_order_closure, on_tpu, using_native_amp, using_lbfgs): # update generator opt every 2 steps if optimizer_idx == 0: if batch_idx % 2 == 0 : optimizer.step() optimizer.zero_grad() # update discriminator opt every 4 steps if optimizer_idx == 1: if batch_idx % 4 == 0 : optimizer.step() optimizer.zero_grad() # ... # add as many optimizers as you want
Here’s another example showing how to use this for more advanced things such as learning rate warm-up:
# learning rate warm-up def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, second_order_closure, on_tpu, using_native_amp, using_lbfgs): # warm up lr if self.trainer.global_step < 500: lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) for pg in optimizer.param_groups: pg['lr'] = lr_scale * self.learning_rate # update params optimizer.step() optimizer.zero_grad()
Note
If you also override the
on_before_zero_grad()
model hook don’t forget to add the call to it beforeoptimizer.zero_grad()
yourself.- Return type
None
-
prepare_data
()[source] Use this to download and prepare data.
Warning
DO NOT set state to the model (use setup instead) since this is NOT called on every GPU in DDP/TPU
Example:
def prepare_data(self): # good download_data() tokenize() etc() # bad self.split = data_split self.some_state = some_other_state()
In DDP prepare_data can be called in two ways (using Trainer(prepare_data_per_node)):
Once per node. This is the default and is only called on LOCAL_RANK=0.
Once in total. Only called on GLOBAL_RANK=0.
Example:
# DEFAULT # called once per node on LOCAL_RANK=0 of that node Trainer(prepare_data_per_node=True) # call on GLOBAL_RANK=0 (great for shared file systems) Trainer(prepare_data_per_node=False)
This is called before requesting the dataloaders:
model.prepare_data() if ddp/tpu: init() model.setup(stage) model.train_dataloader() model.val_dataloader() model.test_dataloader()
- Return type
None
-
print
(*args, **kwargs)[source] Prints only from process 0. Use this in any distributed mode to log only once.
- Parameters
Example
def forward(self, x): self.print(x, 'in forward')
- Return type
None
-
save_hyperparameters
(*args, frame=None)[source] Save all model arguments.
- Parameters
args¶ – single object of dict, NameSpace or OmegaConf or string names or argumenst from class __init__
>>> from collections import OrderedDict >>> class ManuallyArgsModel(LightningModule): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # manually assine arguments ... self.save_hyperparameters('arg1', 'arg3') ... def forward(self, *args, **kwargs): ... ... >>> model = ManuallyArgsModel(1, 'abc', 3.14) >>> model.hparams "arg1": 1 "arg3": 3.14
>>> class AutomaticArgsModel(LightningModule): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # equivalent automatic ... self.save_hyperparameters() ... def forward(self, *args, **kwargs): ... ... >>> model = AutomaticArgsModel(1, 'abc', 3.14) >>> model.hparams "arg1": 1 "arg2": abc "arg3": 3.14
>>> class SingleArgModel(LightningModule): ... def __init__(self, params): ... super().__init__() ... # manually assign single argument ... self.save_hyperparameters(params) ... def forward(self, *args, **kwargs): ... ... >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) >>> model.hparams "p1": 1 "p2": abc "p3": 3.14
- Return type
None
-
tbptt_split_batch
(batch, split_size)[source] When using truncated backpropagation through time, each batch must be split along the time dimension. Lightning handles this by default, but for custom behavior override this function.
- Parameters
- Return type
- Returns
List of batch splits. Each split will be passed to
training_step()
to enable truncated back propagation through time. The default implementation splits root level Tensors and Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.
Examples
def tbptt_split_batch(self, batch, split_size): splits = [] for t in range(0, time_dims[0], split_size): batch_split = [] for i, x in enumerate(batch): if isinstance(x, torch.Tensor): split_x = x[:, t:t + split_size] elif isinstance(x, collections.Sequence): split_x = [None] * len(x) for batch_idx in range(len(x)): split_x[batch_idx] = x[batch_idx][t:t + split_size] batch_split.append(split_x) splits.append(batch_split) return splits
Note
Called in the training loop after
on_batch_start()
iftruncated_bptt_steps
> 0. Each returned batch split is passed separately totraining_step()
.
-
test_dataloader
()[source] Implement one or multiple PyTorch DataLoaders for testing.
The dataloader you return will not be called every epoch unless you set
reload_dataloaders_every_epoch
toTrue
.For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Return type
- Returns
Single or multiple PyTorch DataLoaders.
Example
def test_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader
Note
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.
-
test_end
(outputs)[source] Warning
Deprecated in v0.7.0. Use
test_epoch_end()
instead. Will be removed in 1.0.0.
-
test_epoch_end
(outputs)[source] Called at the end of a test epoch with the output of all test steps.
# the pseudocode for these calls test_outs = [] for test_batch in test_data: out = test_step(test_batch) test_outs.append(out) test_epoch_end(test_outs)
- Parameters
outputs¶ (
Union
[List
[Dict
[str
,Tensor
]],List
[List
[Dict
[str
,Tensor
]]]]) – List of outputs you defined intest_step_end()
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader- Returns
Dict has the following optional keys:
progress_bar -> Dict for progress bar display. Must have only tensors.
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc).
- Return type
Dict or OrderedDict
Note
If you didn’t define a
test_step()
, this won’t be called.The outputs here are strictly for logging or progress bar.
If you don’t need to display anything, don’t return anything.
If you want to manually set current step, specify it with the ‘step’ key in the ‘log’ Dict
Examples
With a single dataloader:
def test_epoch_end(self, outputs): test_acc_mean = 0 for output in outputs: test_acc_mean += output['test_acc'] test_acc_mean /= len(outputs) tqdm_dict = {'test_acc': test_acc_mean.item()} # show test_loss and test_acc in progress bar but only log test_loss results = { 'progress_bar': tqdm_dict, 'log': {'test_acc': test_acc_mean.item()} } return results
With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each test step for that dataloader.
def test_epoch_end(self, outputs): test_acc_mean = 0 i = 0 for dataloader_outputs in outputs: for output in dataloader_outputs: test_acc_mean += output['test_acc'] i += 1 test_acc_mean /= i tqdm_dict = {'test_acc': test_acc_mean.item()} # show test_loss and test_acc in progress bar but only log test_loss results = { 'progress_bar': tqdm_dict, 'log': {'test_acc': test_acc_mean.item(), 'step': self.current_epoch} } return results
-
test_step
(*args, **kwargs)[source] Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
# the pseudocode for these calls test_outs = [] for test_batch in test_data: out = test_step(test_batch) test_outs.append(out) test_epoch_end(test_outs)
- Parameters
- Return type
- Returns
Dict or OrderedDict - passed to the
test_epoch_end()
method. If you definedtest_step_end()
it will go to that first.
# if you have one test dataloader: def test_step(self, batch, batch_idx) # if you have multiple test dataloaders: def test_step(self, batch, batch_idx, dataloader_idx)
Examples
# CASE 1: A single test dataset def test_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # all optional... # return whatever you need for the collation function test_epoch_end output = OrderedDict({ 'val_loss': loss_val, 'val_acc': torch.tensor(val_acc), # everything must be a tensor }) # return an optional dict return output
If you pass in multiple validation datasets,
test_step()
will have an additional argument.# CASE 2: multiple test datasets def test_step(self, batch, batch_idx, dataset_idx): # dataset_idx tells you which dataset this is.
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
test_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.
-
test_step_end
(*args, **kwargs)[source] Use this when testing with dp or ddp2 because
test_step()
will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.Note
If you later switch to ddp or some other mode, this will still be called so that you don’t have to change your code.
# pseudocode sub_batches = split_batches_for_dp(batch) batch_parts_outputs = [test_step(sub_batch) for sub_batch in sub_batches] test_step_end(batch_parts_outputs)
- Parameters
batch_parts_outputs¶ – What you return in
test_step()
for each batch part.- Return type
- Returns
Dict or OrderedDict - passed to the
test_epoch_end()
.
Examples
# WITHOUT test_step_end # if used in DP or DDP2, this batch is 1/num_gpus large def test_step(self, batch, batch_idx): # batch is 1/num_gpus big x, y = batch out = self(x) loss = self.softmax(out) loss = nce_loss(loss) return {'loss': loss} # -------------- # with test_step_end to do softmax over the full batch def test_step(self, batch, batch_idx): # batch is 1/num_gpus big x, y = batch out = self(x) return {'out': out} def test_step_end(self, outputs): # this out is now the full size of the batch out = outputs['out'] # this softmax now uses the full batch size loss = nce_loss(loss) return {'loss': loss}
See also
See the Multi-GPU training guide for more details.
-
tng_dataloader
()[source] Warning
Deprecated in v0.5.0. Use
train_dataloader()
instead. Will be removed in 1.0.0.
-
train_dataloader
()[source] Implement a PyTorch DataLoader for training.
- Return type
- Returns
Single PyTorch
DataLoader
.
The dataloader you return will not be called every epoch unless you set
reload_dataloaders_every_epoch
toTrue
.For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
…
setup()
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Example
def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=True ) return loader
-
training_end
(*args, **kwargs)[source] Warning
Deprecated in v0.7.0. Use
training_step_end()
instead.
-
training_epoch_end
(outputs)[source] Called at the end of the training epoch with the outputs of all training steps.
# the pseudocode for these calls train_outs = [] for train_batch in train_data: out = training_step(train_batch) train_outs.append(out) training_epoch_end(train_outs)
- Parameters
outputs¶ (
Union
[List
[Dict
[str
,Tensor
]],List
[List
[Dict
[str
,Tensor
]]]]) – List of outputs you defined intraining_step()
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.- Return type
- Returns
Dict or OrderedDict. May contain the following optional keys:
log (metrics to be added to the logger; only tensors)
progress_bar (dict for progress bar display)
any metric used in a callback (e.g. early stopping).
Note
If this method is not overridden, this won’t be called.
The outputs here are strictly for logging or progress bar.
If you don’t need to display anything, don’t return anything.
If you want to manually set current step, you can specify the ‘step’ key in the ‘log’ dict.
Examples
With a single dataloader:
def training_epoch_end(self, outputs): train_acc_mean = 0 for output in outputs: train_acc_mean += output['train_acc'] train_acc_mean /= len(outputs) # log training accuracy at the end of an epoch results = { 'log': {'train_acc': train_acc_mean.item()}, 'progress_bar': {'train_acc': train_acc_mean}, } return results
With multiple dataloaders,
outputs
will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each training step for that dataloader.def training_epoch_end(self, outputs): train_acc_mean = 0 i = 0 for dataloader_outputs in outputs: for output in dataloader_outputs: train_acc_mean += output['train_acc'] i += 1 train_acc_mean /= i # log training accuracy at the end of an epoch results = { 'log': {'train_acc': train_acc_mean.item(), 'step': self.current_epoch} 'progress_bar': {'train_acc': train_acc_mean}, } return results
-
training_step
(*args, **kwargs)[source] Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters
batch¶ (
Tensor
| (Tensor
, …) | [Tensor
, …]) – The output of yourDataLoader
. A tensor, tuple or list.optimizer_idx¶ (int) – When using multiple optimizers, this argument will also be present.
hiddens¶ (
Tensor
) – Passed in iftruncated_bptt_steps
> 0.
- Return type
- Returns
Dict with loss key and optional log or progress bar keys. When implementing
training_step()
, return whatever you need in that step:loss -> tensor scalar REQUIRED
progress_bar -> Dict for progress bar display. Must have only tensors
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Examples
def training_step(self, batch, batch_idx): x, y, z = batch # implement your own out = self(x) loss = self.loss(out, x) logger_logs = {'training_loss': loss} # optional (MUST ALL BE TENSORS) # if using TestTubeLogger or TensorBoardLogger you can nest scalars logger_logs = {'losses': logger_logs} # optional (MUST ALL BE TENSORS) output = { 'loss': loss, # required 'progress_bar': {'training_loss': loss}, # optional (MUST ALL BE TENSORS) 'log': logger_logs } # return a dict return output
If you define multiple optimizers, this step will be called with an additional
optimizer_idx
parameter.# Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder if optimizer_idx == 1: # do training_step with decoder
If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.
# Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step ... out, hiddens = self.lstm(data, hiddens) ... return { "loss": ..., "hiddens": hiddens # remember to detach() this }
Notes
The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.
-
training_step_end
(*args, **kwargs)[source] Use this when training with dp or ddp2 because
training_step()
will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.Note
If you later switch to ddp or some other mode, this will still be called so that you don’t have to change your code
# pseudocode sub_batches = split_batches_for_dp(batch) batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches] training_step_end(batch_parts_outputs)
- Parameters
batch_parts_outputs¶ – What you return in training_step for each batch part.
- Return type
- Returns
Dict with loss key and optional log or progress bar keys.
loss -> tensor scalar REQUIRED
progress_bar -> Dict for progress bar display. Must have only tensors
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
Examples
# WITHOUT training_step_end # if used in DP or DDP2, this batch is 1/num_gpus large def training_step(self, batch, batch_idx): # batch is 1/num_gpus big x, y = batch out = self(x) loss = self.softmax(out) loss = nce_loss(loss) return {'loss': loss} # -------------- # with training_step_end to do softmax over the full batch def training_step(self, batch, batch_idx): # batch is 1/num_gpus big x, y = batch out = self(x) return {'out': out} def training_step_end(self, outputs): # this out is now the full size of the batch out = outputs['out'] # this softmax now uses the full batch size loss = nce_loss(loss) return {'loss': loss}
See also
See the Multi-GPU training guide for more details.
-
unfreeze
()[source] Unfreeze all parameters for training.
model = MyLightningModule(...) model.unfreeze()
- Return type
None
-
val_dataloader
()[source] Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be called every epoch unless you set
reload_dataloaders_every_epoch
toTrue
.It’s recommended that all data downloads and preparation happen in
prepare_data()
.Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Return type
- Returns
Single or multiple PyTorch DataLoaders.
Examples
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()
will have an argumentdataset_idx
which matches the order here.
-
validation_end
(outputs)[source] Warning
Deprecated in v0.7.0. Use
validation_epoch_end()
instead. Will be removed in 1.0.0.
-
validation_epoch_end
(outputs)[source] Called at the end of the validation epoch with the outputs of all validation steps.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(val_batch) val_outs.append(out) validation_epoch_end(val_outs)
- Parameters
outputs¶ (
Union
[List
[Dict
[str
,Tensor
]],List
[List
[Dict
[str
,Tensor
]]]]) – List of outputs you defined invalidation_step()
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.- Return type
- Returns
Dict or OrderedDict. May have the following optional keys:
progress_bar (dict for progress bar display; only tensors)
log (dict of metrics to add to logger; only tensors).
Note
If you didn’t define a
validation_step()
, this won’t be called.The outputs here are strictly for logging or progress bar.
If you don’t need to display anything, don’t return anything.
If you want to manually set current step, you can specify the ‘step’ key in the ‘log’ dict.
Examples
With a single dataloader:
def validation_epoch_end(self, outputs): val_acc_mean = 0 for output in outputs: val_acc_mean += output['val_acc'] val_acc_mean /= len(outputs) tqdm_dict = {'val_acc': val_acc_mean.item()} # show val_acc in progress bar but only log val_loss results = { 'progress_bar': tqdm_dict, 'log': {'val_acc': val_acc_mean.item()} } return results
With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each validation step for that dataloader.
def validation_epoch_end(self, outputs): val_acc_mean = 0 i = 0 for dataloader_outputs in outputs: for output in dataloader_outputs: val_acc_mean += output['val_acc'] i += 1 val_acc_mean /= i tqdm_dict = {'val_acc': val_acc_mean.item()} # show val_loss and val_acc in progress bar but only log val_loss results = { 'progress_bar': tqdm_dict, 'log': {'val_acc': val_acc_mean.item(), 'step': self.current_epoch} } return results
-
validation_step
(*args, **kwargs)[source] Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(train_batch) val_outs.append(out) validation_epoch_end(val_outs)
- Parameters
- Return type
- Returns
Dict or OrderedDict - passed to
validation_epoch_end()
. If you definedvalidation_step_end()
it will go to that first.
# pseudocode of order out = validation_step() if defined('validation_step_end'): out = validation_step_end(out) out = validation_epoch_end(out)
# if you have one val dataloader: def validation_step(self, batch, batch_idx) # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx)
Examples
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # all optional... # return whatever you need for the collation function validation_epoch_end output = OrderedDict({ 'val_loss': loss_val, 'val_acc': torch.tensor(val_acc), # everything must be a tensor }) # return an optional dict return output
If you pass in multiple val datasets, validation_step will have an additional argument.
# CASE 2: multiple validation datasets def validation_step(self, batch, batch_idx, dataset_idx): # dataset_idx tells you which dataset this is.
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
-
validation_step_end
(*args, **kwargs)[source] Use this when validating with dp or ddp2 because
validation_step()
will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.Note
If you later switch to ddp or some other mode, this will still be called so that you don’t have to change your code.
# pseudocode sub_batches = split_batches_for_dp(batch) batch_parts_outputs = [validation_step(sub_batch) for sub_batch in sub_batches] validation_step_end(batch_parts_outputs)
- Parameters
batch_parts_outputs¶ – What you return in
validation_step()
for each batch part.- Return type
- Returns
Dict or OrderedDict - passed to the
validation_epoch_end()
method.
Examples
# WITHOUT validation_step_end # if used in DP or DDP2, this batch is 1/num_gpus large def validation_step(self, batch, batch_idx): # batch is 1/num_gpus big x, y = batch out = self(x) loss = self.softmax(out) loss = nce_loss(loss) return {'loss': loss} # -------------- # with validation_step_end to do softmax over the full batch def validation_step(self, batch, batch_idx): # batch is 1/num_gpus big x, y = batch out = self(x) return {'out': out} def validation_epoch_end(self, outputs): # this out is now the full size of the batch out = outputs['out'] # this softmax now uses the full batch size loss = nce_loss(loss) return {'loss': loss}
See also
See the Multi-GPU training guide for more details.
-
_device
= None[source] device reference
-
_dtype
= None[source] Current dtype
-
current_epoch
= None[source] The current epoch
-
global_step
= None[source] Total training batches seen across all epochs
-
logger
= None[source] Pointer to the logger object
-
property
on_gpu
[source] True if your model is currently running on GPUs. Useful to set flags around the LightningModule for different CPU vs GPU behavior.
-
trainer
= None[source] Pointer to the trainer object
-
use_amp
= None[source] True if using amp
-
use_ddp
= None[source] True if using ddp
-
use_ddp2
= None[source] True if using ddp2
-
use_dp
= None[source] True if using dp
-
-
pytorch_lightning.core.
data_loader
(fn)[source] Decorator to make any fx with this use the lazy property.
Warning
This decorator deprecated in v0.7.0 and it will be removed v0.9.0.