diff --git a/13 Meta Learning/作业HW15/HW15.pdf b/13 Meta Learning/作业HW15/HW15.pdf new file mode 100644 index 0000000..29af013 Binary files /dev/null and b/13 Meta Learning/作业HW15/HW15.pdf differ diff --git a/13 Meta Learning/作业HW15/ML2021_HW15_Meta_Learning.ipynb b/13 Meta Learning/作业HW15/ML2021_HW15_Meta_Learning.ipynb new file mode 100644 index 0000000..afad0b5 --- /dev/null +++ b/13 Meta Learning/作业HW15/ML2021_HW15_Meta_Learning.ipynb @@ -0,0 +1,1304 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "ML2021 HW15 Meta Learning.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "wzVBe3h7Xh-2" + }, + "source": [ + "\n", + "# **HW15 Meta Learning: Few-shot Classification**\n", + "\n", + "Please mail to ntu-ml-2021spring-ta@googlegroups.com if you have any questions.\n", + "\n", + "Useful Links:\n", + "1. [Go to hyperparameter setting.](#hyp)\n", + "1. [Go to meta algorithm setting.](#modelsetting)\n", + "1. [Go to main loop.](#mainloop)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RdpzIMG6XsGK" + }, + "source": [ + "## **Step 0: Check GPU**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "zjjHsZbaL7SV" + }, + "source": [ + "!nvidia-smi" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "gWpc6vW3MQhv" + }, + "source": [ + "#@markdown ### Install `qqdm`\n", + "# Check if installed\n", + "try:\n", + " import qqdm\n", + "except:\n", + " ! pip install qqdm > /dev/null 2>&1\n", + "print(\"Done!\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bQ3wvyjnXwGX" + }, + "source": [ + "## **Step 1: Download Data**\n", + "\n", + "Run the cell to download data, which has been pre-processed by TAs. \n", + "The dataset has been augmented, so extra data augmentation is not required.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "g7Gt4Jucug41" + }, + "source": [ + "workspace_dir = '.'\n", + "\n", + "# gdown is a package that downloads files from google drive\n", + "!gdown --id 1FLDrQ0k-iJ-mk8ors0WItqvwgu0w9J0U \\\n", + " --output \"{workspace_dir}/Omniglot.tar.gz\"" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AMGFHI9XX9ms" + }, + "source": [ + "### Decompress the dataset\n", + "\n", + "Since the dataset is quite large, please wait and observe the main program [here](#mainprogram). \n", + "You can come back here later by [*back to pre-process*](#preprocess)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "AvvlAQBUug42" + }, + "source": [ + "# Use `tar' command to decompress\n", + "!tar -zxf \"{workspace_dir}/Omniglot.tar.gz\" \\\n", + " -C \"{workspace_dir}/\"" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "T5P9eT0fYDqV" + }, + "source": [ + "### Data Preview\n", + "\n", + "Just look at some data in the dataset." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 297 + }, + "id": "7VtgHLurYE5x", + "outputId": "961971b2-8b61-4d03-c06e-571a778ab52d" + }, + "source": [ + "from PIL import Image\n", + "from IPython.display import display\n", + "for i in range(10, 20):\n", + " im = Image.open(\"Omniglot/images_background/Japanese_(hiragana).0/character13/0500_\" + str (i) + \".png\")\n", + " display(im)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABEUlEQVR4nGP8z4AbMOGRQ5F8uekLmux/BNjHs+0/CkDWqcKJppMFQv1jYGBgYGb69w/FMsb/DAwMDD+bnjMwMHxbb6nIyMDAwMBWqoyk8/+zRwwMDD//vGFmYGBgYBDhQNbJ8Pvvp7t/OL0nBENMZUG1s2vFsz+C71nYsPnz5QSz/Vt1f//DGgj/GV0MbEMYTqDJQrz7Sc/62VVZZtefKIEAlfy3XthM3TBD7Qu2EGL07ZPWncHGzog9bP/+/v1EvvUv9rBlYvk37VsgE1ad/34+zeeL+4UaKywMDAwM7w7/unH46puSKlZUjYz/GRj+z278y2xkbW7Cy4ApyfD1838mQVY0lzLAAx47IDqBDQpJAN4Euv7fFejQAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABNklEQVR4nGP8zwAH//8xM6AAJiT2pdTXuCVfbvmGW5LhPwNuSUaGP7gl5ZkuoUqyMDAwMLw78I9PjVVYSu2AP9P//39ZUSQfFP/8/4dVWvER9y6GC08+T+dCltQ9zvD5wZPr9/7uPsMkzuXKgnAhHPz71aJ07eHHH3/gIghVDIwsEv9ERHF6RevDG9yBwMn8GZvk/2+3nvxnEOe4g+rR////////scmBX+nov+OCV/4jA4jkNR73ed4aD3M032GRvME/5ddV4QKRqX+xSH4vFM4/4cmg+eQ/Fsn/X6doiPDy7vmHVfL/3xdzjB2+o8r9h/mTSTxBn4UJ1SNIgfD1iTS6JDzgfxSJ7UEzFWbnr5fZQrP+YJf8FKcpsegHuhw0yti02QI9MGxkYPwPtZmREUMOJokdAAB60yoWf/hgewAAAABJRU5ErkJggg==\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABN0lEQVR4nGP8z4AbsCBz/j9lkGZE4jMi6/wTz7AQWTWKToafqMYy4bESXfI/bklG8Yc/cEoyGbz59R8OIA76eBNq2v/7P/c8gJnM6KHHwsDAcC7iN8y133KEYR7lMmdg/M/A8PUxTPWFtCXOMElGTkYWBgYGbg24rawSPEhuQATC/6dPGe7/Q/EKQvJa0GuGH39RPQpz+FNL0zNXW9gyv/1H8gyU/tuicOXf/6tcEs+RJGGB8HOnrwYjA1r4wSQ/3tP5/vXLpV+i7EiSsPh85/JEmJHh/eelfoyYkv8f7nvNwHDy7HkhbK79/+/fv19JFl/+Y3EQAwMjI+PrA36cWP35////m55y15E1/keSfOurdvAvdskfm/3kdv37j1Xy33Jh7ZWo+v7/h6fbV49UedGTIiO+7AAAZ4kCU7KEzEEAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABLUlEQVR4nGP8z4AbMKHx/3/7g1vyredeBIcFTfLP7c8w5r+/EMlX90UEISLv/315B5VbepDxPwMDA8OiAmYuiNDfZ0LcUNs/BkEkPzw4D3XHpyZHL0YIU9ydEeYVKP3dW3g51B2MCNcyQgCn47VfUCamVxjY/uH2JwpAlUQLS+RAeH3rnLbIN+ySn6JPsP35/18Kq7Hvz/edPhr75f9/bDoZGLgUGJveb32hgkUnn0H3ob/8vJ8PILT+R4BLMqIXDssyhPyGCTAiuf7fwQiuj9o/5FbC7EK2k8l+2RnFe20WjNiM/f///7+frp6v4Tz04HvyUlaIAbvOf1eNJQ8juKiSXy2lt/7FIfl7Dl/rn//YJf+tFrR7jKwY2Z8MZw/5qDAi8VEkGf4jSzEwAABSseqGZyInRAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABJklEQVR4nGP8z4AbMKFy/+GRfF54D7fkh8Wv8RgLAxCXsGCR+Pfx0BtTfZjkz5dQNz/79+Lhv32HTj4TbtFjZGBg/M/AwHDO/xdE8s87Qdb/v3WsnYz5WBmgkh9PQ73wqLBVg0FAhwPmkv/I4JrgCWQukoN+/2NECy645P9bTXf4PVFDCG7sU33JNFMOwZvIxsIk/03k2/PnvpTsW2RJeAh9lzRjFtdBNRUuyfjzCwOHP2powhzEaPBypvvlvWjOhZn/xlZYUIgH1U641/5/e/6bufT8BSFs/mTkVmH4+gKHgxgYGBh+PmdhxCn5/a85D1YH/f///0Mk336UeEBI/ntXzdv9C7vkvyOWvJnf/2OX/GxlMBNNDuHPvxdlRFGcygBNJrgAAEPeDmCQZ6aqAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABOElEQVR4nGP8z4AbMKHxXz1AVv0fBfxJt/+B4KHp/P/j7U+cxrKEvHyD2041ppP/cUrKqh34D3EKAwMDC5okq9Sdrz9/3v50ysqTEUny//9/v169/K+01e/Rjz8MUppQnf8YGBi+3bxw+fGLu78Y/nxlCjHXZRfiYmRgYPzP8KfzPsO/iw+4FUUEHAyZX0R0xcIcwsLA8P/lQwYmTzsDXlbGX6f/fGf8zYgcQr9//fr19/////9fRgsKczLO/occQiysrKxMDAwM/5fumr0vmEUcrhPZK38OmvowCkiaYQ0EJoWbuyvmWggjRJDj5IkrPw9D8V84Hyb5bsO3////v9sQJHbxP4bkWYmajXcfnM4QWPQHU/JXjbSYqKiwRj9SXP9nhEXQz/cfH/3n0BJCdiEjKQlscEsCAN5i3onYmdekAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABO0lEQVR4nGP8z4AbMOGRo5bks4v/UCRZkNj/ly06zs3w4+eLf4yK7OiSDH9+/Xl+ZObDNwxcG00Qkk+v/WdgYPh/+1XWic9ySSa87DpIxu4v/sfAwMDw69OtpFAxfkaYSYz/GRgYGL68ZWBgYGA4H39IDy4D18nDw8DAwMDwnIkNWY6CQIB55T+2CIBI/j1+8B4DA5P0P2ySF4NYVVkYPi1m/IEq+///////Fwge+/Hz5/sipvif/5EARPKW2Mx/////P8wi+w5ZEmKsrF4vhx3jh/6/aO6CqLkfxS0iIiQbJvsWWScjVOnnM78ZWNTW95wXwuJPXkcGBoY/x5S5cIbQlYOubJh2/v318+fPn7ctLd78x3Dt//Uz/zEwMDxmXymMGUKMCop/GBgY1BI0UAMI6tp/mA5ASGIHADm3qpNJq4xdAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABI0lEQVR4nGP8z4AbMKFy7yz/jsz9jwz+5su//P///79///7+/v//Pwuqzi9KfP9/Xdv1juEXZxMLA5Lk/39/3r153Mdw9BS/CANjNBMDAyPEQf8ffv92/vizq8+4pBlVg12FGBhYmeB2/nARFJSxTpokWfXpy89/MCdAjWWd8IVRjo/9+3Q+HkaERVBJJm0Ghp9ff6B5GuGg36WbGNXe4AiEd5tMC869xqHzwPcq7XuTsev8/5lTnlWPEUUSKRB+3+L/z4VdklH8qyfj/x84dHrs/8XwOxGHJKshA8NXjof/mLF5hYGBgYFD/84/bK5lYPj///+PxziMfbTg6/+nZ9KYsEo+2PWHgasoE8lKWHwyMDD8/8XAwMiKEgqMJKS+gZcEAF56gf6wykc6AAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABXklEQVR4nGP8z4AbMOGRI1Xy9z/ckp+TVv3HKfnn6Lq/MDYLuiSHMkTjv98MzFDJv/9ZGBj+//339ef1symvmN49/bnn1H9rxv8MDAwM/yecjhFmuLPl7bMPv98IszH8ZmCUtuW0h+rkvhT3n4FVV95M/1tpgBsDpxazIDcjIyPUhq/P/zJwSLIxMr4z9u1nRnUQEy8vAwMDw3+G/yw8H+GOQ3bt/w9bnrOZfTZjwiL5/0LaPd63PJ/YsYXQ6wSezacOW/6+jYio/zDwb7Hkuf///5/ij/sDE0Lo/H9IXYOB4deET9+whS3vrbmvfx3axmTJjGns/6dJMpp+4gxid+EiSJL/f98sc7Ph6PmDVfL//++zZWM+/8cu+TZfOOrdf2yS/15u8BCpe/8fm+TXfm1Bj4O//2OT/Dtf0O/g1///sUreMF707T86gEp+S1/2B0PuPzSZbPscgpHUGBgAt9BS1wiwXusAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABQElEQVR4nGP8z4AbMKHwfvxA4TIi6/xbyNDPjMRnQVb5/wkDii2oxuK1Ew2gGMvwn5EBZjAjTPL3W4jAp6uGT86f+MfAwMDAnKwKde25oF8MDAwMDP9e8nD/kWNhYGBg4JymDZV8d+AvAwMDA8OX0mBXDSWIZ9gYGRgY/kPBv68Hdl6SmfXvPxKAOuj/g92rzzEwfcDmlf97nKs4Vu3y/IPml////////9VI5fL3//+v8KMaC9HJbs8kysHAIMT9D4vO/9eEJ/z59y6HMfEPFgcp+7XfF9tyjfMDSsBDJdl6eXf+F1s1GdVUeHz++cnAxOAlvAI5sOGxwsLNzfn9njlyXKNF2X8BLIGAAyBZ8f/zge9oamF++nG/S59L9RayN//DJP8tlpTL3XwbJfT+w73y7KShLDOqoajpFh3gdS0Aq5C/ToYG3GgAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "baVsWfcSYHVN" + }, + "source": [ + "## **Step 2: Build the model**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gqiOdDLgYOlQ" + }, + "source": [ + "### Library importation" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "-9pfkqh8gxHD" + }, + "source": [ + "# Import modules we need\n", + "import glob, random\n", + "from collections import OrderedDict\n", + "\n", + "import numpy as np\n", + "\n", + "try:\n", + " from qqdm.notebook import qqdm as tqdm\n", + "except ModuleNotFoundError:\n", + " from tqdm.auto import tqdm\n", + "\n", + "import torch, torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader, Dataset\n", + "import torchvision.transforms as transforms\n", + "\n", + "from PIL import Image\n", + "from IPython.display import display\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "# fix random seeds\n", + "random_seed = 0\n", + "random.seed(random_seed)\n", + "np.random.seed(random_seed)\n", + "torch.manual_seed(random_seed)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(random_seed)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3TlwLtC1YRT7" + }, + "source": [ + "### Model Construction Preliminaries\n", + "\n", + "Since our task is image classification, we need to build a CNN-based model. \n", + "However, to implement MAML algorithm, we should adjust some code in `nn.Module`.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dFwB3tuEDYfy" + }, + "source": [ + "Take a look at MAML pseudocode...\n", + "\n", + "\n", + "\n", + "On the 10-th line, what we take gradients on are those $\\theta$ representing \n", + "**the original model parameters** (outer loop) instead of those in the \n", + "**inner loop**, so we need to use `functional_forward` to compute the output \n", + "logits of input image instead of `forward` in `nn.Module`.\n", + "\n", + "The following defines these functions.\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iuYQiPeQYc__" + }, + "source": [ + "### Model block definition" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GgFbbKHYg3Hk" + }, + "source": [ + "def ConvBlock(in_ch: int, out_ch: int):\n", + " return nn.Sequential(\n", + " nn.Conv2d(in_ch, out_ch, 3, padding=1),\n", + " nn.BatchNorm2d(out_ch),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(kernel_size=2, stride=2)\n", + " )\n", + "\n", + "def ConvBlockFunction(x, w, b, w_bn, b_bn):\n", + " x = F.conv2d(x, w, b, padding=1)\n", + " x = F.batch_norm(x,\n", + " running_mean=None,\n", + " running_var=None,\n", + " weight=w_bn, bias=b_bn,\n", + " training=True)\n", + " x = F.relu(x)\n", + " x = F.max_pool2d(x, kernel_size=2, stride=2)\n", + " return x" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iQEzgWN7fi7B" + }, + "source": [ + "### Model definition" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0bFBGEQoHQUW" + }, + "source": [ + "class Classifier(nn.Module):\n", + " def __init__(self, in_ch, k_way):\n", + " super(Classifier, self).__init__()\n", + " self.conv1 = ConvBlock(in_ch, 64)\n", + " self.conv2 = ConvBlock(64, 64)\n", + " self.conv3 = ConvBlock(64, 64)\n", + " self.conv4 = ConvBlock(64, 64)\n", + " self.logits = nn.Linear(64, k_way)\n", + "\n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = self.conv2(x)\n", + " x = self.conv3(x)\n", + " x = self.conv4(x)\n", + " x = x.view(x.shape[0], -1)\n", + " x = self.logits(x)\n", + " return x\n", + "\n", + " def functional_forward(self, x, params):\n", + " '''\n", + " Arguments:\n", + " x: input images [batch, 1, 28, 28]\n", + " params: model parameters, \n", + " i.e. weights and biases of convolution\n", + " and weights and biases of \n", + " batch normalization\n", + " type is an OrderedDict\n", + "\n", + " Arguments:\n", + " x: input images [batch, 1, 28, 28]\n", + " params: The model parameters, \n", + " i.e. weights and biases of convolution \n", + " and batch normalization layers\n", + " It's an `OrderedDict`\n", + " '''\n", + " for block in [1, 2, 3, 4]:\n", + " x = ConvBlockFunction(\n", + " x,\n", + " params[f'conv{block}.0.weight'],\n", + " params[f'conv{block}.0.bias'],\n", + " params.get(f'conv{block}.1.weight'),\n", + " params.get(f'conv{block}.1.bias'))\n", + " x = x.view(x.shape[0], -1)\n", + " x = F.linear(x,\n", + " params['logits.weight'],\n", + " params['logits.bias'])\n", + " return x" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gmJq_0B9Yj0G" + }, + "source": [ + "### Create Label\n", + "\n", + "This function is used to create labels. \n", + "In a N-way K-shot few-shot classification problem,\n", + "each task has `n_way` classes, while there are `k_shot` images for each class. \n", + "This is a function that creates such labels.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GQF5vgLvg5aX", + "outputId": "5df41e04-290c-428b-b06f-cc749f09f027" + }, + "source": [ + "def create_label(n_way, k_shot):\n", + " return (torch.arange(n_way)\n", + " .repeat_interleave(k_shot)\n", + " .long())\n", + "\n", + "# Try to create labels for 5-way 2-shot setting\n", + "create_label(5, 2)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 9 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2nCFv9PGw50J" + }, + "source": [ + "### Accuracy calculation" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "FahDr0xQw50S" + }, + "source": [ + "def calculate_accuracy(logits, val_label):\n", + " \"\"\" utility function for accuracy calculation \"\"\"\n", + " acc = np.asarray([(\n", + " torch.argmax(logits, -1).cpu().numpy() == val_label.cpu().numpy())]\n", + " ).mean() \n", + " return acc" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9Hl7ro2mYzsI" + }, + "source": [ + "### Define Dataset\n", + "\n", + "Define the dataset. \n", + "The dataset returns images of a random character, with (`k_shot + q_query`) images, \n", + "so the size of returned tensor is `[k_shot+q_query, 1, 28, 28]`. \n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "-tJ2mot9hHPb" + }, + "source": [ + "class Omniglot(Dataset):\n", + " def __init__(self, data_dir, k_way, q_query):\n", + " self.file_list = [f for f in glob.glob(\n", + " data_dir + \"**/character*\", \n", + " recursive=True)]\n", + " self.transform = transforms.Compose(\n", + " [transforms.ToTensor()])\n", + " self.n = k_way + q_query\n", + "\n", + " def __getitem__(self, idx):\n", + " sample = np.arange(20)\n", + "\n", + " # For random sampling the characters we want.\n", + " np.random.shuffle(sample) \n", + " img_path = self.file_list[idx]\n", + " img_list = [f for f in glob.glob(\n", + " img_path + \"**/*.png\", recursive=True)]\n", + " img_list.sort()\n", + " imgs = [self.transform(\n", + " Image.open(img_file)) \n", + " for img_file in img_list]\n", + " # `k_way + q_query` examples for each character\n", + " imgs = torch.stack(imgs)[sample[:self.n]] \n", + " return imgs\n", + "\n", + " def __len__(self):\n", + " return len(self.file_list)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Gm5iVp90Ylii" + }, + "source": [ + "## **Step 3: Core MAML**\n", + "\n", + "Here is the main Meta Learning algorithm. \n", + "The algorithm is exactly the same as the paper. \n", + "What the function does is to update the parameters using \"the data of a meta-batch.\"\n", + "Here we implement the second-order MAML (inner_train_step = 1), according to [the slides of meta learning in 2019 (p. 13 ~ p.18)](http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf#page=13&view=FitW)\n", + "\n", + "As for the mathematical derivation of the first-order version, please refer to [p.25 of the slides in 2019](http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf#page=25&view=FitW).\n", + "\n", + "The following is the algorithm with some explanation." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "KjNxrWW_yNck" + }, + "source": [ + "def OriginalMAML(\n", + " model, optimizer, x, n_way, k_shot, q_query, loss_fn,\n", + " inner_train_step=1, inner_lr=0.4, train=True):\n", + " criterion, task_loss, task_acc = loss_fn, [], []\n", + "\n", + " for meta_batch in x:\n", + " # Get data\n", + " support_set = meta_batch[: n_way * k_shot] \n", + " query_set = meta_batch[n_way * k_shot :] \n", + " \n", + " # Copy the params for inner loop\n", + " fast_weights = OrderedDict(model.named_parameters())\n", + " \n", + " ### ---------- INNER TRAIN LOOP ---------- ###\n", + " for inner_step in range(inner_train_step): \n", + " # Simply training\n", + " train_label = create_label(n_way, k_shot) \\\n", + " .to(device)\n", + " logits = model.functional_forward(\n", + " support_set, fast_weights)\n", + " loss = criterion(logits, train_label)\n", + " # Inner gradients update! vvvvvvvvvvvvvvvvvvvv #\n", + " \"\"\" Inner Loop Update \"\"\" #\n", + " grads = torch.autograd.grad( #\n", + " loss, fast_weights.values(), #\n", + " create_graph=True) #\n", + " # Perform SGD #\n", + " fast_weights = OrderedDict( #\n", + " (name, param - inner_lr * grad) #\n", + " for ((name, param), grad) #\n", + " in zip(fast_weights.items(), grads)) #\n", + " # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #\n", + "\n", + " ### ---------- INNER VALID LOOP ---------- ###\n", + " val_label = create_label(n_way, q_query).to(device)\n", + " \n", + " # Collect gradients for outer loop\n", + " logits = model.functional_forward(\n", + " query_set, fast_weights) \n", + " loss = criterion(logits, val_label)\n", + " task_loss.append(loss)\n", + " task_acc.append(\n", + " calculate_accuracy(logits, val_label))\n", + "\n", + " # Update outer loop\n", + " model.train()\n", + " optimizer.zero_grad()\n", + "\n", + " meta_batch_loss = torch.stack(task_loss).mean()\n", + " if train:\n", + " meta_batch_loss.backward() # <--- May change later!\n", + " optimizer.step()\n", + " task_acc = np.mean(task_acc)\n", + " return meta_batch_loss, task_acc" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MF5ZahPdxKbp" + }, + "source": [ + "## Variations of MAML\n", + "\n", + "### First-order approximation of MAML (FOMAML)\n", + "\n", + "Slightly modify the MAML mentioned earlier, applying first-order approximation to decrease amount of computation.\n", + "\n", + "### Almost No Inner Loop (ANIL)\n", + "\n", + "The algorithm from [this paper](https://arxiv.org/abs/1909.09157), using the technique of feature reuse to decrease amount of computation." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qyQ7ZUN4foh-" + }, + "source": [ + "To finish the modification required, we need to change some blocks of the MAML algorithm. \n", + "Below, we have replace three parts that may be modified as functions. \n", + "Please choose to replace the functions with their alternative versions to complete the algorithm." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ne5cOja0H8H7" + }, + "source": [ + "### Part 1: Inner loop update" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LChAX51sIFwi" + }, + "source": [ + "MAML" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Aqgb0kEVzQol" + }, + "source": [ + "def inner_update_MAML(fast_weights, loss, inner_lr):\n", + " \"\"\" Inner Loop Update \"\"\"\n", + " grads = torch.autograd.grad(\n", + " loss, fast_weights.values(), create_graph=True)\n", + " # Perform SGD\n", + " fast_weights = OrderedDict(\n", + " (name, param - inner_lr * grad)\n", + " for ((name, param), grad) in zip(fast_weights.items(), grads))\n", + " return fast_weights" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QnQ_BN-L2Gd7" + }, + "source": [ + "Alternatives" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Ug5LIO6V15cd" + }, + "source": [ + "def inner_update_alt1(fast_weights, loss, inner_lr):\n", + " grads = torch.autograd.grad(\n", + " loss, fast_weights.values(), create_graph=False)\n", + " # Perform SGD\n", + " fast_weights = OrderedDict(\n", + " (name, param - inner_lr * grad)\n", + " for ((name, param), grad) in zip(fast_weights.items(), grads))\n", + " return fast_weights\n", + "\n", + "def inner_update_alt2(fast_weights, loss, inner_lr):\n", + " grads = torch.autograd.grad(\n", + " loss, list(fast_weights.values())[-2:], create_graph=True)\n", + " # Split out the logits\n", + " for ((name, param), grad) in zip(\n", + " list(fast_weights.items())[-2:], grads):\n", + " fast_weights[name] = param - inner_lr * grad\n", + " return fast_weights" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1ZfaWPMt164t" + }, + "source": [ + "### Part 2: Collect gradients" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-W7zL2nN164u" + }, + "source": [ + "MAML \n", + "(Actually do nothing as gradients are computed by PyTorch automatically.)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "sgcuPPm2zSFL" + }, + "source": [ + "def collect_gradients_MAML(\n", + " special_grad: OrderedDict, fast_weights, model, len_data):\n", + " \"\"\" Actually do nothing (just backwards later) \"\"\"\n", + " return special_grad" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2OxEME6l2QOO" + }, + "source": [ + "Alternatives" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fWLYwZlM2RZO" + }, + "source": [ + "def collect_gradients_alt(\n", + " special_grad: OrderedDict, fast_weights, model, len_data):\n", + " \"\"\" Special gradient calculation \"\"\"\n", + " diff = OrderedDict(\n", + " (name, params - fast_weights[name]) \n", + " for (name, params) in model.named_parameters())\n", + " for name in diff:\n", + " special_grad[name] = special_grad.get(name, 0) + diff[name] / len_data\n", + " return special_grad" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ahqE-Sf92TID" + }, + "source": [ + "### Part 3: Outer loop gradients calculation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-wr0hSd02TIE" + }, + "source": [ + "MAML \n", + "(Simply call PyTorch `backward`.)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_hBSQ02xzTXb" + }, + "source": [ + "def outer_update_MAML(model, meta_batch_loss, grad_tensors):\n", + " \"\"\" Simply backwards \"\"\"\n", + " meta_batch_loss.backward()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Q4zxf6yr2TIE" + }, + "source": [ + "Alternatives" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "DEyCwYmI2bdC" + }, + "source": [ + "def outer_update_alt(model, meta_batch_loss, grad_tensors):\n", + " \"\"\" Replace the gradients\n", + " with precalculated tensors \"\"\"\n", + " for (name, params) in model.named_parameters():\n", + " params.grad = grad_tensors[name]" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z1jck3KE2g1D" + }, + "source": [ + "### Complete the algorithm\n", + "Here we have wrapped the algorithm in `MetaAlgorithmGenerator`. \n", + "You can get your modified algorithm by filling in like this:\n", + "```python\n", + "MyAlgorithm = MetaAlgorithmGenerator(inner_update=inner_update_alt2)\n", + "```\n", + "Default the three blocks will be filled with that of `MAML`." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "XosNxVMDxL6V" + }, + "source": [ + "def MetaAlgorithmGenerator(\n", + " inner_update = inner_update_MAML, \n", + " collect_gradients = collect_gradients_MAML, \n", + " outer_update = outer_update_MAML):\n", + "\n", + " global calculate_accuracy\n", + "\n", + " def MetaAlgorithm(\n", + " model, optimizer, x, n_way, k_shot, q_query, loss_fn,\n", + " inner_train_step=1, inner_lr=0.4, train=True): \n", + " criterion = loss_fn\n", + " task_loss, task_acc = [], []\n", + " special_grad = OrderedDict() # Added for variants!\n", + "\n", + " for meta_batch in x:\n", + " support_set = meta_batch[: n_way * k_shot] \n", + " query_set = meta_batch[n_way * k_shot :] \n", + " \n", + " fast_weights = OrderedDict(model.named_parameters())\n", + " \n", + " ### ---------- INNER TRAIN LOOP ---------- ###\n", + " for inner_step in range(inner_train_step): \n", + " train_label = create_label(n_way, k_shot).to(device)\n", + " logits = model.functional_forward(support_set, fast_weights)\n", + " loss = criterion(logits, train_label)\n", + "\n", + " fast_weights = inner_update(fast_weights, loss, inner_lr)\n", + "\n", + " ### ---------- INNER VALID LOOP ---------- ###\n", + " val_label = create_label(n_way, q_query).to(device)\n", + " # FIXME: W for val?\n", + " special_grad = collect_gradients(\n", + " special_grad, fast_weights, model, len(x))\n", + " \n", + " # Collect gradients for outer loop\n", + " logits = model.functional_forward(query_set, fast_weights) \n", + " loss = criterion(logits, val_label)\n", + " task_loss.append(loss)\n", + " task_acc.append(calculate_accuracy(logits, val_label))\n", + "\n", + " # Update outer loop\n", + " model.train()\n", + " optimizer.zero_grad()\n", + "\n", + " meta_batch_loss = torch.stack(task_loss).mean()\n", + " if train:\n", + " # Notice the update part!\n", + " outer_update(model, meta_batch_loss, special_grad)\n", + " optimizer.step()\n", + " task_acc = np.mean(task_acc)\n", + " return meta_batch_loss, task_acc\n", + " return MetaAlgorithm" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "jEsPtV-GzbDv", + "cellView": "form" + }, + "source": [ + "#@title Here is the answer hidden, please fill in yourself!\n", + "Give_me_the_answer = True #@param {\"type\": \"boolean\"}\n", + "\n", + "def HiddenAnswer():\n", + " MAML = MetaAlgorithmGenerator()\n", + " FOMAML = MetaAlgorithmGenerator(inner_update=inner_update_alt1)\n", + " ANIL = MetaAlgorithmGenerator(inner_update=inner_update_alt2)\n", + " return MAML, FOMAML, ANIL" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "2P__5N2Yz9O4" + }, + "source": [ + "# `HiddenAnswer` is hidden in the last cell.\n", + "if Give_me_the_answer:\n", + " MAML, FOMAML, ANIL = HiddenAnswer()\n", + "else: \n", + " # TODO: Please fill in the function names \\\n", + " # as the function arguments to finish the algorithm.\n", + " MAML = MetaAlgorithmGenerator()\n", + " FOMAML = MetaAlgorithmGenerator()\n", + " ANIL = MetaAlgorithmGenerator()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nBoRBhVlZAST" + }, + "source": [ + "## **Step 4: Initialization**\n", + "\n", + "After defining all components we need, the following initialize a model before training." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ip-i7aseftUF" + }, + "source": [ + "\n", + "### Hyperparameters \n", + "[Go back to top!](#top)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0wFHmVcBhE4M" + }, + "source": [ + "n_way = 5\n", + "k_shot = 1\n", + "q_query = 1\n", + "inner_train_step = 1\n", + "inner_lr = 0.4\n", + "meta_lr = 0.001\n", + "meta_batch_size = 32\n", + "max_epoch = 30\n", + "eval_batches = test_batches = 20\n", + "train_data_path = './Omniglot/images_background/'\n", + "test_data_path = './Omniglot/images_evaluation/' " + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Uvzo7NVpfu5V" + }, + "source": [ + "### Dataloader initialization" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3I13GJavhP0_" + }, + "source": [ + "def dataloader_init(datasets, num_workers=2):\n", + " train_set, val_set, test_set = datasets\n", + " train_loader = DataLoader(train_set,\n", + " # The \"batch_size\" here is not \\\n", + " # the meta batch size, but \\\n", + " # how many different \\\n", + " # characters in a task, \\\n", + " # i.e. the \"n_way\" in \\\n", + " # few-shot classification.\n", + " batch_size=n_way,\n", + " num_workers=num_workers,\n", + " shuffle=True,\n", + " drop_last=True)\n", + " val_loader = DataLoader(val_set,\n", + " batch_size=n_way,\n", + " num_workers=num_workers,\n", + " shuffle=True,\n", + " drop_last=True)\n", + " test_loader = DataLoader(test_set,\n", + " batch_size=n_way,\n", + " num_workers=num_workers,\n", + " shuffle=True,\n", + " drop_last=True)\n", + " train_iter = iter(train_loader)\n", + " val_iter = iter(val_loader)\n", + " test_iter = iter(test_loader)\n", + " return (train_loader, val_loader, test_loader), \\\n", + " (train_iter, val_iter, test_iter)\n", + "\n", + "train_set, val_set = torch.utils.data.random_split(\n", + " Omniglot(train_data_path, k_shot, q_query), [3200, 656])\n", + "test_set = Omniglot(test_data_path, k_shot, q_query)\n", + "\n", + "(train_loader, val_loader, test_loader), \\\n", + "(train_iter, val_iter, test_iter) = dataloader_init(\n", + " (train_set, val_set, test_set))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KVund--bfw0e" + }, + "source": [ + "### Model & optimizer initialization" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Kxug882ihF2B" + }, + "source": [ + "def model_init():\n", + " meta_model = Classifier(1, n_way).to(device)\n", + " optimizer = torch.optim.Adam(meta_model.parameters(), \n", + " lr=meta_lr)\n", + " loss_fn = nn.CrossEntropyLoss().to(device)\n", + " return meta_model, optimizer, loss_fn\n", + "\n", + "meta_model, optimizer, loss_fn = model_init()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gj8cLRNLf2zg" + }, + "source": [ + "### Utility function to get a meta-batch" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "zrkCSsxOhC-N" + }, + "source": [ + "def get_meta_batch(meta_batch_size,\n", + " k_shot, q_query, \n", + " data_loader, iterator):\n", + " data = []\n", + " for _ in range(meta_batch_size):\n", + " try:\n", + " # a \"task_data\" tensor is representing \\\n", + " # the data of a task, with size of \\\n", + " # [n_way, k_shot+q_query, 1, 28, 28]\n", + " task_data = iterator.next() \n", + " except StopIteration:\n", + " iterator = iter(data_loader)\n", + " task_data = iterator.next()\n", + " train_data = (task_data[:, :k_shot]\n", + " .reshape(-1, 1, 28, 28))\n", + " val_data = (task_data[:, k_shot:]\n", + " .reshape(-1, 1, 28, 28))\n", + " task_data = torch.cat(\n", + " (train_data, val_data), 0)\n", + " data.append(task_data)\n", + " return torch.stack(data).to(device), iterator" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O5JCtob4fyh_" + }, + "source": [ + "\n", + "### Choose the meta learning algorithm\n", + "[Go back to top!](#top)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3av6pAI7OxOP" + }, + "source": [ + "# You can change this to `FOMAML` or `ANIL`\n", + "MetaAlgorithm = MAML" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pWQczA3FwjEG" + }, + "source": [ + "\n", + "## **Step 5: Main program for training & testing**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8EirEnaof7ep" + }, + "source": [ + "### Start training!\n", + "\n", + "[Go back to top!](#top)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "JQZjJrLAhBWw" + }, + "source": [ + "for epoch in range(max_epoch):\n", + " print(\"Epoch %d\" % (epoch + 1))\n", + " train_meta_loss = []\n", + " train_acc = []\n", + " # The \"step\" here is a meta-gradinet update step\n", + " for step in tqdm(range(\n", + " len(train_loader) // meta_batch_size)): \n", + " x, train_iter = get_meta_batch(\n", + " meta_batch_size, k_shot, q_query, \n", + " train_loader, train_iter)\n", + " meta_loss, acc = MetaAlgorithm(\n", + " meta_model, optimizer, x, \n", + " n_way, k_shot, q_query, loss_fn)\n", + " train_meta_loss.append(meta_loss.item())\n", + " train_acc.append(acc)\n", + " print(\" Loss : \", \"%.3f\" % (np.mean(train_meta_loss)), end='\\t')\n", + " print(\" Accuracy: \", \"%.3f %%\" % (np.mean(train_acc) * 100))\n", + "\n", + " # See the validation accuracy after each epoch.\n", + " # Early stopping is welcomed to implement.\n", + " val_acc = []\n", + " for eval_step in tqdm(range(\n", + " len(val_loader) // (eval_batches))):\n", + " x, val_iter = get_meta_batch(\n", + " eval_batches, k_shot, q_query, \n", + " val_loader, val_iter)\n", + " # We update three inner steps when testing.\n", + " _, acc = MetaAlgorithm(meta_model, optimizer, x, \n", + " n_way, k_shot, q_query, \n", + " loss_fn, \n", + " inner_train_step=3, \n", + " train=False) \n", + " val_acc.append(acc)\n", + " print(\" Validation accuracy: \", \"%.3f %%\" % (np.mean(val_acc) * 100))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u5Ew8-POf9sw" + }, + "source": [ + "### Testing the result" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "CYN_zGB3g_5_" + }, + "source": [ + "test_acc = []\n", + "for test_step in tqdm(range(\n", + " len(test_loader) // (test_batches))):\n", + " x, test_iter = get_meta_batch(\n", + " test_batches, k_shot, q_query, \n", + " test_loader, test_iter)\n", + " # When testing, we update 3 inner-steps\n", + " _, acc = MetaAlgorithm(meta_model, optimizer, x, \n", + " n_way, k_shot, q_query, loss_fn, \n", + " inner_train_step=3, train=False)\n", + " test_acc.append(acc)\n", + "print(\" Testing accuracy: \", \"%.3f %%\" % (np.mean(test_acc) * 100))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rtD8X3RLf-6w" + }, + "source": [ + "## **Reference**\n", + "1. Chelsea Finn, Pieter Abbeel, & Sergey Levine. (2017). [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.](https://arxiv.org/abs/1909.09157)\n", + "1. Aniruddh Raghu, Maithra Raghu, Samy Bengio, & Oriol Vinyals. (2020). [Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML.](https://arxiv.org/abs/1909.09157)" + ] + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md index cf1ddbd..9f67ed9 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,8 @@ 2021/06/11/ 更新Meta Learning 及 HW13&HW14 +2021/06/18/ 更新HW15,随着李老师课程结语视频上传,2021机器学习基本结束啦。 + #------------------------------------------------------------------# B站视频地址:https://www.bilibili.com/video/BV1Wv411h7kN#reply4197445138 diff --git a/范例/HW15/HW15.ipynb b/范例/HW15/HW15.ipynb new file mode 100644 index 0000000..a65b831 --- /dev/null +++ b/范例/HW15/HW15.ipynb @@ -0,0 +1,1316 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "ML2021 HW15 Meta Learning.ipynb", + "provenance": [], + "collapsed_sections": [], + "authorship_tag": "ABX9TyOOROjqnnPfh0u5TL/TR6kp", + "include_colab_link": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wzVBe3h7Xh-2" + }, + "source": [ + "\n", + "# **HW15 Meta Learning: Few-shot Classification**\n", + "\n", + "Please mail to ntu-ml-2021spring-ta@googlegroups.com if you have any questions.\n", + "\n", + "Useful Links:\n", + "1. [Go to hyperparameter setting.](#hyp)\n", + "1. [Go to meta algorithm setting.](#modelsetting)\n", + "1. [Go to main loop.](#mainloop)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RdpzIMG6XsGK" + }, + "source": [ + "## **Step 0: Check GPU**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "zjjHsZbaL7SV" + }, + "source": [ + "!nvidia-smi" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "gWpc6vW3MQhv" + }, + "source": [ + "#@markdown ### Install `qqdm`\n", + "# Check if installed\n", + "try:\n", + " import qqdm\n", + "except:\n", + " ! pip install qqdm > /dev/null 2>&1\n", + "print(\"Done!\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bQ3wvyjnXwGX" + }, + "source": [ + "## **Step 1: Download Data**\n", + "\n", + "Run the cell to download data, which has been pre-processed by TAs. \n", + "The dataset has been augmented, so extra data augmentation is not required.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "g7Gt4Jucug41" + }, + "source": [ + "workspace_dir = '.'\n", + "\n", + "# gdown is a package that downloads files from google drive\n", + "!gdown --id 1FLDrQ0k-iJ-mk8ors0WItqvwgu0w9J0U \\\n", + " --output \"{workspace_dir}/Omniglot.tar.gz\"" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AMGFHI9XX9ms" + }, + "source": [ + "### Decompress the dataset\n", + "\n", + "Since the dataset is quite large, please wait and observe the main program [here](#mainprogram). \n", + "You can come back here later by [*back to pre-process*](#preprocess)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "AvvlAQBUug42" + }, + "source": [ + "# Use `tar' command to decompress\n", + "!tar -zxf \"{workspace_dir}/Omniglot.tar.gz\" \\\n", + " -C \"{workspace_dir}/\"" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "T5P9eT0fYDqV" + }, + "source": [ + "### Data Preview\n", + "\n", + "Just look at some data in the dataset." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 297 + }, + "id": "7VtgHLurYE5x", + "outputId": "961971b2-8b61-4d03-c06e-571a778ab52d" + }, + "source": [ + "from PIL import Image\n", + "from IPython.display import display\n", + "for i in range(10, 20):\n", + " im = Image.open(\"Omniglot/images_background/Japanese_(hiragana).0/character13/0500_\" + str (i) + \".png\")\n", + " display(im)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABEUlEQVR4nGP8z4AbMOGRQ5F8uekLmux/BNjHs+0/CkDWqcKJppMFQv1jYGBgYGb69w/FMsb/DAwMDD+bnjMwMHxbb6nIyMDAwMBWqoyk8/+zRwwMDD//vGFmYGBgYBDhQNbJ8Pvvp7t/OL0nBENMZUG1s2vFsz+C71nYsPnz5QSz/Vt1f//DGgj/GV0MbEMYTqDJQrz7Sc/62VVZZtefKIEAlfy3XthM3TBD7Qu2EGL07ZPWncHGzog9bP/+/v1EvvUv9rBlYvk37VsgE1ad/34+zeeL+4UaKywMDAwM7w7/unH46puSKlZUjYz/GRj+z278y2xkbW7Cy4ApyfD1838mQVY0lzLAAx47IDqBDQpJAN4Euv7fFejQAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABNklEQVR4nGP8zwAH//8xM6AAJiT2pdTXuCVfbvmGW5LhPwNuSUaGP7gl5ZkuoUqyMDAwMLw78I9PjVVYSu2AP9P//39ZUSQfFP/8/4dVWvER9y6GC08+T+dCltQ9zvD5wZPr9/7uPsMkzuXKgnAhHPz71aJ07eHHH3/gIghVDIwsEv9ERHF6RevDG9yBwMn8GZvk/2+3nvxnEOe4g+rR////////scmBX+nov+OCV/4jA4jkNR73ed4aD3M032GRvME/5ddV4QKRqX+xSH4vFM4/4cmg+eQ/Fsn/X6doiPDy7vmHVfL/3xdzjB2+o8r9h/mTSTxBn4UJ1SNIgfD1iTS6JDzgfxSJ7UEzFWbnr5fZQrP+YJf8FKcpsegHuhw0yti02QI9MGxkYPwPtZmREUMOJokdAAB60yoWf/hgewAAAABJRU5ErkJggg==\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABN0lEQVR4nGP8z4AbsCBz/j9lkGZE4jMi6/wTz7AQWTWKToafqMYy4bESXfI/bklG8Yc/cEoyGbz59R8OIA76eBNq2v/7P/c8gJnM6KHHwsDAcC7iN8y133KEYR7lMmdg/M/A8PUxTPWFtCXOMElGTkYWBgYGbg24rawSPEhuQATC/6dPGe7/Q/EKQvJa0GuGH39RPQpz+FNL0zNXW9gyv/1H8gyU/tuicOXf/6tcEs+RJGGB8HOnrwYjA1r4wSQ/3tP5/vXLpV+i7EiSsPh85/JEmJHh/eelfoyYkv8f7nvNwHDy7HkhbK79/+/fv19JFl/+Y3EQAwMjI+PrA36cWP35////m55y15E1/keSfOurdvAvdskfm/3kdv37j1Xy33Jh7ZWo+v7/h6fbV49UedGTIiO+7AAAZ4kCU7KEzEEAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABLUlEQVR4nGP8z4AbMKHx/3/7g1vyredeBIcFTfLP7c8w5r+/EMlX90UEISLv/315B5VbepDxPwMDA8OiAmYuiNDfZ0LcUNs/BkEkPzw4D3XHpyZHL0YIU9ydEeYVKP3dW3g51B2MCNcyQgCn47VfUCamVxjY/uH2JwpAlUQLS+RAeH3rnLbIN+ySn6JPsP35/18Kq7Hvz/edPhr75f9/bDoZGLgUGJveb32hgkUnn0H3ob/8vJ8PILT+R4BLMqIXDssyhPyGCTAiuf7fwQiuj9o/5FbC7EK2k8l+2RnFe20WjNiM/f///7+frp6v4Tz04HvyUlaIAbvOf1eNJQ8juKiSXy2lt/7FIfl7Dl/rn//YJf+tFrR7jKwY2Z8MZw/5qDAi8VEkGf4jSzEwAABSseqGZyInRAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABJklEQVR4nGP8z4AbMKFy/+GRfF54D7fkh8Wv8RgLAxCXsGCR+Pfx0BtTfZjkz5dQNz/79+Lhv32HTj4TbtFjZGBg/M/AwHDO/xdE8s87Qdb/v3WsnYz5WBmgkh9PQ73wqLBVg0FAhwPmkv/I4JrgCWQukoN+/2NECy645P9bTXf4PVFDCG7sU33JNFMOwZvIxsIk/03k2/PnvpTsW2RJeAh9lzRjFtdBNRUuyfjzCwOHP2powhzEaPBypvvlvWjOhZn/xlZYUIgH1U641/5/e/6bufT8BSFs/mTkVmH4+gKHgxgYGBh+PmdhxCn5/a85D1YH/f///0Mk336UeEBI/ntXzdv9C7vkvyOWvJnf/2OX/GxlMBNNDuHPvxdlRFGcygBNJrgAAEPeDmCQZ6aqAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABOElEQVR4nGP8z4AbMKHxXz1AVv0fBfxJt/+B4KHp/P/j7U+cxrKEvHyD2041ppP/cUrKqh34D3EKAwMDC5okq9Sdrz9/3v50ysqTEUny//9/v169/K+01e/Rjz8MUppQnf8YGBi+3bxw+fGLu78Y/nxlCjHXZRfiYmRgYPzP8KfzPsO/iw+4FUUEHAyZX0R0xcIcwsLA8P/lQwYmTzsDXlbGX6f/fGf8zYgcQr9//fr19/////9fRgsKczLO/occQiysrKxMDAwM/5fumr0vmEUcrhPZK38OmvowCkiaYQ0EJoWbuyvmWggjRJDj5IkrPw9D8V84Hyb5bsO3////v9sQJHbxP4bkWYmajXcfnM4QWPQHU/JXjbSYqKiwRj9SXP9nhEXQz/cfH/3n0BJCdiEjKQlscEsCAN5i3onYmdekAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABO0lEQVR4nGP8z4AbMOGRo5bks4v/UCRZkNj/ly06zs3w4+eLf4yK7OiSDH9+/Xl+ZObDNwxcG00Qkk+v/WdgYPh/+1XWic9ySSa87DpIxu4v/sfAwMDw69OtpFAxfkaYSYz/GRgYGL68ZWBgYGA4H39IDy4D18nDw8DAwMDwnIkNWY6CQIB55T+2CIBI/j1+8B4DA5P0P2ySF4NYVVkYPi1m/IEq+///////Fwge+/Hz5/sipvif/5EARPKW2Mx/////P8wi+w5ZEmKsrF4vhx3jh/6/aO6CqLkfxS0iIiQbJvsWWScjVOnnM78ZWNTW95wXwuJPXkcGBoY/x5S5cIbQlYOubJh2/v318+fPn7ctLd78x3Dt//Uz/zEwMDxmXymMGUKMCop/GBgY1BI0UAMI6tp/mA5ASGIHADm3qpNJq4xdAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABI0lEQVR4nGP8z4AbMKFy7yz/jsz9jwz+5su//P///79///7+/v//Pwuqzi9KfP9/Xdv1juEXZxMLA5Lk/39/3r153Mdw9BS/CANjNBMDAyPEQf8ffv92/vizq8+4pBlVg12FGBhYmeB2/nARFJSxTpokWfXpy89/MCdAjWWd8IVRjo/9+3Q+HkaERVBJJm0Ghp9ff6B5GuGg36WbGNXe4AiEd5tMC869xqHzwPcq7XuTsev8/5lTnlWPEUUSKRB+3+L/z4VdklH8qyfj/x84dHrs/8XwOxGHJKshA8NXjof/mLF5hYGBgYFD/84/bK5lYPj///+PxziMfbTg6/+nZ9KYsEo+2PWHgasoE8lKWHwyMDD8/8XAwMiKEgqMJKS+gZcEAF56gf6wykc6AAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABXklEQVR4nGP8z4AbMOGRI1Xy9z/ckp+TVv3HKfnn6Lq/MDYLuiSHMkTjv98MzFDJv/9ZGBj+//339ef1symvmN49/bnn1H9rxv8MDAwM/yecjhFmuLPl7bMPv98IszH8ZmCUtuW0h+rkvhT3n4FVV95M/1tpgBsDpxazIDcjIyPUhq/P/zJwSLIxMr4z9u1nRnUQEy8vAwMDw3+G/yw8H+GOQ3bt/w9bnrOZfTZjwiL5/0LaPd63PJ/YsYXQ6wSezacOW/6+jYio/zDwb7Hkuf///5/ij/sDE0Lo/H9IXYOB4deET9+whS3vrbmvfx3axmTJjGns/6dJMpp+4gxid+EiSJL/f98sc7Ph6PmDVfL//++zZWM+/8cu+TZfOOrdf2yS/15u8BCpe/8fm+TXfm1Bj4O//2OT/Dtf0O/g1///sUreMF707T86gEp+S1/2B0PuPzSZbPscgpHUGBgAt9BS1wiwXusAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABQElEQVR4nGP8z4AbMKHwfvxA4TIi6/xbyNDPjMRnQVb5/wkDii2oxuK1Ew2gGMvwn5EBZjAjTPL3W4jAp6uGT86f+MfAwMDAnKwKde25oF8MDAwMDP9e8nD/kWNhYGBg4JymDZV8d+AvAwMDA8OX0mBXDSWIZ9gYGRgY/kPBv68Hdl6SmfXvPxKAOuj/g92rzzEwfcDmlf97nKs4Vu3y/IPml////////9VI5fL3//+v8KMaC9HJbs8kysHAIMT9D4vO/9eEJ/z59y6HMfEPFgcp+7XfF9tyjfMDSsBDJdl6eXf+F1s1GdVUeHz++cnAxOAlvAI5sOGxwsLNzfn9njlyXKNF2X8BLIGAAyBZ8f/zge9oamF++nG/S59L9RayN//DJP8tlpTL3XwbJfT+w73y7KShLDOqoajpFh3gdS0Aq5C/ToYG3GgAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "baVsWfcSYHVN" + }, + "source": [ + "## **Step 2: Build the model**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gqiOdDLgYOlQ" + }, + "source": [ + "### Library importation" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "-9pfkqh8gxHD" + }, + "source": [ + "# Import modules we need\n", + "import glob, random\n", + "from collections import OrderedDict\n", + "\n", + "import numpy as np\n", + "\n", + "try:\n", + " from qqdm.notebook import qqdm as tqdm\n", + "except ModuleNotFoundError:\n", + " from tqdm.auto import tqdm\n", + "\n", + "import torch, torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader, Dataset\n", + "import torchvision.transforms as transforms\n", + "\n", + "from PIL import Image\n", + "from IPython.display import display\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "# fix random seeds\n", + "random_seed = 0\n", + "random.seed(random_seed)\n", + "np.random.seed(random_seed)\n", + "torch.manual_seed(random_seed)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(random_seed)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3TlwLtC1YRT7" + }, + "source": [ + "### Model Construction Preliminaries\n", + "\n", + "Since our task is image classification, we need to build a CNN-based model. \n", + "However, to implement MAML algorithm, we should adjust some code in `nn.Module`.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dFwB3tuEDYfy" + }, + "source": [ + "Take a look at MAML pseudocode...\n", + "\n", + "\n", + "\n", + "On the 10-th line, what we take gradients on are those $\\theta$ representing \n", + "**the original model parameters** (outer loop) instead of those in the \n", + "**inner loop**, so we need to use `functional_forward` to compute the output \n", + "logits of input image instead of `forward` in `nn.Module`.\n", + "\n", + "The following defines these functions.\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iuYQiPeQYc__" + }, + "source": [ + "### Model block definition" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GgFbbKHYg3Hk" + }, + "source": [ + "def ConvBlock(in_ch: int, out_ch: int):\n", + " return nn.Sequential(\n", + " nn.Conv2d(in_ch, out_ch, 3, padding=1),\n", + " nn.BatchNorm2d(out_ch),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(kernel_size=2, stride=2)\n", + " )\n", + "\n", + "def ConvBlockFunction(x, w, b, w_bn, b_bn):\n", + " x = F.conv2d(x, w, b, padding=1)\n", + " x = F.batch_norm(x,\n", + " running_mean=None,\n", + " running_var=None,\n", + " weight=w_bn, bias=b_bn,\n", + " training=True)\n", + " x = F.relu(x)\n", + " x = F.max_pool2d(x, kernel_size=2, stride=2)\n", + " return x" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iQEzgWN7fi7B" + }, + "source": [ + "### Model definition" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0bFBGEQoHQUW" + }, + "source": [ + "class Classifier(nn.Module):\n", + " def __init__(self, in_ch, k_way):\n", + " super(Classifier, self).__init__()\n", + " self.conv1 = ConvBlock(in_ch, 64)\n", + " self.conv2 = ConvBlock(64, 64)\n", + " self.conv3 = ConvBlock(64, 64)\n", + " self.conv4 = ConvBlock(64, 64)\n", + " self.logits = nn.Linear(64, k_way)\n", + "\n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = self.conv2(x)\n", + " x = self.conv3(x)\n", + " x = self.conv4(x)\n", + " x = x.view(x.shape[0], -1)\n", + " x = self.logits(x)\n", + " return x\n", + "\n", + " def functional_forward(self, x, params):\n", + " '''\n", + " Arguments:\n", + " x: input images [batch, 1, 28, 28]\n", + " params: model parameters, \n", + " i.e. weights and biases of convolution\n", + " and weights and biases of \n", + " batch normalization\n", + " type is an OrderedDict\n", + "\n", + " Arguments:\n", + " x: input images [batch, 1, 28, 28]\n", + " params: The model parameters, \n", + " i.e. weights and biases of convolution \n", + " and batch normalization layers\n", + " It's an `OrderedDict`\n", + " '''\n", + " for block in [1, 2, 3, 4]:\n", + " x = ConvBlockFunction(\n", + " x,\n", + " params[f'conv{block}.0.weight'],\n", + " params[f'conv{block}.0.bias'],\n", + " params.get(f'conv{block}.1.weight'),\n", + " params.get(f'conv{block}.1.bias'))\n", + " x = x.view(x.shape[0], -1)\n", + " x = F.linear(x,\n", + " params['logits.weight'],\n", + " params['logits.bias'])\n", + " return x" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gmJq_0B9Yj0G" + }, + "source": [ + "### Create Label\n", + "\n", + "This function is used to create labels. \n", + "In a N-way K-shot few-shot classification problem,\n", + "each task has `n_way` classes, while there are `k_shot` images for each class. \n", + "This is a function that creates such labels.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GQF5vgLvg5aX", + "outputId": "5df41e04-290c-428b-b06f-cc749f09f027" + }, + "source": [ + "def create_label(n_way, k_shot):\n", + " return (torch.arange(n_way)\n", + " .repeat_interleave(k_shot)\n", + " .long())\n", + "\n", + "# Try to create labels for 5-way 2-shot setting\n", + "create_label(5, 2)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 9 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2nCFv9PGw50J" + }, + "source": [ + "### Accuracy calculation" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "FahDr0xQw50S" + }, + "source": [ + "def calculate_accuracy(logits, val_label):\n", + " \"\"\" utility function for accuracy calculation \"\"\"\n", + " acc = np.asarray([(\n", + " torch.argmax(logits, -1).cpu().numpy() == val_label.cpu().numpy())]\n", + " ).mean() \n", + " return acc" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9Hl7ro2mYzsI" + }, + "source": [ + "### Define Dataset\n", + "\n", + "Define the dataset. \n", + "The dataset returns images of a random character, with (`k_shot + q_query`) images, \n", + "so the size of returned tensor is `[k_shot+q_query, 1, 28, 28]`. \n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "-tJ2mot9hHPb" + }, + "source": [ + "class Omniglot(Dataset):\n", + " def __init__(self, data_dir, k_way, q_query):\n", + " self.file_list = [f for f in glob.glob(\n", + " data_dir + \"**/character*\", \n", + " recursive=True)]\n", + " self.transform = transforms.Compose(\n", + " [transforms.ToTensor()])\n", + " self.n = k_way + q_query\n", + "\n", + " def __getitem__(self, idx):\n", + " sample = np.arange(20)\n", + "\n", + " # For random sampling the characters we want.\n", + " np.random.shuffle(sample) \n", + " img_path = self.file_list[idx]\n", + " img_list = [f for f in glob.glob(\n", + " img_path + \"**/*.png\", recursive=True)]\n", + " img_list.sort()\n", + " imgs = [self.transform(\n", + " Image.open(img_file)) \n", + " for img_file in img_list]\n", + " # `k_way + q_query` examples for each character\n", + " imgs = torch.stack(imgs)[sample[:self.n]] \n", + " return imgs\n", + "\n", + " def __len__(self):\n", + " return len(self.file_list)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Gm5iVp90Ylii" + }, + "source": [ + "## **Step 3: Core MAML**\n", + "\n", + "Here is the main Meta Learning algorithm. \n", + "The algorithm is exactly the same as the paper. \n", + "What the function does is to update the parameters using \"the data of a meta-batch.\"\n", + "Here we implement the second-order MAML (inner_train_step = 1), according to [the slides of meta learning in 2019 (p. 13 ~ p.18)](http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf#page=13&view=FitW)\n", + "\n", + "As for the mathematical derivation of the first-order version, please refer to [p.25 of the slides in 2019](http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf#page=25&view=FitW).\n", + "\n", + "The following is the algorithm with some explanation." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "KjNxrWW_yNck" + }, + "source": [ + "def OriginalMAML(\n", + " model, optimizer, x, n_way, k_shot, q_query, loss_fn,\n", + " inner_train_step=1, inner_lr=0.4, train=True):\n", + " criterion, task_loss, task_acc = loss_fn, [], []\n", + "\n", + " for meta_batch in x:\n", + " # Get data\n", + " support_set = meta_batch[: n_way * k_shot] \n", + " query_set = meta_batch[n_way * k_shot :] \n", + " \n", + " # Copy the params for inner loop\n", + " fast_weights = OrderedDict(model.named_parameters())\n", + " \n", + " ### ---------- INNER TRAIN LOOP ---------- ###\n", + " for inner_step in range(inner_train_step): \n", + " # Simply training\n", + " train_label = create_label(n_way, k_shot) \\\n", + " .to(device)\n", + " logits = model.functional_forward(\n", + " support_set, fast_weights)\n", + " loss = criterion(logits, train_label)\n", + " # Inner gradients update! vvvvvvvvvvvvvvvvvvvv #\n", + " \"\"\" Inner Loop Update \"\"\" #\n", + " grads = torch.autograd.grad( #\n", + " loss, fast_weights.values(), #\n", + " create_graph=True) #\n", + " # Perform SGD #\n", + " fast_weights = OrderedDict( #\n", + " (name, param - inner_lr * grad) #\n", + " for ((name, param), grad) #\n", + " in zip(fast_weights.items(), grads)) #\n", + " # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #\n", + "\n", + " ### ---------- INNER VALID LOOP ---------- ###\n", + " val_label = create_label(n_way, q_query).to(device)\n", + " \n", + " # Collect gradients for outer loop\n", + " logits = model.functional_forward(\n", + " query_set, fast_weights) \n", + " loss = criterion(logits, val_label)\n", + " task_loss.append(loss)\n", + " task_acc.append(\n", + " calculate_accuracy(logits, val_label))\n", + "\n", + " # Update outer loop\n", + " model.train()\n", + " optimizer.zero_grad()\n", + "\n", + " meta_batch_loss = torch.stack(task_loss).mean()\n", + " if train:\n", + " meta_batch_loss.backward() # <--- May change later!\n", + " optimizer.step()\n", + " task_acc = np.mean(task_acc)\n", + " return meta_batch_loss, task_acc" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MF5ZahPdxKbp" + }, + "source": [ + "## Variations of MAML\n", + "\n", + "### First-order approximation of MAML (FOMAML)\n", + "\n", + "Slightly modify the MAML mentioned earlier, applying first-order approximation to decrease amount of computation.\n", + "\n", + "### Almost No Inner Loop (ANIL)\n", + "\n", + "The algorithm from [this paper](https://arxiv.org/abs/1909.09157), using the technique of feature reuse to decrease amount of computation." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qyQ7ZUN4foh-" + }, + "source": [ + "To finish the modification required, we need to change some blocks of the MAML algorithm. \n", + "Below, we have replace three parts that may be modified as functions. \n", + "Please choose to replace the functions with their alternative versions to complete the algorithm." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ne5cOja0H8H7" + }, + "source": [ + "### Part 1: Inner loop update" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LChAX51sIFwi" + }, + "source": [ + "MAML" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Aqgb0kEVzQol" + }, + "source": [ + "def inner_update_MAML(fast_weights, loss, inner_lr):\n", + " \"\"\" Inner Loop Update \"\"\"\n", + " grads = torch.autograd.grad(\n", + " loss, fast_weights.values(), create_graph=True)\n", + " # Perform SGD\n", + " fast_weights = OrderedDict(\n", + " (name, param - inner_lr * grad)\n", + " for ((name, param), grad) in zip(fast_weights.items(), grads))\n", + " return fast_weights" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QnQ_BN-L2Gd7" + }, + "source": [ + "Alternatives" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Ug5LIO6V15cd" + }, + "source": [ + "def inner_update_alt1(fast_weights, loss, inner_lr):\n", + " grads = torch.autograd.grad(\n", + " loss, fast_weights.values(), create_graph=False)\n", + " # Perform SGD\n", + " fast_weights = OrderedDict(\n", + " (name, param - inner_lr * grad)\n", + " for ((name, param), grad) in zip(fast_weights.items(), grads))\n", + " return fast_weights\n", + "\n", + "def inner_update_alt2(fast_weights, loss, inner_lr):\n", + " grads = torch.autograd.grad(\n", + " loss, list(fast_weights.values())[-2:], create_graph=True)\n", + " # Split out the logits\n", + " for ((name, param), grad) in zip(\n", + " list(fast_weights.items())[-2:], grads):\n", + " fast_weights[name] = param - inner_lr * grad\n", + " return fast_weights" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1ZfaWPMt164t" + }, + "source": [ + "### Part 2: Collect gradients" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-W7zL2nN164u" + }, + "source": [ + "MAML \n", + "(Actually do nothing as gradients are computed by PyTorch automatically.)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "sgcuPPm2zSFL" + }, + "source": [ + "def collect_gradients_MAML(\n", + " special_grad: OrderedDict, fast_weights, model, len_data):\n", + " \"\"\" Actually do nothing (just backwards later) \"\"\"\n", + " return special_grad" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2OxEME6l2QOO" + }, + "source": [ + "Alternatives" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fWLYwZlM2RZO" + }, + "source": [ + "def collect_gradients_alt(\n", + " special_grad: OrderedDict, fast_weights, model, len_data):\n", + " \"\"\" Special gradient calculation \"\"\"\n", + " diff = OrderedDict(\n", + " (name, params - fast_weights[name]) \n", + " for (name, params) in model.named_parameters())\n", + " for name in diff:\n", + " special_grad[name] = special_grad.get(name, 0) + diff[name] / len_data\n", + " return special_grad" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ahqE-Sf92TID" + }, + "source": [ + "### Part 3: Outer loop gradients calculation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-wr0hSd02TIE" + }, + "source": [ + "MAML \n", + "(Simply call PyTorch `backward`.)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_hBSQ02xzTXb" + }, + "source": [ + "def outer_update_MAML(model, meta_batch_loss, grad_tensors):\n", + " \"\"\" Simply backwards \"\"\"\n", + " meta_batch_loss.backward()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Q4zxf6yr2TIE" + }, + "source": [ + "Alternatives" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "DEyCwYmI2bdC" + }, + "source": [ + "def outer_update_alt(model, meta_batch_loss, grad_tensors):\n", + " \"\"\" Replace the gradients\n", + " with precalculated tensors \"\"\"\n", + " for (name, params) in model.named_parameters():\n", + " params.grad = grad_tensors[name]" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z1jck3KE2g1D" + }, + "source": [ + "### Complete the algorithm\n", + "Here we have wrapped the algorithm in `MetaAlgorithmGenerator`. \n", + "You can get your modified algorithm by filling in like this:\n", + "```python\n", + "MyAlgorithm = MetaAlgorithmGenerator(inner_update=inner_update_alt2)\n", + "```\n", + "Default the three blocks will be filled with that of `MAML`." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "XosNxVMDxL6V" + }, + "source": [ + "def MetaAlgorithmGenerator(\n", + " inner_update = inner_update_MAML, \n", + " collect_gradients = collect_gradients_MAML, \n", + " outer_update = outer_update_MAML):\n", + "\n", + " global calculate_accuracy\n", + "\n", + " def MetaAlgorithm(\n", + " model, optimizer, x, n_way, k_shot, q_query, loss_fn,\n", + " inner_train_step=1, inner_lr=0.4, train=True): \n", + " criterion = loss_fn\n", + " task_loss, task_acc = [], []\n", + " special_grad = OrderedDict() # Added for variants!\n", + "\n", + " for meta_batch in x:\n", + " support_set = meta_batch[: n_way * k_shot] \n", + " query_set = meta_batch[n_way * k_shot :] \n", + " \n", + " fast_weights = OrderedDict(model.named_parameters())\n", + " \n", + " ### ---------- INNER TRAIN LOOP ---------- ###\n", + " for inner_step in range(inner_train_step): \n", + " train_label = create_label(n_way, k_shot).to(device)\n", + " logits = model.functional_forward(support_set, fast_weights)\n", + " loss = criterion(logits, train_label)\n", + "\n", + " fast_weights = inner_update(fast_weights, loss, inner_lr)\n", + "\n", + " ### ---------- INNER VALID LOOP ---------- ###\n", + " val_label = create_label(n_way, q_query).to(device)\n", + " # FIXME: W for val?\n", + " special_grad = collect_gradients(\n", + " special_grad, fast_weights, model, len(x))\n", + " \n", + " # Collect gradients for outer loop\n", + " logits = model.functional_forward(query_set, fast_weights) \n", + " loss = criterion(logits, val_label)\n", + " task_loss.append(loss)\n", + " task_acc.append(calculate_accuracy(logits, val_label))\n", + "\n", + " # Update outer loop\n", + " model.train()\n", + " optimizer.zero_grad()\n", + "\n", + " meta_batch_loss = torch.stack(task_loss).mean()\n", + " if train:\n", + " # Notice the update part!\n", + " outer_update(model, meta_batch_loss, special_grad)\n", + " optimizer.step()\n", + " task_acc = np.mean(task_acc)\n", + " return meta_batch_loss, task_acc\n", + " return MetaAlgorithm" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "jEsPtV-GzbDv", + "cellView": "form" + }, + "source": [ + "#@title Here is the answer hidden, please fill in yourself!\n", + "Give_me_the_answer = True #@param {\"type\": \"boolean\"}\n", + "\n", + "def HiddenAnswer():\n", + " MAML = MetaAlgorithmGenerator()\n", + " FOMAML = MetaAlgorithmGenerator(inner_update=inner_update_alt1)\n", + " ANIL = MetaAlgorithmGenerator(inner_update=inner_update_alt2)\n", + " return MAML, FOMAML, ANIL" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "2P__5N2Yz9O4" + }, + "source": [ + "# `HiddenAnswer` is hidden in the last cell.\n", + "if Give_me_the_answer:\n", + " MAML, FOMAML, ANIL = HiddenAnswer()\n", + "else: \n", + " # TODO: Please fill in the function names \\\n", + " # as the function arguments to finish the algorithm.\n", + " MAML = MetaAlgorithmGenerator()\n", + " FOMAML = MetaAlgorithmGenerator()\n", + " ANIL = MetaAlgorithmGenerator()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nBoRBhVlZAST" + }, + "source": [ + "## **Step 4: Initialization**\n", + "\n", + "After defining all components we need, the following initialize a model before training." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ip-i7aseftUF" + }, + "source": [ + "\n", + "### Hyperparameters \n", + "[Go back to top!](#top)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0wFHmVcBhE4M" + }, + "source": [ + "n_way = 5\n", + "k_shot = 1\n", + "q_query = 1\n", + "inner_train_step = 1\n", + "inner_lr = 0.4\n", + "meta_lr = 0.001\n", + "meta_batch_size = 32\n", + "max_epoch = 30\n", + "eval_batches = test_batches = 20\n", + "train_data_path = './Omniglot/images_background/'\n", + "test_data_path = './Omniglot/images_evaluation/' " + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Uvzo7NVpfu5V" + }, + "source": [ + "### Dataloader initialization" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3I13GJavhP0_" + }, + "source": [ + "def dataloader_init(datasets, num_workers=2):\n", + " train_set, val_set, test_set = datasets\n", + " train_loader = DataLoader(train_set,\n", + " # The \"batch_size\" here is not \\\n", + " # the meta batch size, but \\\n", + " # how many different \\\n", + " # characters in a task, \\\n", + " # i.e. the \"n_way\" in \\\n", + " # few-shot classification.\n", + " batch_size=n_way,\n", + " num_workers=num_workers,\n", + " shuffle=True,\n", + " drop_last=True)\n", + " val_loader = DataLoader(val_set,\n", + " batch_size=n_way,\n", + " num_workers=num_workers,\n", + " shuffle=True,\n", + " drop_last=True)\n", + " test_loader = DataLoader(test_set,\n", + " batch_size=n_way,\n", + " num_workers=num_workers,\n", + " shuffle=True,\n", + " drop_last=True)\n", + " train_iter = iter(train_loader)\n", + " val_iter = iter(val_loader)\n", + " test_iter = iter(test_loader)\n", + " return (train_loader, val_loader, test_loader), \\\n", + " (train_iter, val_iter, test_iter)\n", + "\n", + "train_set, val_set = torch.utils.data.random_split(\n", + " Omniglot(train_data_path, k_shot, q_query), [3200, 656])\n", + "test_set = Omniglot(test_data_path, k_shot, q_query)\n", + "\n", + "(train_loader, val_loader, test_loader), \\\n", + "(train_iter, val_iter, test_iter) = dataloader_init(\n", + " (train_set, val_set, test_set))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KVund--bfw0e" + }, + "source": [ + "### Model & optimizer initialization" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Kxug882ihF2B" + }, + "source": [ + "def model_init():\n", + " meta_model = Classifier(1, n_way).to(device)\n", + " optimizer = torch.optim.Adam(meta_model.parameters(), \n", + " lr=meta_lr)\n", + " loss_fn = nn.CrossEntropyLoss().to(device)\n", + " return meta_model, optimizer, loss_fn\n", + "\n", + "meta_model, optimizer, loss_fn = model_init()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gj8cLRNLf2zg" + }, + "source": [ + "### Utility function to get a meta-batch" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "zrkCSsxOhC-N" + }, + "source": [ + "def get_meta_batch(meta_batch_size,\n", + " k_shot, q_query, \n", + " data_loader, iterator):\n", + " data = []\n", + " for _ in range(meta_batch_size):\n", + " try:\n", + " # a \"task_data\" tensor is representing \\\n", + " # the data of a task, with size of \\\n", + " # [n_way, k_shot+q_query, 1, 28, 28]\n", + " task_data = iterator.next() \n", + " except StopIteration:\n", + " iterator = iter(data_loader)\n", + " task_data = iterator.next()\n", + " train_data = (task_data[:, :k_shot]\n", + " .reshape(-1, 1, 28, 28))\n", + " val_data = (task_data[:, k_shot:]\n", + " .reshape(-1, 1, 28, 28))\n", + " task_data = torch.cat(\n", + " (train_data, val_data), 0)\n", + " data.append(task_data)\n", + " return torch.stack(data).to(device), iterator" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O5JCtob4fyh_" + }, + "source": [ + "\n", + "### Choose the meta learning algorithm\n", + "[Go back to top!](#top)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3av6pAI7OxOP" + }, + "source": [ + "# You can change this to `FOMAML` or `ANIL`\n", + "MetaAlgorithm = MAML" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pWQczA3FwjEG" + }, + "source": [ + "\n", + "## **Step 5: Main program for training & testing**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8EirEnaof7ep" + }, + "source": [ + "### Start training!\n", + "\n", + "[Go back to top!](#top)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "JQZjJrLAhBWw" + }, + "source": [ + "for epoch in range(max_epoch):\n", + " print(\"Epoch %d\" % (epoch + 1))\n", + " train_meta_loss = []\n", + " train_acc = []\n", + " # The \"step\" here is a meta-gradinet update step\n", + " for step in tqdm(range(\n", + " len(train_loader) // meta_batch_size)): \n", + " x, train_iter = get_meta_batch(\n", + " meta_batch_size, k_shot, q_query, \n", + " train_loader, train_iter)\n", + " meta_loss, acc = MetaAlgorithm(\n", + " meta_model, optimizer, x, \n", + " n_way, k_shot, q_query, loss_fn)\n", + " train_meta_loss.append(meta_loss.item())\n", + " train_acc.append(acc)\n", + " print(\" Loss : \", \"%.3f\" % (np.mean(train_meta_loss)), end='\\t')\n", + " print(\" Accuracy: \", \"%.3f %%\" % (np.mean(train_acc) * 100))\n", + "\n", + " # See the validation accuracy after each epoch.\n", + " # Early stopping is welcomed to implement.\n", + " val_acc = []\n", + " for eval_step in tqdm(range(\n", + " len(val_loader) // (eval_batches))):\n", + " x, val_iter = get_meta_batch(\n", + " eval_batches, k_shot, q_query, \n", + " val_loader, val_iter)\n", + " # We update three inner steps when testing.\n", + " _, acc = MetaAlgorithm(meta_model, optimizer, x, \n", + " n_way, k_shot, q_query, \n", + " loss_fn, \n", + " inner_train_step=3, \n", + " train=False) \n", + " val_acc.append(acc)\n", + " print(\" Validation accuracy: \", \"%.3f %%\" % (np.mean(val_acc) * 100))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u5Ew8-POf9sw" + }, + "source": [ + "### Testing the result" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "CYN_zGB3g_5_" + }, + "source": [ + "test_acc = []\n", + "for test_step in tqdm(range(\n", + " len(test_loader) // (test_batches))):\n", + " x, test_iter = get_meta_batch(\n", + " test_batches, k_shot, q_query, \n", + " test_loader, test_iter)\n", + " # When testing, we update 3 inner-steps\n", + " _, acc = MetaAlgorithm(meta_model, optimizer, x, \n", + " n_way, k_shot, q_query, loss_fn, \n", + " inner_train_step=3, train=False)\n", + " test_acc.append(acc)\n", + "print(\" Testing accuracy: \", \"%.3f %%\" % (np.mean(test_acc) * 100))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rtD8X3RLf-6w" + }, + "source": [ + "## **Reference**\n", + "1. Chelsea Finn, Pieter Abbeel, & Sergey Levine. (2017). [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.](https://arxiv.org/abs/1909.09157)\n", + "1. Aniruddh Raghu, Maithra Raghu, Samy Bengio, & Oriol Vinyals. (2020). [Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML.](https://arxiv.org/abs/1909.09157)" + ] + } + ] +} \ No newline at end of file diff --git a/范例/HW15/HW15.pdf b/范例/HW15/HW15.pdf new file mode 100644 index 0000000..29af013 Binary files /dev/null and b/范例/HW15/HW15.pdf differ