You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SHARE_MLSpring2021_HW2_1.ipynb 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. {
  2. "nbformat": 4,
  3. "nbformat_minor": 0,
  4. "metadata": {
  5. "accelerator": "GPU",
  6. "colab": {
  7. "name": "SHARE MLSpring2021 - HW2-1.ipynb",
  8. "provenance": [],
  9. "collapsed_sections": []
  10. },
  11. "kernelspec": {
  12. "display_name": "Python 3",
  13. "name": "python3"
  14. }
  15. },
  16. "cells": [
  17. {
  18. "cell_type": "markdown",
  19. "metadata": {
  20. "id": "OYlaRwNu7ojq"
  21. },
  22. "source": [
  23. "# **Homework 2-1 Phoneme Classification**"
  24. ]
  25. },
  26. {
  27. "cell_type": "markdown",
  28. "metadata": {
  29. "id": "emUd7uS7crTz"
  30. },
  31. "source": [
  32. "## The DARPA TIMIT Acoustic-Phonetic Continuous Speech Corpus (TIMIT)\n",
  33. "The TIMIT corpus of reading speech has been designed to provide speech data for the acquisition of acoustic-phonetic knowledge and for the development and evaluation of automatic speech recognition systems.\n",
  34. "\n",
  35. "This homework is a multiclass classification task, \n",
  36. "we are going to train a deep neural network classifier to predict the phonemes for each frame from the speech corpus TIMIT.\n",
  37. "\n",
  38. "link: https://academictorrents.com/details/34e2b78745138186976cbc27939b1b34d18bd5b3"
  39. ]
  40. },
  41. {
  42. "cell_type": "markdown",
  43. "metadata": {
  44. "id": "KVUGfWTo7_Oj"
  45. },
  46. "source": [
  47. "## Download Data\n",
  48. "Download data from google drive, then unzip it.\n",
  49. "\n",
  50. "You should have `timit_11/train_11.npy`, `timit_11/train_label_11.npy`, and `timit_11/test_11.npy` after running this block.<br><br>\n",
  51. "`timit_11/`\n",
  52. "- `train_11.npy`: training data<br>\n",
  53. "- `train_label_11.npy`: training label<br>\n",
  54. "- `test_11.npy`: testing data<br><br>\n",
  55. "\n",
  56. "**notes: if the google drive link is dead, you can download the data directly from Kaggle and upload it to the workspace**\n",
  57. "\n",
  58. "\n"
  59. ]
  60. },
  61. {
  62. "cell_type": "code",
  63. "metadata": {
  64. "colab": {
  65. "base_uri": "https://localhost:8080/"
  66. },
  67. "id": "OzkiMEcC3Foq",
  68. "outputId": "4308c64c-6885-4d1c-8eb7-a2d9b8038401"
  69. },
  70. "source": [
  71. "!gdown --id '1HPkcmQmFGu-3OknddKIa5dNDsR05lIQR' --output data.zip\n",
  72. "!unzip data.zip\n",
  73. "!ls "
  74. ],
  75. "execution_count": null,
  76. "outputs": [
  77. {
  78. "output_type": "stream",
  79. "text": [
  80. "Downloading...\n",
  81. "From: https://drive.google.com/uc?id=1HPkcmQmFGu-3OknddKIa5dNDsR05lIQR\n",
  82. "To: /content/data.zip\n",
  83. "372MB [00:03, 121MB/s]\n",
  84. "Archive: data.zip\n",
  85. " creating: timit_11/\n",
  86. " inflating: timit_11/train_11.npy \n",
  87. " inflating: timit_11/test_11.npy \n",
  88. " inflating: timit_11/train_label_11.npy \n",
  89. "data.zip sample_data timit_11\n"
  90. ],
  91. "name": "stdout"
  92. }
  93. ]
  94. },
  95. {
  96. "cell_type": "markdown",
  97. "metadata": {
  98. "id": "_L_4anls8Drv"
  99. },
  100. "source": [
  101. "## Preparing Data\n",
  102. "Load the training and testing data from the `.npy` file (NumPy array)."
  103. ]
  104. },
  105. {
  106. "cell_type": "code",
  107. "metadata": {
  108. "colab": {
  109. "base_uri": "https://localhost:8080/"
  110. },
  111. "id": "IJjLT8em-y9G",
  112. "outputId": "8edc6bfe-7511-447f-f239-00b96dba6dcf"
  113. },
  114. "source": [
  115. "import numpy as np\n",
  116. "\n",
  117. "print('Loading data ...')\n",
  118. "\n",
  119. "data_root='./timit_11/'\n",
  120. "train = np.load(data_root + 'train_11.npy')\n",
  121. "train_label = np.load(data_root + 'train_label_11.npy')\n",
  122. "test = np.load(data_root + 'test_11.npy')\n",
  123. "\n",
  124. "print('Size of training data: {}'.format(train.shape))\n",
  125. "print('Size of testing data: {}'.format(test.shape))"
  126. ],
  127. "execution_count": null,
  128. "outputs": [
  129. {
  130. "output_type": "stream",
  131. "text": [
  132. "Loading data ...\n",
  133. "Size of training data: (1229932, 429)\n",
  134. "Size of testing data: (451552, 429)\n"
  135. ],
  136. "name": "stdout"
  137. }
  138. ]
  139. },
  140. {
  141. "cell_type": "markdown",
  142. "metadata": {
  143. "id": "us5XW_x6udZQ"
  144. },
  145. "source": [
  146. "## Create Dataset"
  147. ]
  148. },
  149. {
  150. "cell_type": "code",
  151. "metadata": {
  152. "id": "Fjf5EcmJtf4e"
  153. },
  154. "source": [
  155. "import torch\n",
  156. "from torch.utils.data import Dataset\n",
  157. "\n",
  158. "class TIMITDataset(Dataset):\n",
  159. " def __init__(self, X, y=None):\n",
  160. " self.data = torch.from_numpy(X).float()\n",
  161. " if y is not None:\n",
  162. " y = y.astype(np.int)\n",
  163. " self.label = torch.LongTensor(y)\n",
  164. " else:\n",
  165. " self.label = None\n",
  166. "\n",
  167. " def __getitem__(self, idx):\n",
  168. " if self.label is not None:\n",
  169. " return self.data[idx], self.label[idx]\n",
  170. " else:\n",
  171. " return self.data[idx]\n",
  172. "\n",
  173. " def __len__(self):\n",
  174. " return len(self.data)\n"
  175. ],
  176. "execution_count": null,
  177. "outputs": []
  178. },
  179. {
  180. "cell_type": "markdown",
  181. "metadata": {
  182. "id": "otIC6WhGeh9v"
  183. },
  184. "source": [
  185. "Split the labeled data into a training set and a validation set, you can modify the variable `VAL_RATIO` to change the ratio of validation data."
  186. ]
  187. },
  188. {
  189. "cell_type": "code",
  190. "metadata": {
  191. "colab": {
  192. "base_uri": "https://localhost:8080/"
  193. },
  194. "id": "sYqi_lAuvC59",
  195. "outputId": "13dabe63-4849-47ee-fe04-57427b9d601c"
  196. },
  197. "source": [
  198. "VAL_RATIO = 0.2\n",
  199. "\n",
  200. "percent = int(train.shape[0] * (1 - VAL_RATIO))\n",
  201. "train_x, train_y, val_x, val_y = train[:percent], train_label[:percent], train[percent:], train_label[percent:]\n",
  202. "print('Size of training set: {}'.format(train_x.shape))\n",
  203. "print('Size of validation set: {}'.format(val_x.shape))"
  204. ],
  205. "execution_count": null,
  206. "outputs": [
  207. {
  208. "output_type": "stream",
  209. "text": [
  210. "Size of training set: (983945, 429)\n",
  211. "Size of validation set: (245987, 429)\n"
  212. ],
  213. "name": "stdout"
  214. }
  215. ]
  216. },
  217. {
  218. "cell_type": "markdown",
  219. "metadata": {
  220. "id": "nbCfclUIgMTX"
  221. },
  222. "source": [
  223. "Create a data loader from the dataset, feel free to tweak the variable `BATCH_SIZE` here."
  224. ]
  225. },
  226. {
  227. "cell_type": "code",
  228. "metadata": {
  229. "id": "RUCbQvqJurYc"
  230. },
  231. "source": [
  232. "BATCH_SIZE = 64\n",
  233. "\n",
  234. "from torch.utils.data import DataLoader\n",
  235. "\n",
  236. "train_set = TIMITDataset(train_x, train_y)\n",
  237. "val_set = TIMITDataset(val_x, val_y)\n",
  238. "train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True) #only shuffle the training data\n",
  239. "val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)"
  240. ],
  241. "execution_count": null,
  242. "outputs": []
  243. },
  244. {
  245. "cell_type": "markdown",
  246. "metadata": {
  247. "id": "_SY7X0lUgb50"
  248. },
  249. "source": [
  250. "Cleanup the unneeded variables to save memory.<br>\n",
  251. "\n",
  252. "**notes: if you need to use these variables later, then you may remove this block or clean up unneeded variables later<br>the data size is quite huge, so be aware of memory usage in colab**"
  253. ]
  254. },
  255. {
  256. "cell_type": "code",
  257. "metadata": {
  258. "colab": {
  259. "base_uri": "https://localhost:8080/"
  260. },
  261. "id": "y8rzkGraeYeN",
  262. "outputId": "dc790996-a43c-4a99-90d4-e7928892a899"
  263. },
  264. "source": [
  265. "import gc\n",
  266. "\n",
  267. "del train, train_label, train_x, train_y, val_x, val_y\n",
  268. "gc.collect()"
  269. ],
  270. "execution_count": null,
  271. "outputs": [
  272. {
  273. "output_type": "execute_result",
  274. "data": {
  275. "text/plain": [
  276. "50"
  277. ]
  278. },
  279. "metadata": {
  280. "tags": []
  281. },
  282. "execution_count": 6
  283. }
  284. ]
  285. },
  286. {
  287. "cell_type": "markdown",
  288. "metadata": {
  289. "id": "IRqKNvNZwe3V"
  290. },
  291. "source": [
  292. "## Create Model"
  293. ]
  294. },
  295. {
  296. "cell_type": "markdown",
  297. "metadata": {
  298. "id": "FYr1ng5fh9pA"
  299. },
  300. "source": [
  301. "Define model architecture, you are encouraged to change and experiment with the model architecture."
  302. ]
  303. },
  304. {
  305. "cell_type": "code",
  306. "metadata": {
  307. "id": "lbZrwT6Ny0XL"
  308. },
  309. "source": [
  310. "import torch\n",
  311. "import torch.nn as nn\n",
  312. "\n",
  313. "class Classifier(nn.Module):\n",
  314. " def __init__(self):\n",
  315. " super(Classifier, self).__init__()\n",
  316. " self.layer1 = nn.Linear(429, 1024)\n",
  317. " self.layer2 = nn.Linear(1024, 512)\n",
  318. " self.layer3 = nn.Linear(512, 128)\n",
  319. " self.out = nn.Linear(128, 39) \n",
  320. "\n",
  321. " self.act_fn = nn.Sigmoid()\n",
  322. "\n",
  323. " def forward(self, x):\n",
  324. " x = self.layer1(x)\n",
  325. " x = self.act_fn(x)\n",
  326. "\n",
  327. " x = self.layer2(x)\n",
  328. " x = self.act_fn(x)\n",
  329. "\n",
  330. " x = self.layer3(x)\n",
  331. " x = self.act_fn(x)\n",
  332. "\n",
  333. " x = self.out(x)\n",
  334. " \n",
  335. " return x"
  336. ],
  337. "execution_count": null,
  338. "outputs": []
  339. },
  340. {
  341. "cell_type": "markdown",
  342. "metadata": {
  343. "id": "VRYciXZvPbYh"
  344. },
  345. "source": [
  346. "## Training"
  347. ]
  348. },
  349. {
  350. "cell_type": "code",
  351. "metadata": {
  352. "id": "y114Vmm3Ja6o"
  353. },
  354. "source": [
  355. "#check device\n",
  356. "def get_device():\n",
  357. " return 'cuda' if torch.cuda.is_available() else 'cpu'"
  358. ],
  359. "execution_count": null,
  360. "outputs": []
  361. },
  362. {
  363. "cell_type": "markdown",
  364. "metadata": {
  365. "id": "sEX-yjHjhGuH"
  366. },
  367. "source": [
  368. "Fix random seeds for reproducibility."
  369. ]
  370. },
  371. {
  372. "cell_type": "code",
  373. "metadata": {
  374. "id": "88xPiUnm0tAd"
  375. },
  376. "source": [
  377. "# fix random seed\n",
  378. "def same_seeds(seed):\n",
  379. " torch.manual_seed(seed)\n",
  380. " if torch.cuda.is_available():\n",
  381. " torch.cuda.manual_seed(seed)\n",
  382. " torch.cuda.manual_seed_all(seed) \n",
  383. " np.random.seed(seed) \n",
  384. " torch.backends.cudnn.benchmark = False\n",
  385. " torch.backends.cudnn.deterministic = True"
  386. ],
  387. "execution_count": null,
  388. "outputs": []
  389. },
  390. {
  391. "cell_type": "markdown",
  392. "metadata": {
  393. "id": "KbBcBXkSp6RA"
  394. },
  395. "source": [
  396. "Feel free to change the training parameters here."
  397. ]
  398. },
  399. {
  400. "cell_type": "code",
  401. "metadata": {
  402. "id": "QTp3ZXg1yO9Y"
  403. },
  404. "source": [
  405. "# fix random seed for reproducibility\n",
  406. "same_seeds(0)\n",
  407. "\n",
  408. "# get device \n",
  409. "device = get_device()\n",
  410. "print(f'DEVICE: {device}')\n",
  411. "\n",
  412. "# training parameters\n",
  413. "num_epoch = 20 # number of training epoch\n",
  414. "learning_rate = 0.0001 # learning rate\n",
  415. "\n",
  416. "# the path where checkpoint saved\n",
  417. "model_path = './model.ckpt'\n",
  418. "\n",
  419. "# create model, define a loss function, and optimizer\n",
  420. "model = Classifier().to(device)\n",
  421. "criterion = nn.CrossEntropyLoss() \n",
  422. "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
  423. ],
  424. "execution_count": null,
  425. "outputs": []
  426. },
  427. {
  428. "cell_type": "code",
  429. "metadata": {
  430. "id": "CdMWsBs7zzNs",
  431. "colab": {
  432. "base_uri": "https://localhost:8080/"
  433. },
  434. "outputId": "c5ed561e-610d-4a35-d936-fd97adf342a0"
  435. },
  436. "source": [
  437. "# start training\n",
  438. "\n",
  439. "best_acc = 0.0\n",
  440. "for epoch in range(num_epoch):\n",
  441. " train_acc = 0.0\n",
  442. " train_loss = 0.0\n",
  443. " val_acc = 0.0\n",
  444. " val_loss = 0.0\n",
  445. "\n",
  446. " # training\n",
  447. " model.train() # set the model to training mode\n",
  448. " for i, data in enumerate(train_loader):\n",
  449. " inputs, labels = data\n",
  450. " inputs, labels = inputs.to(device), labels.to(device)\n",
  451. " optimizer.zero_grad() \n",
  452. " outputs = model(inputs) \n",
  453. " batch_loss = criterion(outputs, labels)\n",
  454. " _, train_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n",
  455. " batch_loss.backward() \n",
  456. " optimizer.step() \n",
  457. "\n",
  458. " train_acc += (train_pred.cpu() == labels.cpu()).sum().item()\n",
  459. " train_loss += batch_loss.item()\n",
  460. "\n",
  461. " # validation\n",
  462. " if len(val_set) > 0:\n",
  463. " model.eval() # set the model to evaluation mode\n",
  464. " with torch.no_grad():\n",
  465. " for i, data in enumerate(val_loader):\n",
  466. " inputs, labels = data\n",
  467. " inputs, labels = inputs.to(device), labels.to(device)\n",
  468. " outputs = model(inputs)\n",
  469. " batch_loss = criterion(outputs, labels) \n",
  470. " _, val_pred = torch.max(outputs, 1) \n",
  471. " \n",
  472. " val_acc += (val_pred.cpu() == labels.cpu()).sum().item() # get the index of the class with the highest probability\n",
  473. " val_loss += batch_loss.item()\n",
  474. "\n",
  475. " print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f} | Val Acc: {:3.6f} loss: {:3.6f}'.format(\n",
  476. " epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader), val_acc/len(val_set), val_loss/len(val_loader)\n",
  477. " ))\n",
  478. "\n",
  479. " # if the model improves, save a checkpoint at this epoch\n",
  480. " if val_acc > best_acc:\n",
  481. " best_acc = val_acc\n",
  482. " torch.save(model.state_dict(), model_path)\n",
  483. " print('saving model with acc {:.3f}'.format(best_acc/len(val_set)))\n",
  484. " else:\n",
  485. " print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f}'.format(\n",
  486. " epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader)\n",
  487. " ))\n",
  488. "\n",
  489. "# if not validating, save the last epoch\n",
  490. "if len(val_set) == 0:\n",
  491. " torch.save(model.state_dict(), model_path)\n",
  492. " print('saving model at last epoch')\n"
  493. ],
  494. "execution_count": null,
  495. "outputs": [
  496. {
  497. "output_type": "stream",
  498. "text": [
  499. "[001/020] Train Acc: 0.467390 Loss: 1.812880 | Val Acc: 0.564884 loss: 1.440870\n",
  500. "saving model with acc 0.565\n",
  501. "[002/020] Train Acc: 0.594031 Loss: 1.332670 | Val Acc: 0.629594 loss: 1.209077\n",
  502. "saving model with acc 0.630\n",
  503. "[003/020] Train Acc: 0.644419 Loss: 1.154247 | Val Acc: 0.658295 loss: 1.102313\n",
  504. "saving model with acc 0.658\n",
  505. "[004/020] Train Acc: 0.672767 Loss: 1.051355 | Val Acc: 0.675568 loss: 1.040186\n",
  506. "saving model with acc 0.676\n",
  507. "[005/020] Train Acc: 0.691564 Loss: 0.982245 | Val Acc: 0.683853 loss: 1.004628\n",
  508. "saving model with acc 0.684\n",
  509. "[006/020] Train Acc: 0.705731 Loss: 0.930892 | Val Acc: 0.691707 loss: 0.977562\n",
  510. "saving model with acc 0.692\n",
  511. "[007/020] Train Acc: 0.716722 Loss: 0.890210 | Val Acc: 0.691016 loss: 0.973670\n",
  512. "[008/020] Train Acc: 0.726312 Loss: 0.856612 | Val Acc: 0.690207 loss: 0.971627\n",
  513. "[009/020] Train Acc: 0.734965 Loss: 0.827445 | Val Acc: 0.698561 loss: 0.942904\n",
  514. "saving model with acc 0.699\n",
  515. "[010/020] Train Acc: 0.741926 Loss: 0.801676 | Val Acc: 0.698854 loss: 0.946376\n",
  516. "saving model with acc 0.699\n",
  517. "[011/020] Train Acc: 0.748191 Loss: 0.779319 | Val Acc: 0.700944 loss: 0.938454\n",
  518. "saving model with acc 0.701\n",
  519. "[012/020] Train Acc: 0.754672 Loss: 0.758071 | Val Acc: 0.699423 loss: 0.940523\n",
  520. "[013/020] Train Acc: 0.759725 Loss: 0.739450 | Val Acc: 0.699728 loss: 0.951068\n",
  521. "[014/020] Train Acc: 0.765137 Loss: 0.721372 | Val Acc: 0.701903 loss: 0.938658\n",
  522. "saving model with acc 0.702\n",
  523. "[015/020] Train Acc: 0.769828 Loss: 0.704748 | Val Acc: 0.701761 loss: 0.937079\n",
  524. "[016/020] Train Acc: 0.774698 Loss: 0.688990 | Val Acc: 0.702293 loss: 0.938634\n",
  525. "saving model with acc 0.702\n",
  526. "[017/020] Train Acc: 0.779358 Loss: 0.674498 | Val Acc: 0.702492 loss: 0.943941\n",
  527. "saving model with acc 0.702\n",
  528. "[018/020] Train Acc: 0.783076 Loss: 0.660028 | Val Acc: 0.695195 loss: 0.966189\n",
  529. "[019/020] Train Acc: 0.787432 Loss: 0.646340 | Val Acc: 0.700708 loss: 0.958220\n",
  530. "[020/020] Train Acc: 0.791536 Loss: 0.633378 | Val Acc: 0.700643 loss: 0.957066\n"
  531. ],
  532. "name": "stdout"
  533. }
  534. ]
  535. },
  536. {
  537. "cell_type": "markdown",
  538. "metadata": {
  539. "id": "1Hi7jTn3PX-m"
  540. },
  541. "source": [
  542. "## Testing"
  543. ]
  544. },
  545. {
  546. "cell_type": "markdown",
  547. "metadata": {
  548. "id": "NfUECMFCn5VG"
  549. },
  550. "source": [
  551. "Create a testing dataset, and load model from the saved checkpoint."
  552. ]
  553. },
  554. {
  555. "cell_type": "code",
  556. "metadata": {
  557. "id": "1PKjtAScPWtr",
  558. "colab": {
  559. "base_uri": "https://localhost:8080/"
  560. },
  561. "outputId": "8c17272b-536a-4692-a95f-a3292766c698"
  562. },
  563. "source": [
  564. "# create testing dataset\n",
  565. "test_set = TIMITDataset(test, None)\n",
  566. "test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)\n",
  567. "\n",
  568. "# create model and load weights from checkpoint\n",
  569. "model = Classifier().to(device)\n",
  570. "model.load_state_dict(torch.load(model_path))"
  571. ],
  572. "execution_count": null,
  573. "outputs": [
  574. {
  575. "output_type": "execute_result",
  576. "data": {
  577. "text/plain": [
  578. "<All keys matched successfully>"
  579. ]
  580. },
  581. "metadata": {
  582. "tags": []
  583. },
  584. "execution_count": 12
  585. }
  586. ]
  587. },
  588. {
  589. "cell_type": "markdown",
  590. "metadata": {
  591. "id": "940TtCCdoYd0"
  592. },
  593. "source": [
  594. "Make prediction."
  595. ]
  596. },
  597. {
  598. "cell_type": "code",
  599. "metadata": {
  600. "id": "84HU5GGjPqR0"
  601. },
  602. "source": [
  603. "predict = []\n",
  604. "model.eval() # set the model to evaluation mode\n",
  605. "with torch.no_grad():\n",
  606. " for i, data in enumerate(test_loader):\n",
  607. " inputs = data\n",
  608. " inputs = inputs.to(device)\n",
  609. " outputs = model(inputs)\n",
  610. " _, test_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n",
  611. "\n",
  612. " for y in test_pred.cpu().numpy():\n",
  613. " predict.append(y)"
  614. ],
  615. "execution_count": null,
  616. "outputs": []
  617. },
  618. {
  619. "cell_type": "markdown",
  620. "metadata": {
  621. "id": "AWDf_C-omElb"
  622. },
  623. "source": [
  624. "Write prediction to a CSV file.\n",
  625. "\n",
  626. "After finish running this block, download the file `prediction.csv` from the files section on the left-hand side and submit it to Kaggle."
  627. ]
  628. },
  629. {
  630. "cell_type": "code",
  631. "metadata": {
  632. "id": "GuljYSPHcZir"
  633. },
  634. "source": [
  635. "with open('prediction.csv', 'w') as f:\n",
  636. " f.write('Id,Class\\n')\n",
  637. " for i, y in enumerate(predict):\n",
  638. " f.write('{},{}\\n'.format(i, y))"
  639. ],
  640. "execution_count": null,
  641. "outputs": []
  642. }
  643. ]
  644. }