You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

文本分类.ipynb 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "## 文本分类(Text classification)\n",
  8. "文本分类任务是将一句话或一段话划分到某个具体的类别。比如垃圾邮件识别,文本情绪分类等。\n",
  9. "\n",
  10. "Example:: \n",
  11. "1,商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!\n",
  12. "\n",
  13. "\n",
  14. "其中开头的1是只这条评论的标签,表示是正面的情绪。我们将使用到的数据可以通过http://dbcloud.irocn.cn:8989/api/public/dl/dataset/chn_senti_corp.zip 下载并解压,当然也可以通过fastNLP自动下载该数据。\n",
  15. "\n",
  16. "数据中的内容如下图所示。接下来,我们将用fastNLP在这个数据上训练一个分类网络。"
  17. ]
  18. },
  19. {
  20. "cell_type": "markdown",
  21. "metadata": {},
  22. "source": [
  23. "![jupyter](./cn_cls_example.png)"
  24. ]
  25. },
  26. {
  27. "cell_type": "markdown",
  28. "metadata": {},
  29. "source": [
  30. "## 步骤\n",
  31. "一共有以下的几个步骤 \n",
  32. "(1) 读取数据 \n",
  33. "(2) 预处理数据 \n",
  34. "(3) 选择预训练词向量 \n",
  35. "(4) 创建模型 \n",
  36. "(5) 训练模型 "
  37. ]
  38. },
  39. {
  40. "cell_type": "markdown",
  41. "metadata": {},
  42. "source": [
  43. "### (1) 读取数据\n",
  44. "fastNLP提供多种数据的自动下载与自动加载功能,对于这里我们要用到的数据,我们可以用\\ref{Loader}自动下载并加载该数据。更多有关Loader的使用可以参考\\ref{Loader}"
  45. ]
  46. },
  47. {
  48. "cell_type": "code",
  49. "execution_count": 1,
  50. "metadata": {
  51. "collapsed": true
  52. },
  53. "outputs": [],
  54. "source": [
  55. "from fastNLP.io import ChnSentiCorpLoader\n",
  56. "\n",
  57. "loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader\n",
  58. "data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回\n",
  59. "data_bundle = loader.load(data_dir) # 这一行代码将从{data_dir}处读取数据至DataBundle"
  60. ]
  61. },
  62. {
  63. "cell_type": "markdown",
  64. "metadata": {},
  65. "source": [
  66. "DataBundle的相关介绍,可以参考\\ref{}。我们可以打印该data_bundle的基本信息。"
  67. ]
  68. },
  69. {
  70. "cell_type": "code",
  71. "execution_count": 2,
  72. "metadata": {},
  73. "outputs": [
  74. {
  75. "name": "stdout",
  76. "output_type": "stream",
  77. "text": [
  78. "In total 3 datasets:\n",
  79. "\tdev has 1200 instances.\n",
  80. "\ttrain has 9600 instances.\n",
  81. "\ttest has 1200 instances.\n",
  82. "In total 0 vocabs:\n",
  83. "\n"
  84. ]
  85. }
  86. ],
  87. "source": [
  88. "print(data_bundle)"
  89. ]
  90. },
  91. {
  92. "cell_type": "markdown",
  93. "metadata": {},
  94. "source": [
  95. "可以看出,该data_bundle中一个含有三个\\ref{DataSet}。通过下面的代码,我们可以查看DataSet的基本情况"
  96. ]
  97. },
  98. {
  99. "cell_type": "code",
  100. "execution_count": 6,
  101. "metadata": {},
  102. "outputs": [
  103. {
  104. "name": "stdout",
  105. "output_type": "stream",
  106. "text": [
  107. "DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n",
  108. "'target': 1 type=str},\n",
  109. "{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n",
  110. "'target': 1 type=str})\n"
  111. ]
  112. }
  113. ],
  114. "source": [
  115. "print(data_bundle.get_dataset('train')[:2]) # 查看Train集前两个sample"
  116. ]
  117. },
  118. {
  119. "cell_type": "markdown",
  120. "metadata": {},
  121. "source": [
  122. "### (2) 预处理数据\n",
  123. "在NLP任务中,预处理一般包括: (a)将一整句话切分成汉字或者词; (b)将文本转换为index \n",
  124. "\n",
  125. "fastNLP中也提供了多种数据集的处理类,这里我们直接使用fastNLP的ChnSentiCorpPipe。更多关于Pipe的说明可以参考\\ref{Pipe}。"
  126. ]
  127. },
  128. {
  129. "cell_type": "code",
  130. "execution_count": 3,
  131. "metadata": {
  132. "collapsed": true
  133. },
  134. "outputs": [],
  135. "source": [
  136. "from fastNLP.io import ChnSentiCorpPipe\n",
  137. "\n",
  138. "pipe = ChnSentiCorpPipe()\n",
  139. "data_bundle = pipe.process(data_bundle) # 所有的Pipe都实现了process()方法,且输入输出都为DataBundle类型"
  140. ]
  141. },
  142. {
  143. "cell_type": "code",
  144. "execution_count": 4,
  145. "metadata": {},
  146. "outputs": [
  147. {
  148. "name": "stdout",
  149. "output_type": "stream",
  150. "text": [
  151. "In total 3 datasets:\n",
  152. "\tdev has 1200 instances.\n",
  153. "\ttrain has 9600 instances.\n",
  154. "\ttest has 1200 instances.\n",
  155. "In total 2 vocabs:\n",
  156. "\tchars has 4409 entries.\n",
  157. "\ttarget has 2 entries.\n",
  158. "\n"
  159. ]
  160. }
  161. ],
  162. "source": [
  163. "print(data_bundle) # 打印data_bundle,查看其变化"
  164. ]
  165. },
  166. {
  167. "cell_type": "markdown",
  168. "metadata": {},
  169. "source": [
  170. "可以看到除了之前已经包含的3个\\ref{DataSet}, 还新增了两个\\ref{Vocabulary}。我们可以打印DataSet中的内容"
  171. ]
  172. },
  173. {
  174. "cell_type": "code",
  175. "execution_count": 5,
  176. "metadata": {},
  177. "outputs": [
  178. {
  179. "name": "stdout",
  180. "output_type": "stream",
  181. "text": [
  182. "DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n",
  183. "'target': 1 type=int,\n",
  184. "'chars': [338, 464, 1400, 784, 468, 739, 3, 289, 151, 21, 5, 88, 143, 2, 9, 81, 134, 2573, 766, 233, 196, 23, 536, 342, 297, 2, 405, 698, 132, 281, 74, 744, 1048, 74, 420, 387, 74, 412, 433, 74, 2021, 180, 8, 219, 1929, 213, 4, 34, 31, 96, 363, 8, 230, 2, 66, 18, 229, 331, 768, 4, 11, 1094, 479, 17, 35, 593, 3, 1126, 967, 2, 151, 245, 12, 44, 2, 6, 52, 260, 263, 635, 5, 152, 162, 4, 11, 336, 3, 154, 132, 5, 236, 443, 3, 2, 18, 229, 761, 700, 4, 11, 48, 59, 653, 2, 8, 230] type=list,\n",
  185. "'seq_len': 106 type=int},\n",
  186. "{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n",
  187. "'target': 1 type=int,\n",
  188. "'chars': [50, 133, 20, 135, 945, 520, 343, 24, 3, 301, 176, 350, 86, 785, 2, 456, 24, 461, 163, 443, 128, 109, 6, 47, 7, 2, 916, 152, 162, 524, 296, 44, 301, 176, 2, 1384, 524, 296, 259, 88, 143, 2, 92, 67, 26, 12, 277, 269, 2, 188, 223, 26, 228, 83, 6, 63] type=list,\n",
  189. "'seq_len': 56 type=int})\n"
  190. ]
  191. }
  192. ],
  193. "source": [
  194. "print(data_bundle.get_dataset('train')[:2])"
  195. ]
  196. },
  197. {
  198. "cell_type": "markdown",
  199. "metadata": {},
  200. "source": [
  201. "新增了一列为数字列表的chars,以及变为数字的target列。可以看出这两列的名称和刚好与data_bundle中两个Vocabulary的名称是一致的,我们可以打印一下Vocabulary看一下里面的内容。"
  202. ]
  203. },
  204. {
  205. "cell_type": "code",
  206. "execution_count": 6,
  207. "metadata": {},
  208. "outputs": [
  209. {
  210. "name": "stdout",
  211. "output_type": "stream",
  212. "text": [
  213. "Vocabulary(['选', '择', '珠', '江', '花']...)\n"
  214. ]
  215. }
  216. ],
  217. "source": [
  218. "char_vocab = data_bundle.get_vocab('chars')\n",
  219. "print(char_vocab)"
  220. ]
  221. },
  222. {
  223. "cell_type": "markdown",
  224. "metadata": {},
  225. "source": [
  226. "Vocabulary是一个记录着词语与index之间映射关系的类,比如"
  227. ]
  228. },
  229. {
  230. "cell_type": "code",
  231. "execution_count": 7,
  232. "metadata": {},
  233. "outputs": [
  234. {
  235. "name": "stdout",
  236. "output_type": "stream",
  237. "text": [
  238. "'选'的index是338\n",
  239. "index:338对应的汉字是选\n"
  240. ]
  241. }
  242. ],
  243. "source": [
  244. "index = char_vocab.to_index('选')\n",
  245. "print(\"'选'的index是{}\".format(index)) # 这个值与上面打印出来的第一个instance的chars的第一个index是一致的\n",
  246. "print(\"index:{}对应的汉字是{}\".format(index, char_vocab.to_word(index))) "
  247. ]
  248. },
  249. {
  250. "cell_type": "markdown",
  251. "metadata": {},
  252. "source": [
  253. "### (3) 选择预训练词向量 \n",
  254. "由于Word2vec, Glove, Elmo, Bert等预训练模型可以增强模型的性能,所以在训练具体任务前,选择合适的预训练词向量非常重要。在fastNLP中我们提供了多种Embedding使得加载这些预训练模型的过程变得更加便捷。更多关于Embedding的说明可以参考\\ref{Embedding}。这里我们先给出一个使用word2vec的中文汉字预训练的示例,之后再给出一个使用Bert的文本分类。这里使用的预训练词向量为'cn-fastnlp-100d',fastNLP将自动下载该embedding至本地缓存,fastNLP支持使用名字指定的Embedding以及相关说明可以参见\\ref{Embedding}"
  255. ]
  256. },
  257. {
  258. "cell_type": "code",
  259. "execution_count": 8,
  260. "metadata": {},
  261. "outputs": [
  262. {
  263. "name": "stdout",
  264. "output_type": "stream",
  265. "text": [
  266. "Found 4321 out of 4409 words in the pre-training embedding.\n"
  267. ]
  268. }
  269. ],
  270. "source": [
  271. "from fastNLP.embeddings import StaticEmbedding\n",
  272. "\n",
  273. "word2vec_embed = StaticEmbedding(char_vocab, model_dir_or_name='cn-char-fastnlp-100d')"
  274. ]
  275. },
  276. {
  277. "cell_type": "markdown",
  278. "metadata": {},
  279. "source": [
  280. "### (4) 创建模型\n",
  281. "这里我们使用到的模型结构如下所示,补图"
  282. ]
  283. },
  284. {
  285. "cell_type": "code",
  286. "execution_count": 9,
  287. "metadata": {
  288. "collapsed": true
  289. },
  290. "outputs": [],
  291. "source": [
  292. "from torch import nn\n",
  293. "from fastNLP.modules import LSTM\n",
  294. "import torch\n",
  295. "\n",
  296. "# 定义模型\n",
  297. "class BiLSTMMaxPoolCls(nn.Module):\n",
  298. " def __init__(self, embed, num_classes, hidden_size=400, num_layers=1, dropout=0.3):\n",
  299. " super().__init__()\n",
  300. " self.embed = embed\n",
  301. " \n",
  302. " self.lstm = LSTM(self.embed.embedding_dim, hidden_size=hidden_size//2, num_layers=num_layers, \n",
  303. " batch_first=True, bidirectional=True)\n",
  304. " self.dropout_layer = nn.Dropout(dropout)\n",
  305. " self.fc = nn.Linear(hidden_size, num_classes)\n",
  306. " \n",
  307. " def forward(self, chars, seq_len): # 这里的名称必须和DataSet中相应的field对应,比如之前我们DataSet中有chars,这里就必须为chars\n",
  308. " # chars:[batch_size, max_len]\n",
  309. " # seq_len: [batch_size, ]\n",
  310. " chars = self.embed(chars)\n",
  311. " outputs, _ = self.lstm(chars, seq_len)\n",
  312. " outputs = self.dropout_layer(outputs)\n",
  313. " outputs, _ = torch.max(outputs, dim=1)\n",
  314. " outputs = self.fc(outputs)\n",
  315. " \n",
  316. " return {'pred':outputs} # [batch_size,], 返回值必须是dict类型,且预测值的key建议设为pred\n",
  317. "\n",
  318. "# 初始化模型\n",
  319. "model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))"
  320. ]
  321. },
  322. {
  323. "cell_type": "markdown",
  324. "metadata": {},
  325. "source": [
  326. "### (5) 训练模型\n",
  327. "fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所以在初始化Trainer的时候需要指定loss类型),梯度更新(所以在初始化Trainer的时候需要提供优化器optimizer)以及在验证集上的性能验证(所以在初始化时需要提供一个Metric)"
  328. ]
  329. },
  330. {
  331. "cell_type": "code",
  332. "execution_count": 10,
  333. "metadata": {},
  334. "outputs": [
  335. {
  336. "name": "stdout",
  337. "output_type": "stream",
  338. "text": [
  339. "input fields after batch(if batch size is 2):\n",
  340. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  341. "\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n",
  342. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  343. "target fields after batch(if batch size is 2):\n",
  344. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  345. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  346. "\n",
  347. "Evaluate data in 0.01 seconds!\n",
  348. "training epochs started 2019-09-03-23-57-10\n"
  349. ]
  350. },
  351. {
  352. "data": {
  353. "text/plain": [
  354. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3000), HTML(value='')), layout=Layout(display…"
  355. ]
  356. },
  357. "metadata": {},
  358. "output_type": "display_data"
  359. },
  360. {
  361. "data": {
  362. "text/plain": [
  363. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  364. ]
  365. },
  366. "metadata": {},
  367. "output_type": "display_data"
  368. },
  369. {
  370. "name": "stdout",
  371. "output_type": "stream",
  372. "text": [
  373. "\r",
  374. "Evaluate data in 0.43 seconds!\n",
  375. "\r",
  376. "Evaluation on dev at Epoch 1/10. Step:300/3000: \n",
  377. "\r",
  378. "AccuracyMetric: acc=0.81\n",
  379. "\n"
  380. ]
  381. },
  382. {
  383. "data": {
  384. "text/plain": [
  385. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  386. ]
  387. },
  388. "metadata": {},
  389. "output_type": "display_data"
  390. },
  391. {
  392. "name": "stdout",
  393. "output_type": "stream",
  394. "text": [
  395. "\r",
  396. "Evaluate data in 0.44 seconds!\n",
  397. "\r",
  398. "Evaluation on dev at Epoch 2/10. Step:600/3000: \n",
  399. "\r",
  400. "AccuracyMetric: acc=0.8675\n",
  401. "\n"
  402. ]
  403. },
  404. {
  405. "data": {
  406. "text/plain": [
  407. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  408. ]
  409. },
  410. "metadata": {},
  411. "output_type": "display_data"
  412. },
  413. {
  414. "name": "stdout",
  415. "output_type": "stream",
  416. "text": [
  417. "\r",
  418. "Evaluate data in 0.44 seconds!\n",
  419. "\r",
  420. "Evaluation on dev at Epoch 3/10. Step:900/3000: \n",
  421. "\r",
  422. "AccuracyMetric: acc=0.878333\n",
  423. "\n"
  424. ]
  425. },
  426. {
  427. "data": {
  428. "text/plain": [
  429. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  430. ]
  431. },
  432. "metadata": {},
  433. "output_type": "display_data"
  434. },
  435. {
  436. "name": "stdout",
  437. "output_type": "stream",
  438. "text": [
  439. "\r",
  440. "Evaluate data in 0.43 seconds!\n",
  441. "\r",
  442. "Evaluation on dev at Epoch 4/10. Step:1200/3000: \n",
  443. "\r",
  444. "AccuracyMetric: acc=0.873333\n",
  445. "\n"
  446. ]
  447. },
  448. {
  449. "data": {
  450. "text/plain": [
  451. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  452. ]
  453. },
  454. "metadata": {},
  455. "output_type": "display_data"
  456. },
  457. {
  458. "name": "stdout",
  459. "output_type": "stream",
  460. "text": [
  461. "\r",
  462. "Evaluate data in 0.44 seconds!\n",
  463. "\r",
  464. "Evaluation on dev at Epoch 5/10. Step:1500/3000: \n",
  465. "\r",
  466. "AccuracyMetric: acc=0.878333\n",
  467. "\n"
  468. ]
  469. },
  470. {
  471. "data": {
  472. "text/plain": [
  473. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  474. ]
  475. },
  476. "metadata": {},
  477. "output_type": "display_data"
  478. },
  479. {
  480. "name": "stdout",
  481. "output_type": "stream",
  482. "text": [
  483. "\r",
  484. "Evaluate data in 0.42 seconds!\n",
  485. "\r",
  486. "Evaluation on dev at Epoch 6/10. Step:1800/3000: \n",
  487. "\r",
  488. "AccuracyMetric: acc=0.895833\n",
  489. "\n"
  490. ]
  491. },
  492. {
  493. "data": {
  494. "text/plain": [
  495. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  496. ]
  497. },
  498. "metadata": {},
  499. "output_type": "display_data"
  500. },
  501. {
  502. "name": "stdout",
  503. "output_type": "stream",
  504. "text": [
  505. "\r",
  506. "Evaluate data in 0.44 seconds!\n",
  507. "\r",
  508. "Evaluation on dev at Epoch 7/10. Step:2100/3000: \n",
  509. "\r",
  510. "AccuracyMetric: acc=0.8975\n",
  511. "\n"
  512. ]
  513. },
  514. {
  515. "data": {
  516. "text/plain": [
  517. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  518. ]
  519. },
  520. "metadata": {},
  521. "output_type": "display_data"
  522. },
  523. {
  524. "name": "stdout",
  525. "output_type": "stream",
  526. "text": [
  527. "\r",
  528. "Evaluate data in 0.43 seconds!\n",
  529. "\r",
  530. "Evaluation on dev at Epoch 8/10. Step:2400/3000: \n",
  531. "\r",
  532. "AccuracyMetric: acc=0.894167\n",
  533. "\n"
  534. ]
  535. },
  536. {
  537. "data": {
  538. "text/plain": [
  539. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  540. ]
  541. },
  542. "metadata": {},
  543. "output_type": "display_data"
  544. },
  545. {
  546. "name": "stdout",
  547. "output_type": "stream",
  548. "text": [
  549. "\r",
  550. "Evaluate data in 0.48 seconds!\n",
  551. "\r",
  552. "Evaluation on dev at Epoch 9/10. Step:2700/3000: \n",
  553. "\r",
  554. "AccuracyMetric: acc=0.8875\n",
  555. "\n"
  556. ]
  557. },
  558. {
  559. "data": {
  560. "text/plain": [
  561. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  562. ]
  563. },
  564. "metadata": {},
  565. "output_type": "display_data"
  566. },
  567. {
  568. "name": "stdout",
  569. "output_type": "stream",
  570. "text": [
  571. "\r",
  572. "Evaluate data in 0.43 seconds!\n",
  573. "\r",
  574. "Evaluation on dev at Epoch 10/10. Step:3000/3000: \n",
  575. "\r",
  576. "AccuracyMetric: acc=0.895833\n",
  577. "\n",
  578. "\r\n",
  579. "In Epoch:7/Step:2100, got best dev performance:\n",
  580. "AccuracyMetric: acc=0.8975\n",
  581. "Reloaded the best model.\n"
  582. ]
  583. },
  584. {
  585. "data": {
  586. "text/plain": [
  587. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…"
  588. ]
  589. },
  590. "metadata": {},
  591. "output_type": "display_data"
  592. },
  593. {
  594. "name": "stdout",
  595. "output_type": "stream",
  596. "text": [
  597. "\r",
  598. "Evaluate data in 0.34 seconds!\n",
  599. "[tester] \n",
  600. "AccuracyMetric: acc=0.8975\n"
  601. ]
  602. },
  603. {
  604. "data": {
  605. "text/plain": [
  606. "{'AccuracyMetric': {'acc': 0.8975}}"
  607. ]
  608. },
  609. "execution_count": 10,
  610. "metadata": {},
  611. "output_type": "execute_result"
  612. }
  613. ],
  614. "source": [
  615. "from fastNLP import Trainer\n",
  616. "from fastNLP import CrossEntropyLoss\n",
  617. "from torch.optim import Adam\n",
  618. "from fastNLP import AccuracyMetric\n",
  619. "\n",
  620. "loss = CrossEntropyLoss()\n",
  621. "optimizer = Adam(model.parameters(), lr=0.001)\n",
  622. "metric = AccuracyMetric()\n",
  623. "device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
  624. "\n",
  625. "trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
  626. " optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),\n",
  627. " metrics=metric, device=device)\n",
  628. "trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
  629. "\n",
  630. "# 在测试集上测试一下模型的性能\n",
  631. "from fastNLP import Tester\n",
  632. "print(\"Performance on test is:\")\n",
  633. "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
  634. "tester.test()"
  635. ]
  636. },
  637. {
  638. "cell_type": "markdown",
  639. "metadata": {},
  640. "source": [
  641. "### 使用Bert进行文本分类"
  642. ]
  643. },
  644. {
  645. "cell_type": "code",
  646. "execution_count": 12,
  647. "metadata": {},
  648. "outputs": [
  649. {
  650. "name": "stdout",
  651. "output_type": "stream",
  652. "text": [
  653. "loading vocabulary file /home/yh/.fastNLP/embedding/bert-chinese-wwm/vocab.txt\n",
  654. "Load pre-trained BERT parameters from file /home/yh/.fastNLP/embedding/bert-chinese-wwm/chinese_wwm_pytorch.bin.\n",
  655. "Start to generating word pieces for word.\n",
  656. "Found(Or segment into word pieces) 4286 words out of 4409.\n",
  657. "input fields after batch(if batch size is 2):\n",
  658. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  659. "\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n",
  660. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  661. "target fields after batch(if batch size is 2):\n",
  662. "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  663. "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  664. "\n",
  665. "Evaluate data in 0.05 seconds!\n",
  666. "training epochs started 2019-09-04-00-02-37\n"
  667. ]
  668. },
  669. {
  670. "data": {
  671. "text/plain": [
  672. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3600), HTML(value='')), layout=Layout(display…"
  673. ]
  674. },
  675. "metadata": {},
  676. "output_type": "display_data"
  677. },
  678. {
  679. "data": {
  680. "text/plain": [
  681. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…"
  682. ]
  683. },
  684. "metadata": {},
  685. "output_type": "display_data"
  686. },
  687. {
  688. "name": "stdout",
  689. "output_type": "stream",
  690. "text": [
  691. "\r",
  692. "Evaluate data in 15.89 seconds!\n",
  693. "\r",
  694. "Evaluation on dev at Epoch 1/3. Step:1200/3600: \n",
  695. "\r",
  696. "AccuracyMetric: acc=0.9\n",
  697. "\n"
  698. ]
  699. },
  700. {
  701. "data": {
  702. "text/plain": [
  703. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…"
  704. ]
  705. },
  706. "metadata": {},
  707. "output_type": "display_data"
  708. },
  709. {
  710. "name": "stdout",
  711. "output_type": "stream",
  712. "text": [
  713. "\r",
  714. "Evaluate data in 15.92 seconds!\n",
  715. "\r",
  716. "Evaluation on dev at Epoch 2/3. Step:2400/3600: \n",
  717. "\r",
  718. "AccuracyMetric: acc=0.904167\n",
  719. "\n"
  720. ]
  721. },
  722. {
  723. "data": {
  724. "text/plain": [
  725. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…"
  726. ]
  727. },
  728. "metadata": {},
  729. "output_type": "display_data"
  730. },
  731. {
  732. "name": "stdout",
  733. "output_type": "stream",
  734. "text": [
  735. "\r",
  736. "Evaluate data in 15.91 seconds!\n",
  737. "\r",
  738. "Evaluation on dev at Epoch 3/3. Step:3600/3600: \n",
  739. "\r",
  740. "AccuracyMetric: acc=0.918333\n",
  741. "\n",
  742. "\r\n",
  743. "In Epoch:3/Step:3600, got best dev performance:\n",
  744. "AccuracyMetric: acc=0.918333\n",
  745. "Reloaded the best model.\n",
  746. "Performance on test is:\n"
  747. ]
  748. },
  749. {
  750. "data": {
  751. "text/plain": [
  752. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…"
  753. ]
  754. },
  755. "metadata": {},
  756. "output_type": "display_data"
  757. },
  758. {
  759. "name": "stdout",
  760. "output_type": "stream",
  761. "text": [
  762. "\r",
  763. "Evaluate data in 29.24 seconds!\n",
  764. "[tester] \n",
  765. "AccuracyMetric: acc=0.919167\n"
  766. ]
  767. },
  768. {
  769. "data": {
  770. "text/plain": [
  771. "{'AccuracyMetric': {'acc': 0.919167}}"
  772. ]
  773. },
  774. "execution_count": 12,
  775. "metadata": {},
  776. "output_type": "execute_result"
  777. }
  778. ],
  779. "source": [
  780. "# 只需要切换一下Embedding即可\n",
  781. "from fastNLP.embeddings import BertEmbedding\n",
  782. "\n",
  783. "# 这里为了演示一下效果,所以默认Bert不更新权重\n",
  784. "bert_embed = BertEmbedding(char_vocab, model_dir_or_name='cn', auto_truncate=True, requires_grad=False)\n",
  785. "model = BiLSTMMaxPoolCls(bert_embed, len(data_bundle.get_vocab('target')), )\n",
  786. "\n",
  787. "\n",
  788. "import torch\n",
  789. "from fastNLP import Trainer\n",
  790. "from fastNLP import CrossEntropyLoss\n",
  791. "from torch.optim import Adam\n",
  792. "from fastNLP import AccuracyMetric\n",
  793. "\n",
  794. "loss = CrossEntropyLoss()\n",
  795. "optimizer = Adam(model.parameters(), lr=2e-5)\n",
  796. "metric = AccuracyMetric()\n",
  797. "device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
  798. "\n",
  799. "trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
  800. " optimizer=optimizer, batch_size=16, dev_data=data_bundle.get_dataset('test'),\n",
  801. " metrics=metric, device=device, n_epochs=3)\n",
  802. "trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
  803. "\n",
  804. "# 在测试集上测试一下模型的性能\n",
  805. "from fastNLP import Tester\n",
  806. "print(\"Performance on test is:\")\n",
  807. "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
  808. "tester.test()"
  809. ]
  810. },
  811. {
  812. "cell_type": "markdown",
  813. "metadata": {},
  814. "source": [
  815. "### 基于词进行文本分类"
  816. ]
  817. },
  818. {
  819. "cell_type": "markdown",
  820. "metadata": {},
  821. "source": [
  822. "由于汉字中没有显示的字与字的边界,一般需要通过分词器先将句子进行分词操作。\n",
  823. "下面的例子演示了如何不基于fastNLP已有的数据读取、预处理代码进行文本分类。"
  824. ]
  825. },
  826. {
  827. "cell_type": "markdown",
  828. "metadata": {},
  829. "source": [
  830. "### (1) 读取数据"
  831. ]
  832. },
  833. {
  834. "cell_type": "markdown",
  835. "metadata": {},
  836. "source": [
  837. "这里我们继续以之前的数据为例,但这次我们不使用fastNLP自带的数据读取代码 "
  838. ]
  839. },
  840. {
  841. "cell_type": "code",
  842. "execution_count": null,
  843. "metadata": {
  844. "collapsed": true
  845. },
  846. "outputs": [],
  847. "source": [
  848. "from fastNLP.io import ChnSentiCorpLoader\n",
  849. "\n",
  850. "loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader\n",
  851. "data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回"
  852. ]
  853. },
  854. {
  855. "cell_type": "markdown",
  856. "metadata": {},
  857. "source": [
  858. "下面我们先定义一个read_file_to_dataset的函数, 即给定一个文件路径,读取其中的内容,并返回一个DataSet。然后我们将所有的DataSet放入到DataBundle对象中来方便接下来的预处理"
  859. ]
  860. },
  861. {
  862. "cell_type": "code",
  863. "execution_count": null,
  864. "metadata": {
  865. "collapsed": true
  866. },
  867. "outputs": [],
  868. "source": [
  869. "import os\n",
  870. "from fastNLP import DataSet, Instance\n",
  871. "from fastNLP.io import DataBundle\n",
  872. "\n",
  873. "\n",
  874. "def read_file_to_dataset(fp):\n",
  875. " ds = DataSet()\n",
  876. " with open(fp, 'r') as f:\n",
  877. " f.readline() # 第一行是title名称,忽略掉\n",
  878. " for line in f:\n",
  879. " line = line.strip()\n",
  880. " target, chars = line.split('\\t')\n",
  881. " ins = Instance(target=target, raw_chars=chars)\n",
  882. " ds.append(ins)\n",
  883. " return ds\n",
  884. "\n",
  885. "data_bundle = DataBundle()\n",
  886. "for name in ['train.tsv', 'dev.tsv', 'test.tsv']:\n",
  887. " fp = os.path.join(data_dir, name)\n",
  888. " ds = read_file_to_dataset(fp)\n",
  889. " data_bundle.set_dataset(name=name.split('.')[0], dataset=ds)\n",
  890. "\n",
  891. "print(data_bundle) # 查看以下数据集的情况\n",
  892. "# In total 3 datasets:\n",
  893. "# train has 9600 instances.\n",
  894. "# dev has 1200 instances.\n",
  895. "# test has 1200 instances."
  896. ]
  897. },
  898. {
  899. "cell_type": "markdown",
  900. "metadata": {},
  901. "source": [
  902. "### (2) 数据预处理"
  903. ]
  904. },
  905. {
  906. "cell_type": "markdown",
  907. "metadata": {},
  908. "source": [
  909. "在这里,我们首先把句子通过 [fastHan](http://gitee.com/fastnlp/fastHan) 进行分词操作,然后创建词表,并将词语转换为序号。"
  910. ]
  911. },
  912. {
  913. "cell_type": "code",
  914. "execution_count": null,
  915. "metadata": {
  916. "collapsed": true
  917. },
  918. "outputs": [],
  919. "source": [
  920. "from fastHan import FastHan\n",
  921. "from fastNLP import Vocabulary\n",
  922. "\n",
  923. "model=FastHan()\n",
  924. "\n",
  925. "# 定义分词处理操作\n",
  926. "def word_seg(ins):\n",
  927. " raw_chars = ins['raw_chars']\n",
  928. " # 由于有些句子比较长,我们只截取前128个汉字\n",
  929. " raw_words = model(raw_chars[:128], target='CWS')[0]\n",
  930. " return raw_words\n",
  931. "\n",
  932. "for name, ds in data_bundle.iter_datasets():\n",
  933. " # apply函数将对内部的instance依次执行word_seg操作,并把其返回值放入到raw_words这个field\n",
  934. " ds.apply(word_seg, new_field_name='raw_words')\n",
  935. " # 除了apply函数,fastNLP还支持apply_field, apply_more(可同时创建多个field)等操作\n",
  936. "\n",
  937. "vocab = Vocabulary()\n",
  938. "\n",
  939. "# 对raw_words列创建词表, 建议把非训练集的dataset放在no_create_entry_dataset参数中\n",
  940. "# 也可以通过add_word(), add_word_lst()等建立词表,请参考http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_2_vocabulary.html\n",
  941. "vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_words', \n",
  942. " no_create_entry_dataset=[data_bundle.get_dataset('dev'), \n",
  943. " data_bundle.get_dataset('test')]) \n",
  944. "\n",
  945. "# 将建立好词表的Vocabulary用于对raw_words列建立词表,并把转为序号的列存入到words列\n",
  946. "vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'), \n",
  947. " data_bundle.get_dataset('test'), field_name='raw_words', new_field_name='words')\n",
  948. "\n",
  949. "# 建立target的词表,target的词表一般不需要padding和unknown\n",
  950. "target_vocab = Vocabulary(padding=None, unknown=None) \n",
  951. "# 一般情况下我们可以只用训练集建立target的词表\n",
  952. "target_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='target') \n",
  953. "# 如果没有传递new_field_name, 则默认覆盖原词表\n",
  954. "target_vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'), \n",
  955. " data_bundle.get_dataset('test'), field_name='target')\n",
  956. "\n",
  957. "# 我们可以把词表保存到data_bundle中,方便之后使用\n",
  958. "data_bundle.set_vocab(field_name='words', vocab=vocab)\n",
  959. "data_bundle.set_vocab(field_name='target', vocab=target_vocab)\n",
  960. "\n",
  961. "# 我们把words和target分别设置为input和target,这样它们才会在训练循环中被取出并自动padding, 有关这部分更多的内容参考\n",
  962. "# http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_6_datasetiter.html\n",
  963. "data_bundle.set_target('target')\n",
  964. "data_bundle.set_input('words') # DataSet也有这两个接口\n",
  965. "# 如果某些field,您希望它被设置为target或者input,但是不希望fastNLP自动padding或需要使用特定的padding方式,请参考\n",
  966. "# http://www.fastnlp.top/docs/fastNLP/fastNLP.core.dataset.html\n",
  967. "\n",
  968. "print(data_bundle.get_dataset('train')[:2]) # 我们可以看一下当前dataset的内容"
  969. ]
  970. },
  971. {
  972. "cell_type": "markdown",
  973. "metadata": {},
  974. "source": [
  975. "### (3) 选择预训练词向量"
  976. ]
  977. },
  978. {
  979. "cell_type": "markdown",
  980. "metadata": {},
  981. "source": [
  982. "这里我们选择腾讯的预训练中文词向量,可以在 [腾讯词向量](https://ai.tencent.com/ailab/nlp/en/embedding.html) 处下载并解压。这里我们不能直接使用BERT,因为BERT是基于中文字进行预训练的。"
  983. ]
  984. },
  985. {
  986. "cell_type": "code",
  987. "execution_count": null,
  988. "metadata": {
  989. "collapsed": true
  990. },
  991. "outputs": [],
  992. "source": [
  993. "from fastNLP.embeddings import StaticEmbedding\n",
  994. "\n",
  995. "word2vec_embed = StaticEmbedding(data_bundle.get_vocab('words'), \n",
  996. " model_dir_or_name='/path/to/Tencent_AILab_ChineseEmbedding.txt')"
  997. ]
  998. },
  999. {
  1000. "cell_type": "code",
  1001. "execution_count": null,
  1002. "metadata": {
  1003. "collapsed": true
  1004. },
  1005. "outputs": [],
  1006. "source": [
  1007. "# 初始化模型\n",
  1008. "model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))\n",
  1009. "\n",
  1010. "# 开始训练\n",
  1011. "loss = CrossEntropyLoss()\n",
  1012. "optimizer = Adam(model.parameters(), lr=0.001)\n",
  1013. "metric = AccuracyMetric()\n",
  1014. "device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
  1015. "\n",
  1016. "trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
  1017. " optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),\n",
  1018. " metrics=metric, device=device)\n",
  1019. "trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
  1020. "\n",
  1021. "# 在测试集上测试一下模型的性能\n",
  1022. "from fastNLP import Tester\n",
  1023. "print(\"Performance on test is:\")\n",
  1024. "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
  1025. "tester.test()"
  1026. ]
  1027. }
  1028. ],
  1029. "metadata": {
  1030. "kernelspec": {
  1031. "display_name": "Python 3",
  1032. "language": "python",
  1033. "name": "python3"
  1034. },
  1035. "language_info": {
  1036. "codemirror_mode": {
  1037. "name": "ipython",
  1038. "version": 3
  1039. },
  1040. "file_extension": ".py",
  1041. "mimetype": "text/x-python",
  1042. "name": "python",
  1043. "nbconvert_exporter": "python",
  1044. "pygments_lexer": "ipython3",
  1045. "version": "3.6.10"
  1046. }
  1047. },
  1048. "nbformat": 4,
  1049. "nbformat_minor": 2
  1050. }