| @@ -0,0 +1,470 @@ | |||
| { | |||
| "cells": [ | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "# BertEmbedding的各种用法\n", | |||
| "fastNLP的BertEmbedding以pytorch-transformer.BertModel的代码为基础,是一个使用BERT对words进行编码的Embedding。\n", | |||
| "\n", | |||
| "使用BertEmbedding和fastNLP.models.bert里面模型可以搭建BERT应用到五种下游任务的模型。\n", | |||
| "\n", | |||
| "*预训练好的Embedding参数及数据集的介绍和自动下载功能见 [Embedding教程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html) 和 [数据处理教程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_load_dataset.html)。*\n", | |||
| "\n", | |||
| "## 1. BERT for Squence Classification\n", | |||
| "在文本分类任务中,我们采用SST数据集作为例子来介绍BertEmbedding的使用方法。" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 1, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "import warnings\n", | |||
| "import torch\n", | |||
| "warnings.filterwarnings(\"ignore\")" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 2, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "data": { | |||
| "text/plain": [ | |||
| "In total 3 datasets:\n", | |||
| "\ttest has 2210 instances.\n", | |||
| "\ttrain has 8544 instances.\n", | |||
| "\tdev has 1101 instances.\n", | |||
| "In total 2 vocabs:\n", | |||
| "\twords has 21701 entries.\n", | |||
| "\ttarget has 5 entries." | |||
| ] | |||
| }, | |||
| "execution_count": 2, | |||
| "metadata": {}, | |||
| "output_type": "execute_result" | |||
| } | |||
| ], | |||
| "source": [ | |||
| "# 载入数据集\n", | |||
| "from fastNLP.io import SSTPipe\n", | |||
| "data_bundle = SSTPipe(subtree=False, train_subtree=False, lower=False, tokenizer='raw').process_from_file()\n", | |||
| "data_bundle" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 3, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "loading vocabulary file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/vocab.txt\n", | |||
| "Load pre-trained BERT parameters from file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/pytorch_model.bin.\n", | |||
| "Start to generate word pieces for word.\n", | |||
| "Found(Or segment into word pieces) 21701 words out of 21701.\n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "# 载入BertEmbedding\n", | |||
| "from fastNLP.embeddings import BertEmbedding\n", | |||
| "embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 4, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# 载入模型\n", | |||
| "from fastNLP.models import BertForSequenceClassification\n", | |||
| "model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 5, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "input fields after batch(if batch size is 2):\n", | |||
| "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 37]) \n", | |||
| "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||
| "target fields after batch(if batch size is 2):\n", | |||
| "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||
| "\n", | |||
| "training epochs started 2019-09-11-17-35-26\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=268), HTML(value='')), layout=Layout(display=…" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Evaluate data in 2.08 seconds!\n", | |||
| "Evaluation on dev at Epoch 1/2. Step:134/268: \n", | |||
| "AccuracyMetric: acc=0.459582\n", | |||
| "\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Evaluate data in 2.2 seconds!\n", | |||
| "Evaluation on dev at Epoch 2/2. Step:268/268: \n", | |||
| "AccuracyMetric: acc=0.468665\n", | |||
| "\n", | |||
| "\n", | |||
| "In Epoch:2/Step:268, got best dev performance:\n", | |||
| "AccuracyMetric: acc=0.468665\n", | |||
| "Reloaded the best model.\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "text/plain": [ | |||
| "{'best_eval': {'AccuracyMetric': {'acc': 0.468665}},\n", | |||
| " 'best_epoch': 2,\n", | |||
| " 'best_step': 268,\n", | |||
| " 'seconds': 114.5}" | |||
| ] | |||
| }, | |||
| "execution_count": 5, | |||
| "metadata": {}, | |||
| "output_type": "execute_result" | |||
| } | |||
| ], | |||
| "source": [ | |||
| "# 训练模型\n", | |||
| "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n", | |||
| "trainer = Trainer(data_bundle.get_dataset('train'), model, \n", | |||
| " optimizer=Adam(model_params=model.parameters(), lr=2e-5), \n", | |||
| " loss=CrossEntropyLoss(), device=[0],\n", | |||
| " batch_size=64, dev_data=data_bundle.get_dataset('dev'), \n", | |||
| " metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n", | |||
| "trainer.train()" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 6, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "\r", | |||
| "Evaluate data in 4.52 seconds!\n", | |||
| "[tester] \n", | |||
| "AccuracyMetric: acc=0.504072\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "text/plain": [ | |||
| "{'AccuracyMetric': {'acc': 0.504072}}" | |||
| ] | |||
| }, | |||
| "execution_count": 6, | |||
| "metadata": {}, | |||
| "output_type": "execute_result" | |||
| } | |||
| ], | |||
| "source": [ | |||
| "# 测试结果并删除模型\n", | |||
| "from fastNLP import Tester\n", | |||
| "tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())\n", | |||
| "tester.test()" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "\n", | |||
| "## 2. BERT for Sentence Matching\n", | |||
| "在Matching任务中,我们采用RTE数据集作为例子来介绍BertEmbedding的使用方法。" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 7, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "data": { | |||
| "text/plain": [ | |||
| "In total 3 datasets:\n", | |||
| "\ttest has 3000 instances.\n", | |||
| "\ttrain has 2490 instances.\n", | |||
| "\tdev has 277 instances.\n", | |||
| "In total 2 vocabs:\n", | |||
| "\twords has 41281 entries.\n", | |||
| "\ttarget has 2 entries." | |||
| ] | |||
| }, | |||
| "execution_count": 7, | |||
| "metadata": {}, | |||
| "output_type": "execute_result" | |||
| } | |||
| ], | |||
| "source": [ | |||
| "# 载入数据集\n", | |||
| "from fastNLP.io import RTEBertPipe\n", | |||
| "data_bundle = RTEBertPipe(lower=False, tokenizer='raw').process_from_file()\n", | |||
| "data_bundle" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 8, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "loading vocabulary file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/vocab.txt\n", | |||
| "Load pre-trained BERT parameters from file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/pytorch_model.bin.\n", | |||
| "Start to generate word pieces for word.\n", | |||
| "Found(Or segment into word pieces) 41279 words out of 41281.\n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "# 载入BertEmbedding\n", | |||
| "from fastNLP.embeddings import BertEmbedding\n", | |||
| "embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 9, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# 载入模型\n", | |||
| "from fastNLP.models import BertForSentenceMatching\n", | |||
| "model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 10, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "input fields after batch(if batch size is 2):\n", | |||
| "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 45]) \n", | |||
| "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||
| "target fields after batch(if batch size is 2):\n", | |||
| "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||
| "\n", | |||
| "training epochs started 2019-09-11-17-37-36\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=312), HTML(value='')), layout=Layout(display=…" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Evaluate data in 1.72 seconds!\n", | |||
| "Evaluation on dev at Epoch 1/2. Step:156/312: \n", | |||
| "AccuracyMetric: acc=0.624549\n", | |||
| "\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Evaluate data in 1.74 seconds!\n", | |||
| "Evaluation on dev at Epoch 2/2. Step:312/312: \n", | |||
| "AccuracyMetric: acc=0.649819\n", | |||
| "\n", | |||
| "\n", | |||
| "In Epoch:2/Step:312, got best dev performance:\n", | |||
| "AccuracyMetric: acc=0.649819\n", | |||
| "Reloaded the best model.\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "text/plain": [ | |||
| "{'best_eval': {'AccuracyMetric': {'acc': 0.649819}},\n", | |||
| " 'best_epoch': 2,\n", | |||
| " 'best_step': 312,\n", | |||
| " 'seconds': 109.87}" | |||
| ] | |||
| }, | |||
| "execution_count": 10, | |||
| "metadata": {}, | |||
| "output_type": "execute_result" | |||
| } | |||
| ], | |||
| "source": [ | |||
| "# 训练模型\n", | |||
| "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n", | |||
| "trainer = Trainer(data_bundle.get_dataset('train'), model, \n", | |||
| " optimizer=Adam(model_params=model.parameters(), lr=2e-5), \n", | |||
| " loss=CrossEntropyLoss(), device=[0],\n", | |||
| " batch_size=16, dev_data=data_bundle.get_dataset('dev'), \n", | |||
| " metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n", | |||
| "trainer.train()" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [] | |||
| } | |||
| ], | |||
| "metadata": { | |||
| "kernelspec": { | |||
| "display_name": "Python 3", | |||
| "language": "python", | |||
| "name": "python3" | |||
| }, | |||
| "language_info": { | |||
| "codemirror_mode": { | |||
| "name": "ipython", | |||
| "version": 3 | |||
| }, | |||
| "file_extension": ".py", | |||
| "mimetype": "text/x-python", | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.7.0" | |||
| } | |||
| }, | |||
| "nbformat": 4, | |||
| "nbformat_minor": 2 | |||
| } | |||