You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_1_minute_tutorial.ipynb 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {
  6. "collapsed": true
  7. },
  8. "source": [
  9. "# FastNLP 1分钟上手教程"
  10. ]
  11. },
  12. {
  13. "cell_type": "markdown",
  14. "metadata": {},
  15. "source": [
  16. "## step 1\n",
  17. "读取数据集"
  18. ]
  19. },
  20. {
  21. "cell_type": "code",
  22. "execution_count": 3,
  23. "metadata": {},
  24. "outputs": [
  25. {
  26. "name": "stderr",
  27. "output_type": "stream",
  28. "text": [
  29. "/Users/yh/miniconda2/envs/python3/lib/python3.6/site-packages/tqdm/autonotebook/__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
  30. " \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n"
  31. ]
  32. }
  33. ],
  34. "source": [
  35. "import sys\n",
  36. "sys.path.append(\"../\")\n",
  37. "\n",
  38. "from fastNLP import DataSet\n",
  39. "\n",
  40. "# linux_path = \"../test/data_for_tests/tutorial_sample_dataset.csv\"\n",
  41. "win_path = \"../test/data_for_tests/tutorial_sample_dataset.csv\"\n",
  42. "ds = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\\t')"
  43. ]
  44. },
  45. {
  46. "cell_type": "code",
  47. "execution_count": 8,
  48. "metadata": {},
  49. "outputs": [
  50. {
  51. "data": {
  52. "text/plain": [
  53. "{'raw_sentence': this quiet , introspective and entertaining independent is worth seeking .,\n",
  54. "'label': 4,\n",
  55. "'label_seq': 4,\n",
  56. "'words': ['this', 'quiet', ',', 'introspective', 'and', 'entertaining', 'independent', 'is', 'worth', 'seeking', '.']}"
  57. ]
  58. },
  59. "execution_count": 8,
  60. "metadata": {},
  61. "output_type": "execute_result"
  62. }
  63. ],
  64. "source": [
  65. "ds[1]"
  66. ]
  67. },
  68. {
  69. "cell_type": "markdown",
  70. "metadata": {},
  71. "source": [
  72. "## step 2\n",
  73. "数据预处理\n",
  74. "1. 类型转换\n",
  75. "2. 切分验证集\n",
  76. "3. 构建词典"
  77. ]
  78. },
  79. {
  80. "cell_type": "code",
  81. "execution_count": 4,
  82. "metadata": {},
  83. "outputs": [],
  84. "source": [
  85. "# 将所有数字转为小写\n",
  86. "ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n",
  87. "# label转int\n",
  88. "ds.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)\n",
  89. "\n",
  90. "def split_sent(ins):\n",
  91. " return ins['raw_sentence'].split()\n",
  92. "ds.apply(split_sent, new_field_name='words', is_input=True)\n"
  93. ]
  94. },
  95. {
  96. "cell_type": "code",
  97. "execution_count": 5,
  98. "metadata": {},
  99. "outputs": [
  100. {
  101. "name": "stdout",
  102. "output_type": "stream",
  103. "text": [
  104. "Train size: 54\n",
  105. "Test size: 23\n"
  106. ]
  107. }
  108. ],
  109. "source": [
  110. "# 分割训练集/验证集\n",
  111. "train_data, dev_data = ds.split(0.3)\n",
  112. "print(\"Train size: \", len(train_data))\n",
  113. "print(\"Test size: \", len(dev_data))"
  114. ]
  115. },
  116. {
  117. "cell_type": "code",
  118. "execution_count": 6,
  119. "metadata": {},
  120. "outputs": [],
  121. "source": [
  122. "from fastNLP import Vocabulary\n",
  123. "vocab = Vocabulary(min_freq=2)\n",
  124. "train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n",
  125. "\n",
  126. "# index句子, Vocabulary.to_index(word)\n",
  127. "train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n",
  128. "dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n"
  129. ]
  130. },
  131. {
  132. "cell_type": "markdown",
  133. "metadata": {},
  134. "source": [
  135. "## step 3\n",
  136. " 定义模型"
  137. ]
  138. },
  139. {
  140. "cell_type": "code",
  141. "execution_count": 62,
  142. "metadata": {},
  143. "outputs": [],
  144. "source": [
  145. "from fastNLP.models import CNNText\n",
  146. "model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n"
  147. ]
  148. },
  149. {
  150. "cell_type": "markdown",
  151. "metadata": {},
  152. "source": [
  153. "## step 4\n",
  154. "开始训练"
  155. ]
  156. },
  157. {
  158. "cell_type": "code",
  159. "execution_count": 63,
  160. "metadata": {},
  161. "outputs": [
  162. {
  163. "name": "stdout",
  164. "output_type": "stream",
  165. "text": [
  166. "training epochs started 2018-12-07 14:03:41\n"
  167. ]
  168. },
  169. {
  170. "data": {
  171. "text/plain": [
  172. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6), HTML(value='')), layout=Layout(display='i…"
  173. ]
  174. },
  175. "execution_count": 0,
  176. "metadata": {},
  177. "output_type": "execute_result"
  178. },
  179. {
  180. "name": "stdout",
  181. "output_type": "stream",
  182. "text": [
  183. "Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.26087\n",
  184. "Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.347826\n",
  185. "Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.608696\n",
  186. "Train finished!\n"
  187. ]
  188. }
  189. ],
  190. "source": [
  191. "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
  192. "trainer = Trainer(model=model, \n",
  193. " train_data=train_data, \n",
  194. " dev_data=dev_data,\n",
  195. " loss=CrossEntropyLoss(),\n",
  196. " metrics=AccuracyMetric()\n",
  197. " )\n",
  198. "trainer.train()\n",
  199. "print('Train finished!')\n"
  200. ]
  201. },
  202. {
  203. "cell_type": "markdown",
  204. "metadata": {},
  205. "source": [
  206. "### 本教程结束。更多操作请参考进阶教程。"
  207. ]
  208. },
  209. {
  210. "cell_type": "code",
  211. "execution_count": null,
  212. "metadata": {},
  213. "outputs": [],
  214. "source": []
  215. }
  216. ],
  217. "metadata": {
  218. "kernelspec": {
  219. "display_name": "Python 3",
  220. "language": "python",
  221. "name": "python3"
  222. },
  223. "language_info": {
  224. "codemirror_mode": {
  225. "name": "ipython",
  226. "version": 3
  227. },
  228. "file_extension": ".py",
  229. "mimetype": "text/x-python",
  230. "name": "python",
  231. "nbconvert_exporter": "python",
  232. "pygments_lexer": "ipython3",
  233. "version": "3.6.7"
  234. }
  235. },
  236. "nbformat": 4,
  237. "nbformat_minor": 1
  238. }