You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_1min_tutorial.ipynb 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {
  6. "collapsed": true
  7. },
  8. "source": [
  9. "# fastNLP 1分钟上手教程"
  10. ]
  11. },
  12. {
  13. "cell_type": "markdown",
  14. "metadata": {},
  15. "source": [
  16. "## step 1\n",
  17. "读取数据集"
  18. ]
  19. },
  20. {
  21. "cell_type": "code",
  22. "execution_count": 1,
  23. "metadata": {},
  24. "outputs": [
  25. {
  26. "name": "stderr",
  27. "output_type": "stream",
  28. "text": [
  29. "c:\\users\\zyfeng\\miniconda3\\envs\\fastnlp\\lib\\site-packages\\tqdm\\autonotebook\\__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
  30. " \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n"
  31. ]
  32. }
  33. ],
  34. "source": [
  35. "import sys\n",
  36. "sys.path.append(\"../\")\n",
  37. "\n",
  38. "from fastNLP import DataSet\n",
  39. "\n",
  40. "data_path = \"./sample_data/tutorial_sample_dataset.csv\"\n",
  41. "ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\\t')"
  42. ]
  43. },
  44. {
  45. "cell_type": "code",
  46. "execution_count": 2,
  47. "metadata": {},
  48. "outputs": [
  49. {
  50. "data": {
  51. "text/plain": [
  52. "{'raw_sentence': This quiet , introspective and entertaining independent is worth seeking . type=str,\n",
  53. "'label': 4 type=str}"
  54. ]
  55. },
  56. "execution_count": 2,
  57. "metadata": {},
  58. "output_type": "execute_result"
  59. }
  60. ],
  61. "source": [
  62. "ds[1]"
  63. ]
  64. },
  65. {
  66. "cell_type": "markdown",
  67. "metadata": {},
  68. "source": [
  69. "## step 2\n",
  70. "数据预处理\n",
  71. "1. 类型转换\n",
  72. "2. 切分验证集\n",
  73. "3. 构建词典"
  74. ]
  75. },
  76. {
  77. "cell_type": "code",
  78. "execution_count": 3,
  79. "metadata": {},
  80. "outputs": [],
  81. "source": [
  82. "# 将所有数字转为小写\n",
  83. "ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n",
  84. "# label转int\n",
  85. "ds.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)\n",
  86. "\n",
  87. "def split_sent(ins):\n",
  88. " return ins['raw_sentence'].split()\n",
  89. "ds.apply(split_sent, new_field_name='words', is_input=True)\n"
  90. ]
  91. },
  92. {
  93. "cell_type": "code",
  94. "execution_count": 4,
  95. "metadata": {},
  96. "outputs": [
  97. {
  98. "name": "stdout",
  99. "output_type": "stream",
  100. "text": [
  101. "Train size: 54\n",
  102. "Test size: 23\n"
  103. ]
  104. }
  105. ],
  106. "source": [
  107. "# 分割训练集/验证集\n",
  108. "train_data, dev_data = ds.split(0.3)\n",
  109. "print(\"Train size: \", len(train_data))\n",
  110. "print(\"Test size: \", len(dev_data))"
  111. ]
  112. },
  113. {
  114. "cell_type": "code",
  115. "execution_count": 5,
  116. "metadata": {},
  117. "outputs": [],
  118. "source": [
  119. "from fastNLP import Vocabulary\n",
  120. "vocab = Vocabulary(min_freq=2)\n",
  121. "train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n",
  122. "\n",
  123. "# index句子, Vocabulary.to_index(word)\n",
  124. "train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n",
  125. "dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n"
  126. ]
  127. },
  128. {
  129. "cell_type": "markdown",
  130. "metadata": {},
  131. "source": [
  132. "## step 3\n",
  133. " 定义模型"
  134. ]
  135. },
  136. {
  137. "cell_type": "code",
  138. "execution_count": 6,
  139. "metadata": {},
  140. "outputs": [],
  141. "source": [
  142. "from fastNLP.models import CNNText\n",
  143. "model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n"
  144. ]
  145. },
  146. {
  147. "cell_type": "markdown",
  148. "metadata": {},
  149. "source": [
  150. "## step 4\n",
  151. "开始训练"
  152. ]
  153. },
  154. {
  155. "cell_type": "code",
  156. "execution_count": 7,
  157. "metadata": {},
  158. "outputs": [
  159. {
  160. "name": "stdout",
  161. "output_type": "stream",
  162. "text": [
  163. "input fields after batch(if batch size is 2):\n",
  164. "\twords: (1)type:numpy.ndarray (2)dtype:object, (3)shape:(2,) \n",
  165. "\tword_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 25]) \n",
  166. "target fields after batch(if batch size is 2):\n",
  167. "\tlabel_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
  168. "\n",
  169. "training epochs started 2019-01-12 17-00-48\n"
  170. ]
  171. },
  172. {
  173. "data": {
  174. "application/vnd.jupyter.widget-view+json": {
  175. "model_id": "23979df0f63e446fbb0406b919b91dd3",
  176. "version_major": 2,
  177. "version_minor": 0
  178. },
  179. "text/plain": [
  180. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6), HTML(value='')), layout=Layout(display='i…"
  181. ]
  182. },
  183. "metadata": {},
  184. "output_type": "display_data"
  185. },
  186. {
  187. "name": "stdout",
  188. "output_type": "stream",
  189. "text": [
  190. "Evaluation at Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.173913\n",
  191. "Evaluation at Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.26087\n",
  192. "Evaluation at Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.304348\n",
  193. "\n",
  194. "In Epoch:3/Step:6, got best dev performance:AccuracyMetric: acc=0.304348\n",
  195. "Reloaded the best model.\n",
  196. "Train finished!\n"
  197. ]
  198. }
  199. ],
  200. "source": [
  201. "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
  202. "trainer = Trainer(model=model, \n",
  203. " train_data=train_data, \n",
  204. " dev_data=dev_data,\n",
  205. " loss=CrossEntropyLoss(),\n",
  206. " metrics=AccuracyMetric()\n",
  207. " )\n",
  208. "trainer.train()\n",
  209. "print('Train finished!')\n"
  210. ]
  211. },
  212. {
  213. "cell_type": "markdown",
  214. "metadata": {},
  215. "source": [
  216. "### 本教程结束。更多操作请参考进阶教程。"
  217. ]
  218. },
  219. {
  220. "cell_type": "code",
  221. "execution_count": null,
  222. "metadata": {},
  223. "outputs": [],
  224. "source": []
  225. }
  226. ],
  227. "metadata": {
  228. "kernelspec": {
  229. "display_name": "Python 3",
  230. "language": "python",
  231. "name": "python3"
  232. },
  233. "language_info": {
  234. "codemirror_mode": {
  235. "name": "ipython",
  236. "version": 3
  237. },
  238. "file_extension": ".py",
  239. "mimetype": "text/x-python",
  240. "name": "python",
  241. "nbconvert_exporter": "python",
  242. "pygments_lexer": "ipython3",
  243. "version": "3.6.7"
  244. }
  245. },
  246. "nbformat": 4,
  247. "nbformat_minor": 1
  248. }