Track and Visualize Experiments (intermediate)

Audience: Users who want to track more complex outputs and use third-party experiment managers.


Track audio and other artifacts

To track other artifacts, such as histograms or model topology graphs first select one of the many loggers supported by Lightning

from lightning.pytorch import loggers as pl_loggers

tensorboard = pl_loggers.TensorBoardLogger(save_dir="")
trainer = Trainer(logger=tensorboard)

then access the logger’s API directly

def training_step(self):
    tensorboard = self.logger.experiment
    tensorboard.add_image()
    tensorboard.add_histogram(...)
    tensorboard.add_figure(...)

Comet.ml

To use Comet.ml first install the comet package:

pip install comet-ml

Configure the logger and pass it to the Trainer:

from lightning.pytorch.loggers import CometLogger

comet_logger = CometLogger(api_key="YOUR_COMET_API_KEY")
trainer = Trainer(logger=comet_logger)

Access the comet logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class LitModel(LightningModule):
    def any_lightning_module_function_or_hook(self):
        comet = self.logger.experiment
        fake_images = torch.Tensor(32, 3, 28, 28)
        comet.add_image("generated_images", fake_images, 0)

Here’s the full documentation for the CometLogger.


MLflow

To use MLflow first install the MLflow package:

pip install mlflow

Configure the logger and pass it to the Trainer:

from lightning.pytorch.loggers import MLFlowLogger

mlf_logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs")
trainer = Trainer(logger=mlf_logger)

Access the mlflow logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class LitModel(LightningModule):
    def any_lightning_module_function_or_hook(self):
        mlf_logger = self.logger.experiment
        fake_images = torch.Tensor(32, 3, 28, 28)
        mlf_logger.add_image("generated_images", fake_images, 0)

Here’s the full documentation for the MLFlowLogger.


Neptune.ai

To use Neptune.ai first install the neptune package:

pip install neptune

or with conda:

conda install -c conda-forge neptune

Configure the logger and pass it to the Trainer:

import neptune
from lightning.pytorch.loggers import NeptuneLogger

neptune_logger = NeptuneLogger(
    api_key=neptune.ANONYMOUS_API_TOKEN,  # replace with your own
    project="common/pytorch-lightning-integration",  # format "<WORKSPACE/PROJECT>"
)
trainer = Trainer(logger=neptune_logger)

Access the neptune logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class LitModel(LightningModule):
    def any_lightning_module_function_or_hook(self):
        neptune_logger = self.logger.experiment["your/metadata/structure"]
        neptune_logger.append(metadata)

Here’s the full documentation for the NeptuneLogger.


Tensorboard

TensorBoard can be installed with:

pip install tensorboard

Configure the logger and pass it to the Trainer:

from lightning.pytorch.loggers import TensorBoardLogger

logger = TensorBoardLogger()
trainer = Trainer(logger=logger)

Access the tensorboard logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class LitModel(LightningModule):
    def any_lightning_module_function_or_hook(self):
        tensorboard_logger = self.logger.experiment
        fake_images = torch.Tensor(32, 3, 28, 28)
        tensorboard_logger.add_image("generated_images", fake_images, 0)

Here’s the full documentation for the TensorBoardLogger.


Weights and Biases

To use Weights and Biases (wandb) first install the wandb package:

pip install wandb

Configure the logger and pass it to the Trainer:

from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MNIST", log_model="all")
trainer = Trainer(logger=wandb_logger)

# log gradients and model topology
wandb_logger.watch(model)

Access the wandb logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts

class MyModule(LightningModule):
    def any_lightning_module_function_or_hook(self):
        wandb_logger = self.logger.experiment
        fake_images = torch.Tensor(32, 3, 28, 28)

        # Option 1
        wandb_logger.log({"generated_images": [wandb.Image(fake_images, caption="...")]})

        # Option 2 for specifically logging images
        wandb_logger.log_image(key="generated_images", images=[fake_images])

Here’s the full documentation for the WandbLogger. Demo in Google Colab with hyperparameter search and model logging.


Use multiple exp managers

To use multiple experiment managers at the same time, pass a list to the logger Trainer argument.

from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

logger1 = TensorBoardLogger()
logger2 = WandbLogger()
trainer = Trainer(logger=[logger1, logger2])

Access all loggers from any function (except the LightningModule init) to use their APIs for tracking advanced artifacts

class MyModule(LightningModule):
    def any_lightning_module_function_or_hook(self):
        tensorboard_logger = self.loggers.experiment[0]
        wandb_logger = self.loggers.experiment[1]

        fake_images = torch.Tensor(32, 3, 28, 28)

        tensorboard_logger.add_image("generated_images", fake_images, 0)
        wandb_logger.add_image("generated_images", fake_images, 0)

Track hyperparameters

To track hyperparameters, first call save_hyperparameters from the LightningModule init:

class MyLightningModule(LightningModule):
    def __init__(self, learning_rate, another_parameter, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

If your logger supports tracked hyperparameters, the hyperparameters will automatically show up on the logger dashboard.

Todo

show tracked hyperparameters.


Track model topology

Multiple loggers support visualizing the model topology. Here’s an example that tracks the model topology using Tensorboard.

def any_lightning_module_function_or_hook(self):
    tensorboard_logger = self.logger

    prototype_array = torch.Tensor(32, 1, 28, 27)
    tensorboard_logger.log_graph(model=self, input_array=prototype_array)

Todo

show tensorboard topology.