You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_tutorial_0.ipynb 30 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "aec0fde7",
  6. "metadata": {},
  7. "source": [
  8. "# T0. trainer 和 evaluator 的基本使用\n",
  9. "\n",
  10. "  1   trainer 和 evaluator 的基本关系\n",
  11. " \n",
  12. "    1.1   trainer 和 evaluater 的初始化\n",
  13. "\n",
  14. "    1.2   driver 的含义与使用要求\n",
  15. "\n",
  16. "    1.3   trainer 内部初始化 evaluater\n",
  17. "\n",
  18. "  2   使用 trainer 训练模型\n",
  19. "\n",
  20. "    2.1   argmax 模型实例\n",
  21. "\n",
  22. "    2.2   trainer 的参数匹配\n",
  23. "\n",
  24. "    2.3   trainer 的实际使用 \n",
  25. "\n",
  26. "  3   使用 evaluator 评测模型\n",
  27. " \n",
  28. "    3.1   trainer 外部初始化的 evaluator\n",
  29. "\n",
  30. "    3.2   trainer 内部初始化的 evaluator "
  31. ]
  32. },
  33. {
  34. "cell_type": "markdown",
  35. "id": "09ea669a",
  36. "metadata": {},
  37. "source": [
  38. "## 1. trainer 和 evaluator 的基本关系\n",
  39. "\n",
  40. "### 1.1 trainer 和 evaluator 的初始化\n",
  41. "\n",
  42. "在`fastNLP 0.8`中,**`Trainer`模块和`Evaluator`模块分别表示“训练器”和“评测器”**\n",
  43. "\n",
  44. "  对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n",
  45. "\n",
  46. "在`fastNLP 0.8`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n",
  47. "\n",
  48. "  非常关键的问题在于**如何正确设置二者的`driver`**。这就引入了另一个问题:什么是 `driver`?\n",
  49. "\n",
  50. "\n",
  51. "```python\n",
  52. "trainer = Trainer(\n",
  53. " model=model,\n",
  54. " train_dataloader=train_dataloader,\n",
  55. " optimizers=optimizer,\n",
  56. "\t...\n",
  57. "\tdriver=\"torch\",\n",
  58. "\tdevice=0,\n",
  59. "\t...\n",
  60. ")\n",
  61. "...\n",
  62. "evaluator = Evaluator(\n",
  63. " model=model,\n",
  64. " dataloaders=evaluate_dataloader,\n",
  65. " metrics={'acc': Accuracy()} \n",
  66. " ...\n",
  67. " driver=trainer.driver,\n",
  68. "\tdevice=None,\n",
  69. " ...\n",
  70. ")\n",
  71. "```"
  72. ]
  73. },
  74. {
  75. "cell_type": "markdown",
  76. "id": "3c11fe1a",
  77. "metadata": {},
  78. "source": [
  79. "### 1.2 driver 的含义与使用要求\n",
  80. "\n",
  81. "在`fastNLP 0.8`中,**`driver`**这一概念被用来表示**控制具体训练的各个步骤的最终执行部分**\n",
  82. "\n",
  83. "  例如神经网络前向、后向传播的具体执行、网络参数的优化和数据在设备间的迁移等\n",
  84. "\n",
  85. "在`fastNLP 0.8`中,**`Trainer`和`Evaluator`都依赖于具体的`driver`来完成整体的工作流程**\n",
  86. "\n",
  87. "  具体`driver`与`Trainer`以及`Evaluator`之间的关系请参考`fastNLP 0.8`的框架设计\n",
  88. "\n",
  89. "注:在同一脚本中,`Trainer`和`Evaluator`使用的`driver`应当保持一致\n",
  90. "\n",
  91. "  一个不能违背的原则在于:**不要将多卡的`driver`前使用单卡的`driver`**(???),这样使用可能会带来很多意想不到的错误。"
  92. ]
  93. },
  94. {
  95. "cell_type": "markdown",
  96. "id": "2cac4a1a",
  97. "metadata": {},
  98. "source": [
  99. "### 1.3 Trainer 内部初始化 Evaluator\n",
  100. "\n",
  101. "在`fastNLP 0.8`中,如果在**初始化`Trainer`时**,**传入参数`evaluator_dataloaders`和`metrics`**\n",
  102. "\n",
  103. "  则在`Trainer`内部,也会初始化单独的`Evaluator`来帮助训练过程中对验证集的评测\n",
  104. "\n",
  105. "```python\n",
  106. "trainer = Trainer(\n",
  107. " model=model,\n",
  108. " train_dataloader=train_dataloader,\n",
  109. " optimizers=optimizer,\n",
  110. "\t...\n",
  111. "\tdriver=\"torch\",\n",
  112. "\tdevice=0,\n",
  113. "\t...\n",
  114. " evaluate_dataloaders=evaluate_dataloader,\n",
  115. " metrics={'acc': Accuracy()},\n",
  116. "\t...\n",
  117. ")\n",
  118. "```"
  119. ]
  120. },
  121. {
  122. "cell_type": "markdown",
  123. "id": "0c9c7dda",
  124. "metadata": {},
  125. "source": [
  126. "## 2. 使用 trainer 训练模型"
  127. ]
  128. },
  129. {
  130. "cell_type": "markdown",
  131. "id": "524ac200",
  132. "metadata": {},
  133. "source": [
  134. "### 2.1 argmax 模型实例\n",
  135. "\n",
  136. "本节将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n",
  137. "\n",
  138. "  使用`pytorch`定义`argmax`模型,输入一组固定维度的向量,输出其中数值最大的数的索引\n",
  139. "\n",
  140. "  除了添加`pytorch`要求的`forward`方法外,还需要添加 **`train_step`** 和 **`evaluate_step`** 这两个方法"
  141. ]
  142. },
  143. {
  144. "cell_type": "code",
  145. "execution_count": null,
  146. "id": "5314482b",
  147. "metadata": {
  148. "pycharm": {
  149. "is_executing": true
  150. }
  151. },
  152. "outputs": [],
  153. "source": [
  154. "import torch\n",
  155. "import torch.nn as nn\n",
  156. "\n",
  157. "class ArgMaxModel(nn.Module):\n",
  158. " def __init__(self, num_labels, feature_dimension):\n",
  159. " super(ArgMaxModel, self).__init__()\n",
  160. " self.num_labels = num_labels\n",
  161. "\n",
  162. " self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n",
  163. " self.ac1 = nn.ReLU()\n",
  164. " self.linear2 = nn.Linear(in_features=10, out_features=10)\n",
  165. " self.ac2 = nn.ReLU()\n",
  166. " self.output = nn.Linear(in_features=10, out_features=num_labels)\n",
  167. " self.loss_fn = nn.CrossEntropyLoss()\n",
  168. "\n",
  169. " def forward(self, x):\n",
  170. " x = self.ac1(self.linear1(x))\n",
  171. " x = self.ac2(self.linear2(x))\n",
  172. " x = self.output(x)\n",
  173. " return x\n",
  174. "\n",
  175. " def train_step(self, x, y):\n",
  176. " x = self(x)\n",
  177. " return {\"loss\": self.loss_fn(x, y)}\n",
  178. "\n",
  179. " def evaluate_step(self, x, y):\n",
  180. " x = self(x)\n",
  181. " x = torch.max(x, dim=-1)[1]\n",
  182. " return {\"pred\": x, \"target\": y}"
  183. ]
  184. },
  185. {
  186. "cell_type": "markdown",
  187. "id": "ca897322",
  188. "metadata": {},
  189. "source": [
  190. "在`fastNLP 0.8`中,**函数`train_step`是`Trainer`中参数`train_fn`的默认值**\n",
  191. "\n",
  192. "  由于,在`Trainer`训练时,**`Trainer`通过参数`_train_fn_`对应的模型方法获得当前数据批次的损失值**\n",
  193. "\n",
  194. "  因此,在`Trainer`训练时,`Trainer`首先会寻找模型是否定义了`train_step`这一方法\n",
  195. "\n",
  196. "    如果没有找到,那么`Trainer`会默认使用模型的`forward`函数来进行训练的前向传播过程\n",
  197. "\n",
  198. "注:在`fastNLP 0.8`中,`Trainer`要求模型通过`train_step`来返回一个字典,将损失值作为`loss`的键值\n",
  199. "\n",
  200. "  此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现高度化的定制,具体请见这一note(???)\n",
  201. "\n",
  202. "同样,在`fastNLP 0.8`中,**函数`evaluate_step`是`Evaluator`中参数`evaluate_fn`的默认值**\n",
  203. "\n",
  204. "  在`Evaluator`测试时,**`Evaluator`通过参数`evaluate_fn`对应的模型方法获得当前数据批次的评测结果**\n",
  205. "\n",
  206. "  从用户角度,模型通过`evaluate_step`方法来返回一个字典,内容与传入`Evaluator`的`metrics`一致\n",
  207. "\n",
  208. "<!-- &emsp; 从模块角度,`fastNLP 0.8`会匹配该字典的键值和一个`metric`的更新函数的函数签名,自动地将`metric`所需要的内容传给该`metric`,也就是我们会自动进行“**参数匹配**”。 -->"
  209. ]
  210. },
  211. {
  212. "cell_type": "markdown",
  213. "id": "fb3272eb",
  214. "metadata": {},
  215. "source": [
  216. "### 2.2 trainer 的参数匹配\n",
  217. "\n",
  218. "`fastNLP 0.8`中的参数匹配涉及到两个方面,一是在模型训练或者评测的前向传播过程中,如果从`dataloader`中出来一个`batch`的数据是一个字典,那么我们会查看模型的`train_step`和`evaluate_step`方法的参数签名,然后对于每一个参数,我们会根据其名字从 batch 这一字典中选择出对应的数据传入进去。例如在接下来的定义`Dataset`的部分,注意`ArgMaxDatset`的`__getitem__`方法,您可以通过在`Trainer`和`Evaluator`中设置参数 `model_wo_auto_param_call`来关闭这一行为。当您关闭了这一行为后,我们会将`batch`直接传给您的`train_step`、`evaluate_step`或者 `forward`函数。\n",
  219. "\n",
  220. "二是在传入`Trainer`或者`Evaluator metrics`后,我们会在需要评测的时间点主动调用`metrics`来对`evaluate_dataloaders`进行评测,这一功能主要就是通过对`metrics`的`update`方法和一个`batch`的数据进行参数评测实现的。首先需要明确的是一个 metric 的计算通常分为 `update` 和 `get_metric`两步,其中`update`表示更新一个`batch`的评测数据,`get_metric` 表示根据已经得到的评测数据计算出最终的评测值,例如对于 `Accuracy`来说,其在`update`的时候会更新一个`batch`计算正确的数量 right_num 和计算错误的数量 total_num,最终在 `get_metric` 时返回评测值`right_num / total_num`。\n",
  221. "\n",
  222. "因为`fastNLP 0.8`的`metrics`是自动计算的(只需要传给`Trainer`或者`Evaluator`),因此其一定依赖于参数匹配。对于从`evaluate_dataloader`中生成的一个`batch`的数据,我们会查看传给 `Trainer`(最终是传给`Evaluator`)和`Evaluator`的每一个`metric`,然后查看其`update`函数的函数签名,然后根据每一个参数的名字从`batch`字典中选择出对应的数据传入进去。"
  223. ]
  224. },
  225. {
  226. "cell_type": "markdown",
  227. "id": "f62b7bb1",
  228. "metadata": {},
  229. "source": [
  230. "### 2.3 trainer的实际使用\n",
  231. "\n",
  232. "接下来我们创建用于训练的 dataset,其接受三个参数:数据维度、数据量和随机数种子,生成指定数量的维度为 `feature_dimension` 向量,而每一个向量的标签就是该向量中最大值的索引。"
  233. ]
  234. },
  235. {
  236. "cell_type": "code",
  237. "execution_count": 2,
  238. "id": "fe612e61",
  239. "metadata": {
  240. "pycharm": {
  241. "is_executing": false
  242. }
  243. },
  244. "outputs": [],
  245. "source": [
  246. "from torch.utils.data import Dataset\n",
  247. "\n",
  248. "class ArgMaxDatset(Dataset):\n",
  249. " def __init__(self, feature_dimension, data_num=1000, seed=0):\n",
  250. " self.num_labels = feature_dimension\n",
  251. " self.feature_dimension = feature_dimension\n",
  252. " self.data_num = data_num\n",
  253. " self.seed = seed\n",
  254. "\n",
  255. " g = torch.Generator()\n",
  256. " g.manual_seed(1000)\n",
  257. " self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()\n",
  258. " self.y = torch.max(self.x, dim=-1)[1]\n",
  259. "\n",
  260. " def __len__(self):\n",
  261. " return self.data_num\n",
  262. "\n",
  263. " def __getitem__(self, item):\n",
  264. " return {\"x\": self.x[item], \"y\": self.y[item]}"
  265. ]
  266. },
  267. {
  268. "cell_type": "markdown",
  269. "id": "2cb96332",
  270. "metadata": {},
  271. "source": [
  272. "现在准备好数据和模型。"
  273. ]
  274. },
  275. {
  276. "cell_type": "code",
  277. "execution_count": 3,
  278. "id": "76172ef8",
  279. "metadata": {
  280. "pycharm": {
  281. "is_executing": false
  282. }
  283. },
  284. "outputs": [],
  285. "source": [
  286. "from torch.utils.data import DataLoader\n",
  287. "\n",
  288. "train_dataset = ArgMaxDatset(feature_dimension=10, data_num=1000)\n",
  289. "evaluate_dataset = ArgMaxDatset(feature_dimension=10, data_num=100)\n",
  290. "\n",
  291. "train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
  292. "evaluate_dataloader = DataLoader(evaluate_dataset, batch_size=8)\n",
  293. "\n",
  294. "# num_labels 设置为 10,与 feature_dimension 保持一致,因为我们是预测十个位置中哪一个的概率最大。\n",
  295. "model = ArgMaxModel(num_labels=10, feature_dimension=10)"
  296. ]
  297. },
  298. {
  299. "cell_type": "markdown",
  300. "id": "4e7d25ee",
  301. "metadata": {},
  302. "source": [
  303. "将优化器也定义好。"
  304. ]
  305. },
  306. {
  307. "cell_type": "code",
  308. "execution_count": 4,
  309. "id": "dc28a2d9",
  310. "metadata": {
  311. "pycharm": {
  312. "is_executing": false
  313. }
  314. },
  315. "outputs": [],
  316. "source": [
  317. "from torch.optim import SGD\n",
  318. "\n",
  319. "optimizer = SGD(model.parameters(), lr=0.001)"
  320. ]
  321. },
  322. {
  323. "cell_type": "markdown",
  324. "id": "4f1fba81",
  325. "metadata": {},
  326. "source": [
  327. "现在万事俱备,开始使用 Trainer 进行训练!"
  328. ]
  329. },
  330. {
  331. "cell_type": "code",
  332. "execution_count": 5,
  333. "id": "b51b7a2d",
  334. "metadata": {
  335. "pycharm": {
  336. "is_executing": false
  337. }
  338. },
  339. "outputs": [
  340. {
  341. "data": {
  342. "text/html": [
  343. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  344. "</pre>\n"
  345. ],
  346. "text/plain": [
  347. "\n"
  348. ]
  349. },
  350. "metadata": {},
  351. "output_type": "display_data"
  352. },
  353. {
  354. "data": {
  355. "text/plain": [
  356. "['__annotations__',\n",
  357. " '__class__',\n",
  358. " '__delattr__',\n",
  359. " '__dict__',\n",
  360. " '__dir__',\n",
  361. " '__doc__',\n",
  362. " '__eq__',\n",
  363. " '__format__',\n",
  364. " '__ge__',\n",
  365. " '__getattribute__',\n",
  366. " '__gt__',\n",
  367. " '__hash__',\n",
  368. " '__init__',\n",
  369. " '__init_subclass__',\n",
  370. " '__le__',\n",
  371. " '__lt__',\n",
  372. " '__module__',\n",
  373. " '__ne__',\n",
  374. " '__new__',\n",
  375. " '__reduce__',\n",
  376. " '__reduce_ex__',\n",
  377. " '__repr__',\n",
  378. " '__setattr__',\n",
  379. " '__sizeof__',\n",
  380. " '__str__',\n",
  381. " '__subclasshook__',\n",
  382. " '__weakref__',\n",
  383. " '_check_callback_called_legality',\n",
  384. " '_check_train_batch_loop_legality',\n",
  385. " '_custom_callbacks',\n",
  386. " '_driver',\n",
  387. " '_evaluate_dataloaders',\n",
  388. " '_fetch_matched_fn_callbacks',\n",
  389. " '_set_num_eval_batch_per_dl',\n",
  390. " '_train_batch_loop',\n",
  391. " '_train_dataloader',\n",
  392. " '_train_step',\n",
  393. " '_train_step_signature_fn',\n",
  394. " 'accumulation_steps',\n",
  395. " 'add_callback_fn',\n",
  396. " 'backward',\n",
  397. " 'batch_idx_in_epoch',\n",
  398. " 'batch_step_fn',\n",
  399. " 'callback_manager',\n",
  400. " 'check_batch_step_fn',\n",
  401. " 'cur_epoch_idx',\n",
  402. " 'data_device',\n",
  403. " 'dataloader',\n",
  404. " 'device',\n",
  405. " 'driver',\n",
  406. " 'driver_name',\n",
  407. " 'epoch_validate',\n",
  408. " 'evaluate_batch_step_fn',\n",
  409. " 'evaluate_dataloaders',\n",
  410. " 'evaluate_every',\n",
  411. " 'evaluate_fn',\n",
  412. " 'evaluator',\n",
  413. " 'extract_loss_from_outputs',\n",
  414. " 'fp16',\n",
  415. " 'get_no_sync_context',\n",
  416. " 'global_forward_batches',\n",
  417. " 'has_checked_train_batch_loop',\n",
  418. " 'input_mapping',\n",
  419. " 'kwargs',\n",
  420. " 'larger_better',\n",
  421. " 'load',\n",
  422. " 'load_model',\n",
  423. " 'marker',\n",
  424. " 'metrics',\n",
  425. " 'model',\n",
  426. " 'model_device',\n",
  427. " 'monitor',\n",
  428. " 'move_data_to_device',\n",
  429. " 'n_epochs',\n",
  430. " 'num_batches_per_epoch',\n",
  431. " 'on',\n",
  432. " 'on_after_backward',\n",
  433. " 'on_after_optimizers_step',\n",
  434. " 'on_after_trainer_initialized',\n",
  435. " 'on_after_zero_grad',\n",
  436. " 'on_before_backward',\n",
  437. " 'on_before_optimizers_step',\n",
  438. " 'on_before_zero_grad',\n",
  439. " 'on_exception',\n",
  440. " 'on_fetch_data_begin',\n",
  441. " 'on_fetch_data_end',\n",
  442. " 'on_load_checkpoint',\n",
  443. " 'on_load_model',\n",
  444. " 'on_sanity_check_begin',\n",
  445. " 'on_sanity_check_end',\n",
  446. " 'on_save_checkpoint',\n",
  447. " 'on_save_model',\n",
  448. " 'on_train_batch_begin',\n",
  449. " 'on_train_batch_end',\n",
  450. " 'on_train_begin',\n",
  451. " 'on_train_end',\n",
  452. " 'on_train_epoch_begin',\n",
  453. " 'on_train_epoch_end',\n",
  454. " 'on_validate_begin',\n",
  455. " 'on_validate_end',\n",
  456. " 'optimizers',\n",
  457. " 'output_mapping',\n",
  458. " 'run',\n",
  459. " 'save',\n",
  460. " 'save_model',\n",
  461. " 'set_grad_to_none',\n",
  462. " 'state',\n",
  463. " 'step',\n",
  464. " 'step_validate',\n",
  465. " 'total_batches',\n",
  466. " 'train_batch_loop',\n",
  467. " 'train_dataloader',\n",
  468. " 'train_fn',\n",
  469. " 'train_step',\n",
  470. " 'trainer_state',\n",
  471. " 'zero_grad']"
  472. ]
  473. },
  474. "execution_count": 5,
  475. "metadata": {},
  476. "output_type": "execute_result"
  477. }
  478. ],
  479. "source": [
  480. "from fastNLP import Trainer\n",
  481. "\n",
  482. "# 定义一个 Trainer\n",
  483. "trainer = Trainer(\n",
  484. " model=model,\n",
  485. " driver=\"torch\", # 使用 pytorch 进行训练\n",
  486. " device=0, # 使用 GPU:0\n",
  487. " train_dataloader=train_dataloader,\n",
  488. " optimizers=optimizer,\n",
  489. " n_epochs=10, # 训练 40 个 epoch\n",
  490. " progress_bar=\"rich\"\n",
  491. ")\n",
  492. "dir(trainer)"
  493. ]
  494. },
  495. {
  496. "cell_type": "code",
  497. "execution_count": 8,
  498. "id": "f8fe9c32",
  499. "metadata": {},
  500. "outputs": [
  501. {
  502. "name": "stdout",
  503. "output_type": "stream",
  504. "text": [
  505. "FullArgSpec(args=['self', 'num_train_batch_per_epoch', 'num_eval_batch_per_dl', 'num_eval_sanity_batch', 'resume_from', 'resume_training', 'catch_KeyboardInterrupt'], varargs=None, varkw=None, defaults=(-1, -1, 2, None, True, None), kwonlyargs=[], kwonlydefaults=None, annotations={'num_train_batch_per_epoch': <class 'int'>, 'num_eval_batch_per_dl': <class 'int'>, 'num_eval_sanity_batch': <class 'int'>, 'resume_from': <class 'str'>, 'resume_training': <class 'bool'>})\n"
  506. ]
  507. }
  508. ],
  509. "source": [
  510. "import inspect \n",
  511. "\n",
  512. "print(inspect.getfullargspec(trainer.run))"
  513. ]
  514. },
  515. {
  516. "cell_type": "markdown",
  517. "id": "6e202d6e",
  518. "metadata": {},
  519. "source": [
  520. "没有问题,那么开始真正的训练!"
  521. ]
  522. },
  523. {
  524. "cell_type": "code",
  525. "execution_count": 9,
  526. "id": "ba047ead",
  527. "metadata": {
  528. "pycharm": {
  529. "is_executing": false
  530. }
  531. },
  532. "outputs": [
  533. {
  534. "data": {
  535. "application/vnd.jupyter.widget-view+json": {
  536. "model_id": "",
  537. "version_major": 2,
  538. "version_minor": 0
  539. },
  540. "text/plain": [
  541. "Output()"
  542. ]
  543. },
  544. "metadata": {},
  545. "output_type": "display_data"
  546. },
  547. {
  548. "data": {
  549. "text/html": [
  550. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  551. ],
  552. "text/plain": []
  553. },
  554. "metadata": {},
  555. "output_type": "display_data"
  556. },
  557. {
  558. "data": {
  559. "text/html": [
  560. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  561. "</pre>\n"
  562. ],
  563. "text/plain": [
  564. "\n"
  565. ]
  566. },
  567. "metadata": {},
  568. "output_type": "display_data"
  569. },
  570. {
  571. "data": {
  572. "text/html": [
  573. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  574. "</pre>\n"
  575. ],
  576. "text/plain": [
  577. "\n"
  578. ]
  579. },
  580. "metadata": {},
  581. "output_type": "display_data"
  582. }
  583. ],
  584. "source": [
  585. "trainer.run()"
  586. ]
  587. },
  588. {
  589. "cell_type": "markdown",
  590. "id": "eb8ca6cf",
  591. "metadata": {},
  592. "source": [
  593. "## 3. 使用 evaluator 评测模型"
  594. ]
  595. },
  596. {
  597. "cell_type": "markdown",
  598. "id": "c16c5fa4",
  599. "metadata": {},
  600. "source": [
  601. "模型训练好了我们开始使用 Evaluator 进行评测,查看效果怎么样吧。"
  602. ]
  603. },
  604. {
  605. "cell_type": "code",
  606. "execution_count": 10,
  607. "id": "1c6b6b36",
  608. "metadata": {
  609. "pycharm": {
  610. "is_executing": false
  611. }
  612. },
  613. "outputs": [],
  614. "source": [
  615. "from fastNLP import Evaluator\n",
  616. "from fastNLP.core.metrics import Accuracy\n",
  617. "\n",
  618. "evaluator = Evaluator(\n",
  619. " model=model,\n",
  620. " driver=trainer.driver, # 使用 trainer 已经启动的 driver;\n",
  621. " device=None,\n",
  622. " dataloaders=evaluate_dataloader,\n",
  623. " metrics={'acc': Accuracy()} # 注意这里一定得是一个字典;\n",
  624. ")"
  625. ]
  626. },
  627. {
  628. "cell_type": "code",
  629. "execution_count": 11,
  630. "id": "257061df",
  631. "metadata": {
  632. "scrolled": true
  633. },
  634. "outputs": [
  635. {
  636. "data": {
  637. "text/plain": [
  638. "['__annotations__',\n",
  639. " '__class__',\n",
  640. " '__delattr__',\n",
  641. " '__dict__',\n",
  642. " '__dir__',\n",
  643. " '__doc__',\n",
  644. " '__eq__',\n",
  645. " '__format__',\n",
  646. " '__ge__',\n",
  647. " '__getattribute__',\n",
  648. " '__gt__',\n",
  649. " '__hash__',\n",
  650. " '__init__',\n",
  651. " '__init_subclass__',\n",
  652. " '__le__',\n",
  653. " '__lt__',\n",
  654. " '__module__',\n",
  655. " '__ne__',\n",
  656. " '__new__',\n",
  657. " '__reduce__',\n",
  658. " '__reduce_ex__',\n",
  659. " '__repr__',\n",
  660. " '__setattr__',\n",
  661. " '__sizeof__',\n",
  662. " '__str__',\n",
  663. " '__subclasshook__',\n",
  664. " '__weakref__',\n",
  665. " '_dist_sampler',\n",
  666. " '_evaluate_batch_loop',\n",
  667. " '_evaluate_step',\n",
  668. " '_evaluate_step_signature_fn',\n",
  669. " '_metric_wrapper',\n",
  670. " '_metrics',\n",
  671. " 'dataloaders',\n",
  672. " 'device',\n",
  673. " 'driver',\n",
  674. " 'evaluate_batch_loop',\n",
  675. " 'evaluate_batch_step_fn',\n",
  676. " 'evaluate_fn',\n",
  677. " 'evaluate_step',\n",
  678. " 'finally_progress_bar',\n",
  679. " 'get_dataloader_metric',\n",
  680. " 'input_mapping',\n",
  681. " 'metrics',\n",
  682. " 'metrics_wrapper',\n",
  683. " 'model',\n",
  684. " 'model_use_eval_mode',\n",
  685. " 'move_data_to_device',\n",
  686. " 'output_mapping',\n",
  687. " 'progress_bar',\n",
  688. " 'remove_progress_bar',\n",
  689. " 'reset',\n",
  690. " 'run',\n",
  691. " 'separator',\n",
  692. " 'start_progress_bar',\n",
  693. " 'update',\n",
  694. " 'update_progress_bar',\n",
  695. " 'verbose']"
  696. ]
  697. },
  698. "execution_count": 11,
  699. "metadata": {},
  700. "output_type": "execute_result"
  701. }
  702. ],
  703. "source": [
  704. "dir(evaluator)"
  705. ]
  706. },
  707. {
  708. "cell_type": "code",
  709. "execution_count": 12,
  710. "id": "f7cb0165",
  711. "metadata": {
  712. "pycharm": {
  713. "is_executing": false
  714. }
  715. },
  716. "outputs": [
  717. {
  718. "data": {
  719. "text/html": [
  720. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  721. ],
  722. "text/plain": []
  723. },
  724. "metadata": {},
  725. "output_type": "display_data"
  726. },
  727. {
  728. "data": {
  729. "text/html": [
  730. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  731. ],
  732. "text/plain": []
  733. },
  734. "metadata": {},
  735. "output_type": "display_data"
  736. },
  737. {
  738. "data": {
  739. "text/html": [
  740. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  741. "</pre>\n"
  742. ],
  743. "text/plain": [
  744. "\n"
  745. ]
  746. },
  747. "metadata": {},
  748. "output_type": "display_data"
  749. },
  750. {
  751. "data": {
  752. "text/html": [
  753. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.3</span><span style=\"font-weight: bold\">}</span>\n",
  754. "</pre>\n"
  755. ],
  756. "text/plain": [
  757. "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.3\u001b[0m\u001b[1m}\u001b[0m\n"
  758. ]
  759. },
  760. "metadata": {},
  761. "output_type": "display_data"
  762. },
  763. {
  764. "data": {
  765. "text/plain": [
  766. "{'acc#acc': 0.3}"
  767. ]
  768. },
  769. "execution_count": 12,
  770. "metadata": {},
  771. "output_type": "execute_result"
  772. }
  773. ],
  774. "source": [
  775. "evaluator.run()"
  776. ]
  777. },
  778. {
  779. "cell_type": "markdown",
  780. "id": "dd9f68fa",
  781. "metadata": {},
  782. "source": [
  783. "## 4. 在 trainer 中加入 metric 来自动评测;"
  784. ]
  785. },
  786. {
  787. "cell_type": "markdown",
  788. "id": "ca97c9a4",
  789. "metadata": {},
  790. "source": [
  791. "现在我们尝试在训练过程中进行评测。"
  792. ]
  793. },
  794. {
  795. "cell_type": "code",
  796. "execution_count": 13,
  797. "id": "183c7d19",
  798. "metadata": {
  799. "pycharm": {
  800. "is_executing": false
  801. }
  802. },
  803. "outputs": [],
  804. "source": [
  805. "# 重新定义一个 Trainer\n",
  806. "\n",
  807. "trainer = Trainer(\n",
  808. " model=model,\n",
  809. " driver=trainer.driver, # 因为我们是在同一脚本中,因此这里的 driver 同样需要重用;\n",
  810. " train_dataloader=train_dataloader,\n",
  811. " evaluate_dataloaders=evaluate_dataloader,\n",
  812. " metrics={'acc': Accuracy()},\n",
  813. " optimizers=optimizer,\n",
  814. " n_epochs=10, # 训练 40 个 epoch;\n",
  815. " evaluate_every=-1, # 表示每一个 epoch 的结束会进行 evaluate;\n",
  816. ")"
  817. ]
  818. },
  819. {
  820. "cell_type": "markdown",
  821. "id": "714cc404",
  822. "metadata": {},
  823. "source": [
  824. "再次训练。"
  825. ]
  826. },
  827. {
  828. "cell_type": "code",
  829. "execution_count": 14,
  830. "id": "2e4daa2c",
  831. "metadata": {
  832. "pycharm": {
  833. "is_executing": false
  834. }
  835. },
  836. "outputs": [
  837. {
  838. "data": {
  839. "text/html": [
  840. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  841. ],
  842. "text/plain": []
  843. },
  844. "metadata": {},
  845. "output_type": "display_data"
  846. },
  847. {
  848. "data": {
  849. "text/html": [
  850. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  851. ],
  852. "text/plain": []
  853. },
  854. "metadata": {},
  855. "output_type": "display_data"
  856. },
  857. {
  858. "data": {
  859. "text/html": [
  860. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  861. "</pre>\n"
  862. ],
  863. "text/plain": [
  864. "\n"
  865. ]
  866. },
  867. "metadata": {},
  868. "output_type": "display_data"
  869. },
  870. {
  871. "data": {
  872. "text/html": [
  873. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  874. "</pre>\n"
  875. ],
  876. "text/plain": [
  877. "\n"
  878. ]
  879. },
  880. "metadata": {},
  881. "output_type": "display_data"
  882. }
  883. ],
  884. "source": [
  885. "trainer.run()"
  886. ]
  887. },
  888. {
  889. "cell_type": "code",
  890. "execution_count": 15,
  891. "id": "eabda5eb",
  892. "metadata": {},
  893. "outputs": [],
  894. "source": [
  895. "evaluator = Evaluator(\n",
  896. " model=model,\n",
  897. " driver=trainer.driver, # 使用 trainer 已经启动的 driver;\n",
  898. " dataloaders=evaluate_dataloader,\n",
  899. " metrics={'acc': Accuracy()} # 注意这里一定得是一个字典;\n",
  900. ")"
  901. ]
  902. },
  903. {
  904. "cell_type": "code",
  905. "execution_count": 16,
  906. "id": "a310d157",
  907. "metadata": {},
  908. "outputs": [
  909. {
  910. "data": {
  911. "text/html": [
  912. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  913. ],
  914. "text/plain": []
  915. },
  916. "metadata": {},
  917. "output_type": "display_data"
  918. },
  919. {
  920. "data": {
  921. "text/html": [
  922. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  923. ],
  924. "text/plain": []
  925. },
  926. "metadata": {},
  927. "output_type": "display_data"
  928. },
  929. {
  930. "data": {
  931. "text/html": [
  932. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  933. "</pre>\n"
  934. ],
  935. "text/plain": [
  936. "\n"
  937. ]
  938. },
  939. "metadata": {},
  940. "output_type": "display_data"
  941. },
  942. {
  943. "data": {
  944. "text/html": [
  945. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.5</span><span style=\"font-weight: bold\">}</span>\n",
  946. "</pre>\n"
  947. ],
  948. "text/plain": [
  949. "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.5\u001b[0m\u001b[1m}\u001b[0m\n"
  950. ]
  951. },
  952. "metadata": {},
  953. "output_type": "display_data"
  954. },
  955. {
  956. "data": {
  957. "text/plain": [
  958. "{'acc#acc': 0.5}"
  959. ]
  960. },
  961. "execution_count": 16,
  962. "metadata": {},
  963. "output_type": "execute_result"
  964. }
  965. ],
  966. "source": [
  967. "evaluator.run()"
  968. ]
  969. },
  970. {
  971. "cell_type": "code",
  972. "execution_count": null,
  973. "id": "f1ef78f0",
  974. "metadata": {},
  975. "outputs": [],
  976. "source": []
  977. }
  978. ],
  979. "metadata": {
  980. "kernelspec": {
  981. "display_name": "Python 3 (ipykernel)",
  982. "language": "python",
  983. "name": "python3"
  984. },
  985. "language_info": {
  986. "codemirror_mode": {
  987. "name": "ipython",
  988. "version": 3
  989. },
  990. "file_extension": ".py",
  991. "mimetype": "text/x-python",
  992. "name": "python",
  993. "nbconvert_exporter": "python",
  994. "pygments_lexer": "ipython3",
  995. "version": "3.7.4"
  996. },
  997. "pycharm": {
  998. "stem_cell": {
  999. "cell_type": "raw",
  1000. "metadata": {
  1001. "collapsed": false
  1002. },
  1003. "source": []
  1004. }
  1005. }
  1006. },
  1007. "nbformat": 4,
  1008. "nbformat_minor": 5
  1009. }