-
Notifications
You must be signed in to change notification settings - Fork 25
/
triplet_selector.py
41 lines (33 loc) · 1.27 KB
/
triplet_selector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import numpy as np
from utils import pdist_torch as pdist
class BatchHardTripletSelector(object):
'''
a selector to generate hard batch embeddings from the embedded batch
'''
def __init__(self, *args, **kwargs):
super(BatchHardTripletSelector, self).__init__()
def __call__(self, embeds, labels):
dist_mtx = pdist(embeds, embeds).detach().cpu().numpy()
labels = labels.contiguous().cpu().numpy().reshape((-1, 1))
num = labels.shape[0]
dia_inds = np.diag_indices(num)
lb_eqs = labels == labels.T
lb_eqs[dia_inds] = False
dist_same = dist_mtx.copy()
dist_same[lb_eqs == False] = -np.inf
pos_idxs = np.argmax(dist_same, axis = 1)
dist_diff = dist_mtx.copy()
lb_eqs[dia_inds] = True
dist_diff[lb_eqs == True] = np.inf
neg_idxs = np.argmin(dist_diff, axis = 1)
pos = embeds[pos_idxs].contiguous().view(num, -1)
neg = embeds[neg_idxs].contiguous().view(num, -1)
return embeds, pos, neg
if __name__ == '__main__':
embds = torch.randn(10, 128)
labels = torch.tensor([0,1,2,2,0,1,2,1,1,0])
selector = BatchHardTripletSelector()
anchor, pos, neg = selector(embds, labels)