{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# BertEmbedding的各种用法\n", "fastNLP的BertEmbedding以pytorch-transformer.BertModel的代码为基础,是一个使用BERT对words进行编码的Embedding。\n", "\n", "使用BertEmbedding和fastNLP.models.bert里面模型可以搭建BERT应用到五种下游任务的模型。\n", "\n", "*预训练好的Embedding参数及数据集的介绍和自动下载功能见 [Embedding教程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html) 和 [数据处理教程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_load_dataset.html)。*\n", "\n", "## 1. BERT for Squence Classification\n", "在文本分类任务中,我们采用SST数据集作为例子来介绍BertEmbedding的使用方法。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "import torch\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "In total 3 datasets:\n", "\ttest has 2210 instances.\n", "\ttrain has 8544 instances.\n", "\tdev has 1101 instances.\n", "In total 2 vocabs:\n", "\twords has 21701 entries.\n", "\ttarget has 5 entries." ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 载入数据集\n", "from fastNLP.io import SSTPipe\n", "data_bundle = SSTPipe(subtree=False, train_subtree=False, lower=False, tokenizer='raw').process_from_file()\n", "data_bundle" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loading vocabulary file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/vocab.txt\n", "Load pre-trained BERT parameters from file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/pytorch_model.bin.\n", "Start to generate word pieces for word.\n", "Found(Or segment into word pieces) 21701 words out of 21701.\n" ] } ], "source": [ "# 载入BertEmbedding\n", "from fastNLP.embeddings import BertEmbedding\n", "embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# 载入模型\n", "from fastNLP.models import BertForSequenceClassification\n", "model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input fields after batch(if batch size is 2):\n", "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 37]) \n", "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "target fields after batch(if batch size is 2):\n", "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\n", "training epochs started 2019-09-11-17-35-26\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=268), HTML(value='')), layout=Layout(display=…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluate data in 2.08 seconds!\n", "Evaluation on dev at Epoch 1/2. Step:134/268: \n", "AccuracyMetric: acc=0.459582\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluate data in 2.2 seconds!\n", "Evaluation on dev at Epoch 2/2. Step:268/268: \n", "AccuracyMetric: acc=0.468665\n", "\n", "\n", "In Epoch:2/Step:268, got best dev performance:\n", "AccuracyMetric: acc=0.468665\n", "Reloaded the best model.\n" ] }, { "data": { "text/plain": [ "{'best_eval': {'AccuracyMetric': {'acc': 0.468665}},\n", " 'best_epoch': 2,\n", " 'best_step': 268,\n", " 'seconds': 114.5}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 训练模型\n", "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n", "trainer = Trainer(data_bundle.get_dataset('train'), model, \n", " optimizer=Adam(model_params=model.parameters(), lr=2e-5), \n", " loss=CrossEntropyLoss(), device=[0],\n", " batch_size=64, dev_data=data_bundle.get_dataset('dev'), \n", " metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 4.52 seconds!\n", "[tester] \n", "AccuracyMetric: acc=0.504072\n" ] }, { "data": { "text/plain": [ "{'AccuracyMetric': {'acc': 0.504072}}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 测试结果并删除模型\n", "from fastNLP import Tester\n", "tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())\n", "tester.test()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 2. BERT for Sentence Matching\n", "在Matching任务中,我们采用RTE数据集作为例子来介绍BertEmbedding的使用方法。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "In total 3 datasets:\n", "\ttest has 3000 instances.\n", "\ttrain has 2490 instances.\n", "\tdev has 277 instances.\n", "In total 2 vocabs:\n", "\twords has 41281 entries.\n", "\ttarget has 2 entries." ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 载入数据集\n", "from fastNLP.io import RTEBertPipe\n", "data_bundle = RTEBertPipe(lower=False, tokenizer='raw').process_from_file()\n", "data_bundle" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loading vocabulary file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/vocab.txt\n", "Load pre-trained BERT parameters from file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/pytorch_model.bin.\n", "Start to generate word pieces for word.\n", "Found(Or segment into word pieces) 41279 words out of 41281.\n" ] } ], "source": [ "# 载入BertEmbedding\n", "from fastNLP.embeddings import BertEmbedding\n", "embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# 载入模型\n", "from fastNLP.models import BertForSentenceMatching\n", "model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input fields after batch(if batch size is 2):\n", "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 45]) \n", "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "target fields after batch(if batch size is 2):\n", "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\n", "training epochs started 2019-09-11-17-37-36\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=312), HTML(value='')), layout=Layout(display=…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluate data in 1.72 seconds!\n", "Evaluation on dev at Epoch 1/2. Step:156/312: \n", "AccuracyMetric: acc=0.624549\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluate data in 1.74 seconds!\n", "Evaluation on dev at Epoch 2/2. Step:312/312: \n", "AccuracyMetric: acc=0.649819\n", "\n", "\n", "In Epoch:2/Step:312, got best dev performance:\n", "AccuracyMetric: acc=0.649819\n", "Reloaded the best model.\n" ] }, { "data": { "text/plain": [ "{'best_eval': {'AccuracyMetric': {'acc': 0.649819}},\n", " 'best_epoch': 2,\n", " 'best_step': 312,\n", " 'seconds': 109.87}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 训练模型\n", "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n", "trainer = Trainer(data_bundle.get_dataset('train'), model, \n", " optimizer=Adam(model_params=model.parameters(), lr=2e-5), \n", " loss=CrossEntropyLoss(), device=[0],\n", " batch_size=16, dev_data=data_bundle.get_dataset('dev'), \n", " metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 2 }