You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import numpy as np
  2. import os
  3. def mnist(dataset='mnist.pkl.gz', onehot=True):
  4. import six.moves.cPickle as pickle
  5. import gzip
  6. # Download the MNIST dataset if it is not present
  7. data_dir, data_file = os.path.split(dataset)
  8. if data_dir == "" and not os.path.isfile(dataset):
  9. # Check if dataset is in the data directory.
  10. new_path = os.path.join(
  11. os.path.split(__file__)[0],
  12. dataset
  13. )
  14. if os.path.isfile(new_path) or data_file == 'mnist.pkl.gz':
  15. dataset = new_path
  16. if (not os.path.isfile(dataset)) and data_file == 'mnist.pkl.gz':
  17. from six.moves import urllib
  18. origin = (
  19. 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'
  20. )
  21. print('Downloading data from %s' % origin)
  22. urllib.request.urlretrieve(origin, dataset)
  23. # Load the dataset
  24. with gzip.open(dataset, 'rb') as f:
  25. try:
  26. train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
  27. except:
  28. train_set, valid_set, test_set = pickle.load(f)
  29. # train_set, valid_set, test_set format: tuple(input, target)
  30. # input is a numpy.ndarray of 2 dimensions (a matrix), np.float32
  31. # where each row corresponds to an example. target is a
  32. # numpy.ndarray of 1 dimension (vector), np.int64 that has the same length
  33. # as the number of rows in the input. It should give the target
  34. # to the example with the same index in the input.
  35. if onehot:
  36. train_set = (train_set[0], convert_to_one_hot(train_set[1], 10))
  37. valid_set = (valid_set[0], convert_to_one_hot(valid_set[1], 10))
  38. test_set = (test_set[0], convert_to_one_hot(test_set[1], 10))
  39. return train_set, valid_set, test_set
  40. def cifar10(directory='CIFAR_10', onehot=True):
  41. import six.moves.cPickle as pickle
  42. file_lists = [os.path.join(directory, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in range(1, 6)] +\
  43. [os.path.join(directory, 'cifar-10-batches-py', 'test_batch')]
  44. if not all([os.path.exists(fl) for fl in file_lists]):
  45. from tqdm import tqdm
  46. from six.moves import urllib
  47. import tarfile
  48. filename = "cifar-10-python.tar.gz"
  49. if not os.path.exists(filename):
  50. def gen_bar_updater():
  51. pbar = tqdm(total=None)
  52. def bar_update(count, block_size, total_size):
  53. if pbar.total is None and total_size:
  54. pbar.total = total_size
  55. progress_bytes = count * block_size
  56. pbar.update(progress_bytes - pbar.n)
  57. return bar_update
  58. print('Downloading CIFAR 10 dataset...')
  59. url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
  60. urllib.request.urlretrieve(
  61. url, filename, reporthook=gen_bar_updater())
  62. with tarfile.open(filename, 'r:gz') as tar:
  63. tar.extractall(path=directory)
  64. images, labels = [], []
  65. for filename in file_lists[:5]:
  66. with open(filename, 'rb') as fo:
  67. cifar10 = pickle.load(fo, encoding='latin1')
  68. for i in range(len(cifar10["labels"])):
  69. image = cifar10["data"][i]
  70. image = image.astype(float)
  71. images.append(image)
  72. labels += cifar10["labels"]
  73. images = np.array(images, dtype='float')
  74. labels = np.array(labels, dtype='int')
  75. train_images, train_labels = images, labels
  76. images, labels = [], []
  77. for filename in file_lists[5:]:
  78. with open(filename, 'rb') as fo:
  79. cifar10 = pickle.load(fo, encoding='latin1')
  80. for i in range(len(cifar10["labels"])):
  81. image = cifar10["data"][i]
  82. image = image.astype(float)
  83. images.append(image)
  84. labels += cifar10["labels"]
  85. images = np.array(images, dtype='float')
  86. labels = np.array(labels, dtype='int')
  87. test_images, test_labels = images, labels
  88. if onehot:
  89. train_labels = convert_to_one_hot(train_labels, 10)
  90. test_labels = convert_to_one_hot(test_labels, 10)
  91. return train_images, train_labels, test_images, test_labels
  92. def cifar100(directory='CIFAR_100', onehot=True):
  93. import six.moves.cPickle as pickle
  94. file_lists = [os.path.join(directory, 'cifar-100-python', 'train'),
  95. os.path.join(directory, 'cifar-100-python', 'test')]
  96. if not all([os.path.exists(fl) for fl in file_lists]):
  97. from tqdm import tqdm
  98. from six.moves import urllib
  99. import tarfile
  100. filename = "cifar-100-python.tar.gz"
  101. if not os.path.exists(filename):
  102. def gen_bar_updater():
  103. pbar = tqdm(total=None)
  104. def bar_update(count, block_size, total_size):
  105. if pbar.total is None and total_size:
  106. pbar.total = total_size
  107. progress_bytes = count * block_size
  108. pbar.update(progress_bytes - pbar.n)
  109. return bar_update
  110. print('Downloading CIFAR 100 dataset...')
  111. url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
  112. urllib.request.urlretrieve(
  113. url, filename, reporthook=gen_bar_updater())
  114. with tarfile.open(filename, 'r:gz') as tar:
  115. tar.extractall(path=directory)
  116. with open(file_lists[0], 'rb') as input_file:
  117. train_file = pickle.load(input_file, encoding='latin1')
  118. train_images = train_file['data']
  119. train_labels = train_file['fine_labels']
  120. train_images = np.array(train_images, dtype='float').reshape(
  121. train_images.shape[0], 3, 32, 32)
  122. train_labels = np.array(train_labels, dtype='int')
  123. with open(file_lists[1], 'rb') as input_file:
  124. test_file = pickle.load(input_file, encoding='latin1')
  125. test_images = test_file['data']
  126. test_labels = test_file['fine_labels']
  127. test_images = np.array(test_images, dtype='float').reshape(
  128. test_images.shape[0], 3, 32, 32)
  129. test_labels = np.array(test_labels, dtype='int')
  130. if onehot:
  131. train_labels = convert_to_one_hot(train_labels, 100)
  132. test_labels = convert_to_one_hot(test_labels, 100)
  133. return train_images, train_labels, test_images, test_labels
  134. def normalize_cifar(num_class=10, onehot=True):
  135. if num_class == 10:
  136. x_train, y_train, x_test, y_test = cifar10(onehot=onehot)
  137. elif num_class == 100:
  138. x_train, y_train, x_test, y_test = cifar100(onehot=onehot)
  139. else:
  140. raise NotImplementedError
  141. x_train = x_train.reshape((-1, 3, 32, 32))
  142. x_test = x_test.reshape((-1, 3, 32, 32))
  143. x_train = x_train.astype('float32')
  144. x_test = x_test.astype('float32')
  145. x_train[:, 0, :, :] = (
  146. x_train[:, 0, :, :] - np.mean(x_train[:, 0, :, :])) / np.std(x_train[:, 0, :, :])
  147. x_train[:, 1, :, :] = (
  148. x_train[:, 1, :, :] - np.mean(x_train[:, 1, :, :])) / np.std(x_train[:, 1, :, :])
  149. x_train[:, 2, :, :] = (
  150. x_train[:, 2, :, :] - np.mean(x_train[:, 2, :, :])) / np.std(x_train[:, 2, :, :])
  151. x_test[:, 0, :, :] = (
  152. x_test[:, 0, :, :] - np.mean(x_test[:, 0, :, :])) / np.std(x_test[:, 0, :, :])
  153. x_test[:, 1, :, :] = (
  154. x_test[:, 1, :, :] - np.mean(x_test[:, 1, :, :])) / np.std(x_test[:, 1, :, :])
  155. x_test[:, 2, :, :] = (
  156. x_test[:, 2, :, :] - np.mean(x_test[:, 2, :, :])) / np.std(x_test[:, 2, :, :])
  157. return x_train, y_train, x_test, y_test
  158. def tf_normalize_cifar(num_class=10, onehot=True):
  159. if num_class == 10:
  160. x_train, y_train, x_test, y_test = cifar10(onehot=onehot)
  161. elif num_class == 100:
  162. x_train, y_train, x_test, y_test = cifar100(onehot=onehot)
  163. else:
  164. raise NotImplementedError
  165. x_train = x_train.reshape((-1, 3, 32, 32))
  166. x_test = x_test.reshape((-1, 3, 32, 32))
  167. x_train = x_train.transpose([0, 2, 3, 1]).astype('float32')
  168. x_test = x_test.transpose([0, 2, 3, 1]).astype('float32')
  169. x_train[:, :, :, 0] = (
  170. x_train[:, :, :, 0] - np.mean(x_train[:, :, :, 0])) / np.std(x_train[:, :, :, 0])
  171. x_train[:, :, :, 1] = (
  172. x_train[:, :, :, 1] - np.mean(x_train[:, :, :, 1])) / np.std(x_train[:, :, :, 1])
  173. x_train[:, :, :, 2] = (
  174. x_train[:, :, :, 2] - np.mean(x_train[:, :, :, 2])) / np.std(x_train[:, :, :, 2])
  175. x_test[:, :, :, 0] = (
  176. x_test[:, :, :, 0] - np.mean(x_test[:, :, :, 0])) / np.std(x_test[:, :, :, 0])
  177. x_test[:, :, :, 1] = (
  178. x_test[:, :, :, 1] - np.mean(x_test[:, :, :, 1])) / np.std(x_test[:, :, :, 1])
  179. x_test[:, :, :, 2] = (
  180. x_test[:, :, :, 2] - np.mean(x_test[:, :, :, 2])) / np.std(x_test[:, :, :, 2])
  181. return x_train, y_train, x_test, y_test
  182. def convert_to_one_hot(vals, max_val=0):
  183. """Helper method to convert label array to one-hot array."""
  184. if max_val == 0:
  185. max_val = vals.max() + 1
  186. one_hot_vals = np.zeros((vals.size, max_val))
  187. one_hot_vals[np.arange(vals.size), vals] = 1
  188. return one_hot_vals
  189. ########################
  190. # Not in use currently #
  191. ########################
  192. def data_augmentation(images, mode='train', flip=False,
  193. crop=False, crop_shape=(24, 24, 3), whiten=False,
  194. noise=False, noise_mean=0, noise_std=0.01):
  195. if crop:
  196. if mode == 'train':
  197. images = self._image_crop(images, shape=crop_shape)
  198. elif mode == 'test':
  199. images = self._image_crop_test(images, shape=crop_shape)
  200. if flip:
  201. images = self._image_flip(images)
  202. if whiten:
  203. images = self._image_whitening(images)
  204. if noise:
  205. images = self._image_noise(images, mean=noise_mean, std=noise_std)
  206. return images
  207. def _image_crop(images, shape):
  208. new_images = []
  209. for i in range(images.shape[0]):
  210. old_image = images[i, :, :, :]
  211. old_image = numpy.pad(old_image, [[4, 4], [4, 4], [0, 0]], 'constant')
  212. left = numpy.random.randint(old_image.shape[0] - shape[0] + 1)
  213. top = numpy.random.randint(old_image.shape[1] - shape[1] + 1)
  214. new_image = old_image[left: left+shape[0], top: top+shape[1], :]
  215. new_images.append(new_image)
  216. return numpy.array(new_images)
  217. def _image_crop_test(images, shape):
  218. new_images = []
  219. for i in range(images.shape[0]):
  220. old_image = images[i, :, :, :]
  221. old_image = numpy.pad(old_image, [[4, 4], [4, 4], [0, 0]], 'constant')
  222. left = int((old_image.shape[0] - shape[0]) / 2)
  223. top = int((old_image.shape[1] - shape[1]) / 2)
  224. new_image = old_image[left: left+shape[0], top: top+shape[1], :]
  225. new_images.append(new_image)
  226. return numpy.array(new_images)
  227. def _image_flip(images):
  228. for i in range(images.shape[0]):
  229. old_image = images[i, :, :, :]
  230. if numpy.random.random() < 0.5:
  231. new_image = cv2.flip(old_image, 1)
  232. else:
  233. new_image = old_image
  234. images[i, :, :, :] = new_image
  235. return images
  236. def _image_whitening(images):
  237. for i in range(images.shape[0]):
  238. old_image = images[i, :, :, :]
  239. new_image = (old_image - numpy.mean(old_image)) / numpy.std(old_image)
  240. images[i, :, :, :] = new_image
  241. return images
  242. def _image_noise(images, mean=0, std=0.01):
  243. for i in range(images.shape[0]):
  244. old_image = images[i, :, :, :]
  245. new_image = old_image
  246. for i in range(image.shape[0]):
  247. for j in range(image.shape[1]):
  248. for k in range(image.shape[2]):
  249. new_image[i, j, k] += random.gauss(mean, std)
  250. images[i, :, :, :] = new_image
  251. return images