|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304 |
- {
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "name": "ML2021 HW15 Meta Learning.ipynb",
- "provenance": [],
- "collapsed_sections": []
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- },
- "language_info": {
- "name": "python"
- }
- },
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "wzVBe3h7Xh-2"
- },
- "source": [
- "<a name=\"top\"></a>\n",
- "# **HW15 Meta Learning: Few-shot Classification**\n",
- "\n",
- "Please mail to ntu-ml-2021spring-ta@googlegroups.com if you have any questions.\n",
- "\n",
- "Useful Links:\n",
- "1. [Go to hyperparameter setting.](#hyp)\n",
- "1. [Go to meta algorithm setting.](#modelsetting)\n",
- "1. [Go to main loop.](#mainloop)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "RdpzIMG6XsGK"
- },
- "source": [
- "## **Step 0: Check GPU**"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "zjjHsZbaL7SV"
- },
- "source": [
- "!nvidia-smi"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "cellView": "form",
- "id": "gWpc6vW3MQhv"
- },
- "source": [
- "#@markdown ### Install `qqdm`\n",
- "# Check if installed\n",
- "try:\n",
- " import qqdm\n",
- "except:\n",
- " ! pip install qqdm > /dev/null 2>&1\n",
- "print(\"Done!\")"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "bQ3wvyjnXwGX"
- },
- "source": [
- "## **Step 1: Download Data**\n",
- "\n",
- "Run the cell to download data, which has been pre-processed by TAs. \n",
- "The dataset has been augmented, so extra data augmentation is not required.\n"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "g7Gt4Jucug41"
- },
- "source": [
- "workspace_dir = '.'\n",
- "\n",
- "# gdown is a package that downloads files from google drive\n",
- "!gdown --id 1FLDrQ0k-iJ-mk8ors0WItqvwgu0w9J0U \\\n",
- " --output \"{workspace_dir}/Omniglot.tar.gz\""
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "AMGFHI9XX9ms"
- },
- "source": [
- "### Decompress the dataset\n",
- "\n",
- "Since the dataset is quite large, please wait and observe the main program [here](#mainprogram). \n",
- "You can come back here later by [*back to pre-process*](#preprocess)."
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "AvvlAQBUug42"
- },
- "source": [
- "# Use `tar' command to decompress\n",
- "!tar -zxf \"{workspace_dir}/Omniglot.tar.gz\" \\\n",
- " -C \"{workspace_dir}/\""
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "T5P9eT0fYDqV"
- },
- "source": [
- "### Data Preview\n",
- "\n",
- "Just look at some data in the dataset."
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 297
- },
- "id": "7VtgHLurYE5x",
- "outputId": "961971b2-8b61-4d03-c06e-571a778ab52d"
- },
- "source": [
- "from PIL import Image\n",
- "from IPython.display import display\n",
- "for i in range(10, 20):\n",
- " im = Image.open(\"Omniglot/images_background/Japanese_(hiragana).0/character13/0500_\" + str (i) + \".png\")\n",
- " display(im)"
- ],
- "execution_count": null,
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABEUlEQVR4nGP8z4AbMOGRQ5F8uekLmux/BNjHs+0/CkDWqcKJppMFQv1jYGBgYGb69w/FMsb/DAwMDD+bnjMwMHxbb6nIyMDAwMBWqoyk8/+zRwwMDD//vGFmYGBgYBDhQNbJ8Pvvp7t/OL0nBENMZUG1s2vFsz+C71nYsPnz5QSz/Vt1f//DGgj/GV0MbEMYTqDJQrz7Sc/62VVZZtefKIEAlfy3XthM3TBD7Qu2EGL07ZPWncHGzog9bP/+/v1EvvUv9rBlYvk37VsgE1ad/34+zeeL+4UaKywMDAwM7w7/unH46puSKlZUjYz/GRj+z278y2xkbW7Cy4ApyfD1838mQVY0lzLAAx47IDqBDQpJAN4Euv7fFejQAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113AF743D0>"
- ]
- },
- "metadata": {
- "tags": []
- }
- },
- {
- "output_type": "display_data",
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABNklEQVR4nGP8zwAH//8xM6AAJiT2pdTXuCVfbvmGW5LhPwNuSUaGP7gl5ZkuoUqyMDAwMLw78I9PjVVYSu2AP9P//39ZUSQfFP/8/4dVWvER9y6GC08+T+dCltQ9zvD5wZPr9/7uPsMkzuXKgnAhHPz71aJ07eHHH3/gIghVDIwsEv9ERHF6RevDG9yBwMn8GZvk/2+3nvxnEOe4g+rR////////scmBX+nov+OCV/4jA4jkNR73ed4aD3M032GRvME/5ddV4QKRqX+xSH4vFM4/4cmg+eQ/Fsn/X6doiPDy7vmHVfL/3xdzjB2+o8r9h/mTSTxBn4UJ1SNIgfD1iTS6JDzgfxSJ7UEzFWbnr5fZQrP+YJf8FKcpsegHuhw0yti02QI9MGxkYPwPtZmREUMOJokdAAB60yoWf/hgewAAAABJRU5ErkJggg==\n",
- "text/plain": [
- "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113AEF16D0>"
- ]
- },
- "metadata": {
- "tags": []
- }
- },
- {
- "output_type": "display_data",
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABN0lEQVR4nGP8z4AbsCBz/j9lkGZE4jMi6/wTz7AQWTWKToafqMYy4bESXfI/bklG8Yc/cEoyGbz59R8OIA76eBNq2v/7P/c8gJnM6KHHwsDAcC7iN8y133KEYR7lMmdg/M/A8PUxTPWFtCXOMElGTkYWBgYGbg24rawSPEhuQATC/6dPGe7/Q/EKQvJa0GuGH39RPQpz+FNL0zNXW9gyv/1H8gyU/tuicOXf/6tcEs+RJGGB8HOnrwYjA1r4wSQ/3tP5/vXLpV+i7EiSsPh85/JEmJHh/eelfoyYkv8f7nvNwHDy7HkhbK79/+/fv19JFl/+Y3EQAwMjI+PrA36cWP35////m55y15E1/keSfOurdvAvdskfm/3kdv37j1Xy33Jh7ZWo+v7/h6fbV49UedGTIiO+7AAAZ4kCU7KEzEEAAAAASUVORK5CYII=\n",
- "text/plain": [
- "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113B605F10>"
- ]
- },
- "metadata": {
- "tags": []
- }
- },
- {
- "output_type": "display_data",
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABLUlEQVR4nGP8z4AbMKHx/3/7g1vyredeBIcFTfLP7c8w5r+/EMlX90UEISLv/315B5VbepDxPwMDA8OiAmYuiNDfZ0LcUNs/BkEkPzw4D3XHpyZHL0YIU9ydEeYVKP3dW3g51B2MCNcyQgCn47VfUCamVxjY/uH2JwpAlUQLS+RAeH3rnLbIN+ySn6JPsP35/18Kq7Hvz/edPhr75f9/bDoZGLgUGJveb32hgkUnn0H3ob/8vJ8PILT+R4BLMqIXDssyhPyGCTAiuf7fwQiuj9o/5FbC7EK2k8l+2RnFe20WjNiM/f///7+frp6v4Tz04HvyUlaIAbvOf1eNJQ8juKiSXy2lt/7FIfl7Dl/rn//YJf+tFrR7jKwY2Z8MZw/5qDAi8VEkGf4jSzEwAABSseqGZyInRAAAAABJRU5ErkJggg==\n",
- "text/plain": [
- "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F114126CE90>"
- ]
- },
- "metadata": {
- "tags": []
- }
- },
- {
- "output_type": "display_data",
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABJklEQVR4nGP8z4AbMKFy/+GRfF54D7fkh8Wv8RgLAxCXsGCR+Pfx0BtTfZjkz5dQNz/79+Lhv32HTj4TbtFjZGBg/M/AwHDO/xdE8s87Qdb/v3WsnYz5WBmgkh9PQ73wqLBVg0FAhwPmkv/I4JrgCWQukoN+/2NECy645P9bTXf4PVFDCG7sU33JNFMOwZvIxsIk/03k2/PnvpTsW2RJeAh9lzRjFtdBNRUuyfjzCwOHP2powhzEaPBypvvlvWjOhZn/xlZYUIgH1U641/5/e/6bufT8BSFs/mTkVmH4+gKHgxgYGBh+PmdhxCn5/a85D1YH/f///0Mk336UeEBI/ntXzdv9C7vkvyOWvJnf/2OX/GxlMBNNDuHPvxdlRFGcygBNJrgAAEPeDmCQZ6aqAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113AEF1810>"
- ]
- },
- "metadata": {
- "tags": []
- }
- },
- {
- "output_type": "display_data",
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABOElEQVR4nGP8z4AbMKHxXz1AVv0fBfxJt/+B4KHp/P/j7U+cxrKEvHyD2041ppP/cUrKqh34D3EKAwMDC5okq9Sdrz9/3v50ysqTEUny//9/v169/K+01e/Rjz8MUppQnf8YGBi+3bxw+fGLu78Y/nxlCjHXZRfiYmRgYPzP8KfzPsO/iw+4FUUEHAyZX0R0xcIcwsLA8P/lQwYmTzsDXlbGX6f/fGf8zYgcQr9//fr19/////9fRgsKczLO/occQiysrKxMDAwM/5fumr0vmEUcrhPZK38OmvowCkiaYQ0EJoWbuyvmWggjRJDj5IkrPw9D8V84Hyb5bsO3////v9sQJHbxP4bkWYmajXcfnM4QWPQHU/JXjbSYqKiwRj9SXP9nhEXQz/cfH/3n0BJCdiEjKQlscEsCAN5i3onYmdekAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113AEF1510>"
- ]
- },
- "metadata": {
- "tags": []
- }
- },
- {
- "output_type": "display_data",
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABO0lEQVR4nGP8z4AbMOGRo5bks4v/UCRZkNj/ly06zs3w4+eLf4yK7OiSDH9+/Xl+ZObDNwxcG00Qkk+v/WdgYPh/+1XWic9ySSa87DpIxu4v/sfAwMDw69OtpFAxfkaYSYz/GRgYGL68ZWBgYGA4H39IDy4D18nDw8DAwMDwnIkNWY6CQIB55T+2CIBI/j1+8B4DA5P0P2ySF4NYVVkYPi1m/IEq+///////Fwge+/Hz5/sipvif/5EARPKW2Mx/////P8wi+w5ZEmKsrF4vhx3jh/6/aO6CqLkfxS0iIiQbJvsWWScjVOnnM78ZWNTW95wXwuJPXkcGBoY/x5S5cIbQlYOubJh2/v318+fPn7ctLd78x3Dt//Uz/zEwMDxmXymMGUKMCop/GBgY1BI0UAMI6tp/mA5ASGIHADm3qpNJq4xdAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113B605750>"
- ]
- },
- "metadata": {
- "tags": []
- }
- },
- {
- "output_type": "display_data",
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABI0lEQVR4nGP8z4AbMKFy7yz/jsz9jwz+5su//P///79///7+/v//Pwuqzi9KfP9/Xdv1juEXZxMLA5Lk/39/3r153Mdw9BS/CANjNBMDAyPEQf8ffv92/vizq8+4pBlVg12FGBhYmeB2/nARFJSxTpokWfXpy89/MCdAjWWd8IVRjo/9+3Q+HkaERVBJJm0Ghp9ff6B5GuGg36WbGNXe4AiEd5tMC869xqHzwPcq7XuTsev8/5lTnlWPEUUSKRB+3+L/z4VdklH8qyfj/x84dHrs/8XwOxGHJKshA8NXjof/mLF5hYGBgYFD/84/bK5lYPj///+PxziMfbTg6/+nZ9KYsEo+2PWHgasoE8lKWHwyMDD8/8XAwMiKEgqMJKS+gZcEAF56gf6wykc6AAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113AEF1910>"
- ]
- },
- "metadata": {
- "tags": []
- }
- },
- {
- "output_type": "display_data",
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABXklEQVR4nGP8z4AbMOGRI1Xy9z/ckp+TVv3HKfnn6Lq/MDYLuiSHMkTjv98MzFDJv/9ZGBj+//339ef1symvmN49/bnn1H9rxv8MDAwM/yecjhFmuLPl7bMPv98IszH8ZmCUtuW0h+rkvhT3n4FVV95M/1tpgBsDpxazIDcjIyPUhq/P/zJwSLIxMr4z9u1nRnUQEy8vAwMDw3+G/yw8H+GOQ3bt/w9bnrOZfTZjwiL5/0LaPd63PJ/YsYXQ6wSezacOW/6+jYio/zDwb7Hkuf///5/ij/sDE0Lo/H9IXYOB4deET9+whS3vrbmvfx3axmTJjGns/6dJMpp+4gxid+EiSJL/f98sc7Ph6PmDVfL//++zZWM+/8cu+TZfOOrdf2yS/15u8BCpe/8fm+TXfm1Bj4O//2OT/Dtf0O/g1///sUreMF707T86gEp+S1/2B0PuPzSZbPscgpHUGBgAt9BS1wiwXusAAAAASUVORK5CYII=\n",
- "text/plain": [
- "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113AEF16D0>"
- ]
- },
- "metadata": {
- "tags": []
- }
- },
- {
- "output_type": "display_data",
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABQElEQVR4nGP8z4AbMKHwfvxA4TIi6/xbyNDPjMRnQVb5/wkDii2oxuK1Ew2gGMvwn5EBZjAjTPL3W4jAp6uGT86f+MfAwMDAnKwKde25oF8MDAwMDP9e8nD/kWNhYGBg4JymDZV8d+AvAwMDA8OX0mBXDSWIZ9gYGRgY/kPBv68Hdl6SmfXvPxKAOuj/g92rzzEwfcDmlf97nKs4Vu3y/IPml////////9VI5fL3//+v8KMaC9HJbs8kysHAIMT9D4vO/9eEJ/z59y6HMfEPFgcp+7XfF9tyjfMDSsBDJdl6eXf+F1s1GdVUeHz++cnAxOAlvAI5sOGxwsLNzfn9njlyXKNF2X8BLIGAAyBZ8f/zge9oamF++nG/S59L9RayN//DJP8tlpTL3XwbJfT+w73y7KShLDOqoajpFh3gdS0Aq5C/ToYG3GgAAAAASUVORK5CYII=\n",
- "text/plain": [
- "<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28 at 0x7F113B605750>"
- ]
- },
- "metadata": {
- "tags": []
- }
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "baVsWfcSYHVN"
- },
- "source": [
- "## **Step 2: Build the model**"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "gqiOdDLgYOlQ"
- },
- "source": [
- "### Library importation"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "-9pfkqh8gxHD"
- },
- "source": [
- "# Import modules we need\n",
- "import glob, random\n",
- "from collections import OrderedDict\n",
- "\n",
- "import numpy as np\n",
- "\n",
- "try:\n",
- " from qqdm.notebook import qqdm as tqdm\n",
- "except ModuleNotFoundError:\n",
- " from tqdm.auto import tqdm\n",
- "\n",
- "import torch, torch.nn as nn\n",
- "import torch.nn.functional as F\n",
- "from torch.utils.data import DataLoader, Dataset\n",
- "import torchvision.transforms as transforms\n",
- "\n",
- "from PIL import Image\n",
- "from IPython.display import display\n",
- "\n",
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
- "\n",
- "# fix random seeds\n",
- "random_seed = 0\n",
- "random.seed(random_seed)\n",
- "np.random.seed(random_seed)\n",
- "torch.manual_seed(random_seed)\n",
- "if torch.cuda.is_available():\n",
- " torch.cuda.manual_seed_all(random_seed)"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "3TlwLtC1YRT7"
- },
- "source": [
- "### Model Construction Preliminaries\n",
- "\n",
- "Since our task is image classification, we need to build a CNN-based model. \n",
- "However, to implement MAML algorithm, we should adjust some code in `nn.Module`.\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "dFwB3tuEDYfy"
- },
- "source": [
- "Take a look at MAML pseudocode...\n",
- "\n",
- "<img src=\"https://i.imgur.com/9aHlvfX.png\" width=\"50%\" />\n",
- "\n",
- "On the 10-th line, what we take gradients on are those $\\theta$ representing \n",
- "<font color=\"#0CC\">**the original model parameters**</font> (outer loop) instead of those in the \n",
- "<font color=\"#0C0\">**inner loop**</font>, so we need to use `functional_forward` to compute the output \n",
- "logits of input image instead of `forward` in `nn.Module`.\n",
- "\n",
- "The following defines these functions.\n",
- "\n",
- "<!-- 由於在第10行,我們是要對原本的參數 θ 微分,並非 inner-loop (Line5~8) 的 θ' 微分,因此在 inner-loop,我們需要用 functional forward 的方式算出 input image 的 output logits,而不是直接用 nn.module 裡面的 forward(直接對 θ 微分)。在下面我們分別定義了 functional forward 以及 forward 函數。 -->"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "iuYQiPeQYc__"
- },
- "source": [
- "### Model block definition"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "GgFbbKHYg3Hk"
- },
- "source": [
- "def ConvBlock(in_ch: int, out_ch: int):\n",
- " return nn.Sequential(\n",
- " nn.Conv2d(in_ch, out_ch, 3, padding=1),\n",
- " nn.BatchNorm2d(out_ch),\n",
- " nn.ReLU(),\n",
- " nn.MaxPool2d(kernel_size=2, stride=2)\n",
- " )\n",
- "\n",
- "def ConvBlockFunction(x, w, b, w_bn, b_bn):\n",
- " x = F.conv2d(x, w, b, padding=1)\n",
- " x = F.batch_norm(x,\n",
- " running_mean=None,\n",
- " running_var=None,\n",
- " weight=w_bn, bias=b_bn,\n",
- " training=True)\n",
- " x = F.relu(x)\n",
- " x = F.max_pool2d(x, kernel_size=2, stride=2)\n",
- " return x"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "iQEzgWN7fi7B"
- },
- "source": [
- "### Model definition"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "0bFBGEQoHQUW"
- },
- "source": [
- "class Classifier(nn.Module):\n",
- " def __init__(self, in_ch, k_way):\n",
- " super(Classifier, self).__init__()\n",
- " self.conv1 = ConvBlock(in_ch, 64)\n",
- " self.conv2 = ConvBlock(64, 64)\n",
- " self.conv3 = ConvBlock(64, 64)\n",
- " self.conv4 = ConvBlock(64, 64)\n",
- " self.logits = nn.Linear(64, k_way)\n",
- "\n",
- " def forward(self, x):\n",
- " x = self.conv1(x)\n",
- " x = self.conv2(x)\n",
- " x = self.conv3(x)\n",
- " x = self.conv4(x)\n",
- " x = x.view(x.shape[0], -1)\n",
- " x = self.logits(x)\n",
- " return x\n",
- "\n",
- " def functional_forward(self, x, params):\n",
- " '''\n",
- " Arguments:\n",
- " x: input images [batch, 1, 28, 28]\n",
- " params: model parameters, \n",
- " i.e. weights and biases of convolution\n",
- " and weights and biases of \n",
- " batch normalization\n",
- " type is an OrderedDict\n",
- "\n",
- " Arguments:\n",
- " x: input images [batch, 1, 28, 28]\n",
- " params: The model parameters, \n",
- " i.e. weights and biases of convolution \n",
- " and batch normalization layers\n",
- " It's an `OrderedDict`\n",
- " '''\n",
- " for block in [1, 2, 3, 4]:\n",
- " x = ConvBlockFunction(\n",
- " x,\n",
- " params[f'conv{block}.0.weight'],\n",
- " params[f'conv{block}.0.bias'],\n",
- " params.get(f'conv{block}.1.weight'),\n",
- " params.get(f'conv{block}.1.bias'))\n",
- " x = x.view(x.shape[0], -1)\n",
- " x = F.linear(x,\n",
- " params['logits.weight'],\n",
- " params['logits.bias'])\n",
- " return x"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "gmJq_0B9Yj0G"
- },
- "source": [
- "### Create Label\n",
- "\n",
- "This function is used to create labels. \n",
- "In a N-way K-shot few-shot classification problem,\n",
- "each task has `n_way` classes, while there are `k_shot` images for each class. \n",
- "This is a function that creates such labels.\n"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "GQF5vgLvg5aX",
- "outputId": "5df41e04-290c-428b-b06f-cc749f09f027"
- },
- "source": [
- "def create_label(n_way, k_shot):\n",
- " return (torch.arange(n_way)\n",
- " .repeat_interleave(k_shot)\n",
- " .long())\n",
- "\n",
- "# Try to create labels for 5-way 2-shot setting\n",
- "create_label(5, 2)"
- ],
- "execution_count": null,
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])"
- ]
- },
- "metadata": {
- "tags": []
- },
- "execution_count": 9
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "2nCFv9PGw50J"
- },
- "source": [
- "### Accuracy calculation"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "FahDr0xQw50S"
- },
- "source": [
- "def calculate_accuracy(logits, val_label):\n",
- " \"\"\" utility function for accuracy calculation \"\"\"\n",
- " acc = np.asarray([(\n",
- " torch.argmax(logits, -1).cpu().numpy() == val_label.cpu().numpy())]\n",
- " ).mean() \n",
- " return acc"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "9Hl7ro2mYzsI"
- },
- "source": [
- "### Define Dataset\n",
- "\n",
- "Define the dataset. \n",
- "The dataset returns images of a random character, with (`k_shot + q_query`) images, \n",
- "so the size of returned tensor is `[k_shot+q_query, 1, 28, 28]`. \n"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "-tJ2mot9hHPb"
- },
- "source": [
- "class Omniglot(Dataset):\n",
- " def __init__(self, data_dir, k_way, q_query):\n",
- " self.file_list = [f for f in glob.glob(\n",
- " data_dir + \"**/character*\", \n",
- " recursive=True)]\n",
- " self.transform = transforms.Compose(\n",
- " [transforms.ToTensor()])\n",
- " self.n = k_way + q_query\n",
- "\n",
- " def __getitem__(self, idx):\n",
- " sample = np.arange(20)\n",
- "\n",
- " # For random sampling the characters we want.\n",
- " np.random.shuffle(sample) \n",
- " img_path = self.file_list[idx]\n",
- " img_list = [f for f in glob.glob(\n",
- " img_path + \"**/*.png\", recursive=True)]\n",
- " img_list.sort()\n",
- " imgs = [self.transform(\n",
- " Image.open(img_file)) \n",
- " for img_file in img_list]\n",
- " # `k_way + q_query` examples for each character\n",
- " imgs = torch.stack(imgs)[sample[:self.n]] \n",
- " return imgs\n",
- "\n",
- " def __len__(self):\n",
- " return len(self.file_list)"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Gm5iVp90Ylii"
- },
- "source": [
- "## **Step 3: Core MAML**\n",
- "\n",
- "Here is the main Meta Learning algorithm. \n",
- "The algorithm is exactly the same as the paper. \n",
- "What the function does is to update the parameters using \"the data of a meta-batch.\"\n",
- "Here we implement the second-order MAML (inner_train_step = 1), according to [the slides of meta learning in 2019 (p. 13 ~ p.18)](http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf#page=13&view=FitW)\n",
- "\n",
- "As for the mathematical derivation of the first-order version, please refer to [p.25 of the slides in 2019](http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf#page=25&view=FitW).\n",
- "\n",
- "The following is the algorithm with some explanation."
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "KjNxrWW_yNck"
- },
- "source": [
- "def OriginalMAML(\n",
- " model, optimizer, x, n_way, k_shot, q_query, loss_fn,\n",
- " inner_train_step=1, inner_lr=0.4, train=True):\n",
- " criterion, task_loss, task_acc = loss_fn, [], []\n",
- "\n",
- " for meta_batch in x:\n",
- " # Get data\n",
- " support_set = meta_batch[: n_way * k_shot] \n",
- " query_set = meta_batch[n_way * k_shot :] \n",
- " \n",
- " # Copy the params for inner loop\n",
- " fast_weights = OrderedDict(model.named_parameters())\n",
- " \n",
- " ### ---------- INNER TRAIN LOOP ---------- ###\n",
- " for inner_step in range(inner_train_step): \n",
- " # Simply training\n",
- " train_label = create_label(n_way, k_shot) \\\n",
- " .to(device)\n",
- " logits = model.functional_forward(\n",
- " support_set, fast_weights)\n",
- " loss = criterion(logits, train_label)\n",
- " # Inner gradients update! vvvvvvvvvvvvvvvvvvvv #\n",
- " \"\"\" Inner Loop Update \"\"\" #\n",
- " grads = torch.autograd.grad( #\n",
- " loss, fast_weights.values(), #\n",
- " create_graph=True) #\n",
- " # Perform SGD #\n",
- " fast_weights = OrderedDict( #\n",
- " (name, param - inner_lr * grad) #\n",
- " for ((name, param), grad) #\n",
- " in zip(fast_weights.items(), grads)) #\n",
- " # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #\n",
- "\n",
- " ### ---------- INNER VALID LOOP ---------- ###\n",
- " val_label = create_label(n_way, q_query).to(device)\n",
- " \n",
- " # Collect gradients for outer loop\n",
- " logits = model.functional_forward(\n",
- " query_set, fast_weights) \n",
- " loss = criterion(logits, val_label)\n",
- " task_loss.append(loss)\n",
- " task_acc.append(\n",
- " calculate_accuracy(logits, val_label))\n",
- "\n",
- " # Update outer loop\n",
- " model.train()\n",
- " optimizer.zero_grad()\n",
- "\n",
- " meta_batch_loss = torch.stack(task_loss).mean()\n",
- " if train:\n",
- " meta_batch_loss.backward() # <--- May change later!\n",
- " optimizer.step()\n",
- " task_acc = np.mean(task_acc)\n",
- " return meta_batch_loss, task_acc"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "MF5ZahPdxKbp"
- },
- "source": [
- "## Variations of MAML\n",
- "\n",
- "### First-order approximation of MAML (FOMAML)\n",
- "\n",
- "Slightly modify the MAML mentioned earlier, applying first-order approximation to decrease amount of computation.\n",
- "\n",
- "### Almost No Inner Loop (ANIL)\n",
- "\n",
- "The algorithm from [this paper](https://arxiv.org/abs/1909.09157), using the technique of feature reuse to decrease amount of computation."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "qyQ7ZUN4foh-"
- },
- "source": [
- "To finish the modification required, we need to change some blocks of the MAML algorithm. \n",
- "Below, we have replace three parts that may be modified as functions. \n",
- "Please choose to replace the functions with their alternative versions to complete the algorithm."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Ne5cOja0H8H7"
- },
- "source": [
- "### Part 1: Inner loop update"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "LChAX51sIFwi"
- },
- "source": [
- "MAML"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "Aqgb0kEVzQol"
- },
- "source": [
- "def inner_update_MAML(fast_weights, loss, inner_lr):\n",
- " \"\"\" Inner Loop Update \"\"\"\n",
- " grads = torch.autograd.grad(\n",
- " loss, fast_weights.values(), create_graph=True)\n",
- " # Perform SGD\n",
- " fast_weights = OrderedDict(\n",
- " (name, param - inner_lr * grad)\n",
- " for ((name, param), grad) in zip(fast_weights.items(), grads))\n",
- " return fast_weights"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "QnQ_BN-L2Gd7"
- },
- "source": [
- "Alternatives"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "Ug5LIO6V15cd"
- },
- "source": [
- "def inner_update_alt1(fast_weights, loss, inner_lr):\n",
- " grads = torch.autograd.grad(\n",
- " loss, fast_weights.values(), create_graph=False)\n",
- " # Perform SGD\n",
- " fast_weights = OrderedDict(\n",
- " (name, param - inner_lr * grad)\n",
- " for ((name, param), grad) in zip(fast_weights.items(), grads))\n",
- " return fast_weights\n",
- "\n",
- "def inner_update_alt2(fast_weights, loss, inner_lr):\n",
- " grads = torch.autograd.grad(\n",
- " loss, list(fast_weights.values())[-2:], create_graph=True)\n",
- " # Split out the logits\n",
- " for ((name, param), grad) in zip(\n",
- " list(fast_weights.items())[-2:], grads):\n",
- " fast_weights[name] = param - inner_lr * grad\n",
- " return fast_weights"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "1ZfaWPMt164t"
- },
- "source": [
- "### Part 2: Collect gradients"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "-W7zL2nN164u"
- },
- "source": [
- "MAML \n",
- "(Actually do nothing as gradients are computed by PyTorch automatically.)"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "sgcuPPm2zSFL"
- },
- "source": [
- "def collect_gradients_MAML(\n",
- " special_grad: OrderedDict, fast_weights, model, len_data):\n",
- " \"\"\" Actually do nothing (just backwards later) \"\"\"\n",
- " return special_grad"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "2OxEME6l2QOO"
- },
- "source": [
- "Alternatives"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "fWLYwZlM2RZO"
- },
- "source": [
- "def collect_gradients_alt(\n",
- " special_grad: OrderedDict, fast_weights, model, len_data):\n",
- " \"\"\" Special gradient calculation \"\"\"\n",
- " diff = OrderedDict(\n",
- " (name, params - fast_weights[name]) \n",
- " for (name, params) in model.named_parameters())\n",
- " for name in diff:\n",
- " special_grad[name] = special_grad.get(name, 0) + diff[name] / len_data\n",
- " return special_grad"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ahqE-Sf92TID"
- },
- "source": [
- "### Part 3: Outer loop gradients calculation"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "-wr0hSd02TIE"
- },
- "source": [
- "MAML \n",
- "(Simply call PyTorch `backward`.)"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "_hBSQ02xzTXb"
- },
- "source": [
- "def outer_update_MAML(model, meta_batch_loss, grad_tensors):\n",
- " \"\"\" Simply backwards \"\"\"\n",
- " meta_batch_loss.backward()"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Q4zxf6yr2TIE"
- },
- "source": [
- "Alternatives"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "DEyCwYmI2bdC"
- },
- "source": [
- "def outer_update_alt(model, meta_batch_loss, grad_tensors):\n",
- " \"\"\" Replace the gradients\n",
- " with precalculated tensors \"\"\"\n",
- " for (name, params) in model.named_parameters():\n",
- " params.grad = grad_tensors[name]"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "z1jck3KE2g1D"
- },
- "source": [
- "### Complete the algorithm\n",
- "Here we have wrapped the algorithm in `MetaAlgorithmGenerator`. \n",
- "You can get your modified algorithm by filling in like this:\n",
- "```python\n",
- "MyAlgorithm = MetaAlgorithmGenerator(inner_update=inner_update_alt2)\n",
- "```\n",
- "Default the three blocks will be filled with that of `MAML`."
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "XosNxVMDxL6V"
- },
- "source": [
- "def MetaAlgorithmGenerator(\n",
- " inner_update = inner_update_MAML, \n",
- " collect_gradients = collect_gradients_MAML, \n",
- " outer_update = outer_update_MAML):\n",
- "\n",
- " global calculate_accuracy\n",
- "\n",
- " def MetaAlgorithm(\n",
- " model, optimizer, x, n_way, k_shot, q_query, loss_fn,\n",
- " inner_train_step=1, inner_lr=0.4, train=True): \n",
- " criterion = loss_fn\n",
- " task_loss, task_acc = [], []\n",
- " special_grad = OrderedDict() # Added for variants!\n",
- "\n",
- " for meta_batch in x:\n",
- " support_set = meta_batch[: n_way * k_shot] \n",
- " query_set = meta_batch[n_way * k_shot :] \n",
- " \n",
- " fast_weights = OrderedDict(model.named_parameters())\n",
- " \n",
- " ### ---------- INNER TRAIN LOOP ---------- ###\n",
- " for inner_step in range(inner_train_step): \n",
- " train_label = create_label(n_way, k_shot).to(device)\n",
- " logits = model.functional_forward(support_set, fast_weights)\n",
- " loss = criterion(logits, train_label)\n",
- "\n",
- " fast_weights = inner_update(fast_weights, loss, inner_lr)\n",
- "\n",
- " ### ---------- INNER VALID LOOP ---------- ###\n",
- " val_label = create_label(n_way, q_query).to(device)\n",
- " # FIXME: W for val?\n",
- " special_grad = collect_gradients(\n",
- " special_grad, fast_weights, model, len(x))\n",
- " \n",
- " # Collect gradients for outer loop\n",
- " logits = model.functional_forward(query_set, fast_weights) \n",
- " loss = criterion(logits, val_label)\n",
- " task_loss.append(loss)\n",
- " task_acc.append(calculate_accuracy(logits, val_label))\n",
- "\n",
- " # Update outer loop\n",
- " model.train()\n",
- " optimizer.zero_grad()\n",
- "\n",
- " meta_batch_loss = torch.stack(task_loss).mean()\n",
- " if train:\n",
- " # Notice the update part!\n",
- " outer_update(model, meta_batch_loss, special_grad)\n",
- " optimizer.step()\n",
- " task_acc = np.mean(task_acc)\n",
- " return meta_batch_loss, task_acc\n",
- " return MetaAlgorithm"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "jEsPtV-GzbDv",
- "cellView": "form"
- },
- "source": [
- "#@title Here is the answer hidden, please fill in yourself!\n",
- "Give_me_the_answer = True #@param {\"type\": \"boolean\"}\n",
- "\n",
- "def HiddenAnswer():\n",
- " MAML = MetaAlgorithmGenerator()\n",
- " FOMAML = MetaAlgorithmGenerator(inner_update=inner_update_alt1)\n",
- " ANIL = MetaAlgorithmGenerator(inner_update=inner_update_alt2)\n",
- " return MAML, FOMAML, ANIL"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "2P__5N2Yz9O4"
- },
- "source": [
- "# `HiddenAnswer` is hidden in the last cell.\n",
- "if Give_me_the_answer:\n",
- " MAML, FOMAML, ANIL = HiddenAnswer()\n",
- "else: \n",
- " # TODO: Please fill in the function names \\\n",
- " # as the function arguments to finish the algorithm.\n",
- " MAML = MetaAlgorithmGenerator()\n",
- " FOMAML = MetaAlgorithmGenerator()\n",
- " ANIL = MetaAlgorithmGenerator()"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "nBoRBhVlZAST"
- },
- "source": [
- "## **Step 4: Initialization**\n",
- "\n",
- "After defining all components we need, the following initialize a model before training."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Ip-i7aseftUF"
- },
- "source": [
- "<a name=\"hyp\"></a>\n",
- "### Hyperparameters \n",
- "[Go back to top!](#top)"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "0wFHmVcBhE4M"
- },
- "source": [
- "n_way = 5\n",
- "k_shot = 1\n",
- "q_query = 1\n",
- "inner_train_step = 1\n",
- "inner_lr = 0.4\n",
- "meta_lr = 0.001\n",
- "meta_batch_size = 32\n",
- "max_epoch = 30\n",
- "eval_batches = test_batches = 20\n",
- "train_data_path = './Omniglot/images_background/'\n",
- "test_data_path = './Omniglot/images_evaluation/' "
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Uvzo7NVpfu5V"
- },
- "source": [
- "### Dataloader initialization"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "3I13GJavhP0_"
- },
- "source": [
- "def dataloader_init(datasets, num_workers=2):\n",
- " train_set, val_set, test_set = datasets\n",
- " train_loader = DataLoader(train_set,\n",
- " # The \"batch_size\" here is not \\\n",
- " # the meta batch size, but \\\n",
- " # how many different \\\n",
- " # characters in a task, \\\n",
- " # i.e. the \"n_way\" in \\\n",
- " # few-shot classification.\n",
- " batch_size=n_way,\n",
- " num_workers=num_workers,\n",
- " shuffle=True,\n",
- " drop_last=True)\n",
- " val_loader = DataLoader(val_set,\n",
- " batch_size=n_way,\n",
- " num_workers=num_workers,\n",
- " shuffle=True,\n",
- " drop_last=True)\n",
- " test_loader = DataLoader(test_set,\n",
- " batch_size=n_way,\n",
- " num_workers=num_workers,\n",
- " shuffle=True,\n",
- " drop_last=True)\n",
- " train_iter = iter(train_loader)\n",
- " val_iter = iter(val_loader)\n",
- " test_iter = iter(test_loader)\n",
- " return (train_loader, val_loader, test_loader), \\\n",
- " (train_iter, val_iter, test_iter)\n",
- "\n",
- "train_set, val_set = torch.utils.data.random_split(\n",
- " Omniglot(train_data_path, k_shot, q_query), [3200, 656])\n",
- "test_set = Omniglot(test_data_path, k_shot, q_query)\n",
- "\n",
- "(train_loader, val_loader, test_loader), \\\n",
- "(train_iter, val_iter, test_iter) = dataloader_init(\n",
- " (train_set, val_set, test_set))"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "KVund--bfw0e"
- },
- "source": [
- "### Model & optimizer initialization"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "Kxug882ihF2B"
- },
- "source": [
- "def model_init():\n",
- " meta_model = Classifier(1, n_way).to(device)\n",
- " optimizer = torch.optim.Adam(meta_model.parameters(), \n",
- " lr=meta_lr)\n",
- " loss_fn = nn.CrossEntropyLoss().to(device)\n",
- " return meta_model, optimizer, loss_fn\n",
- "\n",
- "meta_model, optimizer, loss_fn = model_init()"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "gj8cLRNLf2zg"
- },
- "source": [
- "### Utility function to get a meta-batch"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "zrkCSsxOhC-N"
- },
- "source": [
- "def get_meta_batch(meta_batch_size,\n",
- " k_shot, q_query, \n",
- " data_loader, iterator):\n",
- " data = []\n",
- " for _ in range(meta_batch_size):\n",
- " try:\n",
- " # a \"task_data\" tensor is representing \\\n",
- " # the data of a task, with size of \\\n",
- " # [n_way, k_shot+q_query, 1, 28, 28]\n",
- " task_data = iterator.next() \n",
- " except StopIteration:\n",
- " iterator = iter(data_loader)\n",
- " task_data = iterator.next()\n",
- " train_data = (task_data[:, :k_shot]\n",
- " .reshape(-1, 1, 28, 28))\n",
- " val_data = (task_data[:, k_shot:]\n",
- " .reshape(-1, 1, 28, 28))\n",
- " task_data = torch.cat(\n",
- " (train_data, val_data), 0)\n",
- " data.append(task_data)\n",
- " return torch.stack(data).to(device), iterator"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "O5JCtob4fyh_"
- },
- "source": [
- "<a name=\"modelsetting\"></a>\n",
- "### Choose the meta learning algorithm\n",
- "[Go back to top!](#top)"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "3av6pAI7OxOP"
- },
- "source": [
- "# You can change this to `FOMAML` or `ANIL`\n",
- "MetaAlgorithm = MAML"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "pWQczA3FwjEG"
- },
- "source": [
- "<a name=\"mainprog\" id=\"mainprog\"></a>\n",
- "## **Step 5: Main program for training & testing**"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "8EirEnaof7ep"
- },
- "source": [
- "### Start training!\n",
- "<a name=\"mainloop\"></a>\n",
- "[Go back to top!](#top)"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "JQZjJrLAhBWw"
- },
- "source": [
- "for epoch in range(max_epoch):\n",
- " print(\"Epoch %d\" % (epoch + 1))\n",
- " train_meta_loss = []\n",
- " train_acc = []\n",
- " # The \"step\" here is a meta-gradinet update step\n",
- " for step in tqdm(range(\n",
- " len(train_loader) // meta_batch_size)): \n",
- " x, train_iter = get_meta_batch(\n",
- " meta_batch_size, k_shot, q_query, \n",
- " train_loader, train_iter)\n",
- " meta_loss, acc = MetaAlgorithm(\n",
- " meta_model, optimizer, x, \n",
- " n_way, k_shot, q_query, loss_fn)\n",
- " train_meta_loss.append(meta_loss.item())\n",
- " train_acc.append(acc)\n",
- " print(\" Loss : \", \"%.3f\" % (np.mean(train_meta_loss)), end='\\t')\n",
- " print(\" Accuracy: \", \"%.3f %%\" % (np.mean(train_acc) * 100))\n",
- "\n",
- " # See the validation accuracy after each epoch.\n",
- " # Early stopping is welcomed to implement.\n",
- " val_acc = []\n",
- " for eval_step in tqdm(range(\n",
- " len(val_loader) // (eval_batches))):\n",
- " x, val_iter = get_meta_batch(\n",
- " eval_batches, k_shot, q_query, \n",
- " val_loader, val_iter)\n",
- " # We update three inner steps when testing.\n",
- " _, acc = MetaAlgorithm(meta_model, optimizer, x, \n",
- " n_way, k_shot, q_query, \n",
- " loss_fn, \n",
- " inner_train_step=3, \n",
- " train=False) \n",
- " val_acc.append(acc)\n",
- " print(\" Validation accuracy: \", \"%.3f %%\" % (np.mean(val_acc) * 100))"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "u5Ew8-POf9sw"
- },
- "source": [
- "### Testing the result"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "CYN_zGB3g_5_"
- },
- "source": [
- "test_acc = []\n",
- "for test_step in tqdm(range(\n",
- " len(test_loader) // (test_batches))):\n",
- " x, test_iter = get_meta_batch(\n",
- " test_batches, k_shot, q_query, \n",
- " test_loader, test_iter)\n",
- " # When testing, we update 3 inner-steps\n",
- " _, acc = MetaAlgorithm(meta_model, optimizer, x, \n",
- " n_way, k_shot, q_query, loss_fn, \n",
- " inner_train_step=3, train=False)\n",
- " test_acc.append(acc)\n",
- "print(\" Testing accuracy: \", \"%.3f %%\" % (np.mean(test_acc) * 100))"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "rtD8X3RLf-6w"
- },
- "source": [
- "## **Reference**\n",
- "1. Chelsea Finn, Pieter Abbeel, & Sergey Levine. (2017). [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.](https://arxiv.org/abs/1909.09157)\n",
- "1. Aniruddh Raghu, Maithra Raghu, Samy Bengio, & Oriol Vinyals. (2020). [Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML.](https://arxiv.org/abs/1909.09157)"
- ]
- }
- ]
- }
|