Model Hooks¶
There are cases when you might want to do something different at different parts of the training/validation loop. To enable a hook, simply override the method in your LightningModule and the trainer will call it at the correct time.
Contributing If there’s a hook you’d like to add, simply:
Fork PyTorchLightning.
Add the hook to
pytorch_lightning.core.hooks.ModelHooks.Add it in the correct place in
pytorch_lightning.trainerwhere it should be called.
Hooks lifecycle¶
Training set-up¶
Training loop¶
Validation loop¶
model.zero_grad()model.eval()torch.set_grad_enabled(False)model.train()torch.set_grad_enabled(True)
Test loop¶
model.zero_grad()model.eval()torch.set_grad_enabled(False)model.train()torch.set_grad_enabled(True)
-
class
pytorch_lightning.core.hooks.ModelHooks(*args, **kwargs)[source] Bases:
torch.nn.Module-
backward(trainer, loss, optimizer, optimizer_idx)[source] Override backward with your own implementation if you need to.
- Parameters
Called to perform backward step. Feel free to override as needed.
The loss passed in has already been scaled for accumulated gradients if requested.
Example:
def backward(self, use_amp, loss, optimizer): if use_amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward()
- Return type
None
-
on_after_backward()[source] Called in the training loop after loss.backward() and before optimizers do anything. This is the ideal place to inspect or log gradient information.
Example:
def on_after_backward(self): # example to inspect gradient information in tensorboard if self.trainer.global_step % 25 == 0: # don't make the tf file huge params = self.state_dict() for k, v in params.items(): grads = v name = k self.logger.experiment.add_histogram(tag=name, values=grads, global_step=self.trainer.global_step)
- Return type
None
-
on_batch_end()[source] Called in the training loop after the batch.
- Return type
None
-
on_batch_start(batch)[source] Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
-
on_before_zero_grad(optimizer)[source] Called after optimizer.step() and before optimizer.zero_grad().
Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.
This is where it is called:
for optimizer in optimizers: optimizer.step() model.on_before_zero_grad(optimizer) # < ---- called here optimizer.zero_grad
- Parameters
optimizer¶ (
Optimizer) – The optimizer for which grads should be zeroed.- Return type
None
-
on_epoch_end()[source] Called in the training loop at the very end of the epoch.
- Return type
None
-
on_epoch_start()[source] Called in the training loop at the very beginning of the epoch.
- Return type
None
-
on_post_performance_check()[source] Called at the very end of the validation loop.
- Return type
None
-
on_pre_performance_check()[source] Called at the very beginning of the validation loop.
- Return type
None
-
on_sanity_check_start()[source] Called before starting evaluation.
Warning
Deprecated. Will be removed in v0.9.0.
-
on_train_end()[source] Called at the end of training before logger experiment is closed.
- Return type
None
-
on_train_start()[source] Called at the beginning of training before sanity check.
- Return type
None
-