Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] torchmetrics.functional.classification.binary_auroc gives wrong results when logits are large #2819

Open
zbingsong opened this issue Nov 2, 2024 · 1 comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.4.x

Comments

@zbingsong
Copy link

🐛 Bug

torchmetrics.functional.classification.binary_auroc always gives 0.5 when all logits are large. This seems to be caused by a floating point precision error with sigmoid.

To Reproduce

Code sample
import torch
import torchmetrics.functional.classification

preds = torch.tensor([98.0950, 98.4612, 98.1145, 98.1506, 97.6037, 98.9425, 99.2644, 99.5014, 99.7280, 99.6595, 99.6931, 99.4667, 99.9623, 99.8949, 99.8768])
labels = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

torchmetrics.functional.classification.binary_auroc(preds, labels)

Output:

tensor(0.5000)

Expected behavior

AUROC of the above example should be 0.9286, as computed by sklearn.

import sklearn.metrics

sklearn.metrics.roc_auc_score(labels, preds)

Output:

0.9285714285714286

Environment

  • Windows 11 24H2
  • Python version 3.10.11
  • TorchMetrics version 1.4.3
  • PyTorch version 2.4.1+cu124

Additional context

This appears to be a problem of floating point precision with sigmoid at line 185 in function _binary_precision_recall_curve_format in file torchmetrics/src/torchmetrics/functional/classification/precision_recall_curve.py.

I extracted all the necessary functions and made a miniature binary_auroc function that uses exactly the same algorithm (works for the above example, did not test for other examples):

def binary_auroc(
    preds: torch.Tensor,
    target: torch.Tensor,
) -> torch.Tensor:
    preds = preds.sigmoid()
    print(preds)

    desc_score_indices = torch.argsort(preds, descending=True)
    preds = preds[desc_score_indices]
    target = target[desc_score_indices]
    # print(preds, target)

    # pred typically has many tied values. Here we extract
    # the indices associated with the distinct values. We also
    # concatenate a value for the end of the curve.
    distinct_value_indices = torch.nonzero(preds[1:] - preds[:-1], as_tuple=True)[0]
    # print(distinct_value_indices)
    threshold_idxs = torch.nn.functional.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1)
    # print(threshold_idxs)
    tps = torch.cumsum(target, dim=0)[threshold_idxs]
    fps = 1 + threshold_idxs - tps
    # print(tps, fps)

    # Add an extra threshold position to make sure that the curve starts at (0, 0)
    tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
    fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
    tpr = tps / tps[-1]
    fpr = fps / fps[-1]
    # print(fpr, tpr)
    return torch.trapezoid(tpr, fpr, dim=-1)

Output:

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

tensor(0.5000)

preds = preds.sigmoid() converts all logits to 1 as if all the logits are the same, which is not the case. The maximum magnitude of a logit must be less than 36.74 for double or 16.64 for float32 to avoid being converted to exactly 1.

Suggested fix

It's probably a good idea to scale the raw logits before sigmoid, something like below:

preds /= torch.max(torch.abs(preds)) # scales max element to 1
preds = preds.sigmoid()

All functions that applies sigmoid to raw ogits will need such a fix.

@zbingsong zbingsong added bug / fix Something isn't working help wanted Extra attention is needed labels Nov 2, 2024
Copy link

github-actions bot commented Nov 2, 2024

Hi! thanks for your contribution!, great first issue!

@Borda Borda added the v1.4.x label Nov 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.4.x
Projects
None yet
Development

No branches or pull requests

2 participants