Shortcuts

pytorch_lightning.core.grads module

Module to describe gradients

class pytorch_lightning.core.grads.GradInformation(*args, **kwargs)[source]

Bases: torch.nn.Module

grad_norm(norm_type)[source]

Compute each parameter’s gradient’s norm and their overall norm.

The overall norm is computed over all gradients together, as if they were concatenated into a single vector.

Parameters

norm_type (Union[float, int, str]) – The type of the used p-norm, cast to float if necessary. Can be 'inf' for infinity norm.

Returns

The dictionary of p-norms of each parameter’s gradient and

a special entry for the total p-norm of the gradients viewed as a single vector.

Return type

norms