You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

backend.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. import collections
  2. import os
  3. import time
  4. import os
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import nn
  8. use_graphics = True
  9. def maybe_sleep_and_close(seconds):
  10. if use_graphics and plt.get_fignums():
  11. time.sleep(seconds)
  12. for fignum in plt.get_fignums():
  13. fig = plt.figure(fignum)
  14. plt.close(fig)
  15. try:
  16. # This raises a TclError on some Windows machines
  17. fig.canvas.start_event_loop(1e-3)
  18. except:
  19. pass
  20. def get_data_path(filename):
  21. path = os.path.join(
  22. os.path.dirname(__file__), os.pardir, "data", filename)
  23. if not os.path.exists(path):
  24. path = os.path.join(
  25. os.path.dirname(__file__), "data", filename)
  26. if not os.path.exists(path):
  27. path = os.path.join(
  28. os.path.dirname(__file__), filename)
  29. if not os.path.exists(path):
  30. raise Exception("Could not find data file: {}".format(filename))
  31. return path
  32. class Dataset(object):
  33. def __init__(self, x, y):
  34. assert isinstance(x, np.ndarray)
  35. assert isinstance(y, np.ndarray)
  36. assert np.issubdtype(x.dtype, np.floating)
  37. assert np.issubdtype(y.dtype, np.floating)
  38. assert x.ndim == 2
  39. assert y.ndim == 2
  40. assert x.shape[0] == y.shape[0]
  41. self.x = x
  42. self.y = y
  43. def iterate_once(self, batch_size):
  44. assert isinstance(batch_size, int) and batch_size > 0, (
  45. "Batch size should be a positive integer, got {!r}".format(
  46. batch_size))
  47. assert self.x.shape[0] % batch_size == 0, (
  48. "Dataset size {:d} is not divisible by batch size {:d}".format(
  49. self.x.shape[0], batch_size))
  50. index = 0
  51. while index < self.x.shape[0]:
  52. x = self.x[index:index + batch_size]
  53. y = self.y[index:index + batch_size]
  54. yield nn.Constant(x), nn.Constant(y)
  55. index += batch_size
  56. def iterate_forever(self, batch_size):
  57. while True:
  58. yield from self.iterate_once(batch_size)
  59. def get_validation_accuracy(self):
  60. raise NotImplementedError(
  61. "No validation data is available for this dataset. "
  62. "In this assignment, only the Digit Classification and Language "
  63. "Identification datasets have validation data.")
  64. class PerceptronDataset(Dataset):
  65. def __init__(self, model):
  66. points = 500
  67. x = np.hstack([np.random.randn(points, 2), np.ones((points, 1))])
  68. y = np.where(x[:, 0] + 2 * x[:, 1] - 1 >= 0, 1.0, -1.0)
  69. super().__init__(x, np.expand_dims(y, axis=1))
  70. self.model = model
  71. self.epoch = 0
  72. if use_graphics:
  73. fig, ax = plt.subplots(1, 1)
  74. limits = np.array([-3.0, 3.0])
  75. ax.set_xlim(limits)
  76. ax.set_ylim(limits)
  77. positive = ax.scatter(*x[y == 1, :-1].T, color="red", marker="+")
  78. negative = ax.scatter(*x[y == -1, :-1].T, color="blue", marker="_")
  79. line, = ax.plot([], [], color="black")
  80. text = ax.text(0.03, 0.97, "", transform=ax.transAxes, va="top")
  81. ax.legend([positive, negative], [1, -1])
  82. plt.show(block=False)
  83. self.fig = fig
  84. self.limits = limits
  85. self.line = line
  86. self.text = text
  87. self.last_update = time.time()
  88. def iterate_once(self, batch_size):
  89. self.epoch += 1
  90. for i, (x, y) in enumerate(super().iterate_once(batch_size)):
  91. yield x, y
  92. if use_graphics and time.time() - self.last_update > 0.01:
  93. w = self.model.get_weights().data.flatten()
  94. limits = self.limits
  95. if w[1] != 0:
  96. self.line.set_data(limits, (-w[0] * limits - w[2]) / w[1])
  97. elif w[0] != 0:
  98. self.line.set_data(np.full(2, -w[2] / w[0]), limits)
  99. else:
  100. self.line.set_data([], [])
  101. self.text.set_text(
  102. "epoch: {:,}\npoint: {:,}/{:,}\nweights: {}".format(
  103. self.epoch, i * batch_size + 1, len(self.x), w))
  104. self.fig.canvas.draw_idle()
  105. self.fig.canvas.start_event_loop(1e-3)
  106. self.last_update = time.time()
  107. class RegressionDataset(Dataset):
  108. def __init__(self, model):
  109. x = np.expand_dims(np.linspace(-2 * np.pi, 2 * np.pi, num=200), axis=1)
  110. np.random.RandomState(0).shuffle(x)
  111. self.argsort_x = np.argsort(x.flatten())
  112. y = np.sin(x)
  113. super().__init__(x, y)
  114. self.model = model
  115. self.processed = 0
  116. if use_graphics:
  117. fig, ax = plt.subplots(1, 1)
  118. ax.set_xlim(-2 * np.pi, 2 * np.pi)
  119. ax.set_ylim(-1.4, 1.4)
  120. real, = ax.plot(x[self.argsort_x], y[self.argsort_x], color="blue")
  121. learned, = ax.plot([], [], color="red")
  122. text = ax.text(0.03, 0.97, "", transform=ax.transAxes, va="top")
  123. ax.legend([real, learned], ["real", "learned"])
  124. plt.show(block=False)
  125. self.fig = fig
  126. self.learned = learned
  127. self.text = text
  128. self.last_update = time.time()
  129. def iterate_once(self, batch_size):
  130. for x, y in super().iterate_once(batch_size):
  131. yield x, y
  132. self.processed += batch_size
  133. if use_graphics and time.time() - self.last_update > 0.1:
  134. predicted = self.model.run(nn.Constant(self.x)).data
  135. loss = self.model.get_loss(
  136. nn.Constant(self.x), nn.Constant(self.y)).data
  137. self.learned.set_data(self.x[self.argsort_x], predicted[self.argsort_x])
  138. self.text.set_text("processed: {:,}\nloss: {:.6f}".format(
  139. self.processed, loss))
  140. self.fig.canvas.draw_idle()
  141. self.fig.canvas.start_event_loop(1e-3)
  142. self.last_update = time.time()
  143. class DigitClassificationDataset(Dataset):
  144. def __init__(self, model):
  145. mnist_path = get_data_path("mnist.npz")
  146. with np.load(mnist_path) as data:
  147. train_images = data["train_images"]
  148. train_labels = data["train_labels"]
  149. test_images = data["test_images"]
  150. test_labels = data["test_labels"]
  151. assert len(train_images) == len(train_labels) == 60000
  152. assert len(test_images) == len(test_labels) == 10000
  153. self.dev_images = test_images[0::2]
  154. self.dev_labels = test_labels[0::2]
  155. self.test_images = test_images[1::2]
  156. self.test_labels = test_labels[1::2]
  157. train_labels_one_hot = np.zeros((len(train_images), 10))
  158. train_labels_one_hot[range(len(train_images)), train_labels] = 1
  159. super().__init__(train_images, train_labels_one_hot)
  160. self.model = model
  161. self.epoch = 0
  162. if use_graphics:
  163. width = 20 # Width of each row expressed as a multiple of image width
  164. samples = 100 # Number of images to display per label
  165. fig = plt.figure()
  166. ax = {}
  167. images = collections.defaultdict(list)
  168. texts = collections.defaultdict(list)
  169. for i in reversed(range(10)):
  170. ax[i] = plt.subplot2grid((30, 1), (3 * i, 0), 2, 1,
  171. sharex=ax.get(9))
  172. plt.setp(ax[i].get_xticklabels(), visible=i == 9)
  173. ax[i].set_yticks([])
  174. ax[i].text(-0.03, 0.5, i, transform=ax[i].transAxes,
  175. va="center")
  176. ax[i].set_xlim(0, 28 * width)
  177. ax[i].set_ylim(0, 28)
  178. for j in range(samples):
  179. images[i].append(ax[i].imshow(
  180. np.zeros((28, 28)), vmin=0, vmax=1, cmap="Greens",
  181. alpha=0.3))
  182. texts[i].append(ax[i].text(
  183. 0, 0, "", ha="center", va="top", fontsize="smaller"))
  184. ax[9].set_xticks(np.linspace(0, 28 * width, 11))
  185. ax[9].set_xticklabels(
  186. ["{:.1f}".format(num) for num in np.linspace(0, 1, 11)])
  187. ax[9].tick_params(axis="x", pad=16)
  188. ax[9].set_xlabel("Probability of Correct Label")
  189. status = ax[0].text(
  190. 0.5, 1.5, "", transform=ax[0].transAxes, ha="center",
  191. va="bottom")
  192. plt.show(block=False)
  193. self.width = width
  194. self.samples = samples
  195. self.fig = fig
  196. self.images = images
  197. self.texts = texts
  198. self.status = status
  199. self.last_update = time.time()
  200. def iterate_once(self, batch_size):
  201. self.epoch += 1
  202. for i, (x, y) in enumerate(super().iterate_once(batch_size)):
  203. yield x, y
  204. if use_graphics and time.time() - self.last_update > 1:
  205. dev_logits = self.model.run(nn.Constant(self.dev_images)).data
  206. dev_predicted = np.argmax(dev_logits, axis=1)
  207. dev_probs = np.exp(nn.SoftmaxLoss.log_softmax(dev_logits))
  208. dev_accuracy = np.mean(dev_predicted == self.dev_labels)
  209. self.status.set_text(
  210. "epoch: {:d}, batch: {:d}/{:d}, validation accuracy: "
  211. "{:.2%}".format(
  212. self.epoch, i, len(self.x) // batch_size, dev_accuracy))
  213. for i in range(10):
  214. predicted = dev_predicted[self.dev_labels == i]
  215. probs = dev_probs[self.dev_labels == i][:, i]
  216. linspace = np.linspace(
  217. 0, len(probs) - 1, self.samples).astype(int)
  218. indices = probs.argsort()[linspace]
  219. for j, (prob, image) in enumerate(zip(
  220. probs[indices],
  221. self.dev_images[self.dev_labels == i][indices])):
  222. self.images[i][j].set_data(image.reshape((28, 28)))
  223. left = prob * (self.width - 1) * 28
  224. if predicted[indices[j]] == i:
  225. self.images[i][j].set_cmap("Greens")
  226. self.texts[i][j].set_text("")
  227. else:
  228. self.images[i][j].set_cmap("Reds")
  229. self.texts[i][j].set_text(predicted[indices[j]])
  230. self.texts[i][j].set_x(left + 14)
  231. self.images[i][j].set_extent([left, left + 28, 0, 28])
  232. self.fig.canvas.draw_idle()
  233. self.fig.canvas.start_event_loop(1e-3)
  234. self.last_update = time.time()
  235. def get_validation_accuracy(self):
  236. # print(self.dev_images[:2].tolist())
  237. dev_logits = self.model.run(nn.Constant(self.dev_images)).data
  238. # print(f"dev logits: {dev_logits.flatten()[10:20]}")
  239. dev_predicted = np.argmax(dev_logits, axis=1)
  240. dev_accuracy = np.mean(dev_predicted == self.dev_labels)
  241. return dev_accuracy
  242. class LanguageIDDataset(Dataset):
  243. def __init__(self, model):
  244. self.model = model
  245. data_path = get_data_path("lang_id.npz")
  246. with np.load(data_path) as data:
  247. self.chars = data['chars']
  248. self.language_codes = data['language_codes']
  249. self.language_names = data['language_names']
  250. self.train_x = data['train_x']
  251. self.train_y = data['train_y']
  252. self.train_buckets = data['train_buckets']
  253. self.dev_x = data['dev_x']
  254. self.dev_y = data['dev_y']
  255. self.dev_buckets = data['dev_buckets']
  256. self.test_x = data['test_x']
  257. self.test_y = data['test_y']
  258. self.test_buckets = data['test_buckets']
  259. self.epoch = 0
  260. self.bucket_weights = self.train_buckets[:,1] - self.train_buckets[:,0]
  261. self.bucket_weights = self.bucket_weights / float(self.bucket_weights.sum())
  262. self.chars_print = self.chars
  263. try:
  264. print(u"Alphabet: {}".format(u"".join(self.chars)))
  265. except UnicodeEncodeError:
  266. self.chars_print = "abcdefghijklmnopqrstuvwxyzaaeeeeiinoouuacelnszz"
  267. print("Alphabet: " + self.chars_print)
  268. self.chars_print = list(self.chars_print)
  269. print("""
  270. NOTE: Your terminal does not appear to support printing Unicode characters.
  271. For the purposes of printing to the terminal, some of the letters in the
  272. alphabet above have been substituted with ASCII symbols.""".strip())
  273. print("")
  274. # Select some examples to spotlight in the monitoring phase (3 per language)
  275. spotlight_idxs = []
  276. for i in range(len(self.language_names)):
  277. idxs_lang_i = np.nonzero(self.dev_y == i)[0]
  278. idxs_lang_i = np.random.choice(idxs_lang_i, size=3, replace=False)
  279. spotlight_idxs.extend(list(idxs_lang_i))
  280. self.spotlight_idxs = np.array(spotlight_idxs, dtype=int)
  281. # Templates for printing updates as training progresses
  282. max_word_len = self.dev_x.shape[1]
  283. max_lang_len = max([len(x) for x in self.language_names])
  284. self.predicted_template = u"Pred: {:<NUM}".replace('NUM',
  285. str(max_lang_len))
  286. self.word_template = u" "
  287. self.word_template += u"{:<NUM} ".replace('NUM', str(max_word_len))
  288. self.word_template += u"{:<NUM} ({:6.1%})".replace('NUM', str(max_lang_len))
  289. self.word_template += u" {:<NUM} ".replace('NUM',
  290. str(max_lang_len + len('Pred: ')))
  291. for i in range(len(self.language_names)):
  292. self.word_template += u"|{}".format(self.language_codes[i])
  293. self.word_template += "{probs[" + str(i) + "]:4.0%}"
  294. self.last_update = time.time()
  295. def _encode(self, inp_x, inp_y):
  296. xs = []
  297. for i in range(inp_x.shape[1]):
  298. if np.all(inp_x[:,i] == -1):
  299. break
  300. assert not np.any(inp_x[:,i] == -1), (
  301. "Please report this error in the project: batching by length was done incorrectly in the provided code")
  302. x = np.eye(len(self.chars))[inp_x[:,i]]
  303. xs.append(nn.Constant(x))
  304. y = np.eye(len(self.language_names))[inp_y]
  305. y = nn.Constant(y)
  306. return xs, y
  307. def _softmax(self, x):
  308. exp = np.exp(x - np.max(x, axis=-1, keepdims=True))
  309. return exp / np.sum(exp, axis=-1, keepdims=True)
  310. def _predict(self, split='dev'):
  311. if split == 'dev':
  312. data_x = self.dev_x
  313. data_y = self.dev_y
  314. buckets = self.dev_buckets
  315. else:
  316. data_x = self.test_x
  317. data_y = self.test_y
  318. buckets = self.test_buckets
  319. all_predicted = []
  320. all_correct = []
  321. for bucket_id in range(buckets.shape[0]):
  322. start, end = buckets[bucket_id]
  323. xs, y = self._encode(data_x[start:end], data_y[start:end])
  324. predicted = self.model.run(xs)
  325. all_predicted.extend(list(predicted.data))
  326. all_correct.extend(list(data_y[start:end]))
  327. all_predicted_probs = self._softmax(np.asarray(all_predicted))
  328. all_predicted = np.asarray(all_predicted).argmax(axis=-1)
  329. all_correct = np.asarray(all_correct)
  330. return all_predicted_probs, all_predicted, all_correct
  331. def iterate_once(self, batch_size):
  332. assert isinstance(batch_size, int) and batch_size > 0, (
  333. "Batch size should be a positive integer, got {!r}".format(
  334. batch_size))
  335. assert self.train_x.shape[0] >= batch_size, (
  336. "Dataset size {:d} is smaller than the batch size {:d}".format(
  337. self.train_x.shape[0], batch_size))
  338. self.epoch += 1
  339. for iteration in range(self.train_x.shape[0] // batch_size):
  340. bucket_id = np.random.choice(self.bucket_weights.shape[0], p=self.bucket_weights)
  341. example_ids = self.train_buckets[bucket_id, 0] + np.random.choice(
  342. self.train_buckets[bucket_id, 1] - self.train_buckets[bucket_id, 0],
  343. size=batch_size)
  344. yield self._encode(self.train_x[example_ids], self.train_y[example_ids])
  345. if use_graphics and time.time() - self.last_update > 0.5:
  346. dev_predicted_probs, dev_predicted, dev_correct = self._predict()
  347. dev_accuracy = np.mean(dev_predicted == dev_correct)
  348. print("epoch {:,} iteration {:,} validation-accuracy {:.1%}".format(
  349. self.epoch, iteration, dev_accuracy))
  350. for idx in self.spotlight_idxs:
  351. correct = (dev_predicted[idx] == dev_correct[idx])
  352. word = u"".join([self.chars_print[ch] for ch in self.dev_x[idx] if ch != -1])
  353. print(self.word_template.format(
  354. word,
  355. self.language_names[dev_correct[idx]],
  356. dev_predicted_probs[idx, dev_correct[idx]],
  357. "" if correct else self.predicted_template.format(
  358. self.language_names[dev_predicted[idx]]),
  359. probs=dev_predicted_probs[idx,:],
  360. ))
  361. self.last_update = time.time()
  362. def get_validation_accuracy(self):
  363. dev_predicted_probs, dev_predicted, dev_correct = self._predict()
  364. dev_accuracy = np.mean(dev_predicted == dev_correct)
  365. return dev_accuracy
  366. def main():
  367. import models
  368. # model = models.PerceptronModel(3)
  369. # dataset = PerceptronDataset(model)
  370. # model.train(dataset)
  371. # model = models.RegressionModel()
  372. # dataset = RegressionDataset(model)
  373. # model.train(dataset)
  374. model = models.DigitClassificationModel()
  375. dataset = DigitClassificationDataset(model)
  376. model.train(dataset)
  377. # model = models.LanguageIDModel()
  378. # dataset = LanguageIDDataset(model)
  379. # model.train(dataset)
  380. if __name__ == "__main__":
  381. main()