forked from timctho/VNect-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
caffe_weights_to_pickle.py
53 lines (42 loc) · 1.88 KB
/
caffe_weights_to_pickle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import caffe
import numpy as np
import pickle
import argparse
from collections import OrderedDict
parser = argparse.ArgumentParser()
parser.add_argument('--prototxt',
default='models/vnect_net.prototxt')
parser.add_argument('--caffemodel',
default='models/vnect_model.caffemodel')
parser.add_argument('--output_file',
default='vnect.pkl')
args = parser.parse_args()
if __name__ == '__main__':
pkl_weights = OrderedDict()
net = caffe.Net(args.prototxt,
caffe.TEST,
weights=args.caffemodel)
for layer in net.params.keys():
print(layer)
print('======')
cur_bn_name = ''
for layer in net.params.keys():
print(layer, len(net.params[layer]))
for i in range(len(net.params[layer])):
print(net.params[layer][i].data.shape)
if layer.startswith('bn'):
cur_bn_name = layer
pkl_weights[layer+'/moving_mean'] = np.asarray(net.params[layer][0].data) / net.params[layer][2].data
pkl_weights[layer+'/moving_variance'] = np.asarray(net.params[layer][1].data) / net.params[layer][2].data
elif layer.startswith('scale'):
pkl_weights[cur_bn_name+'/gamma'] = np.asarray(net.params[layer][0].data)
pkl_weights[cur_bn_name+'/beta'] = np.asarray(net.params[layer][1].data)
elif len(net.params[layer]) == 2:
pkl_weights[layer+'/weights'] = np.asarray(net.params[layer][0].data).transpose((2,3,1,0))
pkl_weights[layer+'/biases'] = np.asarray(net.params[layer][1].data)
elif len(net.params[layer]) == 1:
pkl_weights[layer+'/kernel'] = np.asarray(net.params[layer][0].data).transpose((2,3,1,0))
for layer in pkl_weights.keys():
print(layer, pkl_weights[layer].shape)
with open(args.output_file, 'wb') as f:
pickle.dump(pkl_weights, f)