You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tutorial_7_metrics.ipynb 33 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# 使用Metric快速评测你的模型\n",
  8. "\n",
  9. "和上一篇教程一样的实验准备代码"
  10. ]
  11. },
  12. {
  13. "cell_type": "code",
  14. "execution_count": 1,
  15. "metadata": {},
  16. "outputs": [
  17. {
  18. "name": "stderr",
  19. "output_type": "stream",
  20. "text": [
  21. "/remote-home/ynzheng/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/io/loader/classification.py:340: UserWarning: SST2's test file has no target.\n"
  22. ]
  23. }
  24. ],
  25. "source": [
  26. "from fastNLP.io import SST2Pipe\n",
  27. "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
  28. "from fastNLP.models import CNNText\n",
  29. "from fastNLP import CrossEntropyLoss\n",
  30. "import torch\n",
  31. "from torch.optim import Adam\n",
  32. "from fastNLP import AccuracyMetric\n",
  33. "\n",
  34. "databundle = SST2Pipe().process_from_file()\n",
  35. "vocab = databundle.get_vocab('words')\n",
  36. "train_data = databundle.get_dataset('train')[:5000]\n",
  37. "train_data, test_data = train_data.split(0.015)\n",
  38. "dev_data = databundle.get_dataset('dev')\n",
  39. "\n",
  40. "model = CNNText((len(vocab),100), num_classes=2, dropout=0.1)\n",
  41. "loss = CrossEntropyLoss()\n",
  42. "metric = AccuracyMetric()\n",
  43. "optimizer = Adam(model.parameters(), lr=0.001)\n",
  44. "device = 0 if torch.cuda.is_available() else 'cpu'"
  45. ]
  46. },
  47. {
  48. "cell_type": "markdown",
  49. "metadata": {},
  50. "source": [
  51. "进行训练时,fastNLP提供了各种各样的 metrics 。 如前面的教程中所介绍,AccuracyMetric 类的对象被直接传到 Trainer 中用于训练"
  52. ]
  53. },
  54. {
  55. "cell_type": "code",
  56. "execution_count": 2,
  57. "metadata": {
  58. "scrolled": true
  59. },
  60. "outputs": [
  61. {
  62. "name": "stdout",
  63. "output_type": "stream",
  64. "text": [
  65. "input fields after batch(if batch size is 2):\n",
  66. "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) \n",
  67. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  68. "target fields after batch(if batch size is 2):\n",
  69. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  70. "\n",
  71. "training epochs started 2020-02-28-00-11-51\n"
  72. ]
  73. },
  74. {
  75. "data": {
  76. "application/vnd.jupyter.widget-view+json": {
  77. "model_id": "",
  78. "version_major": 2,
  79. "version_minor": 0
  80. },
  81. "text/plain": [
  82. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
  83. ]
  84. },
  85. "metadata": {},
  86. "output_type": "display_data"
  87. },
  88. {
  89. "data": {
  90. "application/vnd.jupyter.widget-view+json": {
  91. "model_id": "",
  92. "version_major": 2,
  93. "version_minor": 0
  94. },
  95. "text/plain": [
  96. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  97. ]
  98. },
  99. "metadata": {},
  100. "output_type": "display_data"
  101. },
  102. {
  103. "name": "stdout",
  104. "output_type": "stream",
  105. "text": [
  106. "\r",
  107. "Evaluate data in 0.16 seconds!\n",
  108. "\r",
  109. "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
  110. "\r",
  111. "AccuracyMetric: acc=0.722477\n",
  112. "\n"
  113. ]
  114. },
  115. {
  116. "data": {
  117. "application/vnd.jupyter.widget-view+json": {
  118. "model_id": "",
  119. "version_major": 2,
  120. "version_minor": 0
  121. },
  122. "text/plain": [
  123. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  124. ]
  125. },
  126. "metadata": {},
  127. "output_type": "display_data"
  128. },
  129. {
  130. "name": "stdout",
  131. "output_type": "stream",
  132. "text": [
  133. "\r",
  134. "Evaluate data in 0.36 seconds!\n",
  135. "\r",
  136. "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
  137. "\r",
  138. "AccuracyMetric: acc=0.762615\n",
  139. "\n"
  140. ]
  141. },
  142. {
  143. "data": {
  144. "application/vnd.jupyter.widget-view+json": {
  145. "model_id": "",
  146. "version_major": 2,
  147. "version_minor": 0
  148. },
  149. "text/plain": [
  150. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  151. ]
  152. },
  153. "metadata": {},
  154. "output_type": "display_data"
  155. },
  156. {
  157. "name": "stdout",
  158. "output_type": "stream",
  159. "text": [
  160. "\r",
  161. "Evaluate data in 0.16 seconds!\n",
  162. "\r",
  163. "Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
  164. "\r",
  165. "AccuracyMetric: acc=0.771789\n",
  166. "\n"
  167. ]
  168. },
  169. {
  170. "data": {
  171. "application/vnd.jupyter.widget-view+json": {
  172. "model_id": "",
  173. "version_major": 2,
  174. "version_minor": 0
  175. },
  176. "text/plain": [
  177. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  178. ]
  179. },
  180. "metadata": {},
  181. "output_type": "display_data"
  182. },
  183. {
  184. "name": "stdout",
  185. "output_type": "stream",
  186. "text": [
  187. "\r",
  188. "Evaluate data in 0.44 seconds!\n",
  189. "\r",
  190. "Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
  191. "\r",
  192. "AccuracyMetric: acc=0.759174\n",
  193. "\n"
  194. ]
  195. },
  196. {
  197. "data": {
  198. "application/vnd.jupyter.widget-view+json": {
  199. "model_id": "",
  200. "version_major": 2,
  201. "version_minor": 0
  202. },
  203. "text/plain": [
  204. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  205. ]
  206. },
  207. "metadata": {},
  208. "output_type": "display_data"
  209. },
  210. {
  211. "name": "stdout",
  212. "output_type": "stream",
  213. "text": [
  214. "\r",
  215. "Evaluate data in 0.29 seconds!\n",
  216. "\r",
  217. "Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
  218. "\r",
  219. "AccuracyMetric: acc=0.75344\n",
  220. "\n"
  221. ]
  222. },
  223. {
  224. "data": {
  225. "application/vnd.jupyter.widget-view+json": {
  226. "model_id": "",
  227. "version_major": 2,
  228. "version_minor": 0
  229. },
  230. "text/plain": [
  231. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  232. ]
  233. },
  234. "metadata": {},
  235. "output_type": "display_data"
  236. },
  237. {
  238. "name": "stdout",
  239. "output_type": "stream",
  240. "text": [
  241. "\r",
  242. "Evaluate data in 0.33 seconds!\n",
  243. "\r",
  244. "Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
  245. "\r",
  246. "AccuracyMetric: acc=0.75\n",
  247. "\n"
  248. ]
  249. },
  250. {
  251. "data": {
  252. "application/vnd.jupyter.widget-view+json": {
  253. "model_id": "",
  254. "version_major": 2,
  255. "version_minor": 0
  256. },
  257. "text/plain": [
  258. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  259. ]
  260. },
  261. "metadata": {},
  262. "output_type": "display_data"
  263. },
  264. {
  265. "name": "stdout",
  266. "output_type": "stream",
  267. "text": [
  268. "\r",
  269. "Evaluate data in 0.19 seconds!\n",
  270. "\r",
  271. "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
  272. "\r",
  273. "AccuracyMetric: acc=0.741972\n",
  274. "\n"
  275. ]
  276. },
  277. {
  278. "data": {
  279. "application/vnd.jupyter.widget-view+json": {
  280. "model_id": "",
  281. "version_major": 2,
  282. "version_minor": 0
  283. },
  284. "text/plain": [
  285. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  286. ]
  287. },
  288. "metadata": {},
  289. "output_type": "display_data"
  290. },
  291. {
  292. "name": "stdout",
  293. "output_type": "stream",
  294. "text": [
  295. "\r",
  296. "Evaluate data in 0.49 seconds!\n",
  297. "\r",
  298. "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
  299. "\r",
  300. "AccuracyMetric: acc=0.740826\n",
  301. "\n"
  302. ]
  303. },
  304. {
  305. "data": {
  306. "application/vnd.jupyter.widget-view+json": {
  307. "model_id": "",
  308. "version_major": 2,
  309. "version_minor": 0
  310. },
  311. "text/plain": [
  312. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  313. ]
  314. },
  315. "metadata": {},
  316. "output_type": "display_data"
  317. },
  318. {
  319. "name": "stdout",
  320. "output_type": "stream",
  321. "text": [
  322. "\r",
  323. "Evaluate data in 0.15 seconds!\n",
  324. "\r",
  325. "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
  326. "\r",
  327. "AccuracyMetric: acc=0.75\n",
  328. "\n"
  329. ]
  330. },
  331. {
  332. "data": {
  333. "application/vnd.jupyter.widget-view+json": {
  334. "model_id": "",
  335. "version_major": 2,
  336. "version_minor": 0
  337. },
  338. "text/plain": [
  339. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  340. ]
  341. },
  342. "metadata": {},
  343. "output_type": "display_data"
  344. },
  345. {
  346. "name": "stdout",
  347. "output_type": "stream",
  348. "text": [
  349. "\r",
  350. "Evaluate data in 0.16 seconds!\n",
  351. "\r",
  352. "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
  353. "\r",
  354. "AccuracyMetric: acc=0.752294\n",
  355. "\n",
  356. "\r\n",
  357. "In Epoch:3/Step:462, got best dev performance:\n",
  358. "AccuracyMetric: acc=0.771789\n",
  359. "Reloaded the best model.\n"
  360. ]
  361. },
  362. {
  363. "data": {
  364. "text/plain": [
  365. "{'best_eval': {'AccuracyMetric': {'acc': 0.771789}},\n",
  366. " 'best_epoch': 3,\n",
  367. " 'best_step': 462,\n",
  368. " 'seconds': 30.04}"
  369. ]
  370. },
  371. "execution_count": 2,
  372. "metadata": {},
  373. "output_type": "execute_result"
  374. }
  375. ],
  376. "source": [
  377. "trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
  378. " optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
  379. " metrics=metric, device=device)\n",
  380. "trainer.train()"
  381. ]
  382. },
  383. {
  384. "cell_type": "markdown",
  385. "metadata": {},
  386. "source": [
  387. "除了 AccuracyMetric 之外,SpanFPreRecMetric 也是一种非常见的评价指标, 例如在序列标注问题中,常以span的方式计算 F-measure, precision, recall。\n",
  388. "\n",
  389. "另外,fastNLP 还实现了用于抽取式QA(如SQuAD)的metric ExtractiveQAMetric。 用户可以参考下面这个表格。\n",
  390. "\n",
  391. "| 名称 | 介绍 |\n",
  392. "| -------------------- | ------------------------------------------------- |\n",
  393. "| `MetricBase` | 自定义metrics需继承的基类 |\n",
  394. "| `AccuracyMetric` | 简单的正确率metric |\n",
  395. "| `SpanFPreRecMetric` | 同时计算 F-measure, precision, recall 值的 metric |\n",
  396. "| `ExtractiveQAMetric` | 用于抽取式QA任务 的metric |\n",
  397. "\n"
  398. ]
  399. },
  400. {
  401. "cell_type": "markdown",
  402. "metadata": {},
  403. "source": [
  404. "## 定义自己的metrics\n",
  405. "\n",
  406. "在定义自己的metrics类时需继承 fastNLP 的 MetricBase, 并覆盖写入 evaluate 和 get_metric 方法。\n",
  407. "\n",
  408. "- evaluate(xxx) 中传入一个批次的数据,将针对一个批次的预测结果做评价指标的累计\n",
  409. "\n",
  410. "- get_metric(xxx) 当所有数据处理完毕时调用该方法,它将根据 evaluate函数累计的评价指标统计量来计算最终的评价结果\n",
  411. "\n",
  412. "以分类问题中,Accuracy计算为例,假设model的forward返回dict中包含 pred 这个key, 并且该key需要用于Accuracy:\n",
  413. "\n",
  414. "```python\n",
  415. "class Model(nn.Module):\n",
  416. " def __init__(xxx):\n",
  417. " # do something\n",
  418. " def forward(self, xxx):\n",
  419. " # do something\n",
  420. " return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes\n",
  421. "```"
  422. ]
  423. },
  424. {
  425. "cell_type": "markdown",
  426. "metadata": {},
  427. "source": [
  428. "### Version 1\n",
  429. "\n",
  430. "假设dataset中 `target` 这个 field 是需要预测的值,并且该 field 被设置为了 target 对应的 `AccMetric` 可以按如下的定义"
  431. ]
  432. },
  433. {
  434. "cell_type": "code",
  435. "execution_count": 3,
  436. "metadata": {},
  437. "outputs": [],
  438. "source": [
  439. "from fastNLP import MetricBase\n",
  440. "\n",
  441. "class AccMetric(MetricBase):\n",
  442. "\n",
  443. " def __init__(self):\n",
  444. " super().__init__()\n",
  445. " # 根据你的情况自定义指标\n",
  446. " self.total = 0\n",
  447. " self.acc_count = 0\n",
  448. "\n",
  449. " # evaluate的参数需要和DataSet 中 field 名以及模型输出的结果 field 名一致,不然找不到对应的value\n",
  450. " # pred, target 的参数是 fastNLP 的默认配置\n",
  451. " def evaluate(self, pred, target):\n",
  452. " # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric\n",
  453. " self.total += target.size(0)\n",
  454. " self.acc_count += target.eq(pred).sum().item()\n",
  455. "\n",
  456. " def get_metric(self, reset=True): # 在这里定义如何计算metric\n",
  457. " acc = self.acc_count/self.total\n",
  458. " if reset: # 是否清零以便重新计算\n",
  459. " self.acc_count = 0\n",
  460. " self.total = 0\n",
  461. " return {'acc': acc}\n",
  462. " # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中"
  463. ]
  464. },
  465. {
  466. "cell_type": "code",
  467. "execution_count": 4,
  468. "metadata": {
  469. "scrolled": true
  470. },
  471. "outputs": [
  472. {
  473. "name": "stdout",
  474. "output_type": "stream",
  475. "text": [
  476. "input fields after batch(if batch size is 2):\n",
  477. "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) \n",
  478. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  479. "target fields after batch(if batch size is 2):\n",
  480. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  481. "\n",
  482. "training epochs started 2020-02-28-00-12-21\n"
  483. ]
  484. },
  485. {
  486. "data": {
  487. "application/vnd.jupyter.widget-view+json": {
  488. "model_id": "",
  489. "version_major": 2,
  490. "version_minor": 0
  491. },
  492. "text/plain": [
  493. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
  494. ]
  495. },
  496. "metadata": {},
  497. "output_type": "display_data"
  498. },
  499. {
  500. "data": {
  501. "application/vnd.jupyter.widget-view+json": {
  502. "model_id": "",
  503. "version_major": 2,
  504. "version_minor": 0
  505. },
  506. "text/plain": [
  507. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  508. ]
  509. },
  510. "metadata": {},
  511. "output_type": "display_data"
  512. },
  513. {
  514. "name": "stdout",
  515. "output_type": "stream",
  516. "text": [
  517. "\r",
  518. "Evaluate data in 0.33 seconds!\n",
  519. "\r",
  520. "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
  521. "\r",
  522. "AccMetric: acc=0.7419724770642202\n",
  523. "\n"
  524. ]
  525. },
  526. {
  527. "data": {
  528. "application/vnd.jupyter.widget-view+json": {
  529. "model_id": "",
  530. "version_major": 2,
  531. "version_minor": 0
  532. },
  533. "text/plain": [
  534. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  535. ]
  536. },
  537. "metadata": {},
  538. "output_type": "display_data"
  539. },
  540. {
  541. "name": "stdout",
  542. "output_type": "stream",
  543. "text": [
  544. "\r",
  545. "Evaluate data in 0.19 seconds!\n",
  546. "\r",
  547. "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
  548. "\r",
  549. "AccMetric: acc=0.7660550458715596\n",
  550. "\n"
  551. ]
  552. },
  553. {
  554. "data": {
  555. "application/vnd.jupyter.widget-view+json": {
  556. "model_id": "",
  557. "version_major": 2,
  558. "version_minor": 0
  559. },
  560. "text/plain": [
  561. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  562. ]
  563. },
  564. "metadata": {},
  565. "output_type": "display_data"
  566. },
  567. {
  568. "name": "stdout",
  569. "output_type": "stream",
  570. "text": [
  571. "\r",
  572. "Evaluate data in 0.27 seconds!\n",
  573. "\r",
  574. "Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
  575. "\r",
  576. "AccMetric: acc=0.75\n",
  577. "\n"
  578. ]
  579. },
  580. {
  581. "data": {
  582. "application/vnd.jupyter.widget-view+json": {
  583. "model_id": "",
  584. "version_major": 2,
  585. "version_minor": 0
  586. },
  587. "text/plain": [
  588. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  589. ]
  590. },
  591. "metadata": {},
  592. "output_type": "display_data"
  593. },
  594. {
  595. "name": "stdout",
  596. "output_type": "stream",
  597. "text": [
  598. "\r",
  599. "Evaluate data in 0.24 seconds!\n",
  600. "\r",
  601. "Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
  602. "\r",
  603. "AccMetric: acc=0.7534403669724771\n",
  604. "\n"
  605. ]
  606. },
  607. {
  608. "data": {
  609. "application/vnd.jupyter.widget-view+json": {
  610. "model_id": "",
  611. "version_major": 2,
  612. "version_minor": 0
  613. },
  614. "text/plain": [
  615. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  616. ]
  617. },
  618. "metadata": {},
  619. "output_type": "display_data"
  620. },
  621. {
  622. "name": "stdout",
  623. "output_type": "stream",
  624. "text": [
  625. "\r",
  626. "Evaluate data in 0.29 seconds!\n",
  627. "\r",
  628. "Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
  629. "\r",
  630. "AccMetric: acc=0.7488532110091743\n",
  631. "\n"
  632. ]
  633. },
  634. {
  635. "data": {
  636. "application/vnd.jupyter.widget-view+json": {
  637. "model_id": "",
  638. "version_major": 2,
  639. "version_minor": 0
  640. },
  641. "text/plain": [
  642. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  643. ]
  644. },
  645. "metadata": {},
  646. "output_type": "display_data"
  647. },
  648. {
  649. "name": "stdout",
  650. "output_type": "stream",
  651. "text": [
  652. "\r",
  653. "Evaluate data in 0.14 seconds!\n",
  654. "\r",
  655. "Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
  656. "\r",
  657. "AccMetric: acc=0.7488532110091743\n",
  658. "\n"
  659. ]
  660. },
  661. {
  662. "data": {
  663. "application/vnd.jupyter.widget-view+json": {
  664. "model_id": "",
  665. "version_major": 2,
  666. "version_minor": 0
  667. },
  668. "text/plain": [
  669. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  670. ]
  671. },
  672. "metadata": {},
  673. "output_type": "display_data"
  674. },
  675. {
  676. "name": "stdout",
  677. "output_type": "stream",
  678. "text": [
  679. "\r",
  680. "Evaluate data in 0.27 seconds!\n",
  681. "\r",
  682. "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
  683. "\r",
  684. "AccMetric: acc=0.7568807339449541\n",
  685. "\n"
  686. ]
  687. },
  688. {
  689. "data": {
  690. "application/vnd.jupyter.widget-view+json": {
  691. "model_id": "",
  692. "version_major": 2,
  693. "version_minor": 0
  694. },
  695. "text/plain": [
  696. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  697. ]
  698. },
  699. "metadata": {},
  700. "output_type": "display_data"
  701. },
  702. {
  703. "name": "stdout",
  704. "output_type": "stream",
  705. "text": [
  706. "\r",
  707. "Evaluate data in 0.42 seconds!\n",
  708. "\r",
  709. "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
  710. "\r",
  711. "AccMetric: acc=0.7488532110091743\n",
  712. "\n"
  713. ]
  714. },
  715. {
  716. "data": {
  717. "application/vnd.jupyter.widget-view+json": {
  718. "model_id": "",
  719. "version_major": 2,
  720. "version_minor": 0
  721. },
  722. "text/plain": [
  723. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  724. ]
  725. },
  726. "metadata": {},
  727. "output_type": "display_data"
  728. },
  729. {
  730. "name": "stdout",
  731. "output_type": "stream",
  732. "text": [
  733. "\r",
  734. "Evaluate data in 0.16 seconds!\n",
  735. "\r",
  736. "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
  737. "\r",
  738. "AccMetric: acc=0.7408256880733946\n",
  739. "\n"
  740. ]
  741. },
  742. {
  743. "data": {
  744. "application/vnd.jupyter.widget-view+json": {
  745. "model_id": "",
  746. "version_major": 2,
  747. "version_minor": 0
  748. },
  749. "text/plain": [
  750. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  751. ]
  752. },
  753. "metadata": {},
  754. "output_type": "display_data"
  755. },
  756. {
  757. "name": "stdout",
  758. "output_type": "stream",
  759. "text": [
  760. "\r",
  761. "Evaluate data in 0.28 seconds!\n",
  762. "\r",
  763. "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
  764. "\r",
  765. "AccMetric: acc=0.7408256880733946\n",
  766. "\n",
  767. "\r\n",
  768. "In Epoch:2/Step:308, got best dev performance:\n",
  769. "AccMetric: acc=0.7660550458715596\n",
  770. "Reloaded the best model.\n"
  771. ]
  772. },
  773. {
  774. "data": {
  775. "text/plain": [
  776. "{'best_eval': {'AccMetric': {'acc': 0.7660550458715596}},\n",
  777. " 'best_epoch': 2,\n",
  778. " 'best_step': 308,\n",
  779. " 'seconds': 29.74}"
  780. ]
  781. },
  782. "execution_count": 4,
  783. "metadata": {},
  784. "output_type": "execute_result"
  785. }
  786. ],
  787. "source": [
  788. "trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
  789. " optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
  790. " metrics=AccMetric(), device=device)\n",
  791. "trainer.train()"
  792. ]
  793. },
  794. {
  795. "cell_type": "markdown",
  796. "metadata": {},
  797. "source": [
  798. "### Version 2\n",
  799. "\n",
  800. "如果需要复用 metric,比如下一次使用 `AccMetric` 时,dataset中目标field不叫 `target` 而叫 `y` ,或者model的输出不是 `pred`\n"
  801. ]
  802. },
  803. {
  804. "cell_type": "code",
  805. "execution_count": 5,
  806. "metadata": {},
  807. "outputs": [],
  808. "source": [
  809. "class AccMetric(MetricBase):\n",
  810. " def __init__(self, pred=None, target=None):\n",
  811. " \"\"\"\n",
  812. " 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时,\n",
  813. " acc_metric = AccMetric(pred='pred_y', target='y')即可。\n",
  814. " 当初始化为acc_metric = AccMetric() 时,fastNLP会直接使用 'pred', 'target' 作为key去索取对应的的值\n",
  815. " \"\"\"\n",
  816. "\n",
  817. " super().__init__()\n",
  818. "\n",
  819. " # 如果没有注册该则效果与 Version 1 就是一样的\n",
  820. " self._init_param_map(pred=pred, target=target) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可\n",
  821. "\n",
  822. " # 根据你的情况自定义指标\n",
  823. " self.total = 0\n",
  824. " self.acc_count = 0\n",
  825. "\n",
  826. " # evaluate的参数需要和DataSet 中 field 名以及模型输出的结果 field 名一致,不然找不到对应的value\n",
  827. " # pred, target 的参数是 fastNLP 的默认配置\n",
  828. " def evaluate(self, pred, target):\n",
  829. " # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric\n",
  830. " self.total += target.size(0)\n",
  831. " self.acc_count += target.eq(pred).sum().item()\n",
  832. "\n",
  833. " def get_metric(self, reset=True): # 在这里定义如何计算metric\n",
  834. " acc = self.acc_count/self.total\n",
  835. " if reset: # 是否清零以便重新计算\n",
  836. " self.acc_count = 0\n",
  837. " self.total = 0\n",
  838. " return {'acc': acc}\n",
  839. " # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中"
  840. ]
  841. },
  842. {
  843. "cell_type": "code",
  844. "execution_count": 6,
  845. "metadata": {
  846. "scrolled": true
  847. },
  848. "outputs": [
  849. {
  850. "name": "stdout",
  851. "output_type": "stream",
  852. "text": [
  853. "input fields after batch(if batch size is 2):\n",
  854. "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) \n",
  855. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  856. "target fields after batch(if batch size is 2):\n",
  857. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  858. "\n",
  859. "training epochs started 2020-02-28-00-12-51\n"
  860. ]
  861. },
  862. {
  863. "data": {
  864. "application/vnd.jupyter.widget-view+json": {
  865. "model_id": "",
  866. "version_major": 2,
  867. "version_minor": 0
  868. },
  869. "text/plain": [
  870. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
  871. ]
  872. },
  873. "metadata": {},
  874. "output_type": "display_data"
  875. },
  876. {
  877. "data": {
  878. "application/vnd.jupyter.widget-view+json": {
  879. "model_id": "",
  880. "version_major": 2,
  881. "version_minor": 0
  882. },
  883. "text/plain": [
  884. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  885. ]
  886. },
  887. "metadata": {},
  888. "output_type": "display_data"
  889. },
  890. {
  891. "name": "stdout",
  892. "output_type": "stream",
  893. "text": [
  894. "\r",
  895. "Evaluate data in 0.24 seconds!\n",
  896. "\r",
  897. "Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
  898. "\r",
  899. "AccMetric: acc=0.7545871559633027\n",
  900. "\n"
  901. ]
  902. },
  903. {
  904. "data": {
  905. "application/vnd.jupyter.widget-view+json": {
  906. "model_id": "",
  907. "version_major": 2,
  908. "version_minor": 0
  909. },
  910. "text/plain": [
  911. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  912. ]
  913. },
  914. "metadata": {},
  915. "output_type": "display_data"
  916. },
  917. {
  918. "name": "stdout",
  919. "output_type": "stream",
  920. "text": [
  921. "\r",
  922. "Evaluate data in 0.24 seconds!\n",
  923. "\r",
  924. "Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
  925. "\r",
  926. "AccMetric: acc=0.7534403669724771\n",
  927. "\n"
  928. ]
  929. },
  930. {
  931. "data": {
  932. "application/vnd.jupyter.widget-view+json": {
  933. "model_id": "",
  934. "version_major": 2,
  935. "version_minor": 0
  936. },
  937. "text/plain": [
  938. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  939. ]
  940. },
  941. "metadata": {},
  942. "output_type": "display_data"
  943. },
  944. {
  945. "name": "stdout",
  946. "output_type": "stream",
  947. "text": [
  948. "\r",
  949. "Evaluate data in 0.18 seconds!\n",
  950. "\r",
  951. "Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
  952. "\r",
  953. "AccMetric: acc=0.7557339449541285\n",
  954. "\n"
  955. ]
  956. },
  957. {
  958. "data": {
  959. "application/vnd.jupyter.widget-view+json": {
  960. "model_id": "",
  961. "version_major": 2,
  962. "version_minor": 0
  963. },
  964. "text/plain": [
  965. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  966. ]
  967. },
  968. "metadata": {},
  969. "output_type": "display_data"
  970. },
  971. {
  972. "name": "stdout",
  973. "output_type": "stream",
  974. "text": [
  975. "\r",
  976. "Evaluate data in 0.11 seconds!\n",
  977. "\r",
  978. "Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
  979. "\r",
  980. "AccMetric: acc=0.7511467889908257\n",
  981. "\n"
  982. ]
  983. },
  984. {
  985. "data": {
  986. "application/vnd.jupyter.widget-view+json": {
  987. "model_id": "",
  988. "version_major": 2,
  989. "version_minor": 0
  990. },
  991. "text/plain": [
  992. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  993. ]
  994. },
  995. "metadata": {},
  996. "output_type": "display_data"
  997. },
  998. {
  999. "name": "stdout",
  1000. "output_type": "stream",
  1001. "text": [
  1002. "\r",
  1003. "Evaluate data in 0.19 seconds!\n",
  1004. "\r",
  1005. "Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
  1006. "\r",
  1007. "AccMetric: acc=0.7465596330275229\n",
  1008. "\n"
  1009. ]
  1010. },
  1011. {
  1012. "data": {
  1013. "application/vnd.jupyter.widget-view+json": {
  1014. "model_id": "",
  1015. "version_major": 2,
  1016. "version_minor": 0
  1017. },
  1018. "text/plain": [
  1019. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  1020. ]
  1021. },
  1022. "metadata": {},
  1023. "output_type": "display_data"
  1024. },
  1025. {
  1026. "name": "stdout",
  1027. "output_type": "stream",
  1028. "text": [
  1029. "\r",
  1030. "Evaluate data in 0.14 seconds!\n",
  1031. "\r",
  1032. "Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
  1033. "\r",
  1034. "AccMetric: acc=0.7454128440366973\n",
  1035. "\n"
  1036. ]
  1037. },
  1038. {
  1039. "data": {
  1040. "application/vnd.jupyter.widget-view+json": {
  1041. "model_id": "",
  1042. "version_major": 2,
  1043. "version_minor": 0
  1044. },
  1045. "text/plain": [
  1046. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  1047. ]
  1048. },
  1049. "metadata": {},
  1050. "output_type": "display_data"
  1051. },
  1052. {
  1053. "name": "stdout",
  1054. "output_type": "stream",
  1055. "text": [
  1056. "\r",
  1057. "Evaluate data in 0.43 seconds!\n",
  1058. "\r",
  1059. "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
  1060. "\r",
  1061. "AccMetric: acc=0.7488532110091743\n",
  1062. "\n"
  1063. ]
  1064. },
  1065. {
  1066. "data": {
  1067. "application/vnd.jupyter.widget-view+json": {
  1068. "model_id": "",
  1069. "version_major": 2,
  1070. "version_minor": 0
  1071. },
  1072. "text/plain": [
  1073. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  1074. ]
  1075. },
  1076. "metadata": {},
  1077. "output_type": "display_data"
  1078. },
  1079. {
  1080. "name": "stdout",
  1081. "output_type": "stream",
  1082. "text": [
  1083. "\r",
  1084. "Evaluate data in 0.21 seconds!\n",
  1085. "\r",
  1086. "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
  1087. "\r",
  1088. "AccMetric: acc=0.7431192660550459\n",
  1089. "\n"
  1090. ]
  1091. },
  1092. {
  1093. "data": {
  1094. "application/vnd.jupyter.widget-view+json": {
  1095. "model_id": "",
  1096. "version_major": 2,
  1097. "version_minor": 0
  1098. },
  1099. "text/plain": [
  1100. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  1101. ]
  1102. },
  1103. "metadata": {},
  1104. "output_type": "display_data"
  1105. },
  1106. {
  1107. "name": "stdout",
  1108. "output_type": "stream",
  1109. "text": [
  1110. "\r",
  1111. "Evaluate data in 0.1 seconds!\n",
  1112. "\r",
  1113. "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
  1114. "\r",
  1115. "AccMetric: acc=0.7477064220183486\n",
  1116. "\n"
  1117. ]
  1118. },
  1119. {
  1120. "data": {
  1121. "application/vnd.jupyter.widget-view+json": {
  1122. "model_id": "",
  1123. "version_major": 2,
  1124. "version_minor": 0
  1125. },
  1126. "text/plain": [
  1127. "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
  1128. ]
  1129. },
  1130. "metadata": {},
  1131. "output_type": "display_data"
  1132. },
  1133. {
  1134. "name": "stdout",
  1135. "output_type": "stream",
  1136. "text": [
  1137. "\r",
  1138. "Evaluate data in 0.29 seconds!\n",
  1139. "\r",
  1140. "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
  1141. "\r",
  1142. "AccMetric: acc=0.7465596330275229\n",
  1143. "\n",
  1144. "\r\n",
  1145. "In Epoch:3/Step:462, got best dev performance:\n",
  1146. "AccMetric: acc=0.7557339449541285\n",
  1147. "Reloaded the best model.\n"
  1148. ]
  1149. },
  1150. {
  1151. "data": {
  1152. "text/plain": [
  1153. "{'best_eval': {'AccMetric': {'acc': 0.7557339449541285}},\n",
  1154. " 'best_epoch': 3,\n",
  1155. " 'best_step': 462,\n",
  1156. " 'seconds': 28.68}"
  1157. ]
  1158. },
  1159. "execution_count": 6,
  1160. "metadata": {},
  1161. "output_type": "execute_result"
  1162. }
  1163. ],
  1164. "source": [
  1165. "trainer = Trainer(train_data=train_data, model=model, loss=loss,\n",
  1166. " optimizer=optimizer, batch_size=32, dev_data=dev_data,\n",
  1167. " metrics=AccMetric(pred=\"pred\", target=\"target\"), device=device)\n",
  1168. "trainer.train()"
  1169. ]
  1170. },
  1171. {
  1172. "cell_type": "code",
  1173. "execution_count": null,
  1174. "metadata": {},
  1175. "outputs": [],
  1176. "source": []
  1177. }
  1178. ],
  1179. "metadata": {
  1180. "kernelspec": {
  1181. "display_name": "Python Now",
  1182. "language": "python",
  1183. "name": "now"
  1184. },
  1185. "language_info": {
  1186. "codemirror_mode": {
  1187. "name": "ipython",
  1188. "version": 3
  1189. },
  1190. "file_extension": ".py",
  1191. "mimetype": "text/x-python",
  1192. "name": "python",
  1193. "nbconvert_exporter": "python",
  1194. "pygments_lexer": "ipython3",
  1195. "version": "3.8.0"
  1196. }
  1197. },
  1198. "nbformat": 4,
  1199. "nbformat_minor": 2
  1200. }