Shortcuts

Loggers

Lightning supports the most popular logging frameworks (TensorBoard, Comet, Weights and Biases, etc…). To use a logger, simply pass it into the Trainer. Lightning uses TensorBoard by default.

from pytorch_lightning import loggers as pl_loggers

tb_logger = pl_loggers.TensorBoardLogger('logs/')
trainer = Trainer(logger=tb_logger)

Choose from any of the others such as MLflow, Comet, Neptune, WandB, …

comet_logger = pl_loggers.CometLogger(save_dir='logs/')
trainer = Trainer(logger=comet_logger)

To use multiple loggers, simply pass in a list or tuple of loggers …

tb_logger = pl_loggers.TensorBoardLogger('logs/')
comet_logger = pl_loggers.CometLogger(save_dir='logs/')
trainer = Trainer(logger=[tb_logger, comet_logger])
Note:

All loggers log by default to os.getcwd(). To change the path without creating a logger set Trainer(default_root_dir='/your/path/to/save/checkpoints')


Custom Logger

You can implement your own logger by writing a class that inherits from LightningLoggerBase. Use the rank_zero_only() decorator to make sure that only the first process in DDP training logs data.

from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import LightningLoggerBase

class MyLogger(LightningLoggerBase):

    @rank_zero_only
    def log_hyperparams(self, params):
        # params is an argparse.Namespace
        # your code to record hyperparameters goes here
        pass

    @rank_zero_only
    def log_metrics(self, metrics, step):
        # metrics is a dictionary of metric names and values
        # your code to record metrics goes here
        pass

    def save(self):
        # Optional. Any code necessary to save logger data goes here
        pass

    @rank_zero_only
    def finalize(self, status):
        # Optional. Any code that needs to be run after training
        # finishes goes here
        pass

If you write a logger that may be useful to others, please send a pull request to add it to Lighting!


Using loggers

Call the logger anywhere except __init__ in your LightningModule by doing:

class LitModel(LightningModule):
    def training_step(self, batch, batch_idx):
        # example
        self.logger.experiment.whatever_method_summary_writer_supports(...)

        # example if logger is a tensorboard logger
        self.logger.experiment.add_image('images', grid, 0)
        self.logger.experiment.add_graph(model, images)

    def any_lightning_module_function_or_hook(self):
        self.logger.experiment.add_histogram(...)

Read more in the Experiment Logging use case.


Supported Loggers

The following are loggers we support

Comet

class pytorch_lightning.loggers.comet.CometLogger(api_key=None, save_dir=None, workspace=None, project_name=None, rest_api_key=None, experiment_name=None, experiment_key=None, **kwargs)[source]

Bases: pytorch_lightning.loggers.base.LightningLoggerBase

Log using Comet.ml. Install it with pip:

pip install comet-ml

Comet requires either an API Key (online mode) or a local directory path (offline mode).

ONLINE MODE

Example

>>> import os
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.loggers import CometLogger
>>> # arguments made to CometLogger are passed on to the comet_ml.Experiment class
>>> comet_logger = CometLogger(
...     api_key=os.environ.get('COMET_API_KEY'),
...     workspace=os.environ.get('COMET_WORKSPACE'),  # Optional
...     save_dir='.',  # Optional
...     project_name='default_project',  # Optional
...     rest_api_key=os.environ.get('COMET_REST_API_KEY'),  # Optional
...     experiment_name='default'  # Optional
... )
>>> trainer = Trainer(logger=comet_logger)

OFFLINE MODE

Example

>>> from pytorch_lightning.loggers import CometLogger
>>> # arguments made to CometLogger are passed on to the comet_ml.Experiment class
>>> comet_logger = CometLogger(
...     save_dir='.',
...     workspace=os.environ.get('COMET_WORKSPACE'),  # Optional
...     project_name='default_project',  # Optional
...     rest_api_key=os.environ.get('COMET_REST_API_KEY'),  # Optional
...     experiment_name='default'  # Optional
... )
>>> trainer = Trainer(logger=comet_logger)
Parameters
  • api_key (Optional[str]) – Required in online mode. API key, found on Comet.ml

  • save_dir (Optional[str]) – Required in offline mode. The path for the directory to save local comet logs

  • workspace (Optional[str]) – Optional. Name of workspace for this user

  • project_name (Optional[str]) – Optional. Send your experiment to a specific project. Otherwise will be sent to Uncategorized Experiments. If the project name does not already exist, Comet.ml will create a new project.

  • rest_api_key (Optional[str]) – Optional. Rest API key found in Comet.ml settings. This is used to determine version number

  • experiment_name (Optional[str]) – Optional. String representing the name for this particular experiment on Comet.ml.

  • experiment_key (Optional[str]) – Optional. If set, restores from existing experiment.

finalize(status)[source]

When calling self.experiment.end(), that experiment won’t log any more data to Comet. That’s why, if you need to log any more data, you need to create an ExistingCometExperiment. For example, to log data when testing your model after training, because when training is finalized CometLogger.finalize() is called.

This happens automatically in the experiment() property, when self._experiment is set to None, i.e. self.reset_experiment().

Return type

None

log_hyperparams(params)[source]

Record hyperparameters.

Parameters

params (Union[Dict[str, Any], Namespace]) – Namespace containing the hyperparameters

Return type

None

log_metrics(metrics, step=None)[source]

Records metrics. This method logs metrics as as soon as it received them. If you want to aggregate metrics for one specific step, use the agg_and_log_metrics() method.

Parameters
  • metrics (Dict[str, Union[Tensor, float]]) – Dictionary with metric names as keys and measured quantities as values

  • step (Optional[int]) – Step number at which the metrics should be recorded

Return type

None

property experiment[source]

Actual Comet object. To use Comet features in your LightningModule do the following.

Example:

self.logger.experiment.some_comet_function()
Return type

BaseExperiment

property name[source]

Return the experiment name.

Return type

str

property save_dir[source]

Return the root directory where experiment logs get saved, or None if the logger does not save data locally.

Return type

Optional[str]

property version[source]

Return the experiment version.

Return type

str

MLFlow

class pytorch_lightning.loggers.mlflow.MLFlowLogger(experiment_name='default', tracking_uri=None, tags=None, save_dir='./mlruns')[source]

Bases: pytorch_lightning.loggers.base.LightningLoggerBase

Log using MLflow. Install it with pip:

pip install mlflow

Example

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.loggers import MLFlowLogger
>>> mlf_logger = MLFlowLogger(
...     experiment_name="default",
...     tracking_uri="file:./ml-runs"
... )
>>> trainer = Trainer(logger=mlf_logger)

Use the logger anywhere in you LightningModule as follows:

>>> from pytorch_lightning import LightningModule
>>> class LitModel(LightningModule):
...     def training_step(self, batch, batch_idx):
...         # example
...         self.logger.experiment.whatever_ml_flow_supports(...)
...
...     def any_lightning_module_function_or_hook(self):
...         self.logger.experiment.whatever_ml_flow_supports(...)
Parameters
  • experiment_name (str) – The name of the experiment

  • tracking_uri (Optional[str]) – Address of local or remote tracking server. If not provided, defaults to file:<save_dir>.

  • tags (Optional[Dict[str, Any]]) – A dictionary tags for the experiment.

  • save_dir (Optional[str]) – A path to a local directory where the MLflow runs get saved. Defaults to ./mlflow if tracking_uri is not provided. Has no effect if tracking_uri is provided.

finalize(status='FINISHED')[source]

Do any processing that is necessary to finalize an experiment.

Parameters

status (str) – Status that the experiment finished with (e.g. success, failed, aborted)

Return type

None

log_hyperparams(params)[source]

Record hyperparameters.

Parameters

params (Union[Dict[str, Any], Namespace]) – Namespace containing the hyperparameters

Return type

None

log_metrics(metrics, step=None)[source]

Records metrics. This method logs metrics as as soon as it received them. If you want to aggregate metrics for one specific step, use the agg_and_log_metrics() method.

Parameters
  • metrics (Dict[str, float]) – Dictionary with metric names as keys and measured quantities as values

  • step (Optional[int]) – Step number at which the metrics should be recorded

Return type

None

property experiment[source]

Actual MLflow object. To use MLflow features in your LightningModule do the following.

Example:

self.logger.experiment.some_mlflow_function()
Return type

MlflowClient

property name[source]

Return the experiment name.

Return type

str

property save_dir[source]

The root file directory in which MLflow experiments are saved.

Return type

Optional[str]

Returns

Local path to the root experiment directory if the tracking uri is local. Otherwhise returns None.

property version[source]

Return the experiment version.

Return type

str

Neptune

class pytorch_lightning.loggers.neptune.NeptuneLogger(api_key=None, project_name=None, close_after_fit=True, offline_mode=False, experiment_name=None, upload_source_files=None, params=None, properties=None, tags=None, **kwargs)[source]

Bases: pytorch_lightning.loggers.base.LightningLoggerBase

Log using Neptune. Install it with pip:

pip install neptune-client

The Neptune logger can be used in the online mode or offline (silent) mode. To log experiment data in online mode, NeptuneLogger requires an API key. In offline mode, the logger does not connect to Neptune.

ONLINE MODE

Example

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.loggers import NeptuneLogger
>>> # arguments made to NeptuneLogger are passed on to the neptune.experiments.Experiment class
>>> # We are using an api_key for the anonymous user "neptuner" but you can use your own.
>>> neptune_logger = NeptuneLogger(
...     api_key='ANONYMOUS',
...     project_name='shared/pytorch-lightning-integration',
...     experiment_name='default',  # Optional,
...     params={'max_epochs': 10},  # Optional,
...     tags=['pytorch-lightning', 'mlp']  # Optional,
... )
>>> trainer = Trainer(max_epochs=10, logger=neptune_logger)

OFFLINE MODE

Example

>>> from pytorch_lightning.loggers import NeptuneLogger
>>> # arguments made to NeptuneLogger are passed on to the neptune.experiments.Experiment class
>>> neptune_logger = NeptuneLogger(
...     offline_mode=True,
...     project_name='USER_NAME/PROJECT_NAME',
...     experiment_name='default',  # Optional,
...     params={'max_epochs': 10},  # Optional,
...     tags=['pytorch-lightning', 'mlp']  # Optional,
... )
>>> trainer = Trainer(max_epochs=10, logger=neptune_logger)

Use the logger anywhere in you LightningModule as follows:

>>> from pytorch_lightning import LightningModule
>>> class LitModel(LightningModule):
...     def training_step(self, batch, batch_idx):
...         # log metrics
...         self.logger.experiment.log_metric('acc_train', ...)
...         # log images
...         self.logger.experiment.log_image('worse_predictions', ...)
...         # log model checkpoint
...         self.logger.experiment.log_artifact('model_checkpoint.pt', ...)
...         self.logger.experiment.whatever_neptune_supports(...)
...
...     def any_lightning_module_function_or_hook(self):
...         self.logger.experiment.log_metric('acc_train', ...)
...         self.logger.experiment.log_image('worse_predictions', ...)
...         self.logger.experiment.log_artifact('model_checkpoint.pt', ...)
...         self.logger.experiment.whatever_neptune_supports(...)

If you want to log objects after the training is finished use close_after_fit=False:

neptune_logger = NeptuneLogger(
    ...
    close_after_fit=False,
    ...
)
trainer = Trainer(logger=neptune_logger)
trainer.fit()

# Log test metrics
trainer.test(model)

# Log additional metrics
from sklearn.metrics import accuracy_score

accuracy = accuracy_score(y_true, y_pred)
neptune_logger.experiment.log_metric('test_accuracy', accuracy)

# Log charts
from scikitplot.metrics import plot_confusion_matrix
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(16, 12))
plot_confusion_matrix(y_true, y_pred, ax=ax)
neptune_logger.experiment.log_image('confusion_matrix', fig)

# Save checkpoints folder
neptune_logger.experiment.log_artifact('my/checkpoints')

# When you are done, stop the experiment
neptune_logger.experiment.stop()

See also

Parameters
  • api_key (Optional[str]) – Required in online mode. Neptune API token, found on https://neptune.ai. Read how to get your API key. It is recommended to keep it in the NEPTUNE_API_TOKEN environment variable and then you can leave api_key=None.

  • project_name (Optional[str]) – Required in online mode. Qualified name of a project in a form of “namespace/project_name” for example “tom/minst-classification”. If None, the value of NEPTUNE_PROJECT environment variable will be taken. You need to create the project in https://neptune.ai first.

  • offline_mode (bool) – Optional default False. If True no logs will be sent to Neptune. Usually used for debug purposes.

  • close_after_fit (Optional[bool]) – Optional default True. If False the experiment will not be closed after training and additional metrics, images or artifacts can be logged. Also, remember to close the experiment explicitly by running neptune_logger.experiment.stop().

  • experiment_name (Optional[str]) – Optional. Editable name of the experiment. Name is displayed in the experiment’s Details (Metadata section) and in experiments view as a column.

  • upload_source_files (Optional[List[str]]) – Optional. List of source files to be uploaded. Must be list of str or single str. Uploaded sources are displayed in the experiment’s Source code tab. If None is passed, the Python file from which the experiment was created will be uploaded. Pass an empty list ([]) to upload no files. Unix style pathname pattern expansion is supported. For example, you can pass '\*.py' to upload all python source files from the current directory. For recursion lookup use '\**/\*.py' (for Python 3.5 and later). For more information see glob library.

  • params (Optional[Dict[str, Any]]) – Optional. Parameters of the experiment. After experiment creation params are read-only. Parameters are displayed in the experiment’s Parameters section and each key-value pair can be viewed in the experiments view as a column.

  • properties (Optional[Dict[str, Any]]) – Optional. Default is {}. Properties of the experiment. They are editable after the experiment is created. Properties are displayed in the experiment’s Details section and each key-value pair can be viewed in the experiments view as a column.

  • tags (Optional[List[str]]) – Optional. Default is []. Must be list of str. Tags of the experiment. They are editable after the experiment is created (see: append_tag() and remove_tag()). Tags are displayed in the experiment’s Details section and can be viewed in the experiments view as a column.

append_tags(tags)[source]

Appends tags to the neptune experiment.

Parameters

tags (Union[str, Iterable[str]]) – Tags to add to the current experiment. If str is passed, a single tag is added. If multiple - comma separated - str are passed, all of them are added as tags. If list of str is passed, all elements of the list are added as tags.

Return type

None

finalize(status)[source]

Do any processing that is necessary to finalize an experiment.

Parameters

status (str) – Status that the experiment finished with (e.g. success, failed, aborted)

Return type

None

log_artifact(artifact, destination=None)[source]

Save an artifact (file) in Neptune experiment storage.

Parameters
  • artifact (str) – A path to the file in local filesystem.

  • destination (Optional[str]) – Optional. Default is None. A destination path. If None is passed, an artifact file name will be used.

Return type

None

log_hyperparams(params)[source]

Record hyperparameters.

Parameters

params (Union[Dict[str, Any], Namespace]) – Namespace containing the hyperparameters

Return type

None

log_image(log_name, image, step=None)[source]

Log image data in Neptune experiment

Parameters
  • log_name (str) – The name of log, i.e. bboxes, visualisations, sample_images.

  • image (Union[str, Any]) – The value of the log (data-point). Can be one of the following types: PIL image, matplotlib.figure.Figure, path to image file (str)

  • step (Optional[int]) – Step number at which the metrics should be recorded, must be strictly increasing

Return type

None

log_metric(metric_name, metric_value, step=None)[source]

Log metrics (numeric values) in Neptune experiments.

Parameters
  • metric_name (str) – The name of log, i.e. mse, loss, accuracy.

  • metric_value (Union[Tensor, float, str]) – The value of the log (data-point).

  • step (Optional[int]) – Step number at which the metrics should be recorded, must be strictly increasing

Return type

None

log_metrics(metrics, step=None)[source]

Log metrics (numeric values) in Neptune experiments.

Parameters
  • metrics (Dict[str, Union[Tensor, float]]) – Dictionary with metric names as keys and measured quantities as values

  • step (Optional[int]) – Step number at which the metrics should be recorded, must be strictly increasing

Return type

None

log_text(log_name, text, step=None)[source]

Log text data in Neptune experiments.

Parameters
  • log_name (str) – The name of log, i.e. mse, my_text_data, timing_info.

  • text (str) – The value of the log (data-point).

  • step (Optional[int]) – Step number at which the metrics should be recorded, must be strictly increasing

Return type

None

set_property(key, value)[source]

Set key-value pair as Neptune experiment property.

Parameters
  • key (str) – Property key.

  • value (Any) – New value of a property.

Return type

None

property experiment[source]

Actual Neptune object. To use neptune features in your LightningModule do the following.

Example:

self.logger.experiment.some_neptune_function()
Return type

Experiment

property name[source]

Return the experiment name.

Return type

str

property save_dir[source]

Return the root directory where experiment logs get saved, or None if the logger does not save data locally.

Return type

Optional[str]

property version[source]

Return the experiment version.

Return type

str

Tensorboard

class pytorch_lightning.loggers.tensorboard.TensorBoardLogger(save_dir, name='default', version=None, **kwargs)[source]

Bases: pytorch_lightning.loggers.base.LightningLoggerBase

Log to local file system in TensorBoard format. Implemented using SummaryWriter. Logs are saved to os.path.join(save_dir, name, version). This is the default logger in Lightning, it comes preinstalled.

Example

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.loggers import TensorBoardLogger
>>> logger = TensorBoardLogger("tb_logs", name="my_model")
>>> trainer = Trainer(logger=logger)
Parameters
  • save_dir (str) – Save directory

  • name (Optional[str]) – Experiment name. Defaults to 'default'. If it is the empty string then no per-experiment subdirectory is used.

  • version (Union[int, str, None]) – Experiment version. If version is not specified the logger inspects the save directory for existing versions, then automatically assigns the next available version. If it is a string then it is used as the run-specific subdirectory name, otherwise 'version_${version}' is used.

  • **kwargs – Other arguments are passed directly to the SummaryWriter constructor.

finalize(status)[source]

Do any processing that is necessary to finalize an experiment.

Parameters

status (str) – Status that the experiment finished with (e.g. success, failed, aborted)

Return type

None

log_hyperparams(params, metrics=None)[source]

Record hyperparameters.

Parameters

params (Union[Dict[str, Any], Namespace]) – Namespace containing the hyperparameters

Return type

None

log_metrics(metrics, step=None)[source]

Records metrics. This method logs metrics as as soon as it received them. If you want to aggregate metrics for one specific step, use the agg_and_log_metrics() method.

Parameters
  • metrics (Dict[str, float]) – Dictionary with metric names as keys and measured quantities as values

  • step (Optional[int]) – Step number at which the metrics should be recorded

Return type

None

save()[source]

Save log data.

Return type

None

property experiment[source]

Actual tensorboard object. To use TensorBoard features in your LightningModule do the following.

Example:

self.logger.experiment.some_tensorboard_function()
Return type

SummaryWriter

property log_dir[source]

The directory for this run’s tensorboard checkpoint. By default, it is named 'version_${self.version}' but it can be overridden by passing a string value for the constructor’s version parameter instead of None or an int.

Return type

str

property name[source]

Return the experiment name.

Return type

str

property root_dir[source]

Parent directory for all tensorboard checkpoint subdirectories. If the experiment name parameter is None or the empty string, no experiment subdirectory is used and the checkpoint will be saved in “save_dir/version_dir”

Return type

str

property save_dir[source]

Return the root directory where experiment logs get saved, or None if the logger does not save data locally.

Return type

Optional[str]

property version[source]

Return the experiment version.

Return type

int

Test-tube

class pytorch_lightning.loggers.test_tube.TestTubeLogger(save_dir, name='default', description=None, debug=False, version=None, create_git_tag=False)[source]

Bases: pytorch_lightning.loggers.base.LightningLoggerBase

Log to local file system in TensorBoard format but using a nicer folder structure (see full docs). Install it with pip:

pip install test_tube

Example

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.loggers import TestTubeLogger
>>> logger = TestTubeLogger("tt_logs", name="my_exp_name")
>>> trainer = Trainer(logger=logger)

Use the logger anywhere in your LightningModule as follows:

>>> from pytorch_lightning import LightningModule
>>> class LitModel(LightningModule):
...     def training_step(self, batch, batch_idx):
...         # example
...         self.logger.experiment.whatever_method_summary_writer_supports(...)
...
...     def any_lightning_module_function_or_hook(self):
...         self.logger.experiment.add_histogram(...)
Parameters
  • save_dir (str) – Save directory

  • name (str) – Experiment name. Defaults to 'default'.

  • description (Optional[str]) – A short snippet about this experiment

  • debug (bool) – If True, it doesn’t log anything.

  • version (Optional[int]) – Experiment version. If version is not specified the logger inspects the save directory for existing versions, then automatically assigns the next available version.

  • create_git_tag (bool) – If True creates a git tag to save the code used in this experiment.

close()[source]

Do any cleanup that is necessary to close an experiment.

Return type

None

finalize(status)[source]

Do any processing that is necessary to finalize an experiment.

Parameters

status (str) – Status that the experiment finished with (e.g. success, failed, aborted)

Return type

None

log_hyperparams(params)[source]

Record hyperparameters.

Parameters

params (Union[Dict[str, Any], Namespace]) – Namespace containing the hyperparameters

Return type

None

log_metrics(metrics, step=None)[source]

Records metrics. This method logs metrics as as soon as it received them. If you want to aggregate metrics for one specific step, use the agg_and_log_metrics() method.

Parameters
  • metrics (Dict[str, float]) – Dictionary with metric names as keys and measured quantities as values

  • step (Optional[int]) – Step number at which the metrics should be recorded

Return type

None

save()[source]

Save log data.

Return type

None

property experiment[source]

Actual TestTube object. To use TestTube features in your LightningModule do the following.

Example:

self.logger.experiment.some_test_tube_function()
Return type

Experiment

property name[source]

Return the experiment name.

Return type

str

property save_dir[source]

Return the root directory where experiment logs get saved, or None if the logger does not save data locally.

Return type

Optional[str]

property version[source]

Return the experiment version.

Return type

int