-
Notifications
You must be signed in to change notification settings - Fork 1
/
validate-run-rerank.py
125 lines (106 loc) · 4.1 KB
/
validate-run-rerank.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python
import sys
import json
import argparse
def read_queries(fn):
queries = {}
with open(fn,"r") as fp:
for line in fp:
line = line.strip()
try:
data = json.loads(line)
qid = data["qid"]
documents = map(lambda x: x["doc_id"], data["documents"])
queries[qid]=set(documents)
except json.JSONDecodeError:
print("illegal json at line %d"%line_number)
return queries
def read_query_sequence(fn):
query_sequence={}
with open(fn,"r") as fp:
for line in fp:
qno,qid = line.strip().split(",")
query_sequence[qno]=int(qid)
return query_sequence
def main():
parser = argparse.ArgumentParser(description='validate trec fair ranking run.')
parser.add_argument('--queries',help='fair ranking query file')
parser.add_argument('--query_sequence_file',help='fair ranking query sequences file')
parser.add_argument('--run_file',help='fair ranking run file')
args = parser.parse_args()
queries = read_queries(args.queries)
query_sequence = read_query_sequence(args.query_sequence_file)
query_sequence_seen = set([])
line_number = 1
with open(args.run_file,"r") as fp:
for line in fp:
line = line.strip()
try:
data = json.loads(line)
#
# 1. check fields in json object
#
for field in ["q_num","qid","ranking"]:
if not(field in data):
print("missing %s in line %d"%(field,line_number))
sys.exit()
#
# 2. validate query number
#
q_num = data["q_num"]
if not(q_num in query_sequence):
print("%s not found in sequence file (line %d)"%(q_num,line_number))
sys.exit()
#
# 3. validate qid
#
qid = data["qid"]
if not(qid in queries):
print("%s not found in query file (line %d)"%(qid,line_number))
sys.exit()
#
# 4. validate qid matches query number
#
if (qid != query_sequence[q_num]):
print("%s is not the correct qid for sequence number %s (line %d)"%(qid,q_num,line_number))
print("\tshould be %s"%(query_sequence[q_num]))
sys.exit()
#
# 5. check for duplicate docids
#
ranking = data["ranking"]
ranking_set = set(ranking)
if (len(ranking) != len(ranking_set)):
print("duplicate document ids (line %d)"%(line_number))
sys.exit()
#
# 6. check for extra documents
#
if (len(ranking_set-queries[qid])>0):
print("extra document ids (line %d)"%(line_number))
sys.exit()
#
# 7. check for missing documents
#
if (len(queries[qid]-ranking_set)>0):
print("missing document ids (line %d)"%(line_number))
sys.exit()
#
# 8. check for duplicate query number
#
if (q_num in query_sequence_seen):
print("duplicate q_num (line %d)"%(line_number))
sys.exit()
query_sequence_seen.add(q_num)
except ValueError:
print("illegal json at line %d"%line_number)
sys.exit()
line_number = line_number + 1
#
# 9. check missing query numbers
#
if (len(query_sequence_seen)!=len(query_sequence.keys())):
print("missing query numbers")
sys.exit()
if __name__== "__main__":
main()