BinaryPrecisionRecallCurve
computes wrong value if used with logits, even though the docstring says this is supported
#2329
Labels
🐛 Bug
The object oriented
BinaryPrecisionRecallCurve
can compute substantially incorrect values if logits are passed as predictions instead of probabilities, even though the docstring says this is ok. The underlying functional versionbinary_precision_recall_curve
seems to work correctly in both cases. Both versions attempt to convert the logits to probabilities by passing them through a sigmoid if any values are outside of the range [0, 1]. In the object oriented case this condition is incorrectly checked independently for each batch, rather than for the metric as a whole. Consequently some batches may have sigmoid applied to their scores, while others do not, resulting in an incorrect curve for the dataset as a whole.To Reproduce
Expected behavior
I would expect both AUCs to equal 0.5, as computed with scikit-learn using
sklearn.metrics.roc_auc_score(label, score)
.Environment
conda
,pip
, build from source): 1.3.0.post0Additional context
The bug is on line https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/classification/precision_recall_curve.py#L165 -- as it is currently written the function
_binary_precision_recall_curve_format
should only be called on the full dataset, not on individual batches. Otherwise the behavior is wrong if some batches have all scores in the range [0, 1] but other batches do not.Some possible solutions are: (1) update the docs not to allow for logits in the object oriented interface, since the behavior is correct for probabilities; (2) don't try to automatically infer whether to apply sigmoid -- my choice, but would be a breaking change (3) refactor
_binary_precision_recall_curve_format
and accept that if any values are found which require sigmoid, then all values from past batches need to have sigmoid applied (this would be tricky in the case where thresholds are specified because the scores are not kept around).The text was updated successfully, but these errors were encountered: