Shortcuts

Source code for pytorch_lightning.metrics.functional.stat_scores

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification


def _del_column(tensor: torch.Tensor, index: int):
    """ Delete the column at index."""

    return torch.cat([tensor[:, :index], tensor[:, (index + 1):]], 1)


def _stat_scores(
    preds: torch.Tensor,
    target: torch.Tensor,
    reduce: str = "micro",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Calculate the number of tp, fp, tn, fn.

    Args:
        preds:
            An ``(N, C)`` or ``(N, C, X)`` tensor of predictions (0 or 1)
        target:
            An ``(N, C)`` or ``(N, C, X)`` tensor of true labels (0 or 1)
        reduce:
            One of ``'micro'``, ``'macro'``, ``'samples'``

    Return:
        Returns a list of 4 tensors; tp, fp, tn, fn.
        The shape of the returned tensors depnds on the shape of the inputs
        and the ``reduce`` parameter:

        If inputs are of the shape ``(N, C)``, then
        - If ``reduce='micro'``, the returned tensors are 1 element tensors
        - If ``reduce='macro'``, the returned tensors are ``(C,)`` tensors
        - If ``reduce'samples'``, the returned tensors are ``(N,)`` tensors

        If inputs are of the shape ``(N, C, X)``, then
        - If ``reduce='micro'``, the returned tensors are ``(N,)`` tensors
        - If ``reduce='macro'``, the returned tensors are ``(N,C)`` tensors
        - If ``reduce='samples'``, the returned tensors are ``(N,X)`` tensors
    """
    if reduce == "micro":
        dim = [0, 1] if preds.ndim == 2 else [1, 2]
    elif reduce == "macro":
        dim = 0 if preds.ndim == 2 else 2
    elif reduce == "samples":
        dim = 1

    true_pred, false_pred = target == preds, target != preds
    pos_pred, neg_pred = preds == 1, preds == 0

    tp = (true_pred * pos_pred).sum(dim=dim)
    fp = (false_pred * pos_pred).sum(dim=dim)

    tn = (true_pred * neg_pred).sum(dim=dim)
    fn = (false_pred * neg_pred).sum(dim=dim)

    return tp.long(), fp.long(), tn.long(), fn.long()


def _stat_scores_update(
    preds: torch.Tensor,
    target: torch.Tensor,
    reduce: str = "micro",
    mdmc_reduce: Optional[str] = None,
    num_classes: Optional[int] = None,
    top_k: Optional[int] = None,
    threshold: float = 0.5,
    is_multiclass: Optional[bool] = None,
    ignore_index: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

    preds, target, _ = _input_format_classification(
        preds, target, threshold=threshold, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k
    )

    if ignore_index is not None and not 0 <= ignore_index < preds.shape[1]:
        raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {preds.shape[0]} classes")

    if ignore_index is not None and preds.shape[1] == 1:
        raise ValueError("You can not use `ignore_index` with binary data.")

    if preds.ndim == 3:
        if not mdmc_reduce:
            raise ValueError(
                "When your inputs are multi-dimensional multi-class, you have to set the `mdmc_reduce` parameter"
            )
        if mdmc_reduce == "global":
            preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
            target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])

    # Delete what is in ignore_index, if applicable (and classes don't matter):
    if ignore_index is not None and reduce != "macro":
        preds = _del_column(preds, ignore_index)
        target = _del_column(target, ignore_index)

    tp, fp, tn, fn = _stat_scores(preds, target, reduce=reduce)

    # Take care of ignore_index
    if ignore_index is not None and reduce == "macro":
        tp[..., ignore_index] = -1
        fp[..., ignore_index] = -1
        tn[..., ignore_index] = -1
        fn[..., ignore_index] = -1

    return tp, fp, tn, fn


def _stat_scores_compute(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor) -> torch.Tensor:

    outputs = [
        tp.unsqueeze(-1),
        fp.unsqueeze(-1),
        tn.unsqueeze(-1),
        fn.unsqueeze(-1),
        tp.unsqueeze(-1) + fn.unsqueeze(-1),  # support
    ]
    outputs = torch.cat(outputs, -1)
    outputs = torch.where(outputs < 0, torch.tensor(-1, device=outputs.device), outputs)

    return outputs


[docs]def stat_scores( preds: torch.Tensor, target: torch.Tensor, reduce: str = "micro", mdmc_reduce: Optional[str] = None, num_classes: Optional[int] = None, top_k: Optional[int] = None, threshold: float = 0.5, is_multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, ) -> torch.Tensor: """Computes the number of true positives, false positives, true negatives, false negatives. Related to `Type I and Type II errors <https://en.wikipedia.org/wiki/Type_I_and_type_II_errors>`__ and the `confusion matrix <https://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion>`__. The reduction method (how the statistics are aggregated) is controlled by the ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`. Args: preds: Predictions from model (probabilities or labels) target: Ground truth values threshold: Threshold probability value for transforming probability predictions to binary (0 or 1) predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, this parameter defaults to 1. Should be left unset (``None``) for inputs with label predictions. reduce: Defines the reduction that is applied. Should be one of the following: - ``'micro'`` [default]: Counts the statistics by summing over all [sample, class] combinations (globally). Each statistic is represented by a single integer. - ``'macro'``: Counts the statistics for each class separately (over all samples). Each statistic is represented by a ``(C,)`` tensor. Requires ``num_classes`` to be set. - ``'samples'``: Counts the statistics for each sample separately (over all classes). Each statistic is represented by a ``(N, )`` 1d tensor. Note that what is considered a sample in the multi-dimensional multi-class case depends on the value of ``mdmc_reduce``. num_classes: Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. ignore_index: Specify a class (label) to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and ``reduce='macro'``, the class statistics for the ignored class will all be returned as ``-1``. mdmc_reduce: Defines how the multi-dimensional multi-class inputs are handeled. Should be one of the following: - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class (see :ref:`extensions/metrics:input types` for the definition of input types). - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then the outputs are concatenated together. In each sample the extra axes ``...`` are flattened to become the sub-sample axis, and statistics for each sample are computed by treating the sub-sample axis as the ``N`` axis for that sample. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's :ref:`documentation section <extensions/metrics:using the is_multiclass parameter>` for a more detailed explanation and examples. Return: The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape depends on the ``reduce`` and ``mdmc_reduce`` (in case of multi-dimensional multi-class data) parameters: - If the data is not multi-dimensional multi-class, then - If ``reduce='micro'``, the shape will be ``(5, )`` - If ``reduce='macro'``, the shape will be ``(C, 5)``, where ``C`` stands for the number of classes - If ``reduce='samples'``, the shape will be ``(N, 5)``, where ``N`` stands for the number of samples - If the data is multi-dimensional multi-class and ``mdmc_reduce='global'``, then - If ``reduce='micro'``, the shape will be ``(5, )`` - If ``reduce='macro'``, the shape will be ``(C, 5)`` - If ``reduce='samples'``, the shape will be ``(N*X, 5)``, where ``X`` stands for the product of sizes of all "extra" dimensions of the data (i.e. all dimensions except for ``C`` and ``N``) - If the data is multi-dimensional multi-class and ``mdmc_reduce='samplewise'``, then - If ``reduce='micro'``, the shape will be ``(N, 5)`` - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` Example: >>> from pytorch_lightning.metrics.functional import stat_scores >>> preds = torch.tensor([1, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) >>> stat_scores(preds, target, reduce='macro', num_classes=3) tensor([[0, 1, 2, 1, 1], [1, 1, 1, 1, 2], [1, 0, 3, 0, 1]]) >>> stat_scores(preds, target, reduce='micro') tensor([2, 2, 6, 2, 4]) """ if reduce not in ["micro", "macro", "samples"]: raise ValueError(f"The `reduce` {reduce} is not valid.") if mdmc_reduce not in [None, "samplewise", "global"]: raise ValueError(f"The `mdmc_reduce` {mdmc_reduce} is not valid.") if reduce == "macro" and (not num_classes or num_classes < 1): raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") tp, fp, tn, fn = _stat_scores_update( preds, target, reduce=reduce, mdmc_reduce=mdmc_reduce, top_k=top_k, threshold=threshold, num_classes=num_classes, is_multiclass=is_multiclass, ignore_index=ignore_index, ) return _stat_scores_compute(tp, fp, tn, fn)

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision cf5dc04d.

Built with Sphinx using a theme provided by Read the Docs.