This repository has been archived by the owner on Oct 23, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
create_localdataset.py
176 lines (162 loc) · 6.88 KB
/
create_localdataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#filename : create_localdataset.py
#author : PRAJWAL T R
#date last modified : Mon Jul 13 14:25:12 2020
#comments :
'''
python script to create dataset for local model in paper "Teaching Robots to Draw"
dataset structure :
dataset : {
lG_data : [
[X_env, X_con, X_diff] ,
... ,
... ,
] ,
lG_croppedimg : [
[cropped_img],
... ,
... ,
]
lG_extract : [
[begin, size],
... ,
... ,
]
lG_touch : [
[touch],
... ,
... ,
]
}
//end of structure
keywords :
X_env : visited region
X_con: continoulsy connected stroke
X_diff: remaining strokes to draw
lG : state of local model
'''
from drawing_utils import *
from os import walk
import pickle as pic
test_dir_path = "./test_dir/local_pics/"
traverse_path = "./font_svgs/"
local_dataset_path = "./local_dataset/"
len_sample = 0 # counter
#dataset structure
dataset = {
'lG_data' : [],
'lG_extract' : [],
'lG_touch' : [], #output
'lG_croppedimg' : [] #output
}
def getCroppedImage(next_xy, current_xy):
# create image with black background
img = np.zeros((HEIGHT, WIDTH))
# mark point with next_xy cordinates
img[next_xy[1], next_xy[0]] = COLOR # open cv and numpy has different axis for x and y
# crop at current_xy
slice_begin = getSliceWindow(current_xy)[:-1]
img = img[slice_begin[0]: slice_begin[0] + 5,slice_begin[1]:slice_begin[1] + 5]
# padding to ensure 5*5 image
if img.shape != (5,5):
rem_x, rem_y = crop_img_size - img.shape[0], crop_img_size - img.shape[1]
img = cv.copyMakeBorder(img, 0, rem_x, 0, rem_y, cv.BORDER_CONSTANT, None, 0)
return img #return cropped image
def plotImages(*images):
# plot images for verification
ind = images[0]
images = images[1] # just for convinience
fig, axs = plt.subplots(1, len(images))
for image, index in zip(images, range(len(images))):
axs[index].imshow(image)
axs[index].set_title("image :" + index.__str__())
plt.savefig(test_dir_path + "local_dataset "+ ind.__str__() + ".png")
def getSliceWindow(current_xy):
'''
generate two variables begin and size for dynamice tensor slicing using tf.slice
'''
x, y = current_xy[0], current_xy[1]
begin = [y - 2, x - 2 , 0] # zero slice begin for batch size and channel dimension
return np.array(begin)
def pickleLocalDataset(dataset, ind, collect_minority):
# prefix m to indicate dataset with only minority classes touch = 0
if collect_minority:
out_path = local_dataset_path+"m_data_batch_"+str(ind)
else:
out_path = local_dataset_path+"data_batch_"+str(ind)
fd = open(out_path,"wb")
#convert list to numpy array ie : compatiable with tensorflow data adapter
dataset['lG_data'] = np.array(dataset['lG_data'])
dataset['lG_extract'] = np.array(dataset['lG_extract'])
dataset['lG_touch'] = np.array(dataset['lG_touch'])
dataset['lG_croppedimg'] = np.array(dataset['lG_croppedimg'])
pic.dump(dataset,fd)
print("dataset created at : ",out_path)
#clear contents of dataset structure
dataset['lG_data'] = []
dataset['lG_extract'] = []
dataset['lG_touch'] = []
dataset['lG_croppedimg'] = []
if __name__ == "__main__":
import sys
sample_rate = int(sys.argv[2]) if len(sys.argv) == 3 else 100
_, _, filelist = next(walk(traverse_path))
file_cap = len(filelist) # limit files to consider
breaks = [i for i in range(0, len(filelist[:file_cap]), sample_rate)]
collect_minority = True if sys.argv[1] == "minority" else False # collect samples with touch = 0
for break_ind in range(len(breaks) - 1):
for file in filelist[breaks[break_ind] : breaks[break_ind + 1]]:
svg_string = open(traverse_path+file).read()
X_target, m_indices = getStrokesIndices(svg_string)
#loop through all strokes
for index in range(len(m_indices)):
# handle single strokes
try:
#get current stroke
stroke = X_target[m_indices[index] : m_indices[index + 1]]
except: # out of index exception
stroke = X_target[m_indices[index] : ]
#all points for given stroke ML,MLL,MLLLL
points = getAllPoints(stroke)
env_l = []
diff_l = points
touch = 1
con_img = drawStroke(stroke)
if not collect_minority:
for ind in range(len(points) - 1):
current_xy = points[ind] # crop at this coordinate
next_xy = points[ind + 1] # mark at this coordinate
# inputs
ext_inp = getSliceWindow(current_xy)
env_img = drawFromPoints(env_l)
diff_img = drawFromPoints(diff_l)
# outputs
next_xy_img = getCroppedImage(next_xy, current_xy) # 5 * 5 image with one point drawn and cropped at current_xy
# plot images for verfication
# plotImages(ind,[con_img, env_img, diff_img, next_xy_img])
# update dataset
dataset['lG_data'].append(np.dstack((env_img, diff_img, con_img)))
dataset['lG_extract'].append(ext_inp)
dataset['lG_croppedimg'].append(np.reshape(next_xy_img, (crop_img_size * crop_img_size)))
dataset['lG_touch'].append(np.array([touch]))
# update env,diffg
env_l = points[0 : ind + 2] # add two points for one complete stroke
diff_l = points[ind + 1 :]
if collect_minority:
# update last instance
touch = 0
env_l = points
diff_l = []
current_xy = points[-1]
# inputs
# con_img
ext_inp = getSliceWindow(current_xy)
env_img = drawFromPoints(env_l)
diff_img = drawFromPoints(diff_l)
# outputs
next_xy_img = np.zeros((crop_img_size, crop_img_size)) # 5 * 5 empty image
dataset['lG_data'].append(np.dstack((env_img, diff_img, con_img)))
dataset['lG_extract'].append(ext_inp)
dataset['lG_croppedimg'].append(np.reshape(next_xy_img, (crop_img_size * crop_img_size)))
dataset['lG_touch'].append(np.array([touch]))
#save dataset to disk
pickleLocalDataset(dataset, break_ind, collect_minority)