You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

文本分类.ipynb 26 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "## 文本分类(Text classification)\n",
  8. "文本分类任务是将一句话或一段话划分到某个具体的类别。比如垃圾邮件识别,文本情绪分类等。\n",
  9. "\n",
  10. "Example:: \n",
  11. "1,商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!\n",
  12. "\n",
  13. "\n",
  14. "其中开头的1是只这条评论的标签,表示是正面的情绪。我们将使用到的数据可以通过http://dbcloud.irocn.cn:8989/api/public/dl/dataset/chn_senti_corp.zip 下载并解压,当然也可以通过fastNLP自动下载该数据。\n",
  15. "\n",
  16. "数据中的内容如下图所示。接下来,我们将用fastNLP在这个数据上训练一个分类网络。"
  17. ]
  18. },
  19. {
  20. "cell_type": "markdown",
  21. "metadata": {},
  22. "source": [
  23. "![jupyter](./cn_cls_example.png)"
  24. ]
  25. },
  26. {
  27. "cell_type": "markdown",
  28. "metadata": {},
  29. "source": [
  30. "## 步骤\n",
  31. "一共有以下的几个步骤 \n",
  32. "(1) 读取数据 \n",
  33. "(2) 预处理数据 \n",
  34. "(3) 选择预训练词向量 \n",
  35. "(4) 创建模型 \n",
  36. "(5) 训练模型 "
  37. ]
  38. },
  39. {
  40. "cell_type": "markdown",
  41. "metadata": {},
  42. "source": [
  43. "### (1) 读取数据\n",
  44. "fastNLP提供多种数据的自动下载与自动加载功能,对于这里我们要用到的数据,我们可以用\\ref{Loader}自动下载并加载该数据。更多有关Loader的使用可以参考\\ref{Loader}"
  45. ]
  46. },
  47. {
  48. "cell_type": "code",
  49. "execution_count": 1,
  50. "metadata": {},
  51. "outputs": [],
  52. "source": [
  53. "from fastNLP.io import ChnSentiCorpLoader\n",
  54. "\n",
  55. "loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader\n",
  56. "data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回\n",
  57. "data_bundle = loader.load(data_dir) # 这一行代码将从{data_dir}处读取数据至DataBundle"
  58. ]
  59. },
  60. {
  61. "cell_type": "markdown",
  62. "metadata": {},
  63. "source": [
  64. "DataBundle的相关介绍,可以参考\\ref{}。我们可以打印该data_bundle的基本信息。"
  65. ]
  66. },
  67. {
  68. "cell_type": "code",
  69. "execution_count": 2,
  70. "metadata": {},
  71. "outputs": [
  72. {
  73. "name": "stdout",
  74. "output_type": "stream",
  75. "text": [
  76. "In total 3 datasets:\n",
  77. "\tdev has 1200 instances.\n",
  78. "\ttrain has 9600 instances.\n",
  79. "\ttest has 1200 instances.\n",
  80. "In total 0 vocabs:\n",
  81. "\n"
  82. ]
  83. }
  84. ],
  85. "source": [
  86. "print(data_bundle)"
  87. ]
  88. },
  89. {
  90. "cell_type": "markdown",
  91. "metadata": {},
  92. "source": [
  93. "可以看出,该data_bundle中一个含有三个\\ref{DataSet}。通过下面的代码,我们可以查看DataSet的基本情况"
  94. ]
  95. },
  96. {
  97. "cell_type": "code",
  98. "execution_count": 6,
  99. "metadata": {},
  100. "outputs": [
  101. {
  102. "name": "stdout",
  103. "output_type": "stream",
  104. "text": [
  105. "DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n",
  106. "'target': 1 type=str},\n",
  107. "{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n",
  108. "'target': 1 type=str})\n"
  109. ]
  110. }
  111. ],
  112. "source": [
  113. "print(data_bundle.get_dataset('train')[:2]) # 查看Train集前两个sample"
  114. ]
  115. },
  116. {
  117. "cell_type": "markdown",
  118. "metadata": {},
  119. "source": [
  120. "### (2) 预处理数据\n",
  121. "在NLP任务中,预处理一般包括: (a)将一整句话切分成汉字或者词; (b)将文本转换为index \n",
  122. "\n",
  123. "fastNLP中也提供了多种数据集的处理类,这里我们直接使用fastNLP的ChnSentiCorpPipe。更多关于Pipe的说明可以参考\\ref{Pipe}。"
  124. ]
  125. },
  126. {
  127. "cell_type": "code",
  128. "execution_count": 3,
  129. "metadata": {},
  130. "outputs": [],
  131. "source": [
  132. "from fastNLP.io import ChnSentiCorpPipe\n",
  133. "\n",
  134. "pipe = ChnSentiCorpPipe()\n",
  135. "data_bundle = pipe.process(data_bundle) # 所有的Pipe都实现了process()方法,且输入输出都为DataBundle类型"
  136. ]
  137. },
  138. {
  139. "cell_type": "code",
  140. "execution_count": 4,
  141. "metadata": {},
  142. "outputs": [
  143. {
  144. "name": "stdout",
  145. "output_type": "stream",
  146. "text": [
  147. "In total 3 datasets:\n",
  148. "\tdev has 1200 instances.\n",
  149. "\ttrain has 9600 instances.\n",
  150. "\ttest has 1200 instances.\n",
  151. "In total 2 vocabs:\n",
  152. "\tchars has 4409 entries.\n",
  153. "\ttarget has 2 entries.\n",
  154. "\n"
  155. ]
  156. }
  157. ],
  158. "source": [
  159. "print(data_bundle) # 打印data_bundle,查看其变化"
  160. ]
  161. },
  162. {
  163. "cell_type": "markdown",
  164. "metadata": {},
  165. "source": [
  166. "可以看到除了之前已经包含的3个\\ref{DataSet}, 还新增了两个\\ref{Vocabulary}。我们可以打印DataSet中的内容"
  167. ]
  168. },
  169. {
  170. "cell_type": "code",
  171. "execution_count": 5,
  172. "metadata": {},
  173. "outputs": [
  174. {
  175. "name": "stdout",
  176. "output_type": "stream",
  177. "text": [
  178. "DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n",
  179. "'target': 1 type=int,\n",
  180. "'chars': [338, 464, 1400, 784, 468, 739, 3, 289, 151, 21, 5, 88, 143, 2, 9, 81, 134, 2573, 766, 233, 196, 23, 536, 342, 297, 2, 405, 698, 132, 281, 74, 744, 1048, 74, 420, 387, 74, 412, 433, 74, 2021, 180, 8, 219, 1929, 213, 4, 34, 31, 96, 363, 8, 230, 2, 66, 18, 229, 331, 768, 4, 11, 1094, 479, 17, 35, 593, 3, 1126, 967, 2, 151, 245, 12, 44, 2, 6, 52, 260, 263, 635, 5, 152, 162, 4, 11, 336, 3, 154, 132, 5, 236, 443, 3, 2, 18, 229, 761, 700, 4, 11, 48, 59, 653, 2, 8, 230] type=list,\n",
  181. "'seq_len': 106 type=int},\n",
  182. "{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n",
  183. "'target': 1 type=int,\n",
  184. "'chars': [50, 133, 20, 135, 945, 520, 343, 24, 3, 301, 176, 350, 86, 785, 2, 456, 24, 461, 163, 443, 128, 109, 6, 47, 7, 2, 916, 152, 162, 524, 296, 44, 301, 176, 2, 1384, 524, 296, 259, 88, 143, 2, 92, 67, 26, 12, 277, 269, 2, 188, 223, 26, 228, 83, 6, 63] type=list,\n",
  185. "'seq_len': 56 type=int})\n"
  186. ]
  187. }
  188. ],
  189. "source": [
  190. "print(data_bundle.get_dataset('train')[:2])"
  191. ]
  192. },
  193. {
  194. "cell_type": "markdown",
  195. "metadata": {},
  196. "source": [
  197. "新增了一列为数字列表的chars,以及变为数字的target列。可以看出这两列的名称和刚好与data_bundle中两个Vocabulary的名称是一致的,我们可以打印一下Vocabulary看一下里面的内容。"
  198. ]
  199. },
  200. {
  201. "cell_type": "code",
  202. "execution_count": 6,
  203. "metadata": {},
  204. "outputs": [
  205. {
  206. "name": "stdout",
  207. "output_type": "stream",
  208. "text": [
  209. "Vocabulary(['选', '择', '珠', '江', '花']...)\n"
  210. ]
  211. }
  212. ],
  213. "source": [
  214. "char_vocab = data_bundle.get_vocab('chars')\n",
  215. "print(char_vocab)"
  216. ]
  217. },
  218. {
  219. "cell_type": "markdown",
  220. "metadata": {},
  221. "source": [
  222. "Vocabulary是一个记录着词语与index之间映射关系的类,比如"
  223. ]
  224. },
  225. {
  226. "cell_type": "code",
  227. "execution_count": 7,
  228. "metadata": {},
  229. "outputs": [
  230. {
  231. "name": "stdout",
  232. "output_type": "stream",
  233. "text": [
  234. "'选'的index是338\n",
  235. "index:338对应的汉字是选\n"
  236. ]
  237. }
  238. ],
  239. "source": [
  240. "index = char_vocab.to_index('选')\n",
  241. "print(\"'选'的index是{}\".format(index)) # 这个值与上面打印出来的第一个instance的chars的第一个index是一致的\n",
  242. "print(\"index:{}对应的汉字是{}\".format(index, char_vocab.to_word(index))) "
  243. ]
  244. },
  245. {
  246. "cell_type": "markdown",
  247. "metadata": {},
  248. "source": [
  249. "### (3) 选择预训练词向量 \n",
  250. "由于Word2vec, Glove, Elmo, Bert等预训练模型可以增强模型的性能,所以在训练具体任务前,选择合适的预训练词向量非常重要。在fastNLP中我们提供了多种Embedding使得加载这些预训练模型的过程变得更加便捷。更多关于Embedding的说明可以参考\\ref{Embedding}。这里我们先给出一个使用word2vec的中文汉字预训练的示例,之后再给出一个使用Bert的文本分类。这里使用的预训练词向量为'cn-fastnlp-100d',fastNLP将自动下载该embedding至本地缓存,fastNLP支持使用名字指定的Embedding以及相关说明可以参见\\ref{Embedding}"
  251. ]
  252. },
  253. {
  254. "cell_type": "code",
  255. "execution_count": 8,
  256. "metadata": {},
  257. "outputs": [
  258. {
  259. "name": "stdout",
  260. "output_type": "stream",
  261. "text": [
  262. "Found 4321 out of 4409 words in the pre-training embedding.\n"
  263. ]
  264. }
  265. ],
  266. "source": [
  267. "from fastNLP.embeddings import StaticEmbedding\n",
  268. "\n",
  269. "word2vec_embed = StaticEmbedding(char_vocab, model_dir_or_name='cn-char-fastnlp-100d')"
  270. ]
  271. },
  272. {
  273. "cell_type": "markdown",
  274. "metadata": {},
  275. "source": [
  276. "### (4) 创建模型\n",
  277. "这里我们使用到的模型结构如下所示,补图"
  278. ]
  279. },
  280. {
  281. "cell_type": "code",
  282. "execution_count": 9,
  283. "metadata": {},
  284. "outputs": [],
  285. "source": [
  286. "from torch import nn\n",
  287. "from fastNLP.modules import LSTM\n",
  288. "import torch\n",
  289. "\n",
  290. "# 定义模型\n",
  291. "class BiLSTMMaxPoolCls(nn.Module):\n",
  292. " def __init__(self, embed, num_classes, hidden_size=400, num_layers=1, dropout=0.3):\n",
  293. " super().__init__()\n",
  294. " self.embed = embed\n",
  295. " \n",
  296. " self.lstm = LSTM(self.embed.embedding_dim, hidden_size=hidden_size//2, num_layers=num_layers, \n",
  297. " batch_first=True, bidirectional=True)\n",
  298. " self.dropout_layer = nn.Dropout(dropout)\n",
  299. " self.fc = nn.Linear(hidden_size, num_classes)\n",
  300. " \n",
  301. " def forward(self, chars, seq_len): # 这里的名称必须和DataSet中相应的field对应,比如之前我们DataSet中有chars,这里就必须为chars\n",
  302. " # chars:[batch_size, max_len]\n",
  303. " # seq_len: [batch_size, ]\n",
  304. " chars = self.embed(chars)\n",
  305. " outputs, _ = self.lstm(chars, seq_len)\n",
  306. " outputs = self.dropout_layer(outputs)\n",
  307. " outputs, _ = torch.max(outputs, dim=1)\n",
  308. " outputs = self.fc(outputs)\n",
  309. " \n",
  310. " return {'pred':outputs} # [batch_size,], 返回值必须是dict类型,且预测值的key建议设为pred\n",
  311. "\n",
  312. "# 初始化模型\n",
  313. "model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))"
  314. ]
  315. },
  316. {
  317. "cell_type": "markdown",
  318. "metadata": {},
  319. "source": [
  320. "### (5) 训练模型\n",
  321. "fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所以在初始化Trainer的时候需要指定loss类型),梯度更新(所以在初始化Trainer的时候需要提供优化器optimizer)以及在验证集上的性能验证(所以在初始化时需要提供一个Metric)"
  322. ]
  323. },
  324. {
  325. "cell_type": "code",
  326. "execution_count": 10,
  327. "metadata": {},
  328. "outputs": [
  329. {
  330. "name": "stdout",
  331. "output_type": "stream",
  332. "text": [
  333. "input fields after batch(if batch size is 2):\n",
  334. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  335. "\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n",
  336. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  337. "target fields after batch(if batch size is 2):\n",
  338. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  339. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  340. "\n",
  341. "Evaluate data in 0.01 seconds!\n",
  342. "training epochs started 2019-09-03-23-57-10\n"
  343. ]
  344. },
  345. {
  346. "data": {
  347. "text/plain": [
  348. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3000), HTML(value='')), layout=Layout(display…"
  349. ]
  350. },
  351. "metadata": {},
  352. "output_type": "display_data"
  353. },
  354. {
  355. "data": {
  356. "text/plain": [
  357. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  358. ]
  359. },
  360. "metadata": {},
  361. "output_type": "display_data"
  362. },
  363. {
  364. "name": "stdout",
  365. "output_type": "stream",
  366. "text": [
  367. "\r",
  368. "Evaluate data in 0.43 seconds!\n",
  369. "\r",
  370. "Evaluation on dev at Epoch 1/10. Step:300/3000: \n",
  371. "\r",
  372. "AccuracyMetric: acc=0.81\n",
  373. "\n"
  374. ]
  375. },
  376. {
  377. "data": {
  378. "text/plain": [
  379. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  380. ]
  381. },
  382. "metadata": {},
  383. "output_type": "display_data"
  384. },
  385. {
  386. "name": "stdout",
  387. "output_type": "stream",
  388. "text": [
  389. "\r",
  390. "Evaluate data in 0.44 seconds!\n",
  391. "\r",
  392. "Evaluation on dev at Epoch 2/10. Step:600/3000: \n",
  393. "\r",
  394. "AccuracyMetric: acc=0.8675\n",
  395. "\n"
  396. ]
  397. },
  398. {
  399. "data": {
  400. "text/plain": [
  401. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  402. ]
  403. },
  404. "metadata": {},
  405. "output_type": "display_data"
  406. },
  407. {
  408. "name": "stdout",
  409. "output_type": "stream",
  410. "text": [
  411. "\r",
  412. "Evaluate data in 0.44 seconds!\n",
  413. "\r",
  414. "Evaluation on dev at Epoch 3/10. Step:900/3000: \n",
  415. "\r",
  416. "AccuracyMetric: acc=0.878333\n",
  417. "\n"
  418. ]
  419. },
  420. {
  421. "data": {
  422. "text/plain": [
  423. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  424. ]
  425. },
  426. "metadata": {},
  427. "output_type": "display_data"
  428. },
  429. {
  430. "name": "stdout",
  431. "output_type": "stream",
  432. "text": [
  433. "\r",
  434. "Evaluate data in 0.43 seconds!\n",
  435. "\r",
  436. "Evaluation on dev at Epoch 4/10. Step:1200/3000: \n",
  437. "\r",
  438. "AccuracyMetric: acc=0.873333\n",
  439. "\n"
  440. ]
  441. },
  442. {
  443. "data": {
  444. "text/plain": [
  445. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  446. ]
  447. },
  448. "metadata": {},
  449. "output_type": "display_data"
  450. },
  451. {
  452. "name": "stdout",
  453. "output_type": "stream",
  454. "text": [
  455. "\r",
  456. "Evaluate data in 0.44 seconds!\n",
  457. "\r",
  458. "Evaluation on dev at Epoch 5/10. Step:1500/3000: \n",
  459. "\r",
  460. "AccuracyMetric: acc=0.878333\n",
  461. "\n"
  462. ]
  463. },
  464. {
  465. "data": {
  466. "text/plain": [
  467. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  468. ]
  469. },
  470. "metadata": {},
  471. "output_type": "display_data"
  472. },
  473. {
  474. "name": "stdout",
  475. "output_type": "stream",
  476. "text": [
  477. "\r",
  478. "Evaluate data in 0.42 seconds!\n",
  479. "\r",
  480. "Evaluation on dev at Epoch 6/10. Step:1800/3000: \n",
  481. "\r",
  482. "AccuracyMetric: acc=0.895833\n",
  483. "\n"
  484. ]
  485. },
  486. {
  487. "data": {
  488. "text/plain": [
  489. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  490. ]
  491. },
  492. "metadata": {},
  493. "output_type": "display_data"
  494. },
  495. {
  496. "name": "stdout",
  497. "output_type": "stream",
  498. "text": [
  499. "\r",
  500. "Evaluate data in 0.44 seconds!\n",
  501. "\r",
  502. "Evaluation on dev at Epoch 7/10. Step:2100/3000: \n",
  503. "\r",
  504. "AccuracyMetric: acc=0.8975\n",
  505. "\n"
  506. ]
  507. },
  508. {
  509. "data": {
  510. "text/plain": [
  511. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  512. ]
  513. },
  514. "metadata": {},
  515. "output_type": "display_data"
  516. },
  517. {
  518. "name": "stdout",
  519. "output_type": "stream",
  520. "text": [
  521. "\r",
  522. "Evaluate data in 0.43 seconds!\n",
  523. "\r",
  524. "Evaluation on dev at Epoch 8/10. Step:2400/3000: \n",
  525. "\r",
  526. "AccuracyMetric: acc=0.894167\n",
  527. "\n"
  528. ]
  529. },
  530. {
  531. "data": {
  532. "text/plain": [
  533. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  534. ]
  535. },
  536. "metadata": {},
  537. "output_type": "display_data"
  538. },
  539. {
  540. "name": "stdout",
  541. "output_type": "stream",
  542. "text": [
  543. "\r",
  544. "Evaluate data in 0.48 seconds!\n",
  545. "\r",
  546. "Evaluation on dev at Epoch 9/10. Step:2700/3000: \n",
  547. "\r",
  548. "AccuracyMetric: acc=0.8875\n",
  549. "\n"
  550. ]
  551. },
  552. {
  553. "data": {
  554. "text/plain": [
  555. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  556. ]
  557. },
  558. "metadata": {},
  559. "output_type": "display_data"
  560. },
  561. {
  562. "name": "stdout",
  563. "output_type": "stream",
  564. "text": [
  565. "\r",
  566. "Evaluate data in 0.43 seconds!\n",
  567. "\r",
  568. "Evaluation on dev at Epoch 10/10. Step:3000/3000: \n",
  569. "\r",
  570. "AccuracyMetric: acc=0.895833\n",
  571. "\n",
  572. "\r\n",
  573. "In Epoch:7/Step:2100, got best dev performance:\n",
  574. "AccuracyMetric: acc=0.8975\n",
  575. "Reloaded the best model.\n"
  576. ]
  577. },
  578. {
  579. "data": {
  580. "text/plain": [
  581. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…"
  582. ]
  583. },
  584. "metadata": {},
  585. "output_type": "display_data"
  586. },
  587. {
  588. "name": "stdout",
  589. "output_type": "stream",
  590. "text": [
  591. "\r",
  592. "Evaluate data in 0.34 seconds!\n",
  593. "[tester] \n",
  594. "AccuracyMetric: acc=0.8975\n"
  595. ]
  596. },
  597. {
  598. "data": {
  599. "text/plain": [
  600. "{'AccuracyMetric': {'acc': 0.8975}}"
  601. ]
  602. },
  603. "execution_count": 10,
  604. "metadata": {},
  605. "output_type": "execute_result"
  606. }
  607. ],
  608. "source": [
  609. "from fastNLP import Trainer\n",
  610. "from fastNLP import CrossEntropyLoss\n",
  611. "from torch.optim import Adam\n",
  612. "from fastNLP import AccuracyMetric\n",
  613. "\n",
  614. "loss = CrossEntropyLoss()\n",
  615. "optimizer = Adam(model.parameters(), lr=0.001)\n",
  616. "metric = AccuracyMetric()\n",
  617. "device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
  618. "\n",
  619. "trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
  620. " optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),\n",
  621. " metrics=metric, device=device)\n",
  622. "trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
  623. "\n",
  624. "# 在测试集上测试一下模型的性能\n",
  625. "from fastNLP import Tester\n",
  626. "print(\"Performance on test is:\")\n",
  627. "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
  628. "tester.test()"
  629. ]
  630. },
  631. {
  632. "cell_type": "markdown",
  633. "metadata": {},
  634. "source": [
  635. "### 使用Bert进行文本分类"
  636. ]
  637. },
  638. {
  639. "cell_type": "code",
  640. "execution_count": 12,
  641. "metadata": {},
  642. "outputs": [
  643. {
  644. "name": "stdout",
  645. "output_type": "stream",
  646. "text": [
  647. "loading vocabulary file /home/yh/.fastNLP/embedding/bert-chinese-wwm/vocab.txt\n",
  648. "Load pre-trained BERT parameters from file /home/yh/.fastNLP/embedding/bert-chinese-wwm/chinese_wwm_pytorch.bin.\n",
  649. "Start to generating word pieces for word.\n",
  650. "Found(Or segment into word pieces) 4286 words out of 4409.\n",
  651. "input fields after batch(if batch size is 2):\n",
  652. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  653. "\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n",
  654. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  655. "target fields after batch(if batch size is 2):\n",
  656. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  657. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  658. "\n",
  659. "Evaluate data in 0.05 seconds!\n",
  660. "training epochs started 2019-09-04-00-02-37\n"
  661. ]
  662. },
  663. {
  664. "data": {
  665. "text/plain": [
  666. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3600), HTML(value='')), layout=Layout(display…"
  667. ]
  668. },
  669. "metadata": {},
  670. "output_type": "display_data"
  671. },
  672. {
  673. "data": {
  674. "text/plain": [
  675. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…"
  676. ]
  677. },
  678. "metadata": {},
  679. "output_type": "display_data"
  680. },
  681. {
  682. "name": "stdout",
  683. "output_type": "stream",
  684. "text": [
  685. "\r",
  686. "Evaluate data in 15.89 seconds!\n",
  687. "\r",
  688. "Evaluation on dev at Epoch 1/3. Step:1200/3600: \n",
  689. "\r",
  690. "AccuracyMetric: acc=0.9\n",
  691. "\n"
  692. ]
  693. },
  694. {
  695. "data": {
  696. "text/plain": [
  697. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…"
  698. ]
  699. },
  700. "metadata": {},
  701. "output_type": "display_data"
  702. },
  703. {
  704. "name": "stdout",
  705. "output_type": "stream",
  706. "text": [
  707. "\r",
  708. "Evaluate data in 15.92 seconds!\n",
  709. "\r",
  710. "Evaluation on dev at Epoch 2/3. Step:2400/3600: \n",
  711. "\r",
  712. "AccuracyMetric: acc=0.904167\n",
  713. "\n"
  714. ]
  715. },
  716. {
  717. "data": {
  718. "text/plain": [
  719. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…"
  720. ]
  721. },
  722. "metadata": {},
  723. "output_type": "display_data"
  724. },
  725. {
  726. "name": "stdout",
  727. "output_type": "stream",
  728. "text": [
  729. "\r",
  730. "Evaluate data in 15.91 seconds!\n",
  731. "\r",
  732. "Evaluation on dev at Epoch 3/3. Step:3600/3600: \n",
  733. "\r",
  734. "AccuracyMetric: acc=0.918333\n",
  735. "\n",
  736. "\r\n",
  737. "In Epoch:3/Step:3600, got best dev performance:\n",
  738. "AccuracyMetric: acc=0.918333\n",
  739. "Reloaded the best model.\n",
  740. "Performance on test is:\n"
  741. ]
  742. },
  743. {
  744. "data": {
  745. "text/plain": [
  746. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…"
  747. ]
  748. },
  749. "metadata": {},
  750. "output_type": "display_data"
  751. },
  752. {
  753. "name": "stdout",
  754. "output_type": "stream",
  755. "text": [
  756. "\r",
  757. "Evaluate data in 29.24 seconds!\n",
  758. "[tester] \n",
  759. "AccuracyMetric: acc=0.919167\n"
  760. ]
  761. },
  762. {
  763. "data": {
  764. "text/plain": [
  765. "{'AccuracyMetric': {'acc': 0.919167}}"
  766. ]
  767. },
  768. "execution_count": 12,
  769. "metadata": {},
  770. "output_type": "execute_result"
  771. }
  772. ],
  773. "source": [
  774. "# 只需要切换一下Embedding即可\n",
  775. "from fastNLP.embeddings import BertEmbedding\n",
  776. "\n",
  777. "# 这里为了演示一下效果,所以默认Bert不更新权重\n",
  778. "bert_embed = BertEmbedding(char_vocab, model_dir_or_name='cn', auto_truncate=True, requires_grad=False)\n",
  779. "model = BiLSTMMaxPoolCls(bert_embed, len(data_bundle.get_vocab('target')), )\n",
  780. "\n",
  781. "\n",
  782. "import torch\n",
  783. "from fastNLP import Trainer\n",
  784. "from fastNLP import CrossEntropyLoss\n",
  785. "from torch.optim import Adam\n",
  786. "from fastNLP import AccuracyMetric\n",
  787. "\n",
  788. "loss = CrossEntropyLoss()\n",
  789. "optimizer = Adam(model.parameters(), lr=2e-5)\n",
  790. "metric = AccuracyMetric()\n",
  791. "device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
  792. "\n",
  793. "trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
  794. " optimizer=optimizer, batch_size=16, dev_data=data_bundle.get_dataset('test'),\n",
  795. " metrics=metric, device=device, n_epochs=3)\n",
  796. "trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
  797. "\n",
  798. "# 在测试集上测试一下模型的性能\n",
  799. "from fastNLP import Tester\n",
  800. "print(\"Performance on test is:\")\n",
  801. "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
  802. "tester.test()"
  803. ]
  804. },
  805. {
  806. "cell_type": "code",
  807. "execution_count": null,
  808. "metadata": {},
  809. "outputs": [],
  810. "source": []
  811. }
  812. ],
  813. "metadata": {
  814. "kernelspec": {
  815. "display_name": "Python 3",
  816. "language": "python",
  817. "name": "python3"
  818. },
  819. "language_info": {
  820. "codemirror_mode": {
  821. "name": "ipython",
  822. "version": 3
  823. },
  824. "file_extension": ".py",
  825. "mimetype": "text/x-python",
  826. "name": "python",
  827. "nbconvert_exporter": "python",
  828. "pygments_lexer": "ipython3",
  829. "version": "3.6.7"
  830. }
  831. },
  832. "nbformat": 4,
  833. "nbformat_minor": 2
  834. }