|
- #! /usr/bin/python
- # -*- coding: utf-8 -*-
-
- import os
- import pickle
- import sys
-
- import numpy as np
-
- from tensorlayer import logging
- from tensorlayer.files.utils import maybe_download_and_extract
-
- __all__ = ['load_cifar10_dataset']
-
-
- def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='data', plotable=False):
- """Load CIFAR-10 dataset.
-
- It consists of 60000 32x32 colour images in 10 classes, with
- 6000 images per class. There are 50000 training images and 10000 test images.
-
- The dataset is divided into five training batches and one test batch, each with
- 10000 images. The test batch contains exactly 1000 randomly-selected images from
- each class. The training batches contain the remaining images in random order,
- but some training batches may contain more images from one class than another.
- Between them, the training batches contain exactly 5000 images from each class.
-
- Parameters
- ----------
- shape : tupe
- The shape of digit images e.g. (-1, 3, 32, 32) and (-1, 32, 32, 3).
- path : str
- The path that the data is downloaded to, defaults is ``data/cifar10/``.
- plotable : boolean
- Whether to plot some image examples, False as default.
-
- Examples
- --------
- >>> X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3))
-
- References
- ----------
- - `CIFAR website <https://www.cs.toronto.edu/~kriz/cifar.html>`__
- - `Data download link <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`__
- - `<https://teratail.com/questions/28932>`__
-
- """
- path = os.path.join(path, 'cifar10')
- logging.info("Load or Download cifar10 > {}".format(path))
-
- #Helper function to unpickle the data
- def unpickle(file):
- fp = open(file, 'rb')
- if sys.version_info.major == 2:
- data = pickle.load(fp)
- elif sys.version_info.major == 3:
- data = pickle.load(fp, encoding='latin-1')
- else:
- raise RuntimeError("Sys Version Unsupported")
- fp.close()
- return data
-
- filename = 'cifar-10-python.tar.gz'
- url = 'https://www.cs.toronto.edu/~kriz/'
- #Download and uncompress file
- maybe_download_and_extract(filename, path, url, extract=True)
-
- #Unpickle file and fill in data
- X_train = None
- y_train = []
- for i in range(1, 6):
- data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "data_batch_{}".format(i)))
- if i == 1:
- X_train = data_dic['data']
- else:
- X_train = np.vstack((X_train, data_dic['data']))
- y_train += data_dic['labels']
-
- test_data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "test_batch"))
- X_test = test_data_dic['data']
- y_test = np.array(test_data_dic['labels'])
-
- if shape == (-1, 3, 32, 32):
- X_test = X_test.reshape(shape)
- X_train = X_train.reshape(shape)
- elif shape == (-1, 32, 32, 3):
- X_test = X_test.reshape(shape, order='F')
- X_train = X_train.reshape(shape, order='F')
- X_test = np.transpose(X_test, (0, 2, 1, 3))
- X_train = np.transpose(X_train, (0, 2, 1, 3))
- else:
- X_test = X_test.reshape(shape)
- X_train = X_train.reshape(shape)
-
- y_train = np.array(y_train)
-
- if plotable:
- logging.info('\nCIFAR-10')
- import matplotlib.pyplot as plt
- fig = plt.figure(1)
-
- logging.info('Shape of a training image: X_train[0] %s' % X_train[0].shape)
-
- plt.ion() # interactive mode
- count = 1
- for _ in range(10): # each row
- for _ in range(10): # each column
- _ = fig.add_subplot(10, 10, count)
- if shape == (-1, 3, 32, 32):
- # plt.imshow(X_train[count-1], interpolation='nearest')
- plt.imshow(np.transpose(X_train[count - 1], (1, 2, 0)), interpolation='nearest')
- # plt.imshow(np.transpose(X_train[count-1], (2, 1, 0)), interpolation='nearest')
- elif shape == (-1, 32, 32, 3):
- plt.imshow(X_train[count - 1], interpolation='nearest')
- # plt.imshow(np.transpose(X_train[count-1], (1, 0, 2)), interpolation='nearest')
- else:
- raise Exception("Do not support the given 'shape' to plot the image examples")
- plt.gca().xaxis.set_major_locator(plt.NullLocator())
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
- count = count + 1
- plt.draw() # interactive mode
- plt.pause(3) # interactive mode
-
- logging.info("X_train: %s" % X_train.shape)
- logging.info("y_train: %s" % y_train.shape)
- logging.info("X_test: %s" % X_test.shape)
- logging.info("y_test: %s" % y_test.shape)
-
- X_train = np.asarray(X_train, dtype=np.float32)
- X_test = np.asarray(X_test, dtype=np.float32)
- y_train = np.asarray(y_train, dtype=np.int32)
- y_test = np.asarray(y_test, dtype=np.int32)
-
- return X_train, y_train, X_test, y_test
|