-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
91 lines (76 loc) · 3.47 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from __future__ import absolute_import
from __future__ import print_function
from future.standard_library import install_aliases
install_aliases()
import numpy as np
import os
import gzip
import struct
import array
import matplotlib.pyplot as plt
import matplotlib.image
from urllib.request import urlretrieve
def download(url, filename):
if not os.path.exists('data'):
os.makedirs('data')
out_file = os.path.join('data', filename)
if not os.path.isfile(out_file):
urlretrieve(url, out_file)
def mnist():
base_url = 'http://yann.lecun.com/exdb/mnist/'
def parse_labels(filename):
with gzip.open(filename, 'rb') as fh:
magic, num_data = struct.unpack(">II", fh.read(8))
return np.array(array.array("B", fh.read()), dtype=np.uint8)
def parse_images(filename):
with gzip.open(filename, 'rb') as fh:
magic, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(num_data, rows, cols)
for filename in ['train-images-idx3-ubyte.gz',
'train-labels-idx1-ubyte.gz',
't10k-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz']:
download(base_url + filename, filename)
train_images = parse_images('data/train-images-idx3-ubyte.gz')
train_labels = parse_labels('data/train-labels-idx1-ubyte.gz')
test_images = parse_images('data/t10k-images-idx3-ubyte.gz')
test_labels = parse_labels('data/t10k-labels-idx1-ubyte.gz')
return train_images, train_labels, test_images, test_labels
def load_mnist():
partial_flatten = lambda x : np.reshape(x, (x.shape[0], np.prod(x.shape[1:])))
one_hot = lambda x, k: np.array(x[:,None] == np.arange(k)[None, :], dtype=int)
train_images, train_labels, test_images, test_labels = mnist()
train_images = partial_flatten(train_images) / 255.0
test_images = partial_flatten(test_images) / 255.0
train_labels = one_hot(train_labels, 10)
test_labels = one_hot(test_labels, 10)
N_data = train_images.shape[0]
return N_data, train_images, train_labels, test_images, test_labels
def plot_images(images, ax, ims_per_row=5, padding=5, digit_dimensions=(28, 28),
cmap=matplotlib.cm.binary, vmin=None, vmax=None):
"""Images should be a (N_images x pixels) matrix."""
N_images = images.shape[0]
N_rows = np.int32(np.ceil(float(N_images) / ims_per_row))
pad_value = np.min(images.ravel())
concat_images = np.full(((digit_dimensions[0] + padding) * N_rows + padding,
(digit_dimensions[1] + padding) * ims_per_row + padding), pad_value)
for i in range(N_images):
cur_image = np.reshape(images[i, :], digit_dimensions)
row_ix = i // ims_per_row
col_ix = i % ims_per_row
row_start = padding + (padding + digit_dimensions[0]) * row_ix
col_start = padding + (padding + digit_dimensions[1]) * col_ix
concat_images[row_start: row_start + digit_dimensions[0],
col_start: col_start + digit_dimensions[1]] = cur_image
cax = ax.matshow(concat_images, cmap=cmap, vmin=vmin, vmax=vmax)
plt.xticks(np.array([]))
plt.yticks(np.array([]))
return cax
def save_images(images, filename, **kwargs):
fig = plt.figure(1)
fig.clf()
ax = fig.add_subplot(111)
plot_images(images, ax, **kwargs)
fig.patch.set_visible(False)
ax.patch.set_visible(False)
plt.savefig(filename)