Shortcuts

model_summary

Functions

get_formatted_model_size

rtype

str

get_human_readable_count

Abbreviates an integer number with K, M, B, T for thousands, millions, billions and trillions, respectively.

parse_batch_shape

rtype

Union[str, List]

summarize

Summarize the LightningModule specified by lightning_module.

Classes

LayerSummary

Summary class for a single layer in a LightningModule.

ModelSummary

Generates a summary of all layers in a LightningModule.

Utilities related to model weights summary.

class pytorch_lightning.utilities.model_summary.LayerSummary(module)[source]

Bases: object

Summary class for a single layer in a LightningModule. It collects the following information:

  • Type of the layer (e.g. Linear, BatchNorm1d, …)

  • Input shape

  • Output shape

  • Number of parameters

The input and output shapes are only known after the example input array was passed through the model.

Example:

>>> model = torch.nn.Conv2d(3, 8, 3)
>>> summary = LayerSummary(model)
>>> summary.num_parameters
224
>>> summary.layer_type
'Conv2d'
>>> output = model(torch.rand(1, 3, 5, 5))
>>> summary.in_size
[1, 3, 5, 5]
>>> summary.out_size
[1, 8, 3, 3]
Parameters

module (Module) – A module to summarize

detach_hook()[source]

Removes the forward hook if it was not already removed in the forward pass.

Will be called after the summary is created.

Return type

None

property layer_type: str

Returns the class name of the module.

Return type

str

property num_parameters: int

Returns the number of parameters in this module.

Return type

int

class pytorch_lightning.utilities.model_summary.ModelSummary(model, max_depth=1)[source]

Bases: object

Generates a summary of all layers in a LightningModule.

Parameters
  • model (LightningModule) – The model to summarize (also referred to as the root module).

  • max_depth (int) – Maximum depth of modules to show. Use -1 to show all modules or 0 to show no summary. Defaults to 1.

The string representation of this summary prints a table with columns containing the name, type and number of parameters for each layer.

The root module may also have an attribute example_input_array as shown in the example below. If present, the root module will be called with it as input to determine the intermediate input- and output shapes of all layers. Supported are tensors and nested lists and tuples of tensors. All other types of inputs will be skipped and show as ? in the summary table. The summary will also display ? for layers not used in the forward pass.

Example:

>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
...
...     def __init__(self):
...         super().__init__()
...         self.net = nn.Sequential(nn.Linear(256, 512), nn.BatchNorm1d(512))
...         self.example_input_array = torch.zeros(10, 256)  # optional
...
...     def forward(self, x):
...         return self.net(x)
...
>>> model = LitModel()
>>> ModelSummary(model, max_depth=1)  
  | Name | Type       | Params | In sizes  | Out sizes
------------------------------------------------------------
0 | net  | Sequential | 132 K  | [10, 256] | [10, 512]
------------------------------------------------------------
132 K     Trainable params
0         Non-trainable params
132 K     Total params
0.530     Total estimated model params size (MB)
>>> ModelSummary(model, max_depth=-1)  
  | Name  | Type        | Params | In sizes  | Out sizes
--------------------------------------------------------------
0 | net   | Sequential  | 132 K  | [10, 256] | [10, 512]
1 | net.0 | Linear      | 131 K  | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K    | [10, 512] | [10, 512]
--------------------------------------------------------------
132 K     Trainable params
0         Non-trainable params
132 K     Total params
0.530     Total estimated model params size (MB)
pytorch_lightning.utilities.model_summary.get_human_readable_count(number)[source]

Abbreviates an integer number with K, M, B, T for thousands, millions, billions and trillions, respectively.

Examples

>>> get_human_readable_count(123)
'123  '
>>> get_human_readable_count(1234)  # (one thousand)
'1.2 K'
>>> get_human_readable_count(2e6)   # (two million)
'2.0 M'
>>> get_human_readable_count(3e9)   # (three billion)
'3.0 B'
>>> get_human_readable_count(4e14)  # (four hundred trillion)
'400 T'
>>> get_human_readable_count(5e15)  # (more than trillion)
'5,000 T'
Parameters

number (int) – a positive integer number

Return type

str

Returns

A string formatted according to the pattern described above.

pytorch_lightning.utilities.model_summary.summarize(lightning_module, max_depth=1)[source]

Summarize the LightningModule specified by lightning_module.

Parameters
  • lightning_module (LightningModule) – LightningModule to summarize.

  • max_depth (int) – The maximum depth of layer nesting that the summary will include. A value of 0 turns the layer summary off. Default: 1.

Return type

ModelSummary

Returns

The model summary object