-
Notifications
You must be signed in to change notification settings - Fork 10
/
losses.py
48 lines (39 loc) · 1.43 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
import torch.nn.functional as F
class ATLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, labels):
# TH label
th_label = torch.zeros_like(labels, dtype=torch.float).to(labels)
th_label[:, 0] = 1.0
labels[:, 0] = 0.0
p_mask = labels + th_label
n_mask = 1 - labels
# Rank positive classes to TH
logit1 = logits - (1 - p_mask) * 1e30
loss1 = -(F.log_softmax(logit1, dim=-1) * labels).sum(1)
# Rank TH to negative classes
logit2 = logits - (1 - n_mask) * 1e30
loss2 = -(F.log_softmax(logit2, dim=-1) * th_label).sum(1)
# Sum two parts
loss = loss1 + loss2
loss = loss.mean()
return loss
def get_label(self, logits, num_labels=-1):
th_logit = logits[:, 0].unsqueeze(1)
output = torch.zeros_like(logits).to(logits)
mask = (logits > th_logit)
if num_labels > 0:
top_v, _ = torch.topk(logits, num_labels, dim=1)
top_v = top_v[:, -1]
mask = (logits >= top_v.unsqueeze(1)) & mask
output[mask] = 1.0
output[:, 0] = (output.sum(1) == 0.).to(logits)
return output
def get_score(self, logits, num_labels=-1):
if num_labels > 0:
return torch.topk(logits, num_labels, dim=1)
else:
return logits[:,1] - logits[:,0], 0