You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_10min_tutorial_v2.ipynb 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "fastNLP上手教程\n",
  8. "-------\n",
  9. "\n",
  10. "fastNLP提供方便的数据预处理,训练和测试模型的功能"
  11. ]
  12. },
  13. {
  14. "cell_type": "markdown",
  15. "metadata": {},
  16. "source": [
  17. "DataSet & Instance\n",
  18. "------\n",
  19. "\n",
  20. "fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。\n",
  21. "\n",
  22. "有一些read_*方法,可以轻松从文件读取数据,存成DataSet。"
  23. ]
  24. },
  25. {
  26. "cell_type": "code",
  27. "execution_count": 9,
  28. "metadata": {},
  29. "outputs": [
  30. {
  31. "name": "stdout",
  32. "output_type": "stream",
  33. "text": [
  34. "8529"
  35. ]
  36. },
  37. {
  38. "name": "stdout",
  39. "output_type": "stream",
  40. "text": [
  41. "\n"
  42. ]
  43. }
  44. ],
  45. "source": [
  46. "from fastNLP import DataSet\n",
  47. "from fastNLP import Instance\n",
  48. "\n",
  49. "# 从csv读取数据到DataSet\n",
  50. "dataset = DataSet.read_csv('../sentence.csv', headers=('raw_sentence', 'label'), sep='\\t')\n",
  51. "print(len(dataset))"
  52. ]
  53. },
  54. {
  55. "cell_type": "code",
  56. "execution_count": 10,
  57. "metadata": {},
  58. "outputs": [
  59. {
  60. "name": "stdout",
  61. "output_type": "stream",
  62. "text": [
  63. "{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}"
  64. ]
  65. },
  66. {
  67. "name": "stdout",
  68. "output_type": "stream",
  69. "text": [
  70. "\n"
  71. ]
  72. }
  73. ],
  74. "source": [
  75. "# 使用数字索引[k],获取第k个样本\n",
  76. "print(dataset[0])\n",
  77. "\n",
  78. "# 索引也可以是负数\n",
  79. "print(dataset[-3])"
  80. ]
  81. },
  82. {
  83. "cell_type": "markdown",
  84. "metadata": {},
  85. "source": [
  86. "## Instance\n",
  87. "Instance表示一个样本,由一个或多个field(域,属性,特征)组成,每个field有名字和值。\n",
  88. "\n",
  89. "在初始化Instance时即可定义它包含的域,使用 \"field_name=field_value\"的写法。"
  90. ]
  91. },
  92. {
  93. "cell_type": "code",
  94. "execution_count": 11,
  95. "metadata": {},
  96. "outputs": [
  97. {
  98. "data": {
  99. "text/plain": [
  100. "{'raw_sentence': fake data,\n'label': 0}"
  101. ]
  102. },
  103. "execution_count": 11,
  104. "metadata": {},
  105. "output_type": "execute_result"
  106. }
  107. ],
  108. "source": [
  109. "# DataSet.append(Instance)加入新数据\n",
  110. "dataset.append(Instance(raw_sentence='fake data', label='0'))\n",
  111. "dataset[-1]"
  112. ]
  113. },
  114. {
  115. "cell_type": "markdown",
  116. "metadata": {},
  117. "source": [
  118. "## DataSet.apply方法\n",
  119. "数据预处理利器"
  120. ]
  121. },
  122. {
  123. "cell_type": "code",
  124. "execution_count": 12,
  125. "metadata": {},
  126. "outputs": [
  127. {
  128. "name": "stdout",
  129. "output_type": "stream",
  130. "text": [
  131. "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}"
  132. ]
  133. },
  134. {
  135. "name": "stdout",
  136. "output_type": "stream",
  137. "text": [
  138. "\n"
  139. ]
  140. }
  141. ],
  142. "source": [
  143. "# 将所有数字转为小写\n",
  144. "dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n",
  145. "print(dataset[0])"
  146. ]
  147. },
  148. {
  149. "cell_type": "code",
  150. "execution_count": 13,
  151. "metadata": {},
  152. "outputs": [
  153. {
  154. "name": "stdout",
  155. "output_type": "stream",
  156. "text": [
  157. "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}"
  158. ]
  159. },
  160. {
  161. "name": "stdout",
  162. "output_type": "stream",
  163. "text": [
  164. "\n"
  165. ]
  166. }
  167. ],
  168. "source": [
  169. "# label转int\n",
  170. "dataset.apply(lambda x: int(x['label']), new_field_name='label')\n",
  171. "print(dataset[0])"
  172. ]
  173. },
  174. {
  175. "cell_type": "code",
  176. "execution_count": 14,
  177. "metadata": {},
  178. "outputs": [
  179. {
  180. "name": "stdout",
  181. "output_type": "stream",
  182. "text": [
  183. "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1,\n'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.']}"
  184. ]
  185. },
  186. {
  187. "name": "stdout",
  188. "output_type": "stream",
  189. "text": [
  190. "\n"
  191. ]
  192. }
  193. ],
  194. "source": [
  195. "# 使用空格分割句子\n",
  196. "def split_sent(ins):\n",
  197. " return ins['raw_sentence'].split()\n",
  198. "dataset.apply(split_sent, new_field_name='words')\n",
  199. "print(dataset[0])"
  200. ]
  201. },
  202. {
  203. "cell_type": "code",
  204. "execution_count": 15,
  205. "metadata": {},
  206. "outputs": [
  207. {
  208. "name": "stdout",
  209. "output_type": "stream",
  210. "text": [
  211. "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1,\n'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'],\n'seq_len': 37}"
  212. ]
  213. },
  214. {
  215. "name": "stdout",
  216. "output_type": "stream",
  217. "text": [
  218. "\n"
  219. ]
  220. }
  221. ],
  222. "source": [
  223. "# 增加长度信息\n",
  224. "dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n",
  225. "print(dataset[0])"
  226. ]
  227. },
  228. {
  229. "cell_type": "markdown",
  230. "metadata": {},
  231. "source": [
  232. "## DataSet.drop\n",
  233. "筛选数据"
  234. ]
  235. },
  236. {
  237. "cell_type": "code",
  238. "execution_count": 16,
  239. "metadata": {},
  240. "outputs": [
  241. {
  242. "name": "stdout",
  243. "output_type": "stream",
  244. "text": [
  245. "8358"
  246. ]
  247. },
  248. {
  249. "name": "stdout",
  250. "output_type": "stream",
  251. "text": [
  252. "\n"
  253. ]
  254. }
  255. ],
  256. "source": [
  257. "dataset.drop(lambda x: x['seq_len'] <= 3)\n",
  258. "print(len(dataset))"
  259. ]
  260. },
  261. {
  262. "cell_type": "markdown",
  263. "metadata": {},
  264. "source": [
  265. "## 配置DataSet\n",
  266. "1. 哪些域是特征,哪些域是标签\n",
  267. "2. 切分训练集/验证集"
  268. ]
  269. },
  270. {
  271. "cell_type": "code",
  272. "execution_count": 17,
  273. "metadata": {},
  274. "outputs": [],
  275. "source": [
  276. "# 设置DataSet中,哪些field要转为tensor\n",
  277. "\n",
  278. "# set target,loss或evaluate中的golden,计算loss,模型评估时使用\n",
  279. "dataset.set_target(\"label\")\n",
  280. "# set input,模型forward时使用\n",
  281. "dataset.set_input(\"words\")"
  282. ]
  283. },
  284. {
  285. "cell_type": "code",
  286. "execution_count": 18,
  287. "metadata": {},
  288. "outputs": [
  289. {
  290. "name": "stdout",
  291. "output_type": "stream",
  292. "text": [
  293. "5851"
  294. ]
  295. },
  296. {
  297. "name": "stdout",
  298. "output_type": "stream",
  299. "text": [
  300. "\n"
  301. ]
  302. },
  303. {
  304. "name": "stdout",
  305. "output_type": "stream",
  306. "text": [
  307. "2507"
  308. ]
  309. },
  310. {
  311. "name": "stdout",
  312. "output_type": "stream",
  313. "text": [
  314. "\n"
  315. ]
  316. }
  317. ],
  318. "source": [
  319. "# 分出测试集、训练集\n",
  320. "\n",
  321. "test_data, train_data = dataset.split(0.3)\n",
  322. "print(len(test_data))\n",
  323. "print(len(train_data))"
  324. ]
  325. },
  326. {
  327. "cell_type": "markdown",
  328. "metadata": {},
  329. "source": [
  330. "Vocabulary\n",
  331. "------\n",
  332. "\n",
  333. "fastNLP中的Vocabulary轻松构建词表,将词转成数字"
  334. ]
  335. },
  336. {
  337. "cell_type": "code",
  338. "execution_count": 19,
  339. "metadata": {},
  340. "outputs": [
  341. {
  342. "name": "stdout",
  343. "output_type": "stream",
  344. "text": [
  345. "{'raw_sentence': the project 's filmmakers forgot to include anything even halfway scary as they poorly rejigger fatal attraction into a high school setting .,\n'label': 0,\n'words': [4, 423, 9, 316, 1, 8, 1, 312, 72, 1478, 885, 14, 86, 725, 1, 1913, 1431, 53, 5, 455, 736, 1, 2],\n'seq_len': 23}"
  346. ]
  347. },
  348. {
  349. "name": "stdout",
  350. "output_type": "stream",
  351. "text": [
  352. "\n"
  353. ]
  354. }
  355. ],
  356. "source": [
  357. "from fastNLP import Vocabulary\n",
  358. "\n",
  359. "# 构建词表, Vocabulary.add(word)\n",
  360. "vocab = Vocabulary(min_freq=2)\n",
  361. "train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n",
  362. "vocab.build_vocab()\n",
  363. "\n",
  364. "# index句子, Vocabulary.to_index(word)\n",
  365. "train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n",
  366. "test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n",
  367. "\n",
  368. "\n",
  369. "print(test_data[0])"
  370. ]
  371. },
  372. {
  373. "cell_type": "markdown",
  374. "metadata": {},
  375. "source": [
  376. "# Model\n",
  377. "定义一个PyTorch模型"
  378. ]
  379. },
  380. {
  381. "cell_type": "code",
  382. "execution_count": 20,
  383. "metadata": {},
  384. "outputs": [
  385. {
  386. "data": {
  387. "text/plain": [
  388. "CNNText(\n (embed): Embedding(\n (embed): Embedding(3459, 50, padding_idx=0)\n (dropout): Dropout(p=0.0)\n )\n (conv_pool): ConvMaxpool(\n (convs): ModuleList(\n (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n )\n )\n (dropout): Dropout(p=0.1)\n (fc): Linear(\n (linear): Linear(in_features=12, out_features=5, bias=True)\n )\n)"
  389. ]
  390. },
  391. "execution_count": 20,
  392. "metadata": {},
  393. "output_type": "execute_result"
  394. }
  395. ],
  396. "source": [
  397. "from fastNLP.models import CNNText\n",
  398. "model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n",
  399. "model"
  400. ]
  401. },
  402. {
  403. "cell_type": "markdown",
  404. "metadata": {},
  405. "source": [
  406. "这是上述模型的forward方法。如果你不知道什么是forward方法,请参考我们的PyTorch教程。\n",
  407. "\n",
  408. "注意两点:\n",
  409. "1. forward参数名字叫**word_seq**,请记住。\n",
  410. "2. forward的返回值是一个**dict**,其中有个key的名字叫**output**。\n",
  411. "\n",
  412. "```Python\n",
  413. " def forward(self, word_seq):\n",
  414. " \"\"\"\n",
  415. "\n",
  416. " :param word_seq: torch.LongTensor, [batch_size, seq_len]\n",
  417. " :return output: dict of torch.LongTensor, [batch_size, num_classes]\n",
  418. " \"\"\"\n",
  419. " x = self.embed(word_seq) # [N,L] -> [N,L,C]\n",
  420. " x = self.conv_pool(x) # [N,L,C] -> [N,C]\n",
  421. " x = self.dropout(x)\n",
  422. " x = self.fc(x) # [N,C] -> [N, N_class]\n",
  423. " return {'output': x}\n",
  424. "```"
  425. ]
  426. },
  427. {
  428. "cell_type": "markdown",
  429. "metadata": {},
  430. "source": [
  431. "这是上述模型的predict方法,是用来直接输出该任务的预测结果,与forward目的不同。\n",
  432. "\n",
  433. "注意两点:\n",
  434. "1. predict参数名也叫**word_seq**。\n",
  435. "2. predict的返回值是也一个**dict**,其中有个key的名字叫**predict**。\n",
  436. "\n",
  437. "```\n",
  438. " def predict(self, word_seq):\n",
  439. " \"\"\"\n",
  440. "\n",
  441. " :param word_seq: torch.LongTensor, [batch_size, seq_len]\n",
  442. " :return predict: dict of torch.LongTensor, [batch_size, seq_len]\n",
  443. " \"\"\"\n",
  444. " output = self(word_seq)\n",
  445. " _, predict = output['output'].max(dim=1)\n",
  446. " return {'predict': predict}\n",
  447. "```"
  448. ]
  449. },
  450. {
  451. "cell_type": "markdown",
  452. "metadata": {},
  453. "source": [
  454. "Trainer & Tester\n",
  455. "------\n",
  456. "\n",
  457. "使用fastNLP的Trainer训练模型"
  458. ]
  459. },
  460. {
  461. "cell_type": "code",
  462. "execution_count": 21,
  463. "metadata": {},
  464. "outputs": [],
  465. "source": [
  466. "from fastNLP import Trainer\n",
  467. "from copy import deepcopy\n",
  468. "from fastNLP.core.losses import CrossEntropyLoss\n",
  469. "from fastNLP.core.metrics import AccuracyMetric\n",
  470. "\n",
  471. "\n",
  472. "# 更改DataSet中对应field的名称,与模型的forward的参数名一致\n",
  473. "# 因为forward的参数叫word_seq, 所以要把原本叫words的field改名为word_seq\n",
  474. "# 这里的演示是让你了解这种**命名规则**\n",
  475. "train_data.rename_field('words', 'word_seq')\n",
  476. "test_data.rename_field('words', 'word_seq')\n",
  477. "\n",
  478. "# 顺便把label换名为label_seq\n",
  479. "train_data.rename_field('label', 'label_seq')\n",
  480. "test_data.rename_field('label', 'label_seq')"
  481. ]
  482. },
  483. {
  484. "cell_type": "markdown",
  485. "metadata": {},
  486. "source": [
  487. "### loss\n",
  488. "训练模型需要提供一个损失函数\n",
  489. "\n",
  490. "下面提供了一个在分类问题中常用的交叉熵损失。注意它的**初始化参数**。\n",
  491. "\n",
  492. "pred参数对应的是模型的forward返回的dict的一个key的名字,这里是\"output\"。\n",
  493. "\n",
  494. "target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。"
  495. ]
  496. },
  497. {
  498. "cell_type": "code",
  499. "execution_count": 22,
  500. "metadata": {},
  501. "outputs": [],
  502. "source": [
  503. "loss = CrossEntropyLoss(pred=\"output\", target=\"label_seq\")"
  504. ]
  505. },
  506. {
  507. "cell_type": "markdown",
  508. "metadata": {},
  509. "source": [
  510. "### Metric\n",
  511. "定义评价指标\n",
  512. "\n",
  513. "这里使用准确率。参数的“命名规则”跟上面类似。\n",
  514. "\n",
  515. "pred参数对应的是模型的predict方法返回的dict的一个key的名字,这里是\"predict\"。\n",
  516. "\n",
  517. "target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。"
  518. ]
  519. },
  520. {
  521. "cell_type": "code",
  522. "execution_count": 23,
  523. "metadata": {},
  524. "outputs": [],
  525. "source": [
  526. "metric = AccuracyMetric(pred=\"predict\", target=\"label_seq\")"
  527. ]
  528. },
  529. {
  530. "cell_type": "code",
  531. "execution_count": 24,
  532. "metadata": {},
  533. "outputs": [
  534. {
  535. "name": "stdout",
  536. "output_type": "stream",
  537. "text": [
  538. "training epochs started 2018-12-07 14:11:31"
  539. ]
  540. },
  541. {
  542. "name": "stdout",
  543. "output_type": "stream",
  544. "text": [
  545. "\n"
  546. ]
  547. },
  548. {
  549. "data": {
  550. "text/plain": [
  551. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=915), HTML(value='')), layout=Layout(display=…"
  552. ]
  553. },
  554. "execution_count": 0,
  555. "metadata": {},
  556. "output_type": "execute_result"
  557. },
  558. {
  559. "name": "stdout",
  560. "output_type": "stream",
  561. "text": [
  562. "\r"
  563. ]
  564. },
  565. {
  566. "name": "stdout",
  567. "output_type": "stream",
  568. "text": [
  569. "Epoch 1/5. Step:183/915. AccuracyMetric: acc=0.350367"
  570. ]
  571. },
  572. {
  573. "name": "stdout",
  574. "output_type": "stream",
  575. "text": [
  576. "\n"
  577. ]
  578. },
  579. {
  580. "name": "stdout",
  581. "output_type": "stream",
  582. "text": [
  583. "\r"
  584. ]
  585. },
  586. {
  587. "name": "stdout",
  588. "output_type": "stream",
  589. "text": [
  590. "Epoch 2/5. Step:366/915. AccuracyMetric: acc=0.409332"
  591. ]
  592. },
  593. {
  594. "name": "stdout",
  595. "output_type": "stream",
  596. "text": [
  597. "\n"
  598. ]
  599. },
  600. {
  601. "name": "stdout",
  602. "output_type": "stream",
  603. "text": [
  604. "\r"
  605. ]
  606. },
  607. {
  608. "name": "stdout",
  609. "output_type": "stream",
  610. "text": [
  611. "Epoch 3/5. Step:549/915. AccuracyMetric: acc=0.572552"
  612. ]
  613. },
  614. {
  615. "name": "stdout",
  616. "output_type": "stream",
  617. "text": [
  618. "\n"
  619. ]
  620. },
  621. {
  622. "name": "stdout",
  623. "output_type": "stream",
  624. "text": [
  625. "\r"
  626. ]
  627. },
  628. {
  629. "name": "stdout",
  630. "output_type": "stream",
  631. "text": [
  632. "Epoch 4/5. Step:732/915. AccuracyMetric: acc=0.711331"
  633. ]
  634. },
  635. {
  636. "name": "stdout",
  637. "output_type": "stream",
  638. "text": [
  639. "\n"
  640. ]
  641. },
  642. {
  643. "name": "stdout",
  644. "output_type": "stream",
  645. "text": [
  646. "\r"
  647. ]
  648. },
  649. {
  650. "name": "stdout",
  651. "output_type": "stream",
  652. "text": [
  653. "Epoch 5/5. Step:915/915. AccuracyMetric: acc=0.801572"
  654. ]
  655. },
  656. {
  657. "name": "stdout",
  658. "output_type": "stream",
  659. "text": [
  660. "\n"
  661. ]
  662. },
  663. {
  664. "name": "stdout",
  665. "output_type": "stream",
  666. "text": [
  667. "\r"
  668. ]
  669. }
  670. ],
  671. "source": [
  672. "# 实例化Trainer,传入模型和数据,进行训练\n",
  673. "# 先在test_data拟合\n",
  674. "copy_model = deepcopy(model)\n",
  675. "overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,\n",
  676. " loss=loss,\n",
  677. " metrics=metric,\n",
  678. " save_path=None,\n",
  679. " batch_size=32,\n",
  680. " n_epochs=5)\n",
  681. "overfit_trainer.train()"
  682. ]
  683. },
  684. {
  685. "cell_type": "code",
  686. "execution_count": 25,
  687. "metadata": {},
  688. "outputs": [
  689. {
  690. "name": "stdout",
  691. "output_type": "stream",
  692. "text": [
  693. "training epochs started 2018-12-07 14:12:21"
  694. ]
  695. },
  696. {
  697. "name": "stdout",
  698. "output_type": "stream",
  699. "text": [
  700. "\n"
  701. ]
  702. },
  703. {
  704. "data": {
  705. "text/plain": [
  706. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=395), HTML(value='')), layout=Layout(display=…"
  707. ]
  708. },
  709. "execution_count": 0,
  710. "metadata": {},
  711. "output_type": "execute_result"
  712. },
  713. {
  714. "name": "stdout",
  715. "output_type": "stream",
  716. "text": [
  717. "\r"
  718. ]
  719. },
  720. {
  721. "name": "stdout",
  722. "output_type": "stream",
  723. "text": [
  724. "Epoch 1/5. Step:79/395. AccuracyMetric: acc=0.250043"
  725. ]
  726. },
  727. {
  728. "name": "stdout",
  729. "output_type": "stream",
  730. "text": [
  731. "\n"
  732. ]
  733. },
  734. {
  735. "name": "stdout",
  736. "output_type": "stream",
  737. "text": [
  738. "\r"
  739. ]
  740. },
  741. {
  742. "name": "stdout",
  743. "output_type": "stream",
  744. "text": [
  745. "Epoch 2/5. Step:158/395. AccuracyMetric: acc=0.280807"
  746. ]
  747. },
  748. {
  749. "name": "stdout",
  750. "output_type": "stream",
  751. "text": [
  752. "\n"
  753. ]
  754. },
  755. {
  756. "name": "stdout",
  757. "output_type": "stream",
  758. "text": [
  759. "\r"
  760. ]
  761. },
  762. {
  763. "name": "stdout",
  764. "output_type": "stream",
  765. "text": [
  766. "Epoch 3/5. Step:237/395. AccuracyMetric: acc=0.280978"
  767. ]
  768. },
  769. {
  770. "name": "stdout",
  771. "output_type": "stream",
  772. "text": [
  773. "\n"
  774. ]
  775. },
  776. {
  777. "name": "stdout",
  778. "output_type": "stream",
  779. "text": [
  780. "\r"
  781. ]
  782. },
  783. {
  784. "name": "stdout",
  785. "output_type": "stream",
  786. "text": [
  787. "Epoch 4/5. Step:316/395. AccuracyMetric: acc=0.285592"
  788. ]
  789. },
  790. {
  791. "name": "stdout",
  792. "output_type": "stream",
  793. "text": [
  794. "\n"
  795. ]
  796. },
  797. {
  798. "name": "stdout",
  799. "output_type": "stream",
  800. "text": [
  801. "\r"
  802. ]
  803. },
  804. {
  805. "name": "stdout",
  806. "output_type": "stream",
  807. "text": [
  808. "Epoch 5/5. Step:395/395. AccuracyMetric: acc=0.278927"
  809. ]
  810. },
  811. {
  812. "name": "stdout",
  813. "output_type": "stream",
  814. "text": [
  815. "\n"
  816. ]
  817. },
  818. {
  819. "name": "stdout",
  820. "output_type": "stream",
  821. "text": [
  822. "\r"
  823. ]
  824. }
  825. ],
  826. "source": [
  827. "# 用train_data训练,在test_data验证\n",
  828. "trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,\n",
  829. " loss=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n",
  830. " metrics=AccuracyMetric(pred=\"predict\", target=\"label_seq\"),\n",
  831. " save_path=None,\n",
  832. " batch_size=32,\n",
  833. " n_epochs=5)\n",
  834. "trainer.train()\n",
  835. "print('Train finished!')"
  836. ]
  837. },
  838. {
  839. "cell_type": "code",
  840. "execution_count": 26,
  841. "metadata": {},
  842. "outputs": [
  843. {
  844. "name": "stdout",
  845. "output_type": "stream",
  846. "text": [
  847. "[tester] \nAccuracyMetric: acc=0.280636"
  848. ]
  849. },
  850. {
  851. "name": "stdout",
  852. "output_type": "stream",
  853. "text": [
  854. "\n"
  855. ]
  856. },
  857. {
  858. "name": "stdout",
  859. "output_type": "stream",
  860. "text": [
  861. "{'AccuracyMetric': {'acc': 0.280636}}"
  862. ]
  863. },
  864. {
  865. "name": "stdout",
  866. "output_type": "stream",
  867. "text": [
  868. "\n"
  869. ]
  870. }
  871. ],
  872. "source": [
  873. "# 调用Tester在test_data上评价效果\n",
  874. "from fastNLP import Tester\n",
  875. "\n",
  876. "tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred=\"predict\", target=\"label_seq\"),\n",
  877. " batch_size=4)\n",
  878. "acc = tester.test()\n",
  879. "print(acc)"
  880. ]
  881. },
  882. {
  883. "cell_type": "code",
  884. "execution_count": null,
  885. "metadata": {},
  886. "outputs": [],
  887. "source": []
  888. }
  889. ],
  890. "metadata": {
  891. "kernelspec": {
  892. "display_name": "Python 3",
  893. "language": "python",
  894. "name": "python3"
  895. },
  896. "language_info": {
  897. "codemirror_mode": {
  898. "name": "ipython",
  899. "version": 3
  900. },
  901. "file_extension": ".py",
  902. "mimetype": "text/x-python",
  903. "name": "python",
  904. "nbconvert_exporter": "python",
  905. "pygments_lexer": "ipython3",
  906. "version": "3.6.7"
  907. }
  908. },
  909. "nbformat": 4,
  910. "nbformat_minor": 2
  911. }