You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tutorial_8_modules_models.ipynb 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# 使用Modules和Models快速搭建自定义模型\n",
  8. "\n",
  9. "modules 和 models 用于构建 fastNLP 所需的神经网络模型,它可以和 torch.nn 中的模型一起使用。 下面我们会分三节介绍编写构建模型的具体方法。\n"
  10. ]
  11. },
  12. {
  13. "cell_type": "markdown",
  14. "metadata": {},
  15. "source": [
  16. "我们首先准备好和上篇教程一样的基础实验代码"
  17. ]
  18. },
  19. {
  20. "cell_type": "code",
  21. "execution_count": 2,
  22. "metadata": {},
  23. "outputs": [],
  24. "source": [
  25. "from fastNLP.io import SST2Pipe\n",
  26. "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
  27. "import torch\n",
  28. "\n",
  29. "databundle = SST2Pipe().process_from_file()\n",
  30. "vocab = databundle.get_vocab('words')\n",
  31. "train_data = databundle.get_dataset('train')[:5000]\n",
  32. "train_data, test_data = train_data.split(0.015)\n",
  33. "dev_data = databundle.get_dataset('dev')\n",
  34. "\n",
  35. "loss = CrossEntropyLoss()\n",
  36. "metric = AccuracyMetric()\n",
  37. "device = 0 if torch.cuda.is_available() else 'cpu'"
  38. ]
  39. },
  40. {
  41. "cell_type": "markdown",
  42. "metadata": {},
  43. "source": [
  44. "## 使用 models 中的模型\n",
  45. "\n",
  46. "fastNLP 在 models 模块中内置了如 CNNText 、 SeqLabeling 等完整的模型,以供用户直接使用。 以文本分类的任务为例,我们从 models 中导入 CNNText 模型,用它进行训练。"
  47. ]
  48. },
  49. {
  50. "cell_type": "code",
  51. "execution_count": 3,
  52. "metadata": {},
  53. "outputs": [
  54. {
  55. "name": "stdout",
  56. "output_type": "stream",
  57. "text": [
  58. "input fields after batch(if batch size is 2):\n",
  59. "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 41]) \n",
  60. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  61. "target fields after batch(if batch size is 2):\n",
  62. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  63. "\n",
  64. "training epochs started 2020-02-28-00-56-04\n"
  65. ]
  66. },
  67. {
  68. "data": {
  69. "application/vnd.jupyter.widget-view+json": {
  70. "model_id": "",
  71. "version_major": 2,
  72. "version_minor": 0
  73. },
  74. "text/plain": [
  75. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
  76. ]
  77. },
  78. "metadata": {},
  79. "output_type": "display_data"
  80. },
  81. {
  82. "data": {
  83. "application/vnd.jupyter.widget-view+json": {
  84. "model_id": "",
  85. "version_major": 2,
  86. "version_minor": 0
  87. },
  88. "text/plain": [
  89. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  90. ]
  91. },
  92. "metadata": {},
  93. "output_type": "display_data"
  94. },
  95. {
  96. "name": "stdout",
  97. "output_type": "stream",
  98. "text": [
  99. "\r",
  100. "Evaluate data in 0.22 seconds!\n",
  101. "\r",
  102. "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
  103. "\r",
  104. "AccuracyMetric: acc=0.760321\n",
  105. "\n"
  106. ]
  107. },
  108. {
  109. "data": {
  110. "application/vnd.jupyter.widget-view+json": {
  111. "model_id": "",
  112. "version_major": 2,
  113. "version_minor": 0
  114. },
  115. "text/plain": [
  116. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  117. ]
  118. },
  119. "metadata": {},
  120. "output_type": "display_data"
  121. },
  122. {
  123. "name": "stdout",
  124. "output_type": "stream",
  125. "text": [
  126. "\r",
  127. "Evaluate data in 0.29 seconds!\n",
  128. "\r",
  129. "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
  130. "\r",
  131. "AccuracyMetric: acc=0.727064\n",
  132. "\n"
  133. ]
  134. },
  135. {
  136. "data": {
  137. "application/vnd.jupyter.widget-view+json": {
  138. "model_id": "",
  139. "version_major": 2,
  140. "version_minor": 0
  141. },
  142. "text/plain": [
  143. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  144. ]
  145. },
  146. "metadata": {},
  147. "output_type": "display_data"
  148. },
  149. {
  150. "name": "stdout",
  151. "output_type": "stream",
  152. "text": [
  153. "\r",
  154. "Evaluate data in 0.48 seconds!\n",
  155. "\r",
  156. "Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
  157. "\r",
  158. "AccuracyMetric: acc=0.758028\n",
  159. "\n"
  160. ]
  161. },
  162. {
  163. "data": {
  164. "application/vnd.jupyter.widget-view+json": {
  165. "model_id": "",
  166. "version_major": 2,
  167. "version_minor": 0
  168. },
  169. "text/plain": [
  170. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  171. ]
  172. },
  173. "metadata": {},
  174. "output_type": "display_data"
  175. },
  176. {
  177. "name": "stdout",
  178. "output_type": "stream",
  179. "text": [
  180. "\r",
  181. "Evaluate data in 0.24 seconds!\n",
  182. "\r",
  183. "Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
  184. "\r",
  185. "AccuracyMetric: acc=0.759174\n",
  186. "\n"
  187. ]
  188. },
  189. {
  190. "data": {
  191. "application/vnd.jupyter.widget-view+json": {
  192. "model_id": "",
  193. "version_major": 2,
  194. "version_minor": 0
  195. },
  196. "text/plain": [
  197. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  198. ]
  199. },
  200. "metadata": {},
  201. "output_type": "display_data"
  202. },
  203. {
  204. "name": "stdout",
  205. "output_type": "stream",
  206. "text": [
  207. "\r",
  208. "Evaluate data in 0.47 seconds!\n",
  209. "\r",
  210. "Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
  211. "\r",
  212. "AccuracyMetric: acc=0.743119\n",
  213. "\n"
  214. ]
  215. },
  216. {
  217. "data": {
  218. "application/vnd.jupyter.widget-view+json": {
  219. "model_id": "",
  220. "version_major": 2,
  221. "version_minor": 0
  222. },
  223. "text/plain": [
  224. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  225. ]
  226. },
  227. "metadata": {},
  228. "output_type": "display_data"
  229. },
  230. {
  231. "name": "stdout",
  232. "output_type": "stream",
  233. "text": [
  234. "\r",
  235. "Evaluate data in 0.22 seconds!\n",
  236. "\r",
  237. "Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
  238. "\r",
  239. "AccuracyMetric: acc=0.756881\n",
  240. "\n"
  241. ]
  242. },
  243. {
  244. "data": {
  245. "application/vnd.jupyter.widget-view+json": {
  246. "model_id": "",
  247. "version_major": 2,
  248. "version_minor": 0
  249. },
  250. "text/plain": [
  251. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  252. ]
  253. },
  254. "metadata": {},
  255. "output_type": "display_data"
  256. },
  257. {
  258. "name": "stdout",
  259. "output_type": "stream",
  260. "text": [
  261. "\r",
  262. "Evaluate data in 0.21 seconds!\n",
  263. "\r",
  264. "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
  265. "\r",
  266. "AccuracyMetric: acc=0.752294\n",
  267. "\n"
  268. ]
  269. },
  270. {
  271. "data": {
  272. "application/vnd.jupyter.widget-view+json": {
  273. "model_id": "",
  274. "version_major": 2,
  275. "version_minor": 0
  276. },
  277. "text/plain": [
  278. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  279. ]
  280. },
  281. "metadata": {},
  282. "output_type": "display_data"
  283. },
  284. {
  285. "name": "stdout",
  286. "output_type": "stream",
  287. "text": [
  288. "\r",
  289. "Evaluate data in 0.21 seconds!\n",
  290. "\r",
  291. "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
  292. "\r",
  293. "AccuracyMetric: acc=0.756881\n",
  294. "\n"
  295. ]
  296. },
  297. {
  298. "data": {
  299. "application/vnd.jupyter.widget-view+json": {
  300. "model_id": "",
  301. "version_major": 2,
  302. "version_minor": 0
  303. },
  304. "text/plain": [
  305. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  306. ]
  307. },
  308. "metadata": {},
  309. "output_type": "display_data"
  310. },
  311. {
  312. "name": "stdout",
  313. "output_type": "stream",
  314. "text": [
  315. "\r",
  316. "Evaluate data in 0.15 seconds!\n",
  317. "\r",
  318. "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
  319. "\r",
  320. "AccuracyMetric: acc=0.75344\n",
  321. "\n"
  322. ]
  323. },
  324. {
  325. "data": {
  326. "application/vnd.jupyter.widget-view+json": {
  327. "model_id": "",
  328. "version_major": 2,
  329. "version_minor": 0
  330. },
  331. "text/plain": [
  332. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  333. ]
  334. },
  335. "metadata": {},
  336. "output_type": "display_data"
  337. },
  338. {
  339. "name": "stdout",
  340. "output_type": "stream",
  341. "text": [
  342. "\r",
  343. "Evaluate data in 0.12 seconds!\n",
  344. "\r",
  345. "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
  346. "\r",
  347. "AccuracyMetric: acc=0.752294\n",
  348. "\n",
  349. "\r\n",
  350. "In Epoch:1/Step:154, got best dev performance:\n",
  351. "AccuracyMetric: acc=0.760321\n",
  352. "Reloaded the best model.\n"
  353. ]
  354. },
  355. {
  356. "data": {
  357. "text/plain": [
  358. "{'best_eval': {'AccuracyMetric': {'acc': 0.760321}},\n",
  359. " 'best_epoch': 1,\n",
  360. " 'best_step': 154,\n",
  361. " 'seconds': 29.3}"
  362. ]
  363. },
  364. "execution_count": 3,
  365. "metadata": {},
  366. "output_type": "execute_result"
  367. }
  368. ],
  369. "source": [
  370. "from fastNLP.models import CNNText\n",
  371. "\n",
  372. "model_cnn = CNNText((len(vocab),100), num_classes=2, dropout=0.1)\n",
  373. "\n",
  374. "trainer = Trainer(train_data=train_data, dev_data=dev_data, metrics=metric,\n",
  375. " loss=loss, device=device, model=model_cnn)\n",
  376. "trainer.train()"
  377. ]
  378. },
  379. {
  380. "cell_type": "markdown",
  381. "metadata": {},
  382. "source": [
  383. "在 iPython 环境输入 model_cnn ,我们可以看到 model_cnn 的网络结构"
  384. ]
  385. },
  386. {
  387. "cell_type": "code",
  388. "execution_count": 4,
  389. "metadata": {},
  390. "outputs": [
  391. {
  392. "data": {
  393. "text/plain": [
  394. "CNNText(\n",
  395. " (embed): Embedding(\n",
  396. " (embed): Embedding(16292, 100)\n",
  397. " (dropout): Dropout(p=0.0, inplace=False)\n",
  398. " )\n",
  399. " (conv_pool): ConvMaxpool(\n",
  400. " (convs): ModuleList(\n",
  401. " (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n",
  402. " (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
  403. " (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
  404. " )\n",
  405. " )\n",
  406. " (dropout): Dropout(p=0.1, inplace=False)\n",
  407. " (fc): Linear(in_features=120, out_features=2, bias=True)\n",
  408. ")"
  409. ]
  410. },
  411. "execution_count": 4,
  412. "metadata": {},
  413. "output_type": "execute_result"
  414. }
  415. ],
  416. "source": [
  417. "model_cnn"
  418. ]
  419. },
  420. {
  421. "cell_type": "markdown",
  422. "metadata": {},
  423. "source": [
  424. "## 使用 nn.torch 编写模型\n",
  425. "\n",
  426. "FastNLP 完全支持使用 pyTorch 编写的模型,但与 pyTorch 中编写模型的常见方法不同, 用于 fastNLP 的模型中 forward 函数需要返回一个字典,字典中至少需要包含 pred 这个字段。\n",
  427. "\n",
  428. "下面是使用 pyTorch 中的 torch.nn 模块编写的文本分类,注意观察代码中标注的向量维度。 由于 pyTorch 使用了约定俗成的维度设置,使得 forward 中需要多次处理维度顺序"
  429. ]
  430. },
  431. {
  432. "cell_type": "code",
  433. "execution_count": 5,
  434. "metadata": {},
  435. "outputs": [],
  436. "source": [
  437. "import torch\n",
  438. "import torch.nn as nn\n",
  439. "\n",
  440. "class LSTMText(nn.Module):\n",
  441. " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
  442. " super().__init__()\n",
  443. "\n",
  444. " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
  445. " self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True, dropout=dropout)\n",
  446. " self.fc = nn.Linear(hidden_dim * 2, output_dim)\n",
  447. " self.dropout = nn.Dropout(dropout)\n",
  448. "\n",
  449. " def forward(self, words):\n",
  450. " # (input) words : (batch_size, seq_len)\n",
  451. " words = words.permute(1,0)\n",
  452. " # words : (seq_len, batch_size)\n",
  453. "\n",
  454. " embedded = self.dropout(self.embedding(words))\n",
  455. " # embedded : (seq_len, batch_size, embedding_dim)\n",
  456. " output, (hidden, cell) = self.lstm(embedded)\n",
  457. " # output: (seq_len, batch_size, hidden_dim * 2)\n",
  458. " # hidden: (num_layers * 2, batch_size, hidden_dim)\n",
  459. " # cell: (num_layers * 2, batch_size, hidden_dim)\n",
  460. "\n",
  461. " hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)\n",
  462. " hidden = self.dropout(hidden)\n",
  463. " # hidden: (batch_size, hidden_dim * 2)\n",
  464. "\n",
  465. " pred = self.fc(hidden.squeeze(0))\n",
  466. " # result: (batch_size, output_dim)\n",
  467. " return {\"pred\":pred}"
  468. ]
  469. },
  470. {
  471. "cell_type": "markdown",
  472. "metadata": {},
  473. "source": [
  474. "我们同样可以在 iPython 环境中查看这个模型的网络结构"
  475. ]
  476. },
  477. {
  478. "cell_type": "code",
  479. "execution_count": 6,
  480. "metadata": {},
  481. "outputs": [
  482. {
  483. "data": {
  484. "text/plain": [
  485. "LSTMText(\n",
  486. " (embedding): Embedding(16292, 100)\n",
  487. " (lstm): LSTM(100, 64, num_layers=2, dropout=0.5, bidirectional=True)\n",
  488. " (fc): Linear(in_features=128, out_features=2, bias=True)\n",
  489. " (dropout): Dropout(p=0.5, inplace=False)\n",
  490. ")"
  491. ]
  492. },
  493. "execution_count": 6,
  494. "metadata": {},
  495. "output_type": "execute_result"
  496. }
  497. ],
  498. "source": [
  499. "model_lstm = LSTMText(len(vocab), 100, 2)\n",
  500. "model_lstm "
  501. ]
  502. },
  503. {
  504. "cell_type": "code",
  505. "execution_count": 7,
  506. "metadata": {
  507. "scrolled": true
  508. },
  509. "outputs": [
  510. {
  511. "name": "stdout",
  512. "output_type": "stream",
  513. "text": [
  514. "input fields after batch(if batch size is 2):\n",
  515. "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 41]) \n",
  516. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  517. "target fields after batch(if batch size is 2):\n",
  518. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  519. "\n",
  520. "training epochs started 2020-02-28-00-56-34\n"
  521. ]
  522. },
  523. {
  524. "data": {
  525. "application/vnd.jupyter.widget-view+json": {
  526. "model_id": "",
  527. "version_major": 2,
  528. "version_minor": 0
  529. },
  530. "text/plain": [
  531. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
  532. ]
  533. },
  534. "metadata": {},
  535. "output_type": "display_data"
  536. },
  537. {
  538. "data": {
  539. "application/vnd.jupyter.widget-view+json": {
  540. "model_id": "",
  541. "version_major": 2,
  542. "version_minor": 0
  543. },
  544. "text/plain": [
  545. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  546. ]
  547. },
  548. "metadata": {},
  549. "output_type": "display_data"
  550. },
  551. {
  552. "name": "stdout",
  553. "output_type": "stream",
  554. "text": [
  555. "\r",
  556. "Evaluate data in 0.36 seconds!\n",
  557. "\r",
  558. "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
  559. "\r",
  560. "AccuracyMetric: acc=0.59289\n",
  561. "\n"
  562. ]
  563. },
  564. {
  565. "data": {
  566. "application/vnd.jupyter.widget-view+json": {
  567. "model_id": "",
  568. "version_major": 2,
  569. "version_minor": 0
  570. },
  571. "text/plain": [
  572. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  573. ]
  574. },
  575. "metadata": {},
  576. "output_type": "display_data"
  577. },
  578. {
  579. "name": "stdout",
  580. "output_type": "stream",
  581. "text": [
  582. "\r",
  583. "Evaluate data in 0.35 seconds!\n",
  584. "\r",
  585. "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
  586. "\r",
  587. "AccuracyMetric: acc=0.674312\n",
  588. "\n"
  589. ]
  590. },
  591. {
  592. "data": {
  593. "application/vnd.jupyter.widget-view+json": {
  594. "model_id": "",
  595. "version_major": 2,
  596. "version_minor": 0
  597. },
  598. "text/plain": [
  599. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  600. ]
  601. },
  602. "metadata": {},
  603. "output_type": "display_data"
  604. },
  605. {
  606. "name": "stdout",
  607. "output_type": "stream",
  608. "text": [
  609. "\r",
  610. "Evaluate data in 0.21 seconds!\n",
  611. "\r",
  612. "Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
  613. "\r",
  614. "AccuracyMetric: acc=0.724771\n",
  615. "\n"
  616. ]
  617. },
  618. {
  619. "data": {
  620. "application/vnd.jupyter.widget-view+json": {
  621. "model_id": "",
  622. "version_major": 2,
  623. "version_minor": 0
  624. },
  625. "text/plain": [
  626. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  627. ]
  628. },
  629. "metadata": {},
  630. "output_type": "display_data"
  631. },
  632. {
  633. "name": "stdout",
  634. "output_type": "stream",
  635. "text": [
  636. "\r",
  637. "Evaluate data in 0.4 seconds!\n",
  638. "\r",
  639. "Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
  640. "\r",
  641. "AccuracyMetric: acc=0.748853\n",
  642. "\n"
  643. ]
  644. },
  645. {
  646. "data": {
  647. "application/vnd.jupyter.widget-view+json": {
  648. "model_id": "",
  649. "version_major": 2,
  650. "version_minor": 0
  651. },
  652. "text/plain": [
  653. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  654. ]
  655. },
  656. "metadata": {},
  657. "output_type": "display_data"
  658. },
  659. {
  660. "name": "stdout",
  661. "output_type": "stream",
  662. "text": [
  663. "\r",
  664. "Evaluate data in 0.24 seconds!\n",
  665. "\r",
  666. "Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
  667. "\r",
  668. "AccuracyMetric: acc=0.756881\n",
  669. "\n"
  670. ]
  671. },
  672. {
  673. "data": {
  674. "application/vnd.jupyter.widget-view+json": {
  675. "model_id": "",
  676. "version_major": 2,
  677. "version_minor": 0
  678. },
  679. "text/plain": [
  680. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  681. ]
  682. },
  683. "metadata": {},
  684. "output_type": "display_data"
  685. },
  686. {
  687. "name": "stdout",
  688. "output_type": "stream",
  689. "text": [
  690. "\r",
  691. "Evaluate data in 0.29 seconds!\n",
  692. "\r",
  693. "Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
  694. "\r",
  695. "AccuracyMetric: acc=0.741972\n",
  696. "\n"
  697. ]
  698. },
  699. {
  700. "data": {
  701. "application/vnd.jupyter.widget-view+json": {
  702. "model_id": "",
  703. "version_major": 2,
  704. "version_minor": 0
  705. },
  706. "text/plain": [
  707. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  708. ]
  709. },
  710. "metadata": {},
  711. "output_type": "display_data"
  712. },
  713. {
  714. "name": "stdout",
  715. "output_type": "stream",
  716. "text": [
  717. "\r",
  718. "Evaluate data in 0.32 seconds!\n",
  719. "\r",
  720. "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
  721. "\r",
  722. "AccuracyMetric: acc=0.754587\n",
  723. "\n"
  724. ]
  725. },
  726. {
  727. "data": {
  728. "application/vnd.jupyter.widget-view+json": {
  729. "model_id": "",
  730. "version_major": 2,
  731. "version_minor": 0
  732. },
  733. "text/plain": [
  734. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  735. ]
  736. },
  737. "metadata": {},
  738. "output_type": "display_data"
  739. },
  740. {
  741. "name": "stdout",
  742. "output_type": "stream",
  743. "text": [
  744. "\r",
  745. "Evaluate data in 0.24 seconds!\n",
  746. "\r",
  747. "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
  748. "\r",
  749. "AccuracyMetric: acc=0.756881\n",
  750. "\n"
  751. ]
  752. },
  753. {
  754. "data": {
  755. "application/vnd.jupyter.widget-view+json": {
  756. "model_id": "",
  757. "version_major": 2,
  758. "version_minor": 0
  759. },
  760. "text/plain": [
  761. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  762. ]
  763. },
  764. "metadata": {},
  765. "output_type": "display_data"
  766. },
  767. {
  768. "name": "stdout",
  769. "output_type": "stream",
  770. "text": [
  771. "\r",
  772. "Evaluate data in 0.28 seconds!\n",
  773. "\r",
  774. "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
  775. "\r",
  776. "AccuracyMetric: acc=0.740826\n",
  777. "\n"
  778. ]
  779. },
  780. {
  781. "data": {
  782. "application/vnd.jupyter.widget-view+json": {
  783. "model_id": "",
  784. "version_major": 2,
  785. "version_minor": 0
  786. },
  787. "text/plain": [
  788. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  789. ]
  790. },
  791. "metadata": {},
  792. "output_type": "display_data"
  793. },
  794. {
  795. "name": "stdout",
  796. "output_type": "stream",
  797. "text": [
  798. "\r",
  799. "Evaluate data in 0.23 seconds!\n",
  800. "\r",
  801. "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
  802. "\r",
  803. "AccuracyMetric: acc=0.751147\n",
  804. "\n",
  805. "\r\n",
  806. "In Epoch:5/Step:770, got best dev performance:\n",
  807. "AccuracyMetric: acc=0.756881\n",
  808. "Reloaded the best model.\n"
  809. ]
  810. },
  811. {
  812. "data": {
  813. "text/plain": [
  814. "{'best_eval': {'AccuracyMetric': {'acc': 0.756881}},\n",
  815. " 'best_epoch': 5,\n",
  816. " 'best_step': 770,\n",
  817. " 'seconds': 45.69}"
  818. ]
  819. },
  820. "execution_count": 7,
  821. "metadata": {},
  822. "output_type": "execute_result"
  823. }
  824. ],
  825. "source": [
  826. "trainer = Trainer(train_data=train_data, dev_data=dev_data, metrics=metric,\n",
  827. " loss=loss, device=device, model=model_lstm)\n",
  828. "trainer.train()"
  829. ]
  830. },
  831. {
  832. "cell_type": "markdown",
  833. "metadata": {},
  834. "source": [
  835. "## 使用 modules 编写模型\n",
  836. "\n",
  837. "下面我们使用 fastNLP.modules 中的组件来构建同样的网络。由于 fastNLP 统一把 batch_size 放在第一维, 在编写代码的过程中会有一定的便利。"
  838. ]
  839. },
  840. {
  841. "cell_type": "code",
  842. "execution_count": 8,
  843. "metadata": {},
  844. "outputs": [
  845. {
  846. "data": {
  847. "text/plain": [
  848. "MyText(\n",
  849. " (embedding): Embedding(\n",
  850. " (embed): Embedding(16292, 100)\n",
  851. " (dropout): Dropout(p=0.0, inplace=False)\n",
  852. " )\n",
  853. " (lstm): LSTM(\n",
  854. " (lstm): LSTM(100, 64, num_layers=2, batch_first=True, bidirectional=True)\n",
  855. " )\n",
  856. " (mlp): MLP(\n",
  857. " (hiddens): ModuleList()\n",
  858. " (output): Linear(in_features=128, out_features=2, bias=True)\n",
  859. " (dropout): Dropout(p=0.5, inplace=False)\n",
  860. " )\n",
  861. ")"
  862. ]
  863. },
  864. "execution_count": 8,
  865. "metadata": {},
  866. "output_type": "execute_result"
  867. }
  868. ],
  869. "source": [
  870. "from fastNLP.modules import LSTM, MLP\n",
  871. "from fastNLP.embeddings import Embedding\n",
  872. "\n",
  873. "\n",
  874. "class MyText(nn.Module):\n",
  875. " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
  876. " super().__init__()\n",
  877. "\n",
  878. " self.embedding = Embedding((vocab_size, embedding_dim))\n",
  879. " self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n",
  880. " self.mlp = MLP([hidden_dim*2,output_dim], dropout=dropout)\n",
  881. "\n",
  882. " def forward(self, words):\n",
  883. " embedded = self.embedding(words)\n",
  884. " _,(hidden,_) = self.lstm(embedded)\n",
  885. " pred = self.mlp(torch.cat((hidden[-1],hidden[-2]),dim=1))\n",
  886. " return {\"pred\":pred}\n",
  887. " \n",
  888. "model_text = MyText(len(vocab), 100, 2)\n",
  889. "model_text"
  890. ]
  891. },
  892. {
  893. "cell_type": "code",
  894. "execution_count": null,
  895. "metadata": {},
  896. "outputs": [
  897. {
  898. "name": "stdout",
  899. "output_type": "stream",
  900. "text": [
  901. "input fields after batch(if batch size is 2):\n",
  902. "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 41]) \n",
  903. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  904. "target fields after batch(if batch size is 2):\n",
  905. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  906. "\n",
  907. "training epochs started 2020-02-28-00-57-19\n"
  908. ]
  909. },
  910. {
  911. "data": {
  912. "application/vnd.jupyter.widget-view+json": {
  913. "model_id": "16a35f2b0ef0457dae15c5f240a19a3a",
  914. "version_major": 2,
  915. "version_minor": 0
  916. },
  917. "text/plain": [
  918. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
  919. ]
  920. },
  921. "metadata": {},
  922. "output_type": "display_data"
  923. },
  924. {
  925. "data": {
  926. "application/vnd.jupyter.widget-view+json": {
  927. "model_id": "",
  928. "version_major": 2,
  929. "version_minor": 0
  930. },
  931. "text/plain": [
  932. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  933. ]
  934. },
  935. "metadata": {},
  936. "output_type": "display_data"
  937. },
  938. {
  939. "name": "stdout",
  940. "output_type": "stream",
  941. "text": [
  942. "\r",
  943. "Evaluate data in 0.38 seconds!\n",
  944. "\r",
  945. "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
  946. "\r",
  947. "AccuracyMetric: acc=0.767202\n",
  948. "\n"
  949. ]
  950. },
  951. {
  952. "data": {
  953. "application/vnd.jupyter.widget-view+json": {
  954. "model_id": "",
  955. "version_major": 2,
  956. "version_minor": 0
  957. },
  958. "text/plain": [
  959. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  960. ]
  961. },
  962. "metadata": {},
  963. "output_type": "display_data"
  964. },
  965. {
  966. "name": "stdout",
  967. "output_type": "stream",
  968. "text": [
  969. "\r",
  970. "Evaluate data in 0.22 seconds!\n",
  971. "\r",
  972. "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
  973. "\r",
  974. "AccuracyMetric: acc=0.743119\n",
  975. "\n"
  976. ]
  977. }
  978. ],
  979. "source": [
  980. "trainer = Trainer(train_data=train_data, dev_data=dev_data, metrics=metric,\n",
  981. " loss=loss, device=device, model=model_lstm)\n",
  982. "trainer.train()"
  983. ]
  984. },
  985. {
  986. "cell_type": "code",
  987. "execution_count": null,
  988. "metadata": {},
  989. "outputs": [],
  990. "source": []
  991. }
  992. ],
  993. "metadata": {
  994. "kernelspec": {
  995. "display_name": "Python Now",
  996. "language": "python",
  997. "name": "now"
  998. },
  999. "language_info": {
  1000. "codemirror_mode": {
  1001. "name": "ipython",
  1002. "version": 3
  1003. },
  1004. "file_extension": ".py",
  1005. "mimetype": "text/x-python",
  1006. "name": "python",
  1007. "nbconvert_exporter": "python",
  1008. "pygments_lexer": "ipython3",
  1009. "version": "3.8.0"
  1010. }
  1011. },
  1012. "nbformat": 4,
  1013. "nbformat_minor": 2
  1014. }