-
Notifications
You must be signed in to change notification settings - Fork 24
/
tf_data.py
41 lines (34 loc) · 1001 Bytes
/
tf_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import numpy as np
import cv2
import tensorflow as tf
def read_image(x):
x = x.decode()
image = cv2.imread(x, cv2.IMREAD_COLOR)
image = np.clip(image - np.median(image)+127, 0, 255)
image = image/255.0
image = image.astype(np.float32)
return image
def read_mask(y):
y = y.decode()
mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
mask = mask/255.0
mask = mask.astype(np.float32)
mask = np.expand_dims(mask, axis=-1)
return mask
def _parse(x, y):
x = read_image(x)
y = read_mask(y)
return x, y
def parse_data(x, y):
x, y = tf.numpy_function(_parse, [x, y], [tf.float32, tf.float32])
x.set_shape([256, 256, 3])
y.set_shape([256, 256, 1])
return x, y
def tf_dataset(x, y, batch=8):
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.shuffle(buffer_size=32)
dataset = dataset.map(map_func=parse_data)
dataset = dataset.batch(batch)
dataset = dataset.repeat()
return dataset