You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ML2021_HW15_Meta_Learning.ipynb 48 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304
  1. {
  2. "nbformat": 4,
  3. "nbformat_minor": 0,
  4. "metadata": {
  5. "accelerator": "GPU",
  6. "colab": {
  7. "name": "ML2021 HW15 Meta Learning.ipynb",
  8. "provenance": [],
  9. "collapsed_sections": []
  10. },
  11. "kernelspec": {
  12. "display_name": "Python 3",
  13. "name": "python3"
  14. },
  15. "language_info": {
  16. "name": "python"
  17. }
  18. },
  19. "cells": [
  20. {
  21. "cell_type": "markdown",
  22. "metadata": {
  23. "id": "wzVBe3h7Xh-2"
  24. },
  25. "source": [
  26. "<a name=\"top\"></a>\n",
  27. "# **HW15 Meta Learning: Few-shot Classification**\n",
  28. "\n",
  29. "Please mail to ntu-ml-2021spring-ta@googlegroups.com if you have any questions.\n",
  30. "\n",
  31. "Useful Links:\n",
  32. "1. [Go to hyperparameter setting.](#hyp)\n",
  33. "1. [Go to meta algorithm setting.](#modelsetting)\n",
  34. "1. [Go to main loop.](#mainloop)"
  35. ]
  36. },
  37. {
  38. "cell_type": "markdown",
  39. "metadata": {
  40. "id": "RdpzIMG6XsGK"
  41. },
  42. "source": [
  43. "## **Step 0: Check GPU**"
  44. ]
  45. },
  46. {
  47. "cell_type": "code",
  48. "metadata": {
  49. "id": "zjjHsZbaL7SV"
  50. },
  51. "source": [
  52. "!nvidia-smi"
  53. ],
  54. "execution_count": null,
  55. "outputs": []
  56. },
  57. {
  58. "cell_type": "code",
  59. "metadata": {
  60. "cellView": "form",
  61. "id": "gWpc6vW3MQhv"
  62. },
  63. "source": [
  64. "#@markdown ### Install `qqdm`\n",
  65. "# Check if installed\n",
  66. "try:\n",
  67. " import qqdm\n",
  68. "except:\n",
  69. " ! pip install qqdm > /dev/null 2>&1\n",
  70. "print(\"Done!\")"
  71. ],
  72. "execution_count": null,
  73. "outputs": []
  74. },
  75. {
  76. "cell_type": "markdown",
  77. "metadata": {
  78. "id": "bQ3wvyjnXwGX"
  79. },
  80. "source": [
  81. "## **Step 1: Download Data**\n",
  82. "\n",
  83. "Run the cell to download data, which has been pre-processed by TAs. \n",
  84. "The dataset has been augmented, so extra data augmentation is not required.\n"
  85. ]
  86. },
  87. {
  88. "cell_type": "code",
  89. "metadata": {
  90. "id": "g7Gt4Jucug41"
  91. },
  92. "source": [
  93. "workspace_dir = '.'\n",
  94. "\n",
  95. "# gdown is a package that downloads files from google drive\n",
  96. "!gdown --id 1FLDrQ0k-iJ-mk8ors0WItqvwgu0w9J0U \\\n",
  97. " --output \"{workspace_dir}/Omniglot.tar.gz\""
  98. ],
  99. "execution_count": null,
  100. "outputs": []
  101. },
  102. {
  103. "cell_type": "markdown",
  104. "metadata": {
  105. "id": "AMGFHI9XX9ms"
  106. },
  107. "source": [
  108. "### Decompress the dataset\n",
  109. "\n",
  110. "Since the dataset is quite large, please wait and observe the main program [here](#mainprogram). \n",
  111. "You can come back here later by [*back to pre-process*](#preprocess)."
  112. ]
  113. },
  114. {
  115. "cell_type": "code",
  116. "metadata": {
  117. "id": "AvvlAQBUug42"
  118. },
  119. "source": [
  120. "# Use `tar' command to decompress\n",
  121. "!tar -zxf \"{workspace_dir}/Omniglot.tar.gz\" \\\n",
  122. " -C \"{workspace_dir}/\""
  123. ],
  124. "execution_count": null,
  125. "outputs": []
  126. },
  127. {
  128. "cell_type": "markdown",
  129. "metadata": {
  130. "id": "T5P9eT0fYDqV"
  131. },
  132. "source": [
  133. "### Data Preview\n",
  134. "\n",
  135. "Just look at some data in the dataset."
  136. ]
  137. },
  138. {
  139. "cell_type": "code",
  140. "metadata": {
  141. "colab": {
  142. "base_uri": "https://localhost:8080/",
  143. "height": 297
  144. },
  145. "id": "7VtgHLurYE5x",
  146. "outputId": "961971b2-8b61-4d03-c06e-571a778ab52d"
  147. },
  148. "source": [
  149. "from PIL import Image\n",
  150. "from IPython.display import display\n",
  151. "for i in range(10, 20):\n",
  152. " im = Image.open(\"Omniglot/images_background/Japanese_(hiragana).0/character13/0500_\" + str (i) + \".png\")\n",
  153. " display(im)"
  154. ],
  155. "execution_count": null,
  156. "outputs": [
  157. {
  158. "output_type": "display_data",
  159. "data": {
  160. "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABEUlEQVR4nGP8z4AbMOGRQ5F8uekLmux/BNjHs+0/CkDWqcKJppMFQv1jYGBgYGb69w/FMsb/DAwMDD+bnjMwMHxbb6nIyMDAwMBWqoyk8/+zRwwMDD//vGFmYGBgYBDhQNbJ8Pvvp7t/OL0nBENMZUG1s2vFsz+C71nYsPnz5QSz/Vt1f//DGgj/GV0MbEMYTqDJQrz7Sc/62VVZZtefKIEAlfy3XthM3TBD7Qu2EGL07ZPWncHGzog9bP/+/v1EvvUv9rBlYvk37VsgE1ad/34+zeeL+4UaKywMDAwM7w7/unH46puSKlZUjYz/GRj+z278y2xkbW7Cy4ApyfD1838mQVY0lzLAAx47IDqBDQpJAN4Euv7fFejQAAAAAElFTkSuQmCC\n",
  161. "text/plain": [
  162. "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113AF743D0>"
  163. ]
  164. },
  165. "metadata": {
  166. "tags": []
  167. }
  168. },
  169. {
  170. "output_type": "display_data",
  171. "data": {
  172. "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABNklEQVR4nGP8zwAH//8xM6AAJiT2pdTXuCVfbvmGW5LhPwNuSUaGP7gl5ZkuoUqyMDAwMLw78I9PjVVYSu2AP9P//39ZUSQfFP/8/4dVWvER9y6GC08+T+dCltQ9zvD5wZPr9/7uPsMkzuXKgnAhHPz71aJ07eHHH3/gIghVDIwsEv9ERHF6RevDG9yBwMn8GZvk/2+3nvxnEOe4g+rR////////scmBX+nov+OCV/4jA4jkNR73ed4aD3M032GRvME/5ddV4QKRqX+xSH4vFM4/4cmg+eQ/Fsn/X6doiPDy7vmHVfL/3xdzjB2+o8r9h/mTSTxBn4UJ1SNIgfD1iTS6JDzgfxSJ7UEzFWbnr5fZQrP+YJf8FKcpsegHuhw0yti02QI9MGxkYPwPtZmREUMOJokdAAB60yoWf/hgewAAAABJRU5ErkJggg==\n",
  173. "text/plain": [
  174. "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113AEF16D0>"
  175. ]
  176. },
  177. "metadata": {
  178. "tags": []
  179. }
  180. },
  181. {
  182. "output_type": "display_data",
  183. "data": {
  184. "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABN0lEQVR4nGP8z4AbsCBz/j9lkGZE4jMi6/wTz7AQWTWKToafqMYy4bESXfI/bklG8Yc/cEoyGbz59R8OIA76eBNq2v/7P/c8gJnM6KHHwsDAcC7iN8y133KEYR7lMmdg/M/A8PUxTPWFtCXOMElGTkYWBgYGbg24rawSPEhuQATC/6dPGe7/Q/EKQvJa0GuGH39RPQpz+FNL0zNXW9gyv/1H8gyU/tuicOXf/6tcEs+RJGGB8HOnrwYjA1r4wSQ/3tP5/vXLpV+i7EiSsPh85/JEmJHh/eelfoyYkv8f7nvNwHDy7HkhbK79/+/fv19JFl/+Y3EQAwMjI+PrA36cWP35////m55y15E1/keSfOurdvAvdskfm/3kdv37j1Xy33Jh7ZWo+v7/h6fbV49UedGTIiO+7AAAZ4kCU7KEzEEAAAAASUVORK5CYII=\n",
  185. "text/plain": [
  186. "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113B605F10>"
  187. ]
  188. },
  189. "metadata": {
  190. "tags": []
  191. }
  192. },
  193. {
  194. "output_type": "display_data",
  195. "data": {
  196. "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABLUlEQVR4nGP8z4AbMKHx/3/7g1vyredeBIcFTfLP7c8w5r+/EMlX90UEISLv/315B5VbepDxPwMDA8OiAmYuiNDfZ0LcUNs/BkEkPzw4D3XHpyZHL0YIU9ydEeYVKP3dW3g51B2MCNcyQgCn47VfUCamVxjY/uH2JwpAlUQLS+RAeH3rnLbIN+ySn6JPsP35/18Kq7Hvz/edPhr75f9/bDoZGLgUGJveb32hgkUnn0H3ob/8vJ8PILT+R4BLMqIXDssyhPyGCTAiuf7fwQiuj9o/5FbC7EK2k8l+2RnFe20WjNiM/f///7+frp6v4Tz04HvyUlaIAbvOf1eNJQ8juKiSXy2lt/7FIfl7Dl/rn//YJf+tFrR7jKwY2Z8MZw/5qDAi8VEkGf4jSzEwAABSseqGZyInRAAAAABJRU5ErkJggg==\n",
  197. "text/plain": [
  198. "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F114126CE90>"
  199. ]
  200. },
  201. "metadata": {
  202. "tags": []
  203. }
  204. },
  205. {
  206. "output_type": "display_data",
  207. "data": {
  208. "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABJklEQVR4nGP8z4AbMKFy/+GRfF54D7fkh8Wv8RgLAxCXsGCR+Pfx0BtTfZjkz5dQNz/79+Lhv32HTj4TbtFjZGBg/M/AwHDO/xdE8s87Qdb/v3WsnYz5WBmgkh9PQ73wqLBVg0FAhwPmkv/I4JrgCWQukoN+/2NECy645P9bTXf4PVFDCG7sU33JNFMOwZvIxsIk/03k2/PnvpTsW2RJeAh9lzRjFtdBNRUuyfjzCwOHP2powhzEaPBypvvlvWjOhZn/xlZYUIgH1U641/5/e/6bufT8BSFs/mTkVmH4+gKHgxgYGBh+PmdhxCn5/a85D1YH/f///0Mk336UeEBI/ntXzdv9C7vkvyOWvJnf/2OX/GxlMBNNDuHPvxdlRFGcygBNJrgAAEPeDmCQZ6aqAAAAAElFTkSuQmCC\n",
  209. "text/plain": [
  210. "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113AEF1810>"
  211. ]
  212. },
  213. "metadata": {
  214. "tags": []
  215. }
  216. },
  217. {
  218. "output_type": "display_data",
  219. "data": {
  220. "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABOElEQVR4nGP8z4AbMKHxXz1AVv0fBfxJt/+B4KHp/P/j7U+cxrKEvHyD2041ppP/cUrKqh34D3EKAwMDC5okq9Sdrz9/3v50ysqTEUny//9/v169/K+01e/Rjz8MUppQnf8YGBi+3bxw+fGLu78Y/nxlCjHXZRfiYmRgYPzP8KfzPsO/iw+4FUUEHAyZX0R0xcIcwsLA8P/lQwYmTzsDXlbGX6f/fGf8zYgcQr9//fr19/////9fRgsKczLO/occQiysrKxMDAwM/5fumr0vmEUcrhPZK38OmvowCkiaYQ0EJoWbuyvmWggjRJDj5IkrPw9D8V84Hyb5bsO3////v9sQJHbxP4bkWYmajXcfnM4QWPQHU/JXjbSYqKiwRj9SXP9nhEXQz/cfH/3n0BJCdiEjKQlscEsCAN5i3onYmdekAAAAAElFTkSuQmCC\n",
  221. "text/plain": [
  222. "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113AEF1510>"
  223. ]
  224. },
  225. "metadata": {
  226. "tags": []
  227. }
  228. },
  229. {
  230. "output_type": "display_data",
  231. "data": {
  232. "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABO0lEQVR4nGP8z4AbMOGRo5bks4v/UCRZkNj/ly06zs3w4+eLf4yK7OiSDH9+/Xl+ZObDNwxcG00Qkk+v/WdgYPh/+1XWic9ySSa87DpIxu4v/sfAwMDw69OtpFAxfkaYSYz/GRgYGL68ZWBgYGA4H39IDy4D18nDw8DAwMDwnIkNWY6CQIB55T+2CIBI/j1+8B4DA5P0P2ySF4NYVVkYPi1m/IEq+///////Fwge+/Hz5/sipvif/5EARPKW2Mx/////P8wi+w5ZEmKsrF4vhx3jh/6/aO6CqLkfxS0iIiQbJvsWWScjVOnnM78ZWNTW95wXwuJPXkcGBoY/x5S5cIbQlYOubJh2/v318+fPn7ctLd78x3Dt//Uz/zEwMDxmXymMGUKMCop/GBgY1BI0UAMI6tp/mA5ASGIHADm3qpNJq4xdAAAAAElFTkSuQmCC\n",
  233. "text/plain": [
  234. "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113B605750>"
  235. ]
  236. },
  237. "metadata": {
  238. "tags": []
  239. }
  240. },
  241. {
  242. "output_type": "display_data",
  243. "data": {
  244. "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABI0lEQVR4nGP8z4AbMKFy7yz/jsz9jwz+5su//P///79///7+/v//Pwuqzi9KfP9/Xdv1juEXZxMLA5Lk/39/3r153Mdw9BS/CANjNBMDAyPEQf8ffv92/vizq8+4pBlVg12FGBhYmeB2/nARFJSxTpokWfXpy89/MCdAjWWd8IVRjo/9+3Q+HkaERVBJJm0Ghp9ff6B5GuGg36WbGNXe4AiEd5tMC869xqHzwPcq7XuTsev8/5lTnlWPEUUSKRB+3+L/z4VdklH8qyfj/x84dHrs/8XwOxGHJKshA8NXjof/mLF5hYGBgYFD/84/bK5lYPj///+PxziMfbTg6/+nZ9KYsEo+2PWHgasoE8lKWHwyMDD8/8XAwMiKEgqMJKS+gZcEAF56gf6wykc6AAAAAElFTkSuQmCC\n",
  245. "text/plain": [
  246. "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113AEF1910>"
  247. ]
  248. },
  249. "metadata": {
  250. "tags": []
  251. }
  252. },
  253. {
  254. "output_type": "display_data",
  255. "data": {
  256. "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABXklEQVR4nGP8z4AbMOGRI1Xy9z/ckp+TVv3HKfnn6Lq/MDYLuiSHMkTjv98MzFDJv/9ZGBj+//339ef1symvmN49/bnn1H9rxv8MDAwM/yecjhFmuLPl7bMPv98IszH8ZmCUtuW0h+rkvhT3n4FVV95M/1tpgBsDpxazIDcjIyPUhq/P/zJwSLIxMr4z9u1nRnUQEy8vAwMDw3+G/yw8H+GOQ3bt/w9bnrOZfTZjwiL5/0LaPd63PJ/YsYXQ6wSezacOW/6+jYio/zDwb7Hkuf///5/ij/sDE0Lo/H9IXYOB4deET9+whS3vrbmvfx3axmTJjGns/6dJMpp+4gxid+EiSJL/f98sc7Ph6PmDVfL//++zZWM+/8cu+TZfOOrdf2yS/15u8BCpe/8fm+TXfm1Bj4O//2OT/Dtf0O/g1///sUreMF707T86gEp+S1/2B0PuPzSZbPscgpHUGBgAt9BS1wiwXusAAAAASUVORK5CYII=\n",
  257. "text/plain": [
  258. "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113AEF16D0>"
  259. ]
  260. },
  261. "metadata": {
  262. "tags": []
  263. }
  264. },
  265. {
  266. "output_type": "display_data",
  267. "data": {
  268. "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABQElEQVR4nGP8z4AbMKHwfvxA4TIi6/xbyNDPjMRnQVb5/wkDii2oxuK1Ew2gGMvwn5EBZjAjTPL3W4jAp6uGT86f+MfAwMDAnKwKde25oF8MDAwMDP9e8nD/kWNhYGBg4JymDZV8d+AvAwMDA8OX0mBXDSWIZ9gYGRgY/kPBv68Hdl6SmfXvPxKAOuj/g92rzzEwfcDmlf97nKs4Vu3y/IPml////////9VI5fL3//+v8KMaC9HJbs8kysHAIMT9D4vO/9eEJ/z59y6HMfEPFgcp+7XfF9tyjfMDSsBDJdl6eXf+F1s1GdVUeHz++cnAxOAlvAI5sOGxwsLNzfn9njlyXKNF2X8BLIGAAyBZ8f/zge9oamF++nG/S59L9RayN//DJP8tlpTL3XwbJfT+w73y7KShLDOqoajpFh3gdS0Aq5C/ToYG3GgAAAAASUVORK5CYII=\n",
  269. "text/plain": [
  270. "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113B605750>"
  271. ]
  272. },
  273. "metadata": {
  274. "tags": []
  275. }
  276. }
  277. ]
  278. },
  279. {
  280. "cell_type": "markdown",
  281. "metadata": {
  282. "id": "baVsWfcSYHVN"
  283. },
  284. "source": [
  285. "## **Step 2: Build the model**"
  286. ]
  287. },
  288. {
  289. "cell_type": "markdown",
  290. "metadata": {
  291. "id": "gqiOdDLgYOlQ"
  292. },
  293. "source": [
  294. "### Library importation"
  295. ]
  296. },
  297. {
  298. "cell_type": "code",
  299. "metadata": {
  300. "id": "-9pfkqh8gxHD"
  301. },
  302. "source": [
  303. "# Import modules we need\n",
  304. "import glob, random\n",
  305. "from collections import OrderedDict\n",
  306. "\n",
  307. "import numpy as np\n",
  308. "\n",
  309. "try:\n",
  310. " from qqdm.notebook import qqdm as tqdm\n",
  311. "except ModuleNotFoundError:\n",
  312. " from tqdm.auto import tqdm\n",
  313. "\n",
  314. "import torch, torch.nn as nn\n",
  315. "import torch.nn.functional as F\n",
  316. "from torch.utils.data import DataLoader, Dataset\n",
  317. "import torchvision.transforms as transforms\n",
  318. "\n",
  319. "from PIL import Image\n",
  320. "from IPython.display import display\n",
  321. "\n",
  322. "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
  323. "\n",
  324. "# fix random seeds\n",
  325. "random_seed = 0\n",
  326. "random.seed(random_seed)\n",
  327. "np.random.seed(random_seed)\n",
  328. "torch.manual_seed(random_seed)\n",
  329. "if torch.cuda.is_available():\n",
  330. " torch.cuda.manual_seed_all(random_seed)"
  331. ],
  332. "execution_count": null,
  333. "outputs": []
  334. },
  335. {
  336. "cell_type": "markdown",
  337. "metadata": {
  338. "id": "3TlwLtC1YRT7"
  339. },
  340. "source": [
  341. "### Model Construction Preliminaries\n",
  342. "\n",
  343. "Since our task is image classification, we need to build a CNN-based model. \n",
  344. "However, to implement MAML algorithm, we should adjust some code in `nn.Module`.\n"
  345. ]
  346. },
  347. {
  348. "cell_type": "markdown",
  349. "metadata": {
  350. "id": "dFwB3tuEDYfy"
  351. },
  352. "source": [
  353. "Take a look at MAML pseudocode...\n",
  354. "\n",
  355. "<img src=\"https://i.imgur.com/9aHlvfX.png\" width=\"50%\" />\n",
  356. "\n",
  357. "On the 10-th line, what we take gradients on are those $\\theta$ representing \n",
  358. "<font color=\"#0CC\">**the original model parameters**</font> (outer loop) instead of those in the \n",
  359. "<font color=\"#0C0\">**inner loop**</font>, so we need to use `functional_forward` to compute the output \n",
  360. "logits of input image instead of `forward` in `nn.Module`.\n",
  361. "\n",
  362. "The following defines these functions.\n",
  363. "\n",
  364. "<!-- 由於在第10行,我們是要對原本的參數 θ 微分,並非 inner-loop (Line5~8) 的 θ' 微分,因此在 inner-loop,我們需要用 functional forward 的方式算出 input image 的 output logits,而不是直接用 nn.module 裡面的 forward(直接對 θ 微分)。在下面我們分別定義了 functional forward 以及 forward 函數。 -->"
  365. ]
  366. },
  367. {
  368. "cell_type": "markdown",
  369. "metadata": {
  370. "id": "iuYQiPeQYc__"
  371. },
  372. "source": [
  373. "### Model block definition"
  374. ]
  375. },
  376. {
  377. "cell_type": "code",
  378. "metadata": {
  379. "id": "GgFbbKHYg3Hk"
  380. },
  381. "source": [
  382. "def ConvBlock(in_ch: int, out_ch: int):\n",
  383. " return nn.Sequential(\n",
  384. " nn.Conv2d(in_ch, out_ch, 3, padding=1),\n",
  385. " nn.BatchNorm2d(out_ch),\n",
  386. " nn.ReLU(),\n",
  387. " nn.MaxPool2d(kernel_size=2, stride=2)\n",
  388. " )\n",
  389. "\n",
  390. "def ConvBlockFunction(x, w, b, w_bn, b_bn):\n",
  391. " x = F.conv2d(x, w, b, padding=1)\n",
  392. " x = F.batch_norm(x,\n",
  393. " running_mean=None,\n",
  394. " running_var=None,\n",
  395. " weight=w_bn, bias=b_bn,\n",
  396. " training=True)\n",
  397. " x = F.relu(x)\n",
  398. " x = F.max_pool2d(x, kernel_size=2, stride=2)\n",
  399. " return x"
  400. ],
  401. "execution_count": null,
  402. "outputs": []
  403. },
  404. {
  405. "cell_type": "markdown",
  406. "metadata": {
  407. "id": "iQEzgWN7fi7B"
  408. },
  409. "source": [
  410. "### Model definition"
  411. ]
  412. },
  413. {
  414. "cell_type": "code",
  415. "metadata": {
  416. "id": "0bFBGEQoHQUW"
  417. },
  418. "source": [
  419. "class Classifier(nn.Module):\n",
  420. " def __init__(self, in_ch, k_way):\n",
  421. " super(Classifier, self).__init__()\n",
  422. " self.conv1 = ConvBlock(in_ch, 64)\n",
  423. " self.conv2 = ConvBlock(64, 64)\n",
  424. " self.conv3 = ConvBlock(64, 64)\n",
  425. " self.conv4 = ConvBlock(64, 64)\n",
  426. " self.logits = nn.Linear(64, k_way)\n",
  427. "\n",
  428. " def forward(self, x):\n",
  429. " x = self.conv1(x)\n",
  430. " x = self.conv2(x)\n",
  431. " x = self.conv3(x)\n",
  432. " x = self.conv4(x)\n",
  433. " x = x.view(x.shape[0], -1)\n",
  434. " x = self.logits(x)\n",
  435. " return x\n",
  436. "\n",
  437. " def functional_forward(self, x, params):\n",
  438. " '''\n",
  439. " Arguments:\n",
  440. " x: input images [batch, 1, 28, 28]\n",
  441. " params: model parameters, \n",
  442. " i.e. weights and biases of convolution\n",
  443. " and weights and biases of \n",
  444. " batch normalization\n",
  445. " type is an OrderedDict\n",
  446. "\n",
  447. " Arguments:\n",
  448. " x: input images [batch, 1, 28, 28]\n",
  449. " params: The model parameters, \n",
  450. " i.e. weights and biases of convolution \n",
  451. " and batch normalization layers\n",
  452. " It's an `OrderedDict`\n",
  453. " '''\n",
  454. " for block in [1, 2, 3, 4]:\n",
  455. " x = ConvBlockFunction(\n",
  456. " x,\n",
  457. " params[f'conv{block}.0.weight'],\n",
  458. " params[f'conv{block}.0.bias'],\n",
  459. " params.get(f'conv{block}.1.weight'),\n",
  460. " params.get(f'conv{block}.1.bias'))\n",
  461. " x = x.view(x.shape[0], -1)\n",
  462. " x = F.linear(x,\n",
  463. " params['logits.weight'],\n",
  464. " params['logits.bias'])\n",
  465. " return x"
  466. ],
  467. "execution_count": null,
  468. "outputs": []
  469. },
  470. {
  471. "cell_type": "markdown",
  472. "metadata": {
  473. "id": "gmJq_0B9Yj0G"
  474. },
  475. "source": [
  476. "### Create Label\n",
  477. "\n",
  478. "This function is used to create labels. \n",
  479. "In a N-way K-shot few-shot classification problem,\n",
  480. "each task has `n_way` classes, while there are `k_shot` images for each class. \n",
  481. "This is a function that creates such labels.\n"
  482. ]
  483. },
  484. {
  485. "cell_type": "code",
  486. "metadata": {
  487. "colab": {
  488. "base_uri": "https://localhost:8080/"
  489. },
  490. "id": "GQF5vgLvg5aX",
  491. "outputId": "5df41e04-290c-428b-b06f-cc749f09f027"
  492. },
  493. "source": [
  494. "def create_label(n_way, k_shot):\n",
  495. " return (torch.arange(n_way)\n",
  496. " .repeat_interleave(k_shot)\n",
  497. " .long())\n",
  498. "\n",
  499. "# Try to create labels for 5-way 2-shot setting\n",
  500. "create_label(5, 2)"
  501. ],
  502. "execution_count": null,
  503. "outputs": [
  504. {
  505. "output_type": "execute_result",
  506. "data": {
  507. "text/plain": [
  508. "tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])"
  509. ]
  510. },
  511. "metadata": {
  512. "tags": []
  513. },
  514. "execution_count": 9
  515. }
  516. ]
  517. },
  518. {
  519. "cell_type": "markdown",
  520. "metadata": {
  521. "id": "2nCFv9PGw50J"
  522. },
  523. "source": [
  524. "### Accuracy calculation"
  525. ]
  526. },
  527. {
  528. "cell_type": "code",
  529. "metadata": {
  530. "id": "FahDr0xQw50S"
  531. },
  532. "source": [
  533. "def calculate_accuracy(logits, val_label):\n",
  534. " \"\"\" utility function for accuracy calculation \"\"\"\n",
  535. " acc = np.asarray([(\n",
  536. " torch.argmax(logits, -1).cpu().numpy() == val_label.cpu().numpy())]\n",
  537. " ).mean() \n",
  538. " return acc"
  539. ],
  540. "execution_count": null,
  541. "outputs": []
  542. },
  543. {
  544. "cell_type": "markdown",
  545. "metadata": {
  546. "id": "9Hl7ro2mYzsI"
  547. },
  548. "source": [
  549. "### Define Dataset\n",
  550. "\n",
  551. "Define the dataset. \n",
  552. "The dataset returns images of a random character, with (`k_shot + q_query`) images, \n",
  553. "so the size of returned tensor is `[k_shot+q_query, 1, 28, 28]`. \n"
  554. ]
  555. },
  556. {
  557. "cell_type": "code",
  558. "metadata": {
  559. "id": "-tJ2mot9hHPb"
  560. },
  561. "source": [
  562. "class Omniglot(Dataset):\n",
  563. " def __init__(self, data_dir, k_way, q_query):\n",
  564. " self.file_list = [f for f in glob.glob(\n",
  565. " data_dir + \"**/character*\", \n",
  566. " recursive=True)]\n",
  567. " self.transform = transforms.Compose(\n",
  568. " [transforms.ToTensor()])\n",
  569. " self.n = k_way + q_query\n",
  570. "\n",
  571. " def __getitem__(self, idx):\n",
  572. " sample = np.arange(20)\n",
  573. "\n",
  574. " # For random sampling the characters we want.\n",
  575. " np.random.shuffle(sample) \n",
  576. " img_path = self.file_list[idx]\n",
  577. " img_list = [f for f in glob.glob(\n",
  578. " img_path + \"**/*.png\", recursive=True)]\n",
  579. " img_list.sort()\n",
  580. " imgs = [self.transform(\n",
  581. " Image.open(img_file)) \n",
  582. " for img_file in img_list]\n",
  583. " # `k_way + q_query` examples for each character\n",
  584. " imgs = torch.stack(imgs)[sample[:self.n]] \n",
  585. " return imgs\n",
  586. "\n",
  587. " def __len__(self):\n",
  588. " return len(self.file_list)"
  589. ],
  590. "execution_count": null,
  591. "outputs": []
  592. },
  593. {
  594. "cell_type": "markdown",
  595. "metadata": {
  596. "id": "Gm5iVp90Ylii"
  597. },
  598. "source": [
  599. "## **Step 3: Core MAML**\n",
  600. "\n",
  601. "Here is the main Meta Learning algorithm. \n",
  602. "The algorithm is exactly the same as the paper. \n",
  603. "What the function does is to update the parameters using \"the data of a meta-batch.\"\n",
  604. "Here we implement the second-order MAML (inner_train_step = 1), according to [the slides of meta learning in 2019 (p. 13 ~ p.18)](http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf#page=13&view=FitW)\n",
  605. "\n",
  606. "As for the mathematical derivation of the first-order version, please refer to [p.25 of the slides in 2019](http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf#page=25&view=FitW).\n",
  607. "\n",
  608. "The following is the algorithm with some explanation."
  609. ]
  610. },
  611. {
  612. "cell_type": "code",
  613. "metadata": {
  614. "id": "KjNxrWW_yNck"
  615. },
  616. "source": [
  617. "def OriginalMAML(\n",
  618. " model, optimizer, x, n_way, k_shot, q_query, loss_fn,\n",
  619. " inner_train_step=1, inner_lr=0.4, train=True):\n",
  620. " criterion, task_loss, task_acc = loss_fn, [], []\n",
  621. "\n",
  622. " for meta_batch in x:\n",
  623. " # Get data\n",
  624. " support_set = meta_batch[: n_way * k_shot] \n",
  625. " query_set = meta_batch[n_way * k_shot :] \n",
  626. " \n",
  627. " # Copy the params for inner loop\n",
  628. " fast_weights = OrderedDict(model.named_parameters())\n",
  629. " \n",
  630. " ### ---------- INNER TRAIN LOOP ---------- ###\n",
  631. " for inner_step in range(inner_train_step): \n",
  632. " # Simply training\n",
  633. " train_label = create_label(n_way, k_shot) \\\n",
  634. " .to(device)\n",
  635. " logits = model.functional_forward(\n",
  636. " support_set, fast_weights)\n",
  637. " loss = criterion(logits, train_label)\n",
  638. " # Inner gradients update! vvvvvvvvvvvvvvvvvvvv #\n",
  639. " \"\"\" Inner Loop Update \"\"\" #\n",
  640. " grads = torch.autograd.grad( #\n",
  641. " loss, fast_weights.values(), #\n",
  642. " create_graph=True) #\n",
  643. " # Perform SGD #\n",
  644. " fast_weights = OrderedDict( #\n",
  645. " (name, param - inner_lr * grad) #\n",
  646. " for ((name, param), grad) #\n",
  647. " in zip(fast_weights.items(), grads)) #\n",
  648. " # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #\n",
  649. "\n",
  650. " ### ---------- INNER VALID LOOP ---------- ###\n",
  651. " val_label = create_label(n_way, q_query).to(device)\n",
  652. " \n",
  653. " # Collect gradients for outer loop\n",
  654. " logits = model.functional_forward(\n",
  655. " query_set, fast_weights) \n",
  656. " loss = criterion(logits, val_label)\n",
  657. " task_loss.append(loss)\n",
  658. " task_acc.append(\n",
  659. " calculate_accuracy(logits, val_label))\n",
  660. "\n",
  661. " # Update outer loop\n",
  662. " model.train()\n",
  663. " optimizer.zero_grad()\n",
  664. "\n",
  665. " meta_batch_loss = torch.stack(task_loss).mean()\n",
  666. " if train:\n",
  667. " meta_batch_loss.backward() # <--- May change later!\n",
  668. " optimizer.step()\n",
  669. " task_acc = np.mean(task_acc)\n",
  670. " return meta_batch_loss, task_acc"
  671. ],
  672. "execution_count": null,
  673. "outputs": []
  674. },
  675. {
  676. "cell_type": "markdown",
  677. "metadata": {
  678. "id": "MF5ZahPdxKbp"
  679. },
  680. "source": [
  681. "## Variations of MAML\n",
  682. "\n",
  683. "### First-order approximation of MAML (FOMAML)\n",
  684. "\n",
  685. "Slightly modify the MAML mentioned earlier, applying first-order approximation to decrease amount of computation.\n",
  686. "\n",
  687. "### Almost No Inner Loop (ANIL)\n",
  688. "\n",
  689. "The algorithm from [this paper](https://arxiv.org/abs/1909.09157), using the technique of feature reuse to decrease amount of computation."
  690. ]
  691. },
  692. {
  693. "cell_type": "markdown",
  694. "metadata": {
  695. "id": "qyQ7ZUN4foh-"
  696. },
  697. "source": [
  698. "To finish the modification required, we need to change some blocks of the MAML algorithm. \n",
  699. "Below, we have replace three parts that may be modified as functions. \n",
  700. "Please choose to replace the functions with their alternative versions to complete the algorithm."
  701. ]
  702. },
  703. {
  704. "cell_type": "markdown",
  705. "metadata": {
  706. "id": "Ne5cOja0H8H7"
  707. },
  708. "source": [
  709. "### Part 1: Inner loop update"
  710. ]
  711. },
  712. {
  713. "cell_type": "markdown",
  714. "metadata": {
  715. "id": "LChAX51sIFwi"
  716. },
  717. "source": [
  718. "MAML"
  719. ]
  720. },
  721. {
  722. "cell_type": "code",
  723. "metadata": {
  724. "id": "Aqgb0kEVzQol"
  725. },
  726. "source": [
  727. "def inner_update_MAML(fast_weights, loss, inner_lr):\n",
  728. " \"\"\" Inner Loop Update \"\"\"\n",
  729. " grads = torch.autograd.grad(\n",
  730. " loss, fast_weights.values(), create_graph=True)\n",
  731. " # Perform SGD\n",
  732. " fast_weights = OrderedDict(\n",
  733. " (name, param - inner_lr * grad)\n",
  734. " for ((name, param), grad) in zip(fast_weights.items(), grads))\n",
  735. " return fast_weights"
  736. ],
  737. "execution_count": null,
  738. "outputs": []
  739. },
  740. {
  741. "cell_type": "markdown",
  742. "metadata": {
  743. "id": "QnQ_BN-L2Gd7"
  744. },
  745. "source": [
  746. "Alternatives"
  747. ]
  748. },
  749. {
  750. "cell_type": "code",
  751. "metadata": {
  752. "id": "Ug5LIO6V15cd"
  753. },
  754. "source": [
  755. "def inner_update_alt1(fast_weights, loss, inner_lr):\n",
  756. " grads = torch.autograd.grad(\n",
  757. " loss, fast_weights.values(), create_graph=False)\n",
  758. " # Perform SGD\n",
  759. " fast_weights = OrderedDict(\n",
  760. " (name, param - inner_lr * grad)\n",
  761. " for ((name, param), grad) in zip(fast_weights.items(), grads))\n",
  762. " return fast_weights\n",
  763. "\n",
  764. "def inner_update_alt2(fast_weights, loss, inner_lr):\n",
  765. " grads = torch.autograd.grad(\n",
  766. " loss, list(fast_weights.values())[-2:], create_graph=True)\n",
  767. " # Split out the logits\n",
  768. " for ((name, param), grad) in zip(\n",
  769. " list(fast_weights.items())[-2:], grads):\n",
  770. " fast_weights[name] = param - inner_lr * grad\n",
  771. " return fast_weights"
  772. ],
  773. "execution_count": null,
  774. "outputs": []
  775. },
  776. {
  777. "cell_type": "markdown",
  778. "metadata": {
  779. "id": "1ZfaWPMt164t"
  780. },
  781. "source": [
  782. "### Part 2: Collect gradients"
  783. ]
  784. },
  785. {
  786. "cell_type": "markdown",
  787. "metadata": {
  788. "id": "-W7zL2nN164u"
  789. },
  790. "source": [
  791. "MAML \n",
  792. "(Actually do nothing as gradients are computed by PyTorch automatically.)"
  793. ]
  794. },
  795. {
  796. "cell_type": "code",
  797. "metadata": {
  798. "id": "sgcuPPm2zSFL"
  799. },
  800. "source": [
  801. "def collect_gradients_MAML(\n",
  802. " special_grad: OrderedDict, fast_weights, model, len_data):\n",
  803. " \"\"\" Actually do nothing (just backwards later) \"\"\"\n",
  804. " return special_grad"
  805. ],
  806. "execution_count": null,
  807. "outputs": []
  808. },
  809. {
  810. "cell_type": "markdown",
  811. "metadata": {
  812. "id": "2OxEME6l2QOO"
  813. },
  814. "source": [
  815. "Alternatives"
  816. ]
  817. },
  818. {
  819. "cell_type": "code",
  820. "metadata": {
  821. "id": "fWLYwZlM2RZO"
  822. },
  823. "source": [
  824. "def collect_gradients_alt(\n",
  825. " special_grad: OrderedDict, fast_weights, model, len_data):\n",
  826. " \"\"\" Special gradient calculation \"\"\"\n",
  827. " diff = OrderedDict(\n",
  828. " (name, params - fast_weights[name]) \n",
  829. " for (name, params) in model.named_parameters())\n",
  830. " for name in diff:\n",
  831. " special_grad[name] = special_grad.get(name, 0) + diff[name] / len_data\n",
  832. " return special_grad"
  833. ],
  834. "execution_count": null,
  835. "outputs": []
  836. },
  837. {
  838. "cell_type": "markdown",
  839. "metadata": {
  840. "id": "ahqE-Sf92TID"
  841. },
  842. "source": [
  843. "### Part 3: Outer loop gradients calculation"
  844. ]
  845. },
  846. {
  847. "cell_type": "markdown",
  848. "metadata": {
  849. "id": "-wr0hSd02TIE"
  850. },
  851. "source": [
  852. "MAML \n",
  853. "(Simply call PyTorch `backward`.)"
  854. ]
  855. },
  856. {
  857. "cell_type": "code",
  858. "metadata": {
  859. "id": "_hBSQ02xzTXb"
  860. },
  861. "source": [
  862. "def outer_update_MAML(model, meta_batch_loss, grad_tensors):\n",
  863. " \"\"\" Simply backwards \"\"\"\n",
  864. " meta_batch_loss.backward()"
  865. ],
  866. "execution_count": null,
  867. "outputs": []
  868. },
  869. {
  870. "cell_type": "markdown",
  871. "metadata": {
  872. "id": "Q4zxf6yr2TIE"
  873. },
  874. "source": [
  875. "Alternatives"
  876. ]
  877. },
  878. {
  879. "cell_type": "code",
  880. "metadata": {
  881. "id": "DEyCwYmI2bdC"
  882. },
  883. "source": [
  884. "def outer_update_alt(model, meta_batch_loss, grad_tensors):\n",
  885. " \"\"\" Replace the gradients\n",
  886. " with precalculated tensors \"\"\"\n",
  887. " for (name, params) in model.named_parameters():\n",
  888. " params.grad = grad_tensors[name]"
  889. ],
  890. "execution_count": null,
  891. "outputs": []
  892. },
  893. {
  894. "cell_type": "markdown",
  895. "metadata": {
  896. "id": "z1jck3KE2g1D"
  897. },
  898. "source": [
  899. "### Complete the algorithm\n",
  900. "Here we have wrapped the algorithm in `MetaAlgorithmGenerator`. \n",
  901. "You can get your modified algorithm by filling in like this:\n",
  902. "```python\n",
  903. "MyAlgorithm = MetaAlgorithmGenerator(inner_update=inner_update_alt2)\n",
  904. "```\n",
  905. "Default the three blocks will be filled with that of `MAML`."
  906. ]
  907. },
  908. {
  909. "cell_type": "code",
  910. "metadata": {
  911. "id": "XosNxVMDxL6V"
  912. },
  913. "source": [
  914. "def MetaAlgorithmGenerator(\n",
  915. " inner_update = inner_update_MAML, \n",
  916. " collect_gradients = collect_gradients_MAML, \n",
  917. " outer_update = outer_update_MAML):\n",
  918. "\n",
  919. " global calculate_accuracy\n",
  920. "\n",
  921. " def MetaAlgorithm(\n",
  922. " model, optimizer, x, n_way, k_shot, q_query, loss_fn,\n",
  923. " inner_train_step=1, inner_lr=0.4, train=True): \n",
  924. " criterion = loss_fn\n",
  925. " task_loss, task_acc = [], []\n",
  926. " special_grad = OrderedDict() # Added for variants!\n",
  927. "\n",
  928. " for meta_batch in x:\n",
  929. " support_set = meta_batch[: n_way * k_shot] \n",
  930. " query_set = meta_batch[n_way * k_shot :] \n",
  931. " \n",
  932. " fast_weights = OrderedDict(model.named_parameters())\n",
  933. " \n",
  934. " ### ---------- INNER TRAIN LOOP ---------- ###\n",
  935. " for inner_step in range(inner_train_step): \n",
  936. " train_label = create_label(n_way, k_shot).to(device)\n",
  937. " logits = model.functional_forward(support_set, fast_weights)\n",
  938. " loss = criterion(logits, train_label)\n",
  939. "\n",
  940. " fast_weights = inner_update(fast_weights, loss, inner_lr)\n",
  941. "\n",
  942. " ### ---------- INNER VALID LOOP ---------- ###\n",
  943. " val_label = create_label(n_way, q_query).to(device)\n",
  944. " # FIXME: W for val?\n",
  945. " special_grad = collect_gradients(\n",
  946. " special_grad, fast_weights, model, len(x))\n",
  947. " \n",
  948. " # Collect gradients for outer loop\n",
  949. " logits = model.functional_forward(query_set, fast_weights) \n",
  950. " loss = criterion(logits, val_label)\n",
  951. " task_loss.append(loss)\n",
  952. " task_acc.append(calculate_accuracy(logits, val_label))\n",
  953. "\n",
  954. " # Update outer loop\n",
  955. " model.train()\n",
  956. " optimizer.zero_grad()\n",
  957. "\n",
  958. " meta_batch_loss = torch.stack(task_loss).mean()\n",
  959. " if train:\n",
  960. " # Notice the update part!\n",
  961. " outer_update(model, meta_batch_loss, special_grad)\n",
  962. " optimizer.step()\n",
  963. " task_acc = np.mean(task_acc)\n",
  964. " return meta_batch_loss, task_acc\n",
  965. " return MetaAlgorithm"
  966. ],
  967. "execution_count": null,
  968. "outputs": []
  969. },
  970. {
  971. "cell_type": "code",
  972. "metadata": {
  973. "id": "jEsPtV-GzbDv",
  974. "cellView": "form"
  975. },
  976. "source": [
  977. "#@title Here is the answer hidden, please fill in yourself!\n",
  978. "Give_me_the_answer = True #@param {\"type\": \"boolean\"}\n",
  979. "\n",
  980. "def HiddenAnswer():\n",
  981. " MAML = MetaAlgorithmGenerator()\n",
  982. " FOMAML = MetaAlgorithmGenerator(inner_update=inner_update_alt1)\n",
  983. " ANIL = MetaAlgorithmGenerator(inner_update=inner_update_alt2)\n",
  984. " return MAML, FOMAML, ANIL"
  985. ],
  986. "execution_count": null,
  987. "outputs": []
  988. },
  989. {
  990. "cell_type": "code",
  991. "metadata": {
  992. "id": "2P__5N2Yz9O4"
  993. },
  994. "source": [
  995. "# `HiddenAnswer` is hidden in the last cell.\n",
  996. "if Give_me_the_answer:\n",
  997. " MAML, FOMAML, ANIL = HiddenAnswer()\n",
  998. "else: \n",
  999. " # TODO: Please fill in the function names \\\n",
  1000. " # as the function arguments to finish the algorithm.\n",
  1001. " MAML = MetaAlgorithmGenerator()\n",
  1002. " FOMAML = MetaAlgorithmGenerator()\n",
  1003. " ANIL = MetaAlgorithmGenerator()"
  1004. ],
  1005. "execution_count": null,
  1006. "outputs": []
  1007. },
  1008. {
  1009. "cell_type": "markdown",
  1010. "metadata": {
  1011. "id": "nBoRBhVlZAST"
  1012. },
  1013. "source": [
  1014. "## **Step 4: Initialization**\n",
  1015. "\n",
  1016. "After defining all components we need, the following initialize a model before training."
  1017. ]
  1018. },
  1019. {
  1020. "cell_type": "markdown",
  1021. "metadata": {
  1022. "id": "Ip-i7aseftUF"
  1023. },
  1024. "source": [
  1025. "<a name=\"hyp\"></a>\n",
  1026. "### Hyperparameters \n",
  1027. "[Go back to top!](#top)"
  1028. ]
  1029. },
  1030. {
  1031. "cell_type": "code",
  1032. "metadata": {
  1033. "id": "0wFHmVcBhE4M"
  1034. },
  1035. "source": [
  1036. "n_way = 5\n",
  1037. "k_shot = 1\n",
  1038. "q_query = 1\n",
  1039. "inner_train_step = 1\n",
  1040. "inner_lr = 0.4\n",
  1041. "meta_lr = 0.001\n",
  1042. "meta_batch_size = 32\n",
  1043. "max_epoch = 30\n",
  1044. "eval_batches = test_batches = 20\n",
  1045. "train_data_path = './Omniglot/images_background/'\n",
  1046. "test_data_path = './Omniglot/images_evaluation/' "
  1047. ],
  1048. "execution_count": null,
  1049. "outputs": []
  1050. },
  1051. {
  1052. "cell_type": "markdown",
  1053. "metadata": {
  1054. "id": "Uvzo7NVpfu5V"
  1055. },
  1056. "source": [
  1057. "### Dataloader initialization"
  1058. ]
  1059. },
  1060. {
  1061. "cell_type": "code",
  1062. "metadata": {
  1063. "id": "3I13GJavhP0_"
  1064. },
  1065. "source": [
  1066. "def dataloader_init(datasets, num_workers=2):\n",
  1067. " train_set, val_set, test_set = datasets\n",
  1068. " train_loader = DataLoader(train_set,\n",
  1069. " # The \"batch_size\" here is not \\\n",
  1070. " # the meta batch size, but \\\n",
  1071. " # how many different \\\n",
  1072. " # characters in a task, \\\n",
  1073. " # i.e. the \"n_way\" in \\\n",
  1074. " # few-shot classification.\n",
  1075. " batch_size=n_way,\n",
  1076. " num_workers=num_workers,\n",
  1077. " shuffle=True,\n",
  1078. " drop_last=True)\n",
  1079. " val_loader = DataLoader(val_set,\n",
  1080. " batch_size=n_way,\n",
  1081. " num_workers=num_workers,\n",
  1082. " shuffle=True,\n",
  1083. " drop_last=True)\n",
  1084. " test_loader = DataLoader(test_set,\n",
  1085. " batch_size=n_way,\n",
  1086. " num_workers=num_workers,\n",
  1087. " shuffle=True,\n",
  1088. " drop_last=True)\n",
  1089. " train_iter = iter(train_loader)\n",
  1090. " val_iter = iter(val_loader)\n",
  1091. " test_iter = iter(test_loader)\n",
  1092. " return (train_loader, val_loader, test_loader), \\\n",
  1093. " (train_iter, val_iter, test_iter)\n",
  1094. "\n",
  1095. "train_set, val_set = torch.utils.data.random_split(\n",
  1096. " Omniglot(train_data_path, k_shot, q_query), [3200, 656])\n",
  1097. "test_set = Omniglot(test_data_path, k_shot, q_query)\n",
  1098. "\n",
  1099. "(train_loader, val_loader, test_loader), \\\n",
  1100. "(train_iter, val_iter, test_iter) = dataloader_init(\n",
  1101. " (train_set, val_set, test_set))"
  1102. ],
  1103. "execution_count": null,
  1104. "outputs": []
  1105. },
  1106. {
  1107. "cell_type": "markdown",
  1108. "metadata": {
  1109. "id": "KVund--bfw0e"
  1110. },
  1111. "source": [
  1112. "### Model & optimizer initialization"
  1113. ]
  1114. },
  1115. {
  1116. "cell_type": "code",
  1117. "metadata": {
  1118. "id": "Kxug882ihF2B"
  1119. },
  1120. "source": [
  1121. "def model_init():\n",
  1122. " meta_model = Classifier(1, n_way).to(device)\n",
  1123. " optimizer = torch.optim.Adam(meta_model.parameters(), \n",
  1124. " lr=meta_lr)\n",
  1125. " loss_fn = nn.CrossEntropyLoss().to(device)\n",
  1126. " return meta_model, optimizer, loss_fn\n",
  1127. "\n",
  1128. "meta_model, optimizer, loss_fn = model_init()"
  1129. ],
  1130. "execution_count": null,
  1131. "outputs": []
  1132. },
  1133. {
  1134. "cell_type": "markdown",
  1135. "metadata": {
  1136. "id": "gj8cLRNLf2zg"
  1137. },
  1138. "source": [
  1139. "### Utility function to get a meta-batch"
  1140. ]
  1141. },
  1142. {
  1143. "cell_type": "code",
  1144. "metadata": {
  1145. "id": "zrkCSsxOhC-N"
  1146. },
  1147. "source": [
  1148. "def get_meta_batch(meta_batch_size,\n",
  1149. " k_shot, q_query, \n",
  1150. " data_loader, iterator):\n",
  1151. " data = []\n",
  1152. " for _ in range(meta_batch_size):\n",
  1153. " try:\n",
  1154. " # a \"task_data\" tensor is representing \\\n",
  1155. " # the data of a task, with size of \\\n",
  1156. " # [n_way, k_shot+q_query, 1, 28, 28]\n",
  1157. " task_data = iterator.next() \n",
  1158. " except StopIteration:\n",
  1159. " iterator = iter(data_loader)\n",
  1160. " task_data = iterator.next()\n",
  1161. " train_data = (task_data[:, :k_shot]\n",
  1162. " .reshape(-1, 1, 28, 28))\n",
  1163. " val_data = (task_data[:, k_shot:]\n",
  1164. " .reshape(-1, 1, 28, 28))\n",
  1165. " task_data = torch.cat(\n",
  1166. " (train_data, val_data), 0)\n",
  1167. " data.append(task_data)\n",
  1168. " return torch.stack(data).to(device), iterator"
  1169. ],
  1170. "execution_count": null,
  1171. "outputs": []
  1172. },
  1173. {
  1174. "cell_type": "markdown",
  1175. "metadata": {
  1176. "id": "O5JCtob4fyh_"
  1177. },
  1178. "source": [
  1179. "<a name=\"modelsetting\"></a>\n",
  1180. "### Choose the meta learning algorithm\n",
  1181. "[Go back to top!](#top)"
  1182. ]
  1183. },
  1184. {
  1185. "cell_type": "code",
  1186. "metadata": {
  1187. "id": "3av6pAI7OxOP"
  1188. },
  1189. "source": [
  1190. "# You can change this to `FOMAML` or `ANIL`\n",
  1191. "MetaAlgorithm = MAML"
  1192. ],
  1193. "execution_count": null,
  1194. "outputs": []
  1195. },
  1196. {
  1197. "cell_type": "markdown",
  1198. "metadata": {
  1199. "id": "pWQczA3FwjEG"
  1200. },
  1201. "source": [
  1202. "<a name=\"mainprog\" id=\"mainprog\"></a>\n",
  1203. "## **Step 5: Main program for training & testing**"
  1204. ]
  1205. },
  1206. {
  1207. "cell_type": "markdown",
  1208. "metadata": {
  1209. "id": "8EirEnaof7ep"
  1210. },
  1211. "source": [
  1212. "### Start training!\n",
  1213. "<a name=\"mainloop\"></a>\n",
  1214. "[Go back to top!](#top)"
  1215. ]
  1216. },
  1217. {
  1218. "cell_type": "code",
  1219. "metadata": {
  1220. "id": "JQZjJrLAhBWw"
  1221. },
  1222. "source": [
  1223. "for epoch in range(max_epoch):\n",
  1224. " print(\"Epoch %d\" % (epoch + 1))\n",
  1225. " train_meta_loss = []\n",
  1226. " train_acc = []\n",
  1227. " # The \"step\" here is a meta-gradinet update step\n",
  1228. " for step in tqdm(range(\n",
  1229. " len(train_loader) // meta_batch_size)): \n",
  1230. " x, train_iter = get_meta_batch(\n",
  1231. " meta_batch_size, k_shot, q_query, \n",
  1232. " train_loader, train_iter)\n",
  1233. " meta_loss, acc = MetaAlgorithm(\n",
  1234. " meta_model, optimizer, x, \n",
  1235. " n_way, k_shot, q_query, loss_fn)\n",
  1236. " train_meta_loss.append(meta_loss.item())\n",
  1237. " train_acc.append(acc)\n",
  1238. " print(\" Loss : \", \"%.3f\" % (np.mean(train_meta_loss)), end='\\t')\n",
  1239. " print(\" Accuracy: \", \"%.3f %%\" % (np.mean(train_acc) * 100))\n",
  1240. "\n",
  1241. " # See the validation accuracy after each epoch.\n",
  1242. " # Early stopping is welcomed to implement.\n",
  1243. " val_acc = []\n",
  1244. " for eval_step in tqdm(range(\n",
  1245. " len(val_loader) // (eval_batches))):\n",
  1246. " x, val_iter = get_meta_batch(\n",
  1247. " eval_batches, k_shot, q_query, \n",
  1248. " val_loader, val_iter)\n",
  1249. " # We update three inner steps when testing.\n",
  1250. " _, acc = MetaAlgorithm(meta_model, optimizer, x, \n",
  1251. " n_way, k_shot, q_query, \n",
  1252. " loss_fn, \n",
  1253. " inner_train_step=3, \n",
  1254. " train=False) \n",
  1255. " val_acc.append(acc)\n",
  1256. " print(\" Validation accuracy: \", \"%.3f %%\" % (np.mean(val_acc) * 100))"
  1257. ],
  1258. "execution_count": null,
  1259. "outputs": []
  1260. },
  1261. {
  1262. "cell_type": "markdown",
  1263. "metadata": {
  1264. "id": "u5Ew8-POf9sw"
  1265. },
  1266. "source": [
  1267. "### Testing the result"
  1268. ]
  1269. },
  1270. {
  1271. "cell_type": "code",
  1272. "metadata": {
  1273. "id": "CYN_zGB3g_5_"
  1274. },
  1275. "source": [
  1276. "test_acc = []\n",
  1277. "for test_step in tqdm(range(\n",
  1278. " len(test_loader) // (test_batches))):\n",
  1279. " x, test_iter = get_meta_batch(\n",
  1280. " test_batches, k_shot, q_query, \n",
  1281. " test_loader, test_iter)\n",
  1282. " # When testing, we update 3 inner-steps\n",
  1283. " _, acc = MetaAlgorithm(meta_model, optimizer, x, \n",
  1284. " n_way, k_shot, q_query, loss_fn, \n",
  1285. " inner_train_step=3, train=False)\n",
  1286. " test_acc.append(acc)\n",
  1287. "print(\" Testing accuracy: \", \"%.3f %%\" % (np.mean(test_acc) * 100))"
  1288. ],
  1289. "execution_count": null,
  1290. "outputs": []
  1291. },
  1292. {
  1293. "cell_type": "markdown",
  1294. "metadata": {
  1295. "id": "rtD8X3RLf-6w"
  1296. },
  1297. "source": [
  1298. "## **Reference**\n",
  1299. "1. Chelsea Finn, Pieter Abbeel, & Sergey Levine. (2017). [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.](https://arxiv.org/abs/1909.09157)\n",
  1300. "1. Aniruddh Raghu, Maithra Raghu, Samy Bengio, & Oriol Vinyals. (2020). [Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML.](https://arxiv.org/abs/1909.09157)"
  1301. ]
  1302. }
  1303. ]
  1304. }