
Source code for pytorch_lightning.metrics.classification.stat_scores

# Copyright The PyTorch Lightning team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Tuple

import torch

from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_compute, _stat_scores_update
from pytorch_lightning.metrics.metric import Metric

[docs]class StatScores(Metric): """Computes the number of true positives, false positives, true negatives, false negatives. Related to `Type I and Type II errors <>`__ and the `confusion matrix <>`__. The reduction method (how the statistics are aggregated) is controlled by the ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. Args: threshold: Threshold probability value for transforming probability predictions to binary (0 or 1) predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, this parameter defaults to 1. Should be left unset (``None``) for inputs with label predictions. reduce: Defines the reduction that is applied. Should be one of the following: - ``'micro'`` [default]: Counts the statistics by summing over all [sample, class] combinations (globally). Each statistic is represented by a single integer. - ``'macro'``: Counts the statistics for each class separately (over all samples). Each statistic is represented by a ``(C,)`` tensor. Requires ``num_classes`` to be set. - ``'samples'``: Counts the statistics for each sample separately (over all classes). Each statistic is represented by a ``(N, )`` 1d tensor. Note that what is considered a sample in the multi-dimensional multi-class case depends on the value of ``mdmc_reduce``. num_classes: Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. ignore_index: Specify a class (label) to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and ``reduce='macro'``, the class statistics for the ignored class will all be returned as ``-1``. mdmc_reduce: Defines how the multi-dimensional multi-class inputs are handeled. Should be one of the following: - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class (see :ref:`extensions/metrics:input types` for the definition of input types). - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then the outputs are concatenated together. In each sample the extra axes ``...`` are flattened to become the sub-sample axis, and statistics for each sample are computed by treating the sub-sample axis as the ``N`` axis for that sample. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's :ref:`documentation section <extensions/metrics:using the is_multiclass parameter>` for a more detailed explanation and examples. compute_on_step: Forward only calls ``update()`` and return ``None`` if this is set to ``False``. dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step process_group: Specify the process group on which synchronization is called. default: ``None`` (which selects the entire world) dist_sync_fn: Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather. Example: >>> from pytorch_lightning.metrics.classification import StatScores >>> preds = torch.tensor([1, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) >>> stat_scores = StatScores(reduce='macro', num_classes=3) >>> stat_scores(preds, target) tensor([[0, 1, 2, 1, 1], [1, 1, 1, 1, 2], [1, 0, 3, 0, 1]]) >>> stat_scores = StatScores(reduce='micro') >>> stat_scores(preds, target) tensor([2, 2, 6, 2, 4]) """ def __init__( self, threshold: float = 0.5, top_k: Optional[int] = None, reduce: str = "micro", num_classes: Optional[int] = None, ignore_index: Optional[int] = None, mdmc_reduce: Optional[str] = None, is_multiclass: Optional[bool] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, ) self.reduce = reduce self.mdmc_reduce = mdmc_reduce self.num_classes = num_classes self.threshold = threshold self.is_multiclass = is_multiclass self.ignore_index = ignore_index self.top_k = top_k if not 0 < threshold < 1: raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") if reduce not in ["micro", "macro", "samples"]: raise ValueError(f"The `reduce` {reduce} is not valid.") if mdmc_reduce not in [None, "samplewise", "global"]: raise ValueError(f"The `mdmc_reduce` {mdmc_reduce} is not valid.") if reduce == "macro" and (not num_classes or num_classes < 1): raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") if mdmc_reduce != "samplewise" and reduce != "samples": if reduce == "micro": zeros_shape = [] elif reduce == "macro": zeros_shape = (num_classes, ) default, reduce_fn = lambda: torch.zeros(zeros_shape, dtype=torch.long), "sum" else: default, reduce_fn = lambda: [], None for s in ("tp", "fp", "tn", "fn"): self.add_state(s, default=default(), dist_reduce_fx=reduce_fn)
[docs] def update(self, preds: torch.Tensor, target: torch.Tensor): """ Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information on input types. Args: preds: Predictions from model (probabilities or labels) target: Ground truth values """ tp, fp, tn, fn = _stat_scores_update( preds, target, reduce=self.reduce, mdmc_reduce=self.mdmc_reduce, threshold=self.threshold, num_classes=self.num_classes, top_k=self.top_k, is_multiclass=self.is_multiclass, ignore_index=self.ignore_index, ) # Update states if self.reduce != "samples" and self.mdmc_reduce != "samplewise": += tp self.fp += fp += tn self.fn += fn else: self.fp.append(fp) self.fn.append(fn)
def _get_final_stats(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Performs concatenation on the stat scores if neccesary, before passing them to a compute function. """ if isinstance(, list): tp = fp = tn = fn = else: tp, fp, tn, fn =, self.fp,, self.fn return tp, fp, tn, fn
[docs] def compute(self) -> torch.Tensor: """ Computes the stat scores based on inputs passed in to ``update`` previously. Return: The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape depends on the ``reduce`` and ``mdmc_reduce`` (in case of multi-dimensional multi-class data) parameters: - If the data is not multi-dimensional multi-class, then - If ``reduce='micro'``, the shape will be ``(5, )`` - If ``reduce='macro'``, the shape will be ``(C, 5)``, where ``C`` stands for the number of classes - If ``reduce='samples'``, the shape will be ``(N, 5)``, where ``N`` stands for the number of samples - If the data is multi-dimensional multi-class and ``mdmc_reduce='global'``, then - If ``reduce='micro'``, the shape will be ``(5, )`` - If ``reduce='macro'``, the shape will be ``(C, 5)`` - If ``reduce='samples'``, the shape will be ``(N*X, 5)``, where ``X`` stands for the product of sizes of all "extra" dimensions of the data (i.e. all dimensions except for ``C`` and ``N``) - If the data is multi-dimensional multi-class and ``mdmc_reduce='samplewise'``, then - If ``reduce='micro'``, the shape will be ``(N, 5)`` - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` """ tp, fp, tn, fn = self._get_final_stats() return _stat_scores_compute(tp, fp, tn, fn)

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision cf5dc04d.

Built with Sphinx using a theme provided by Read the Docs.