You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_tutorial_5.ipynb 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "fdd7ff16",
  6. "metadata": {},
  7. "source": [
  8. "# T5. trainer 和 evaluator 的深入介绍\n",
  9. "\n",
  10. "  1   fastNLP 中 driver 的补充介绍\n",
  11. " \n",
  12. "    1.1   trainer 和 driver 的构想 \n",
  13. "\n",
  14. "    1.2   device 与 多卡训练\n",
  15. "\n",
  16. "  2   fastNLP 中的更多 metric 类型\n",
  17. "\n",
  18. "    2.1   预定义的 metric 类型\n",
  19. "\n",
  20. "    2.2   自定义的 metric 类型\n",
  21. "\n",
  22. "  3   fastNLP 中 trainer 的补充介绍\n",
  23. "\n",
  24. "    3.1   trainer 的内部结构"
  25. ]
  26. },
  27. {
  28. "cell_type": "markdown",
  29. "id": "08752c5a",
  30. "metadata": {
  31. "pycharm": {
  32. "name": "#%% md\n"
  33. }
  34. },
  35. "source": [
  36. "## 1. fastNLP 中 driver 的补充介绍\n",
  37. "\n",
  38. "### 1.1 trainer 和 driver 的构想\n",
  39. "\n",
  40. "在`fastNLP 0.8`中,模型训练最关键的模块便是**训练模块`trainer`、评测模块`evaluator`、驱动模块`driver`**,\n",
  41. "\n",
  42. "  在`tutorial 0`中,已经简单介绍过上述三个模块:**`driver`用来控制训练评测中的`model`的最终运行**\n",
  43. "\n",
  44. "    **`evaluator`封装评测的`metric`**,**`trainer`封装训练的`optimizer`**,**也可以包括`evaluator`**\n",
  45. "\n",
  46. "之所以做出上述的划分,其根本目的在于要**达成对于多个`python`学习框架**,**例如`pytorch`、`paddle`、`jittor`的兼容**\n",
  47. "\n",
  48. "  对于训练环节,其伪代码如下方左边紫色一栏所示,由于**不同框架对模型、损失、张量的定义各有不同**,所以将训练环节\n",
  49. "\n",
  50. "    划分为**框架无关的循环控制、批量分发部分**,**由`trainer`模块负责**实现,对应的伪代码如下方中间蓝色一栏所示\n",
  51. "\n",
  52. "    以及**随框架不同的模型调用、数值优化部分**,**由`driver`模块负责**实现,对应的伪代码如下方右边红色一栏所示\n",
  53. "\n",
  54. "| <div align=\"center\">训练过程</div> | <div align=\"center\">框架无关 对应`trainer`</div> | <div align=\"center\">框架相关 对应`driver`</div> |\n",
  55. "|:--|:--|:--|\n",
  56. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;\">try:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;\">try:</div> | |\n",
  57. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">for epoch in 1:n_eoochs:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">for epoch in 1:n_eoochs:</div> | |\n",
  58. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">for step in 1:total_steps:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:40px;\">for step in 1:total_steps:</div> | |\n",
  59. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">batch = fetch_batch()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:60px;\">batch = fetch_batch()</div> | |\n",
  60. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">loss = model.forward(batch)&emsp;</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">loss = model.forward(batch)&emsp;</div> |\n",
  61. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">loss.backward()</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">loss.backward()</div> |\n",
  62. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">model.clear_grad()</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">model.clear_grad()</div> |\n",
  63. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">model.update()</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">model.update()</div> |\n",
  64. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">if need_save:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:40px;\">if need_save:</div> | |\n",
  65. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">model.save()</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">model.save()</div> |\n",
  66. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;\">except:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;\">except:</div> | |\n",
  67. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">process_exception()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">process_exception()</div> | |"
  68. ]
  69. },
  70. {
  71. "cell_type": "markdown",
  72. "id": "3e55f07b",
  73. "metadata": {},
  74. "source": [
  75. "&emsp; 对于评测环节,其伪代码如下方左边紫色一栏所示,同样由于不同框架对模型、损失、张量的定义各有不同,所以将评测环节\n",
  76. "\n",
  77. "&emsp; &emsp; 划分为**框架无关的循环控制、分发汇总部分**,**由`evaluator`模块负责**实现,对应的伪代码如下方中间蓝色一栏所示\n",
  78. "\n",
  79. "&emsp; &emsp; 以及**随框架不同的模型调用、评测计算部分**,同样**由`driver`模块负责**实现,对应的伪代码如下方右边红色一栏所示\n",
  80. "\n",
  81. "| <div align=\"center\">评测过程</div> | <div align=\"center\">框架无关 对应`evaluator`</div> | <div align=\"center\">框架相关 对应`driver`</div> |\n",
  82. "|:--|:--|:--|\n",
  83. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;\">try:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;\">try:</div> | |\n",
  84. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">model.set_eval()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">model.set_eval()</div> | |\n",
  85. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">for step in 1:total_steps:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">for step in 1:total_steps:</div> | |\n",
  86. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">batch = fetch_batch()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:40px;\">batch = fetch_batch()</div> | |\n",
  87. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">outputs = model.evaluate(batch)&emsp;</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:40px;\">outputs = model.evaluate(batch)&emsp;</div> |\n",
  88. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">metric.compute(batch, outputs)</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:40px;\">metric.compute(batch, outputs)</div> |\n",
  89. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">results = metric.get_metric()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">results = metric.get_metric()</div> | |\n",
  90. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;\">except:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;\">except:</div> | |\n",
  91. "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">process_exception()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">process_exception()</div> | |"
  92. ]
  93. },
  94. {
  95. "cell_type": "markdown",
  96. "id": "94ba11c6",
  97. "metadata": {
  98. "pycharm": {
  99. "name": "#%%\n"
  100. }
  101. },
  102. "source": [
  103. "由此,从程序员的角度,`fastNLP v0.8`**通过一个`driver`让基于`pytorch`、`paddle`、`jittor`框架的模型**\n",
  104. "\n",
  105. "&emsp; &emsp; **都能在相同的`trainer`和`evaluator`上运行**,这也**是`fastNLP v0.8`相比于之前版本的一大亮点**\n",
  106. "\n",
  107. "&emsp; 而从`driver`的角度,`fastNLP v0.8`通过定义一个`driver`基类,**将所有张量转化为`numpy.tensor`**\n",
  108. "\n",
  109. "&emsp; &emsp; 并由此泛化出`torch_driver`、`paddle_driver`、`jittor_driver`三个子类,从而实现了\n",
  110. "\n",
  111. "&emsp; &emsp; 对`pytorch`、`paddle`、`jittor`的兼容,有关后两者的实践请参考接下来的`tutorial-6`"
  112. ]
  113. },
  114. {
  115. "cell_type": "markdown",
  116. "id": "ab1cea7d",
  117. "metadata": {},
  118. "source": [
  119. "### 1.2 device 与 多卡训练\n",
  120. "\n",
  121. "**`fastNLP v0.8`支持多卡训练**,实现方法则是**通过将`trainer`中的`device`设置为对应显卡的序号列表**\n",
  122. "\n",
  123. "&emsp; 由单卡切换成多卡,无论是数据、模型还是评测都会面临一定的调整,`fastNLP v0.8`保证:\n",
  124. "\n",
  125. "&emsp; &emsp; 数据拆分时,不同卡之间相互协调,所有数据都可以被训练,且不会使用到相同的数据\n",
  126. "\n",
  127. "&emsp; &emsp; 模型训练时,模型之间需要交换梯度;评测计算时,每张卡先各自计算,再汇总结果\n",
  128. "\n",
  129. "&emsp; 例如,在评测计算运行`get_metric`函数时,`fastNLP v0.8`将自动按照`self.right`和`self.total`\n",
  130. "\n",
  131. "&emsp; &emsp; 指定的**`aggregate_method`方法**,默认为`sum`,将每张卡上结果汇总起来,因此最终\n",
  132. "\n",
  133. "&emsp; &emsp; 在调用`get_metric`方法时,`Accuracy`类能够返回全部的统计结果,代码如下\n",
  134. " \n",
  135. "```python\n",
  136. "trainer = Trainer(\n",
  137. " model=model, # model 基于 pytorch 实现 \n",
  138. " train_dataloader=train_dataloader,\n",
  139. " optimizers=optimizer,\n",
  140. " ...\n",
  141. " driver='torch', # driver 使用 torch_driver \n",
  142. " device=[0, 1], # gpu 选择 cuda:0 + cuda:1\n",
  143. " ...\n",
  144. " evaluate_dataloaders=evaluate_dataloader,\n",
  145. " metrics={'acc': Accuracy()},\n",
  146. " ...\n",
  147. " )\n",
  148. "\n",
  149. "class Accuracy(Metric):\n",
  150. " def __init__(self):\n",
  151. " super().__init__()\n",
  152. " self.register_element(name='total', value=0, aggregate_method='sum')\n",
  153. " self.register_element(name='right', value=0, aggregate_method='sum')\n",
  154. "```\n"
  155. ]
  156. },
  157. {
  158. "cell_type": "markdown",
  159. "id": "e2e0a210",
  160. "metadata": {
  161. "pycharm": {
  162. "name": "#%%\n"
  163. }
  164. },
  165. "source": [
  166. "注:`fastNLP v0.8`中要求`jupyter`不能多卡,仅能单卡,故在所有`tutorial`中均不作相关演示"
  167. ]
  168. },
  169. {
  170. "cell_type": "markdown",
  171. "id": "8d19220c",
  172. "metadata": {},
  173. "source": [
  174. "## 2. fastNLP 中的更多 metric 类型\n",
  175. "\n",
  176. "### 2.1 预定义的 metric 类型\n",
  177. "\n",
  178. "在`fastNLP 0.8`中,除了前几篇`tutorial`中经常见到的**正确率`Accuracy`**,还有其他**预定义的评测标准`metric`**\n",
  179. "\n",
  180. "&emsp; 包括**所有`metric`的基类`Metric`**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n",
  181. "\n",
  182. "&emsp; &emsp; **适用于分类语境下的`F1`值`ClassifyFPreRecMetric`**(其中也包括召回率`Pre`、精确率`Rec`\n",
  183. "\n",
  184. "&emsp; &emsp; **适用于抽取语境下的`F1`值`SpanFPreRecMetric`**;相关基本信息内容见下表,之后是详细分析\n",
  185. "\n",
  186. "| <div align=\"center\">代码名称</div> | <div align=\"center\">简要介绍</div> | <div align=\"center\">代码路径</div> |\n",
  187. "|:--|:--|:--|\n",
  188. "| `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n",
  189. "| `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n",
  190. "| `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n",
  191. "| `ClassifyFPreRecMetric` | 召回率、精确率、F1值,适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n",
  192. "| `SpanFPreRecMetric` | 召回率、精确率、F1值,适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |"
  193. ]
  194. },
  195. {
  196. "cell_type": "markdown",
  197. "id": "fdc083a3",
  198. "metadata": {
  199. "pycharm": {
  200. "name": "#%%\n"
  201. }
  202. },
  203. "source": [
  204. "&emsp; 如`tutorial-0`中所述,所有的`metric`都包含`get_metric`和`update`函数,其中\n",
  205. "\n",
  206. "&emsp; &emsp; **`update`函数更新单个`batch`的统计量**,**`get_metric`函数返回最终结果**,并打印显示\n",
  207. "\n",
  208. "\n",
  209. "### 2.1.1 Accuracy 与 TransformersAccuracy\n",
  210. "\n",
  211. "`Accuracy`,正确率,预测正确的数据`right_num`在总数据`total_num`,中的占比(公式就不用列了\n",
  212. "\n",
  213. "&emsp; `get_metric`函数打印格式为 **`{\"acc#xx\": float, 'total#xx': float, 'correct#xx': float}`**\n",
  214. "\n",
  215. "&emsp; 一般在初始化时不需要传参,`fastNLP`会根据`update`函数的传入参数确定对应后台框架`backend`\n",
  216. "\n",
  217. "&emsp; **`update`函数的参数包括`pred`、`target`、`seq_len`**,**后者用来标记批次中每笔数据的长度**\n",
  218. "\n",
  219. "`TransformersAccuracy`,继承自`Accuracy`,只是为了兼容`Transformers`框架中相关模型\n",
  220. "\n",
  221. "&emsp; 在`update`函数中,将`Transformers`框架输出的`attention_mask`参数转化为`seq_len`参数\n",
  222. "\n",
  223. "\n",
  224. "### 2.1.2 ClassifyFPreRecMetric 与 SpanFPreRecMetric\n",
  225. "\n",
  226. "`ClassifyFPreRecMetric`,分类评价,`SpanFPreRecMetric`,抽取评价,后者在`tutorial-4`中已出现\n",
  227. "\n",
  228. "&emsp; 两者的相同之处在于:**第一**,**都包括召回率/查全率`Rec`**、**精确率/查准率`Pre`**、**`F1`值**这三个指标\n",
  229. "\n",
  230. "&emsp; &emsp; `get_metric`函数打印格式为 **`{\"f#xx\": float, 'pre#xx': float, 'rec#xx': float}`**\n",
  231. "\n",
  232. "&emsp; &emsp; 三者的计算公式如下,其中`beta`默认为`1`,即`F1`值是召回率`Rec`和精确率`Pre`的调和平均数\n",
  233. "\n",
  234. "$$\\text{召回率}\\ Rec=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有本来是正例的数量}}\\qquad \\text{精确率}\\ Pre=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有预测为正例的数量}}$$\n",
  235. "\n",
  236. "$$F_{beta} = \\frac{(1 + {beta}^{2})*(Pre*Rec)}{({beta}^{2}*Pre + Rec)}$$\n",
  237. "\n",
  238. "&emsp; **第二**,可以通过参数`only_gross`为`False`,要求返回所有类别的`Rec-Pre-F1`,同时`F1`值又根据参数`f_type`又分为\n",
  239. "\n",
  240. "&emsp; &emsp; **`micro F1`**(**直接统计所有类别的`Rec-Pre-F1`**)、**`macro F1`**(**统计各类别的`Rec-Pre-F1`再算术平均**)\n",
  241. "\n",
  242. "&emsp; **第三**,两者在初始化时还可以**传入基于`fastNLP.Vocabulary`的`tag_vocab`参数记录数据集中的标签序号**\n",
  243. "\n",
  244. "&emsp; &emsp; **与标签名称之间的映射**,通过字符串列表`ignore_labels`参数,指定若干标签不用于`Rec-Pre-F1`的计算\n",
  245. "\n",
  246. "两者的不同之处在于:`ClassifyFPreRecMetric`针对简单的分类问题,每个分类标签之间彼此独立,不构成标签对\n",
  247. "\n",
  248. "&emsp; &emsp; **`SpanFPreRecMetric`针对更复杂的抽取问题**,**规定标签`B-xx`和`I-xx`或`B-xx`和`E-xx`构成标签对**\n",
  249. "\n",
  250. "&emsp; 在计算`Rec-Pre-F1`时,`ClassifyFPreRecMetric`只需要考虑标签本身是否正确这就足够了,但是\n",
  251. "\n",
  252. "&emsp; &emsp; 对于`SpanFPreRecMetric`,需要保证**标签符合规则且覆盖的区间与正确结果重合才算正确**\n",
  253. "\n",
  254. "&emsp; &emsp; 因此回到`tutorial-4`中`CoNLL-2003`的`NER`任务,如果评测方法选择`ClassifyFPreRecMetric`\n",
  255. "\n",
  256. "&emsp; &emsp; &emsp; 或者`Accuracy`,会发现虽然评测结果显示很高,这是因为选择的评测方法要求太低\n",
  257. "\n",
  258. "&emsp; &emsp; 最后通过`CoNLL-2003`的词性标注`POS`任务简单演示下`ClassifyFPreRecMetric`相关的使用\n",
  259. "\n",
  260. "```python\n",
  261. "from fastNLP import Vocabulary\n",
  262. "from fastNLP import ClassifyFPreRecMetric\n",
  263. "\n",
  264. "tag_vocab = Vocabulary(padding=None, unknown=None) # 记录序号与标签之间的映射\n",
  265. "tag_vocab.add_word_lst(['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', \n",
  266. " 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', \n",
  267. " 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', \n",
  268. " 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', \n",
  269. " 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP+', 'WRB', ]) # CoNLL-2003 中的 pos_tags\n",
  270. "ignore_labels = ['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', ]\n",
  271. "\n",
  272. "FPreRec = ClassifyFPreRecMetric(tag_vocab=tag_vocab, \n",
  273. " ignore_labels=ignore_labels, # 表示评测/优化中不考虑上述标签的正误/损失\n",
  274. " only_gross=True, # 默认为 True 表示输出所有类别的综合统计结果\n",
  275. " f_type='micro') # 默认为 'micro' 表示统计所有类别的 Rec-Pre-F1\n",
  276. "metrics = {'F1': FPreRec}\n",
  277. "```"
  278. ]
  279. },
  280. {
  281. "cell_type": "markdown",
  282. "id": "8a22f522",
  283. "metadata": {},
  284. "source": [
  285. "### 2.2 自定义的 metric 类型\n",
  286. "\n",
  287. "如上文所述,`Metric`作为所有`metric`的基类,`Accuracy`等都是其子类,同样地,对于**自定义的`metric`类型**\n",
  288. "\n",
  289. "&emsp; &emsp; 也**需要继承自`Metric`类**,同时**内部自定义好`__init__`、`update`和`get_metric`函数**\n",
  290. "\n",
  291. "&emsp; 在`__init__`函数中,根据需求定义评测时需要用到的变量,此处沿用`Accuracy`中的`total_num`和`right_num`\n",
  292. "\n",
  293. "&emsp; 在`update`函数中,根据需求定义评测变量的更新方式,需要注意的是如`tutorial-0`中所述,**`update`的参数名**\n",
  294. "\n",
  295. "&emsp; &emsp; **需要待评估模型在`evaluate_step`中的输出名称一致**,由此**和数据集中对应字段名称一致**,即**参数匹配**\n",
  296. "\n",
  297. "&emsp; &emsp; 在`fastNLP v0.8`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n",
  298. "\n",
  299. "&emsp; &emsp; 此处刻意调整为:`pred`,对应预测值,和模型输出一致;`true`,对应真实值,数据集字段需要调整\n",
  300. "\n",
  301. "&emsp; 在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n",
  302. "\n",
  303. "&emsp; &emsp; 其中,字串`'prefix'`表示该`metric`的名称,会对应显示到`trainer`的`progress bar`中\n",
  304. "\n",
  305. "根据上述要求,这里简单定义了一个名为`MyMetric`的评测模块,用于分类问题的评测,以此展开一个实例展示"
  306. ]
  307. },
  308. {
  309. "cell_type": "code",
  310. "execution_count": null,
  311. "id": "08a872e9",
  312. "metadata": {},
  313. "outputs": [],
  314. "source": [
  315. "from fastNLP import Metric\n",
  316. "\n",
  317. "class MyMetric(Metric):\n",
  318. "\n",
  319. " def __init__(self):\n",
  320. " MyMetric.__init__(self)\n",
  321. " self.total_num = 0\n",
  322. " self.right_num = 0\n",
  323. "\n",
  324. " def update(self, pred, true):\n",
  325. " self.total_num += target.size(0)\n",
  326. " self.right_num += target.eq(pred).sum().item()\n",
  327. "\n",
  328. " def get_metric(self, reset=True):\n",
  329. " acc = self.acc_count / self.total_num\n",
  330. " if reset:\n",
  331. " self.total_num = 0\n",
  332. " self.right_num = 0\n",
  333. " return {'prefix': acc}"
  334. ]
  335. },
  336. {
  337. "cell_type": "markdown",
  338. "id": "af3f8c63",
  339. "metadata": {},
  340. "source": [
  341. "&emsp; 模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类"
  342. ]
  343. },
  344. {
  345. "cell_type": "code",
  346. "execution_count": null,
  347. "id": "2fd210c5",
  348. "metadata": {},
  349. "outputs": [],
  350. "source": [
  351. "import sys\n",
  352. "sys.path.append('..')\n",
  353. "\n",
  354. "from fastNLP.models.torch import CNNText\n",
  355. "\n",
  356. "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
  357. "\n",
  358. "from torch.optim import AdamW\n",
  359. "\n",
  360. "optimizers = AdamW(params=model.parameters(), lr=5e-4)"
  361. ]
  362. },
  363. {
  364. "cell_type": "markdown",
  365. "id": "0155f447",
  366. "metadata": {},
  367. "source": [
  368. "&emsp; 数据使用方面,此处仍然使用`datasets`模块中的`load_dataset`函数,加载`SST-2`二分类数据集"
  369. ]
  370. },
  371. {
  372. "cell_type": "code",
  373. "execution_count": null,
  374. "id": "5ad81ac7",
  375. "metadata": {
  376. "pycharm": {
  377. "name": "#%%\n"
  378. }
  379. },
  380. "outputs": [],
  381. "source": [
  382. "from datasets import load_dataset\n",
  383. "\n",
  384. "sst2data = load_dataset('glue', 'sst2')"
  385. ]
  386. },
  387. {
  388. "cell_type": "markdown",
  389. "id": "e9d81760",
  390. "metadata": {},
  391. "source": [
  392. "接着是数据预处理,需要注意的是,由于`MyMetric`的`update`函数中,输入参数名称为`pred`和`true`\n",
  393. "\n",
  394. "&emsp; 对应地,需要将数据集中表示预测目标的字段,调整为`true`(预定义的`metric`,应调整为`target`"
  395. ]
  396. },
  397. {
  398. "cell_type": "code",
  399. "execution_count": null,
  400. "id": "cfb28b1b",
  401. "metadata": {
  402. "pycharm": {
  403. "name": "#%%\n"
  404. }
  405. },
  406. "outputs": [],
  407. "source": [
  408. "from fastNLP import DataSet\n",
  409. "\n",
  410. "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
  411. "\n",
  412. "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'true': ins['label']}, \n",
  413. " progress_bar=\"tqdm\")\n",
  414. "dataset.delete_field('sentence')\n",
  415. "dataset.delete_field('label')\n",
  416. "dataset.delete_field('idx')\n",
  417. "\n",
  418. "from fastNLP import Vocabulary\n",
  419. "\n",
  420. "vocab = Vocabulary()\n",
  421. "vocab.from_dataset(dataset, field_name='words')\n",
  422. "vocab.index_dataset(dataset, field_name='words')\n",
  423. "\n",
  424. "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n",
  425. "\n",
  426. "from fastNLP import prepare_torch_dataloader\n",
  427. "\n",
  428. "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
  429. "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
  430. ]
  431. },
  432. {
  433. "cell_type": "markdown",
  434. "id": "1e21df35",
  435. "metadata": {},
  436. "source": [
  437. "然后就是初始化`trainer`实例,其中`metrics`变量输入的键值对,字串`'suffix'`和之前定义的字串`'prefix'`\n",
  438. "\n",
  439. "&emsp; 将拼接在一起显示到`trainer`的`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`"
  440. ]
  441. },
  442. {
  443. "cell_type": "code",
  444. "execution_count": null,
  445. "id": "926a9c50",
  446. "metadata": {},
  447. "outputs": [],
  448. "source": [
  449. "from fastNLP import Trainer\n",
  450. "\n",
  451. "trainer = Trainer(\n",
  452. " model=model,\n",
  453. " driver='torch',\n",
  454. " device=0, # 'cuda'\n",
  455. " n_epochs=10,\n",
  456. " optimizers=optimizers,\n",
  457. " train_dataloader=train_dataloader,\n",
  458. " evaluate_dataloaders=evaluate_dataloader,\n",
  459. " metrics={'suffix': MyMetric()}\n",
  460. ")"
  461. ]
  462. },
  463. {
  464. "cell_type": "markdown",
  465. "id": "6e723b87",
  466. "metadata": {},
  467. "source": [
  468. "## 3. fastNLP 中 trainer 的补充介绍\n",
  469. "\n",
  470. "### 3.1 trainer 的内部结构\n",
  471. "\n",
  472. "在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经\n",
  473. "\n",
  474. "&emsp; 展示了很多关于`trainer`的使用案例,以下我们先补充介绍训练模块`trainer`的一些内部结构\n",
  475. "\n",
  476. "\n",
  477. "\n",
  478. "'accumulation_steps', 'add_callback_fn', 'backward', 'batch_idx_in_epoch', 'batch_step_fn',\n",
  479. "'callback_manager', 'check_batch_step_fn', 'cur_epoch_idx', 'data_device', 'dataloader',\n",
  480. "'device', 'driver', 'driver_name', 'epoch_evaluate', 'evaluate_batch_step_fn', 'evaluate_dataloaders',\n",
  481. "'evaluate_every', 'evaluate_fn', 'evaluator', 'extract_loss_from_outputs', 'fp16',\n",
  482. "'get_no_sync_context', 'global_forward_batches', 'has_checked_train_batch_loop',\n",
  483. "'input_mapping', 'kwargs', 'larger_better', 'load_checkpoint', 'load_model', 'marker',\n",
  484. "'metrics', 'model', 'model_device', 'monitor', 'move_data_to_device', 'n_epochs', 'num_batches_per_epoch',\n",
  485. "'on', 'on_after_backward', 'on_after_optimizers_step', 'on_after_trainer_initialized',\n",
  486. "'on_after_zero_grad', 'on_before_backward', 'on_before_optimizers_step', 'on_before_zero_grad',\n",
  487. "'on_evaluate_begin', 'on_evaluate_end', 'on_exception', 'on_fetch_data_begin', 'on_fetch_data_end',\n",
  488. "'on_load_checkpoint', 'on_load_model', 'on_sanity_check_begin', 'on_sanity_check_end',\n",
  489. "'on_save_checkpoint', 'on_save_model', 'on_train_batch_begin', 'on_train_batch_end',\n",
  490. "'on_train_begin', 'on_train_end', 'on_train_epoch_begin', 'on_train_epoch_end',\n",
  491. "'optimizers', 'output_mapping', 'progress_bar', 'run', 'run_evaluate',\n",
  492. "'save_checkpoint', 'save_model', 'start_batch_idx_in_epoch', 'state',\n",
  493. "'step', 'step_evaluate', 'total_batches', 'train_batch_loop', 'train_dataloader', 'train_fn', 'train_step',\n",
  494. "'trainer_state', 'zero_grad'\n",
  495. "\n",
  496. "&emsp; run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None)"
  497. ]
  498. },
  499. {
  500. "cell_type": "code",
  501. "execution_count": null,
  502. "id": "c348864c",
  503. "metadata": {
  504. "pycharm": {
  505. "name": "#%%\n"
  506. }
  507. },
  508. "outputs": [],
  509. "source": []
  510. },
  511. {
  512. "cell_type": "code",
  513. "execution_count": null,
  514. "id": "43be274f",
  515. "metadata": {
  516. "pycharm": {
  517. "name": "#%%\n"
  518. }
  519. },
  520. "outputs": [],
  521. "source": []
  522. }
  523. ],
  524. "metadata": {
  525. "kernelspec": {
  526. "display_name": "Python 3 (ipykernel)",
  527. "language": "python",
  528. "name": "python3"
  529. },
  530. "language_info": {
  531. "codemirror_mode": {
  532. "name": "ipython",
  533. "version": 3
  534. },
  535. "file_extension": ".py",
  536. "mimetype": "text/x-python",
  537. "name": "python",
  538. "nbconvert_exporter": "python",
  539. "pygments_lexer": "ipython3",
  540. "version": "3.7.13"
  541. },
  542. "pycharm": {
  543. "stem_cell": {
  544. "cell_type": "raw",
  545. "metadata": {
  546. "collapsed": false
  547. },
  548. "source": []
  549. }
  550. }
  551. },
  552. "nbformat": 4,
  553. "nbformat_minor": 5
  554. }