Skip to content
This repository has been archived by the owner on Nov 3, 2020. It is now read-only.

Deprecated MNIST data loader #4

Open
zhanghuimeng opened this issue Oct 9, 2018 · 2 comments
Open

Deprecated MNIST data loader #4

zhanghuimeng opened this issue Oct 9, 2018 · 2 comments

Comments

@zhanghuimeng
Copy link

In the MLP example, this line

mnist = tf.contrib.learn.datasets.load_dataset("mnist")

prints tons of deprecation error:

WARNING:tensorflow:From /home/zhanghuimeng/Documents/learnTensorFlow/simple_introduction/multilayer_perceptron.py:9: load_dataset (from tensorflow.contrib.learn.python.learn.datasets) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data.
...

And it actually can't download anything. (The reason might be...) In the end, you might have to download MNIST by hand. (see this)

A better (not deprecated) alternative is:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

But it cannot download anything either. Finally I had to download from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz and load it by NumPy.

There might be a better alternative, but I still suggest using something not deprecated.

@Ubpa
Copy link

Ubpa commented Jan 17, 2019

You can put mnist.npz to user-root/.keras/datasets/.
tf.keras.datasets.mnist.load_data() will try to load it.

@Ubpa
Copy link

Ubpa commented Jan 17, 2019

class DataLoader():
    def __init__(self):
        (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
        train_images = train_images.reshape(-1, 28 * 28) / 255.0
        test_images = test_images.reshape(-1, 28 * 28) / 255.0
        self.train_data = train_images                     # np.array [55000, 784]
        self.train_labels = train_labels.astype(int)       # np.array [55000] of int32
        self.eval_data = test_images                       # np.array [10000, 784]
        self.eval_labels = test_labels.astype(int)         # np.array [10000] of int32

    def get_batch(self, batch_size):
        index = np.random.randint(0, np.shape(self.train_data)[0], batch_size)
        return self.train_data[index, :], self.train_labels[index]

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants