-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepare_lrw_video.py
101 lines (79 loc) · 2.91 KB
/
prepare_lrw_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import cv2
from turbojpeg import TurboJPEG, TJPF_GRAY, TJSAMP_GRAY, TJFLAG_PROGRESSIVE
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import glob
import time
import os
from tqdm import tqdm
jpeg = TurboJPEG()
def extract_opencv(filename):
video = []
cap = cv2.VideoCapture(filename)
while(cap.isOpened()):
ret, frame = cap.read() # BGR
if ret:
frame = frame[115:211, 79:175]
frame = jpeg.encode(frame)
video.append(frame)
else:
break
cap.release()
return video
target_dir = '/data/zhangyk/data/lrw_roi_80_116_175_211_npy_gray_pkl_jpeg'
if(not os.path.exists(target_dir)):
os.makedirs(target_dir)
DATA_PATH = "/data/zhangyk/data/lipread_mp4"
class LRWDataset(Dataset):
def __init__(self):
with open('../label_sorted.txt') as myfile:
self.labels = myfile.read().splitlines()
self.list = []
for (i, label) in enumerate(tqdm(self.labels)):
files = glob.glob(os.path.join(DATA_PATH, label, '*', '*.mp4'))
for file in files:
savefile = file.replace(DATA_PATH, target_dir).replace('.mp4', '.pkl')
savepath = os.path.split(savefile)[0]
if(not os.path.exists(savepath)):
os.makedirs(savepath)
files = sorted(files)
self.list += [(file, i) for file in files]
def __getitem__(self, idx):
inputs = extract_opencv(self.list[idx][0])
result = {}
name = self.list[idx][0]
duration = self.list[idx][0]
labels = self.list[idx][1]
result['video'] = inputs
result['label'] = int(labels)
result['duration'] = self.load_duration(duration.replace('.mp4', '.txt')).astype(np.bool)
savename = self.list[idx][0].replace(DATA_PATH, target_dir).replace('.mp4', '.pkl')
torch.save(result, savename)
return result
def __len__(self):
return len(self.list)
def load_duration(self, file):
with open(file, 'r') as f:
lines = f.readlines()
for line in lines:
if(line.find('Duration') != -1):
duration = float(line.split(' ')[1])
tensor = np.zeros(29)
mid = 29 / 2
start = int(mid - duration / 2 * 25)
end = int(mid + duration / 2 * 25)
tensor[start:end] = 1.0
return tensor
if(__name__ == '__main__'):
loader = DataLoader(LRWDataset(),
batch_size = 96,
num_workers = 16,
shuffle = False,
drop_last = False)
import time
tic = time.time()
for i, batch in enumerate(loader):
toc = time.time()
eta = ((toc - tic) / (i + 1) * (len(loader) - i)) / 3600.0
print(f'eta:{eta:.5f}')