About torchmetrics.classification.MulticlassAccuracy: why torch.long is required when 'multidim_average' set to 'samplewise'? #1969
Unanswered
mifan002
asked this question in
Classification
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi all,
I'm using torchmetrics.classification.MulticlassAccuracy to calculate the pixel accuracy per class for semantic segmentation. When I set the 'multidim_average' to defalt('global'), then it's totally OK whichever dtype I chose between torch.uint8 or torch.long for the input data (preds and target) of "MulticlassAccuracy.forward()". However, when I set 'multidim_average' to 'samplewise', then the dtype of the input data HAS to be torch.long, otherwise I got the error:
Traceback (most recent call last):
File "C:\Users\fanmi\anaconda3\envs\env_thesis\lib\site-packages\torchmetrics\metric.py", line 405, in wrapped_func
raise err
File "C:\Users\fanmi\anaconda3\envs\env_thesis\lib\site-packages\torchmetrics\metric.py", line 395, in wrapped_func
update(*args, **kwargs)
File "C:\Users\fanmi\anaconda3\envs\env_thesis\lib\site-packages\torchmetrics\classification\stat_scores.py", line 317, in update
tp, fp, tn, fn = _multiclass_stat_scores_update(
File "C:\Users\fanmi\anaconda3\envs\env_thesis\lib\site-packages\torchmetrics\functional\classification\stat_scores.py", line 378, in _multiclass_stat_scores_update
preds_oh = torch.nn.functional.one_hot(
RuntimeError: one_hot is only applicable to index tensor.
python-BaseException
Why is this the case? And how should I understand the error message here? It seems irrelevant to what's actually happening here...
Version Info:
torchmetrics==0.10.1
torch==1.12.1
Code snippet for issue reproduction:
from torchmetrics.classification import MulticlassAccuracy
metric = MulticlassAccuracy(num_classes=4, average="none", multidim_average='samplewise', ignore_index=0)
prediction = torch.randint(low=0, high=4, size=(1,224,224)).to(torch.uint8)
label = torch.randint(low=0, high=4, size=(1,224,224)).to(torch.uint8)
score = metric(preds=prediction, target=label)
Beta Was this translation helpful? Give feedback.
All reactions