-
Notifications
You must be signed in to change notification settings - Fork 6
/
main.py
170 lines (152 loc) · 7.43 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import math
import glob
import tqdm
import random
from config import *
from chatgpt import *
# support six types of inputs
def func_get_response(batch, emos, modality, sleeptime):
if modality == 'image':
response = get_image_emotion_batch(batch, emos, sleeptime)
elif modality == 'evoke':
response = get_evoke_emotion_batch(batch, emos, sleeptime)
elif modality == 'micro':
response = get_micro_emotion_batch(batch, emos, sleeptime)
elif modality == 'video':
response = get_video_emotion_batch(batch, emos, sleeptime)
elif modality == 'text':
response = get_text_emotion_batch(batch, emos, sleeptime)
elif modality == 'multi':
response = get_multi_emotion_batch(batch, emos, sleeptime)
return response
# split one batch into multiple segments
def func_get_segment_batch(batch, savename, xishu=2):
segment_num = math.ceil(len(batch)/xishu)
store = []
for ii in range(xishu):
segbatch = batch[ii*segment_num:(ii+1)*segment_num]
segsave = savename[:-4] + f"_segment_{ii+1}.npz"
if not isinstance(segbatch, list):
segbatch = [segbatch]
if len(segbatch) > 0:
store.append((segbatch, segsave))
return store
# main process
def evaluate_performance_using_gpt4v(image_root, save_root, save_order, modality, bsize, xishus, batch_flag='flag1', sleeptime=0):
# params assert
if len(xishus) == 1: assert batch_flag in ['flag1', 'flag2']
if len(xishus) == 2: assert batch_flag in ['flag1', 'flag2', 'flag3']
if len(xishus) == 3: assert batch_flag in ['flag1', 'flag2', 'flag3', 'flag4']
multiple = 1
for item in xishus: multiple *= item
assert multiple == bsize, f'multiple of xishus should equal to bsize'
# create folders
if not os.path.exists(save_root):
os.makedirs(save_root)
# preprocess for 'multi'
if modality == 'multi':
image_root = os.path.split(image_root)[0] + '/video'
# shuffle image orders
if not os.path.exists(save_order):
image_paths = glob.glob(image_root + '/*')
indices = np.arange(len(image_paths))
random.shuffle(indices)
image_paths = np.array(image_paths)[indices]
np.savez_compressed(save_order, image_paths=image_paths)
else:
image_paths = np.load(save_order, allow_pickle=True)['image_paths'].tolist()
print (f'process sample numbers: {len(image_paths)}') # 981
# split int batch [20 samples per batch]
batches = []
splitnum = math.ceil(len(image_paths) / bsize)
for ii in range(splitnum):
batches.append(image_paths[ii*bsize:(ii+1)*bsize])
print (f'process batch number: {len(batches)}') # 50 batches
print (f'process sample number: {sum([len(batch) for batch in batches])}')
# generate predictions for each batch and store
for ii, batch in tqdm.tqdm(enumerate(batches)):
save_path = os.path.join(save_root, f'batch_{ii+1}.npz')
if os.path.exists(save_path): continue
## batch not exists -> how to deal with these false batches
if batch_flag == 'flag1': # process the whole batch again # 20
response = func_get_response(batch, emos, modality, sleeptime)
np.savez_compressed(save_path, gpt4v=response, names=batch)
elif batch_flag == 'flag2': # split and process # 10
stores = func_get_segment_batch(batch, save_path, xishu=xishus[0])
for (segbatch, segsave) in stores:
if os.path.exists(segsave): continue
response = func_get_response(segbatch, emos, modality, sleeptime)
np.savez_compressed(segsave, gpt4v=response, names=segbatch)
elif batch_flag == 'flag3': # split and process # 5
stores = func_get_segment_batch(batch, save_path, xishu=xishus[0])
for (segbatch, segsave) in stores:
if os.path.exists(segsave): continue
newstores = func_get_segment_batch(segbatch, segsave, xishu=xishus[1])
for (newbatch, newsave) in newstores:
if os.path.exists(newsave): continue
response = func_get_response(newbatch, emos, modality, sleeptime)
np.savez_compressed(newsave, gpt4v=response, names=newbatch)
elif batch_flag == 'flag4': # split and process # 5
stores = func_get_segment_batch(batch, save_path, xishu=xishus[0])
for (segbatch, segsave) in stores:
if os.path.exists(segsave): continue
newstores = func_get_segment_batch(segbatch, segsave, xishu=xishus[1])
for (newbatch, newsave) in newstores:
if os.path.exists(newsave): continue
new2stores = func_get_segment_batch(newbatch, newsave, xishu=xishus[2])
for (new2batch, new2save) in new2stores:
if os.path.exists(new2save): continue
response = func_get_response(new2batch, emos, modality, sleeptime)
np.savez_compressed(new2save, gpt4v=response, names=new2batch)
def func_analyze_gpt4v_outputs(gpt_path):
names = np.load(gpt_path, allow_pickle=True)['names'].tolist()
## analyze gpt-4v
store_results = []
gpt4v = np.load(gpt_path, allow_pickle=True)['gpt4v'].tolist()
gpt4v = gpt4v.replace("name", "==========")
gpt4v = gpt4v.replace("result", "==========")
gpt4v = gpt4v.split("==========")
for line in gpt4v:
if line.find('[') != -1:
res = line.split('[')[1]
res = res.split(']')[0]
store_results.append(res)
return names, store_results
def check_gpt4_performance(gpt4v_root):
error_number = 0
whole_names, whole_gpt4vs = [], []
for gpt_path in sorted(glob.glob(gpt4v_root + '/*')):
names, gpt4vs = func_analyze_gpt4v_outputs(gpt_path)
print (f'number of samples: {len(names)} number of results: {len(gpt4vs)}')
if len(names) == len(gpt4vs):
names = [os.path.basename(name) for name in names]
whole_names.extend(names)
whole_gpt4vs.extend(gpt4vs)
else:
print (f'error batch: {gpt_path}. Need re-test!!')
os.system(f'rm -rf {gpt_path}')
error_number += 1
print (f'error number: {error_number}')
return whole_names, whole_gpt4vs
if __name__ == '__main__':
# -------------- defined by users --------------- #
dataset = 'mer2023'
save_root = '/root/dataset/' + dataset
# ----------------------------------------------- #
# please pre-defined dataset-raleted params in config.py
emos = dataset2emos[dataset]
modalities = dataset2modality[dataset]
for modality in modalities:
bsize, xishus = modality2params[modality]
# flags: request multiple times
flags = ['flag1', 'flag1', 'flag1', 'flag2', 'flag2', 'flag3']
if len(xishus) == 3:
flags.append('flag4')
# process for each modality
image_root = os.path.join(save_root, modality)
gpt4v_root = os.path.join(save_root, f'{modality}-gpt4v') # store results
save_order = os.path.join(save_root, f'{modality}-order.npz') # ensure each request is in the same order
for flag in flags:
evaluate_performance_using_gpt4v(image_root, gpt4v_root, save_order, modality, bsize, xishus, batch_flag=flag, sleeptime=20)
check_gpt4_performance(gpt4v_root) # remove falsly predictions