You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bert_embedding_tutorial.ipynb 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# BertEmbedding的各种用法\n",
  8. "fastNLP的BertEmbedding以pytorch-transformer.BertModel的代码为基础,是一个使用BERT对words进行编码的Embedding。\n",
  9. "\n",
  10. "使用BertEmbedding和fastNLP.models.bert里面模型可以搭建BERT应用到五种下游任务的模型。\n",
  11. "\n",
  12. "*预训练好的Embedding参数及数据集的介绍和自动下载功能见 [Embedding教程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html) 和 [数据处理教程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_load_dataset.html)。*\n",
  13. "\n",
  14. "## 1. BERT for Squence Classification\n",
  15. "在文本分类任务中,我们采用SST数据集作为例子来介绍BertEmbedding的使用方法。"
  16. ]
  17. },
  18. {
  19. "cell_type": "code",
  20. "execution_count": 1,
  21. "metadata": {},
  22. "outputs": [],
  23. "source": [
  24. "import warnings\n",
  25. "import torch\n",
  26. "warnings.filterwarnings(\"ignore\")"
  27. ]
  28. },
  29. {
  30. "cell_type": "code",
  31. "execution_count": 2,
  32. "metadata": {},
  33. "outputs": [
  34. {
  35. "data": {
  36. "text/plain": [
  37. "In total 3 datasets:\n",
  38. "\ttest has 2210 instances.\n",
  39. "\ttrain has 8544 instances.\n",
  40. "\tdev has 1101 instances.\n",
  41. "In total 2 vocabs:\n",
  42. "\twords has 21701 entries.\n",
  43. "\ttarget has 5 entries."
  44. ]
  45. },
  46. "execution_count": 2,
  47. "metadata": {},
  48. "output_type": "execute_result"
  49. }
  50. ],
  51. "source": [
  52. "# 载入数据集\n",
  53. "from fastNLP.io import SSTPipe\n",
  54. "data_bundle = SSTPipe(subtree=False, train_subtree=False, lower=False, tokenizer='raw').process_from_file()\n",
  55. "data_bundle"
  56. ]
  57. },
  58. {
  59. "cell_type": "code",
  60. "execution_count": 3,
  61. "metadata": {},
  62. "outputs": [
  63. {
  64. "name": "stdout",
  65. "output_type": "stream",
  66. "text": [
  67. "loading vocabulary file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/vocab.txt\n",
  68. "Load pre-trained BERT parameters from file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/pytorch_model.bin.\n",
  69. "Start to generate word pieces for word.\n",
  70. "Found(Or segment into word pieces) 21701 words out of 21701.\n"
  71. ]
  72. }
  73. ],
  74. "source": [
  75. "# 载入BertEmbedding\n",
  76. "from fastNLP.embeddings import BertEmbedding\n",
  77. "embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)"
  78. ]
  79. },
  80. {
  81. "cell_type": "code",
  82. "execution_count": 4,
  83. "metadata": {},
  84. "outputs": [],
  85. "source": [
  86. "# 载入模型\n",
  87. "from fastNLP.models import BertForSequenceClassification\n",
  88. "model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))"
  89. ]
  90. },
  91. {
  92. "cell_type": "code",
  93. "execution_count": 5,
  94. "metadata": {},
  95. "outputs": [
  96. {
  97. "name": "stdout",
  98. "output_type": "stream",
  99. "text": [
  100. "input fields after batch(if batch size is 2):\n",
  101. "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 37]) \n",
  102. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  103. "target fields after batch(if batch size is 2):\n",
  104. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  105. "\n",
  106. "training epochs started 2019-09-11-17-35-26\n"
  107. ]
  108. },
  109. {
  110. "data": {
  111. "application/vnd.jupyter.widget-view+json": {
  112. "model_id": "",
  113. "version_major": 2,
  114. "version_minor": 0
  115. },
  116. "text/plain": [
  117. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=268), HTML(value='')), layout=Layout(display=…"
  118. ]
  119. },
  120. "metadata": {},
  121. "output_type": "display_data"
  122. },
  123. {
  124. "data": {
  125. "application/vnd.jupyter.widget-view+json": {
  126. "model_id": "",
  127. "version_major": 2,
  128. "version_minor": 0
  129. },
  130. "text/plain": [
  131. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
  132. ]
  133. },
  134. "metadata": {},
  135. "output_type": "display_data"
  136. },
  137. {
  138. "name": "stdout",
  139. "output_type": "stream",
  140. "text": [
  141. "Evaluate data in 2.08 seconds!\n",
  142. "Evaluation on dev at Epoch 1/2. Step:134/268: \n",
  143. "AccuracyMetric: acc=0.459582\n",
  144. "\n"
  145. ]
  146. },
  147. {
  148. "data": {
  149. "application/vnd.jupyter.widget-view+json": {
  150. "model_id": "",
  151. "version_major": 2,
  152. "version_minor": 0
  153. },
  154. "text/plain": [
  155. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
  156. ]
  157. },
  158. "metadata": {},
  159. "output_type": "display_data"
  160. },
  161. {
  162. "name": "stdout",
  163. "output_type": "stream",
  164. "text": [
  165. "Evaluate data in 2.2 seconds!\n",
  166. "Evaluation on dev at Epoch 2/2. Step:268/268: \n",
  167. "AccuracyMetric: acc=0.468665\n",
  168. "\n",
  169. "\n",
  170. "In Epoch:2/Step:268, got best dev performance:\n",
  171. "AccuracyMetric: acc=0.468665\n",
  172. "Reloaded the best model.\n"
  173. ]
  174. },
  175. {
  176. "data": {
  177. "text/plain": [
  178. "{'best_eval': {'AccuracyMetric': {'acc': 0.468665}},\n",
  179. " 'best_epoch': 2,\n",
  180. " 'best_step': 268,\n",
  181. " 'seconds': 114.5}"
  182. ]
  183. },
  184. "execution_count": 5,
  185. "metadata": {},
  186. "output_type": "execute_result"
  187. }
  188. ],
  189. "source": [
  190. "# 训练模型\n",
  191. "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
  192. "trainer = Trainer(data_bundle.get_dataset('train'), model, \n",
  193. " optimizer=Adam(model_params=model.parameters(), lr=2e-5), \n",
  194. " loss=CrossEntropyLoss(), device=[0],\n",
  195. " batch_size=64, dev_data=data_bundle.get_dataset('dev'), \n",
  196. " metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n",
  197. "trainer.train()"
  198. ]
  199. },
  200. {
  201. "cell_type": "code",
  202. "execution_count": 6,
  203. "metadata": {},
  204. "outputs": [
  205. {
  206. "data": {
  207. "application/vnd.jupyter.widget-view+json": {
  208. "model_id": "",
  209. "version_major": 2,
  210. "version_minor": 0
  211. },
  212. "text/plain": [
  213. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
  214. ]
  215. },
  216. "metadata": {},
  217. "output_type": "display_data"
  218. },
  219. {
  220. "name": "stdout",
  221. "output_type": "stream",
  222. "text": [
  223. "\r",
  224. "Evaluate data in 4.52 seconds!\n",
  225. "[tester] \n",
  226. "AccuracyMetric: acc=0.504072\n"
  227. ]
  228. },
  229. {
  230. "data": {
  231. "text/plain": [
  232. "{'AccuracyMetric': {'acc': 0.504072}}"
  233. ]
  234. },
  235. "execution_count": 6,
  236. "metadata": {},
  237. "output_type": "execute_result"
  238. }
  239. ],
  240. "source": [
  241. "# 测试结果并删除模型\n",
  242. "from fastNLP import Tester\n",
  243. "tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())\n",
  244. "tester.test()"
  245. ]
  246. },
  247. {
  248. "cell_type": "markdown",
  249. "metadata": {},
  250. "source": [
  251. "\n",
  252. "## 2. BERT for Sentence Matching\n",
  253. "在Matching任务中,我们采用RTE数据集作为例子来介绍BertEmbedding的使用方法。"
  254. ]
  255. },
  256. {
  257. "cell_type": "code",
  258. "execution_count": 7,
  259. "metadata": {},
  260. "outputs": [
  261. {
  262. "data": {
  263. "text/plain": [
  264. "In total 3 datasets:\n",
  265. "\ttest has 3000 instances.\n",
  266. "\ttrain has 2490 instances.\n",
  267. "\tdev has 277 instances.\n",
  268. "In total 2 vocabs:\n",
  269. "\twords has 41281 entries.\n",
  270. "\ttarget has 2 entries."
  271. ]
  272. },
  273. "execution_count": 7,
  274. "metadata": {},
  275. "output_type": "execute_result"
  276. }
  277. ],
  278. "source": [
  279. "# 载入数据集\n",
  280. "from fastNLP.io import RTEBertPipe\n",
  281. "data_bundle = RTEBertPipe(lower=False, tokenizer='raw').process_from_file()\n",
  282. "data_bundle"
  283. ]
  284. },
  285. {
  286. "cell_type": "code",
  287. "execution_count": 8,
  288. "metadata": {},
  289. "outputs": [
  290. {
  291. "name": "stdout",
  292. "output_type": "stream",
  293. "text": [
  294. "loading vocabulary file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/vocab.txt\n",
  295. "Load pre-trained BERT parameters from file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/pytorch_model.bin.\n",
  296. "Start to generate word pieces for word.\n",
  297. "Found(Or segment into word pieces) 41279 words out of 41281.\n"
  298. ]
  299. }
  300. ],
  301. "source": [
  302. "# 载入BertEmbedding\n",
  303. "from fastNLP.embeddings import BertEmbedding\n",
  304. "embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)"
  305. ]
  306. },
  307. {
  308. "cell_type": "code",
  309. "execution_count": 9,
  310. "metadata": {},
  311. "outputs": [],
  312. "source": [
  313. "# 载入模型\n",
  314. "from fastNLP.models import BertForSentenceMatching\n",
  315. "model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))"
  316. ]
  317. },
  318. {
  319. "cell_type": "code",
  320. "execution_count": 10,
  321. "metadata": {},
  322. "outputs": [
  323. {
  324. "name": "stdout",
  325. "output_type": "stream",
  326. "text": [
  327. "input fields after batch(if batch size is 2):\n",
  328. "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 45]) \n",
  329. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  330. "target fields after batch(if batch size is 2):\n",
  331. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  332. "\n",
  333. "training epochs started 2019-09-11-17-37-36\n"
  334. ]
  335. },
  336. {
  337. "data": {
  338. "application/vnd.jupyter.widget-view+json": {
  339. "model_id": "",
  340. "version_major": 2,
  341. "version_minor": 0
  342. },
  343. "text/plain": [
  344. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=312), HTML(value='')), layout=Layout(display=…"
  345. ]
  346. },
  347. "metadata": {},
  348. "output_type": "display_data"
  349. },
  350. {
  351. "data": {
  352. "application/vnd.jupyter.widget-view+json": {
  353. "model_id": "",
  354. "version_major": 2,
  355. "version_minor": 0
  356. },
  357. "text/plain": [
  358. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
  359. ]
  360. },
  361. "metadata": {},
  362. "output_type": "display_data"
  363. },
  364. {
  365. "name": "stdout",
  366. "output_type": "stream",
  367. "text": [
  368. "Evaluate data in 1.72 seconds!\n",
  369. "Evaluation on dev at Epoch 1/2. Step:156/312: \n",
  370. "AccuracyMetric: acc=0.624549\n",
  371. "\n"
  372. ]
  373. },
  374. {
  375. "data": {
  376. "application/vnd.jupyter.widget-view+json": {
  377. "model_id": "",
  378. "version_major": 2,
  379. "version_minor": 0
  380. },
  381. "text/plain": [
  382. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
  383. ]
  384. },
  385. "metadata": {},
  386. "output_type": "display_data"
  387. },
  388. {
  389. "name": "stdout",
  390. "output_type": "stream",
  391. "text": [
  392. "Evaluate data in 1.74 seconds!\n",
  393. "Evaluation on dev at Epoch 2/2. Step:312/312: \n",
  394. "AccuracyMetric: acc=0.649819\n",
  395. "\n",
  396. "\n",
  397. "In Epoch:2/Step:312, got best dev performance:\n",
  398. "AccuracyMetric: acc=0.649819\n",
  399. "Reloaded the best model.\n"
  400. ]
  401. },
  402. {
  403. "data": {
  404. "text/plain": [
  405. "{'best_eval': {'AccuracyMetric': {'acc': 0.649819}},\n",
  406. " 'best_epoch': 2,\n",
  407. " 'best_step': 312,\n",
  408. " 'seconds': 109.87}"
  409. ]
  410. },
  411. "execution_count": 10,
  412. "metadata": {},
  413. "output_type": "execute_result"
  414. }
  415. ],
  416. "source": [
  417. "# 训练模型\n",
  418. "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
  419. "trainer = Trainer(data_bundle.get_dataset('train'), model, \n",
  420. " optimizer=Adam(model_params=model.parameters(), lr=2e-5), \n",
  421. " loss=CrossEntropyLoss(), device=[0],\n",
  422. " batch_size=16, dev_data=data_bundle.get_dataset('dev'), \n",
  423. " metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n",
  424. "trainer.train()"
  425. ]
  426. },
  427. {
  428. "cell_type": "code",
  429. "execution_count": null,
  430. "metadata": {},
  431. "outputs": [],
  432. "source": []
  433. },
  434. {
  435. "cell_type": "code",
  436. "execution_count": null,
  437. "metadata": {},
  438. "outputs": [],
  439. "source": []
  440. },
  441. {
  442. "cell_type": "code",
  443. "execution_count": null,
  444. "metadata": {},
  445. "outputs": [],
  446. "source": []
  447. }
  448. ],
  449. "metadata": {
  450. "kernelspec": {
  451. "display_name": "Python 3",
  452. "language": "python",
  453. "name": "python3"
  454. },
  455. "language_info": {
  456. "codemirror_mode": {
  457. "name": "ipython",
  458. "version": 3
  459. },
  460. "file_extension": ".py",
  461. "mimetype": "text/x-python",
  462. "name": "python",
  463. "nbconvert_exporter": "python",
  464. "pygments_lexer": "ipython3",
  465. "version": "3.7.0"
  466. }
  467. },
  468. "nbformat": 4,
  469. "nbformat_minor": 2
  470. }