-
Notifications
You must be signed in to change notification settings - Fork 4
/
preprocess.py
125 lines (99 loc) · 3.33 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import cv2
import params
import numpy as np
import pandas as pd
img_height = params.img_height
img_width = params.img_width
img_channels = params.img_channels
data_dir = params.data_dir
out_dir = params.out_dir
model_dir = params.model_dir
def preprocess(img, color_mode='RGB'):
'''resize and crop the image
:img: The image to be processed
:return: Returns the processed image'''
# Chop off 1/2 from the top and cut bottom 150px(which contains the head
# of car)
ratio = img_height / img_width
h1, h2 = int(img.shape[0] / 2), img.shape[0] - 150
w = (h2 - h1) / ratio
padding = int(round((img.shape[1] - w) / 2))
img = img[h1:h2, padding:-padding]
# Resize the image
img = cv2.resize(img, (img_width, img_height),
interpolation=cv2.INTER_AREA)
if color_mode == 'YUV':
img = cv2.cvtColor(img, cv2.COLOR_BGR2YUV)
# Image Normalization
#img = img / 255.
return img
def frame_count_func(file_path):
'''return frame count of this video'''
cap = cv2.VideoCapture(file_path)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
return frame_count
def load_data(mode, color_mode='RGB', flip=True):
'''get train and valid data,
mode: train or valid, color_mode:RGB or YUV
output: batch data.'''
if mode == 'train':
epochs = [1, 2, 3, 4, 5, 6, 7, 8, 9]
# elif mode == 'valid':
# epochs = [1, 2, 7, 9]
elif mode == 'test':
epochs = [10]
else:
print('Wrong mode input')
imgs = []
wheels = []
# extract image and steering data
for epoch_id in epochs:
yy = []
vid_path = os.path.join(
data_dir, 'epoch{:0>2}_front.mkv'.format(epoch_id))
frame_count = frame_count_func(vid_path)
cap = cv2.VideoCapture(vid_path)
csv_path = os.path.join(
data_dir, 'epoch{:0>2}_steering.csv'.format(epoch_id))
rows = pd.read_csv(csv_path)
yy = rows['wheel'].values
wheels.extend(yy)
while True:
ret, img = cap.read()
if not ret:
break
img = preprocess(img, color_mode)
imgs.append(img)
assert len(imgs) == len(wheels)
cap.release()
if mode == 'train' and flip:
augmented_imgs = []
augmented_measurements = []
for image, measurement in zip(imgs, wheels):
augmented_imgs.append(image)
augmented_measurements.append(measurement)
# Flip images
flipped_image = cv2.flip(image, 1)
flipped_measurement = float(measurement) * -1.0
augmented_imgs.append(flipped_image)
augmented_measurements.append(flipped_measurement)
X_train = np.array(augmented_imgs)
y_train = np.array(augmented_measurements)
y_train = np.reshape(y_train,(len(y_train),1))
else:
X_train = np.array(imgs)
y_train = np.array(wheels)
y_train = np.reshape(y_train,(len(y_train),1))
return X_train, y_train
def load_batch(imgs, wheels):
assert len(imgs) == len(wheels)
n = len(imgs)
assert n > 0
ii = random.sample(range(0, n), params.batch_size)
assert len(ii) == params.batch_size
xx, yy = [], []
for i in ii:
xx.append(imgs[i])
yy.append(wheels[i])
return xx, yy