# BertEmbedding的各种用法
fastNLP的BertEmbedding以pytorch-transformer.BertModel的代码为基础，是一个使用BERT对words进行编码的Embedding。

使用BertEmbedding和fastNLP.models.bert里面模型可以搭建BERT应用到五种下游任务的模型。

*预训练好的Embedding参数及数据集的介绍和自动下载功能见 [Embedding教程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html) 和 [数据处理教程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_load_dataset.html)。*

## 1. BERT for Squence Classification
在文本分类任务中，我们采用SST数据集作为例子来介绍BertEmbedding的使用方法。

In [1]:
import warnings
import torch
warnings.filterwarnings("ignore")

In [2]:
# 载入数据集
from fastNLP.io import SSTPipe
data_bundle = SSTPipe(subtree=False, train_subtree=False, lower=False, tokenizer='raw').process_from_file()
data_bundle

In total 3 datasets:
	test has 2210 instances.
	train has 8544 instances.
	dev has 1101 instances.
In total 2 vocabs:
	words has 21701 entries.
	target has 5 entries.

In [3]:
# 载入BertEmbedding
from fastNLP.embeddings import BertEmbedding
embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)

loading vocabulary file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/vocab.txt
Load pre-trained BERT parameters from file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/pytorch_model.bin.
Start to generate word pieces for word.
Found(Or segment into word pieces) 21701 words out of 21701.


In [4]:
# 载入模型
from fastNLP.models import BertForSequenceClassification
model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))

In [5]:
# 训练模型
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
trainer = Trainer(data_bundle.get_dataset('train'), model, 
                  optimizer=Adam(model_params=model.parameters(), lr=2e-5), 
                  loss=CrossEntropyLoss(), device=[0],
                  batch_size=64, dev_data=data_bundle.get_dataset('dev'), 
                  metrics=AccuracyMetric(), n_epochs=2, print_every=1)
trainer.train()

input fields after batch(if batch size is 2):
	words: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 37]) 
	seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) 
target fields after batch(if batch size is 2):
	target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) 

training epochs started 2019-09-11-17-35-26


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=268), HTML(value='')), layout=Layout(display=…

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…

Evaluate data in 2.08 seconds!
Evaluation on dev at Epoch 1/2. Step:134/268: 
AccuracyMetric: acc=0.459582



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…

Evaluate data in 2.2 seconds!
Evaluation on dev at Epoch 2/2. Step:268/268: 
AccuracyMetric: acc=0.468665


In Epoch:2/Step:268, got best dev performance:
AccuracyMetric: acc=0.468665
Reloaded the best model.


{'best_eval': {'AccuracyMetric': {'acc': 0.468665}},
 'best_epoch': 2,
 'best_step': 268,
 'seconds': 114.5}

In [6]:
# 测试结果并删除模型
from fastNLP import Tester
tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())
tester.test()

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…

Evaluate data in 4.52 seconds!
[tester] 
AccuracyMetric: acc=0.504072


{'AccuracyMetric': {'acc': 0.504072}}


## 2. BERT for Sentence Matching
在Matching任务中，我们采用RTE数据集作为例子来介绍BertEmbedding的使用方法。

In [7]:
# 载入数据集
from fastNLP.io import RTEBertPipe
data_bundle = RTEBertPipe(lower=False, tokenizer='raw').process_from_file()
data_bundle

In total 3 datasets:
	test has 3000 instances.
	train has 2490 instances.
	dev has 277 instances.
In total 2 vocabs:
	words has 41281 entries.
	target has 2 entries.

In [8]:
# 载入BertEmbedding
from fastNLP.embeddings import BertEmbedding
embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)

loading vocabulary file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/vocab.txt
Load pre-trained BERT parameters from file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/pytorch_model.bin.
Start to generate word pieces for word.
Found(Or segment into word pieces) 41279 words out of 41281.


In [9]:
# 载入模型
from fastNLP.models import BertForSentenceMatching
model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))

In [10]:
# 训练模型
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
trainer = Trainer(data_bundle.get_dataset('train'), model, 
                  optimizer=Adam(model_params=model.parameters(), lr=2e-5), 
                  loss=CrossEntropyLoss(), device=[0],
                  batch_size=16, dev_data=data_bundle.get_dataset('dev'), 
                  metrics=AccuracyMetric(), n_epochs=2, print_every=1)
trainer.train()

input fields after batch(if batch size is 2):
	words: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 45]) 
	seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) 
target fields after batch(if batch size is 2):
	target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) 

training epochs started 2019-09-11-17-37-36


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=312), HTML(value='')), layout=Layout(display=…

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…

Evaluate data in 1.72 seconds!
Evaluation on dev at Epoch 1/2. Step:156/312: 
AccuracyMetric: acc=0.624549



HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…

Evaluate data in 1.74 seconds!
Evaluation on dev at Epoch 2/2. Step:312/312: 
AccuracyMetric: acc=0.649819


In Epoch:2/Step:312, got best dev performance:
AccuracyMetric: acc=0.649819
Reloaded the best model.


{'best_eval': {'AccuracyMetric': {'acc': 0.649819}},
 'best_epoch': 2,
 'best_step': 312,
 'seconds': 109.87}