Shortcuts

Source code for pytorch_lightning.metrics.functional.auroc

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distutils.version import LooseVersion
from typing import Optional, Sequence, Tuple

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional.auc import auc
from pytorch_lightning.metrics.functional.roc import roc
from pytorch_lightning.utilities import LightningEnum


class AverageMethods(LightningEnum):
    """ Type of averages """
    MACRO = 'macro'
    WEIGHTED = 'weighted'
    NONE = None


def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, str]:
    # use _input_format_classification for validating the input and get the mode of data
    _, _, mode = _input_format_classification(preds, target)

    if mode == 'multi class multi dim':
        n_classes = preds.shape[1]
        preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
        target = target.flatten()
    if mode == 'multi-label' and preds.ndim > 2:
        n_classes = preds.shape[1]
        preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
        target = target.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)

    return preds, target, mode


def _auroc_compute(
    preds: torch.Tensor,
    target: torch.Tensor,
    mode: str,
    num_classes: Optional[int] = None,
    pos_label: Optional[int] = None,
    average: Optional[str] = 'macro',
    max_fpr: Optional[float] = None,
    sample_weights: Optional[Sequence] = None,
) -> torch.Tensor:
    # binary mode override num_classes
    if mode == 'binary':
        num_classes = 1

    # check max_fpr parameter
    if max_fpr is not None:
        if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
            raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")

        if LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
            raise RuntimeError(
                "`max_fpr` argument requires `torch.bucketize` which"
                " is not available below PyTorch version 1.6"
            )

        # max_fpr parameter is only support for binary
        if mode != 'binary':
            raise ValueError(
                f"Partial AUC computation not available in"
                f" multilabel/multiclass setting, 'max_fpr' must be"
                f" set to `None`, received `{max_fpr}`."
            )

    # calculate fpr, tpr
    if mode == 'multi-label':
        # for multilabel we iteratively evaluate roc in a binary fashion
        output = [
            roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights)
            for i in range(num_classes)
        ]
        fpr = [o[0] for o in output]
        tpr = [o[1] for o in output]
    else:
        fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights)

    # calculate standard roc auc score
    if max_fpr is None or max_fpr == 1:
        if num_classes != 1:
            # calculate auc scores per class
            auc_scores = [auc(x, y) for x, y in zip(fpr, tpr)]

            # calculate average
            if average == AverageMethods.NONE:
                return auc_scores
            elif average == AverageMethods.MACRO:
                return torch.mean(torch.stack(auc_scores))
            elif average == AverageMethods.WEIGHTED:
                if mode == DataType.MULTILABEL:
                    support = torch.sum(target, dim=0)
                else:
                    support = torch.bincount(target.flatten(), minlength=num_classes)
                return torch.sum(torch.stack(auc_scores) * support / support.sum())

            allowed_average = [e.value for e in AverageMethods]
            raise ValueError(
                f"Argument `average` expected to be one of the following:"
                f" {allowed_average} but got {average}"
            )

        return auc(fpr, tpr)

    max_fpr = torch.tensor(max_fpr, device=fpr.device)
    # Add a single point at max_fpr and interpolate its tpr value
    stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True)
    weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1])
    interp_tpr = torch.lerp(tpr[stop - 1], tpr[stop], weight)
    tpr = torch.cat([tpr[:stop], interp_tpr.view(1)])
    fpr = torch.cat([fpr[:stop], max_fpr.view(1)])

    # Compute partial AUC
    partial_auc = auc(fpr, tpr)

    # McClish correction: standardize result to be 0.5 if non-discriminant
    # and 1 if maximal
    min_area = 0.5 * max_fpr**2
    max_area = max_fpr
    return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))


[docs]def auroc( preds: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, pos_label: Optional[int] = None, average: Optional[str] = 'macro', max_fpr: Optional[float] = None, sample_weights: Optional[Sequence] = None, ) -> torch.Tensor: """ Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC) <https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Further_interpretations>`_ Args: preds: predictions from model (logits or probabilities) target: Ground truth labels num_classes: integer with number of classes. Not nessesary to provide for binary problems. pos_label: integer determining the positive class. Default is ``None`` which for binary problem is translate to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0,num_classes-1] average: - ``'macro'`` computes metric for each class and uniformly averages them - ``'weighted'`` computes metric for each class and does a weighted-average, where each class is weighted by their support (accounts for class imbalance) - ``None`` computes and returns the metric per class max_fpr: If not ``None``, calculates standardized partial AUC over the range [0, max_fpr]. Should be a float between 0 and 1. sample_weight: sample weights for each data point Example (binary case): >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> auroc(preds, target, pos_label=1) tensor(0.5000) Example (multiclass case): >>> preds = torch.tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], ... [0.85, 0.05, 0.10], ... [0.10, 0.10, 0.80]]) >>> target = torch.tensor([0, 1, 1, 2, 2]) >>> auroc(preds, target, num_classes=3) tensor(0.7778) """ preds, target, mode = _auroc_update(preds, target) return _auroc_compute(preds, target, mode, num_classes, pos_label, average, max_fpr, sample_weights)

© Copyright Copyright (c) 2018-2021, William Falcon et al... Revision b3b8f95e.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
1.2.2
1.2.1
1.2.0
1.1.8
1.1.7
1.1.6
1.1.5
1.1.4
1.1.3
1.1.2
1.1.1
1.1.0
1.0.8
1.0.7
1.0.6
1.0.5
1.0.4
1.0.3
1.0.2
1.0.1
1.0.0
0.10.0
0.9.0
0.8.5
0.8.4
0.8.3
0.8.2
0.8.1
0.8.0
0.7.6
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.3.2
0.5.3
0.4.9
release-1.2-dev
release-1.0.x
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.