{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# E1. 使用 DistilBert 完成 SST2 分类" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "4.18.0\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader, Dataset\n", "\n", "import transformers\n", "from transformers import AutoTokenizer\n", "from transformers import AutoModelForSequenceClassification\n", "\n", "import sys\n", "sys.path.append('..')\n", "\n", "import fastNLP\n", "from fastNLP import Trainer\n", "from fastNLP.core.utils.utils import dataclass_to_dict\n", "from fastNLP.core.metrics import Accuracy\n", "\n", "print(transformers.__version__)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]\n", "\n", "task = \"sst2\"\n", "model_checkpoint = \"distilbert-base-uncased\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "253d79d7a67e4dc88338448b5bcb3fb9", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00[21:00:11] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", "\n" ], "text/plain": [ "\u001b[2;36m[21:00:11]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=22992;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=669026;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"acc#acc\": 0.871875,\n",
       "  \"total#acc\": 320.0,\n",
       "  \"correct#acc\": 279.0\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"acc#acc\": 0.878125,\n",
       "  \"total#acc\": 320.0,\n",
       "  \"correct#acc\": 281.0\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"acc#acc\": 0.871875,\n",
       "  \"total#acc\": 320.0,\n",
       "  \"correct#acc\": 279.0\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"acc#acc\": 0.903125,\n",
       "  \"total#acc\": 320.0,\n",
       "  \"correct#acc\": 289.0\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.903125\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m289.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"acc#acc\": 0.871875,\n",
       "  \"total#acc\": 320.0,\n",
       "  \"correct#acc\": 279.0\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"acc#acc\": 0.890625,\n",
       "  \"total#acc\": 320.0,\n",
       "  \"correct#acc\": 285.0\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"acc#acc\": 0.875,\n",
       "  \"total#acc\": 320.0,\n",
       "  \"correct#acc\": 280.0\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"acc#acc\": 0.8875,\n",
       "  \"total#acc\": 320.0,\n",
       "  \"correct#acc\": 284.0\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"acc#acc\": 0.8875,\n",
       "  \"total#acc\": 320.0,\n",
       "  \"correct#acc\": 284.0\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
       "
\n" ], "text/plain": [ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n",
       "  \"acc#acc\": 0.890625,\n",
       "  \"total#acc\": 320.0,\n",
       "  \"correct#acc\": 285.0\n",
       "}\n",
       "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 1 }