You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_1_minute_tutorial.ipynb 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {
  6. "collapsed": true
  7. },
  8. "source": [
  9. "# FastNLP 1分钟上手教程"
  10. ]
  11. },
  12. {
  13. "cell_type": "markdown",
  14. "metadata": {},
  15. "source": [
  16. "## step 1\n",
  17. "读取数据集"
  18. ]
  19. },
  20. {
  21. "cell_type": "code",
  22. "execution_count": 50,
  23. "metadata": {},
  24. "outputs": [],
  25. "source": [
  26. "from fastNLP import DataSet\n",
  27. "# linux_path = \"../test/data_for_tests/tutorial_sample_dataset.csv\"\n",
  28. "win_path = \"C:\\\\Users\\zyfeng\\Desktop\\FudanNLP\\\\fastNLP\\\\test\\\\data_for_tests\\\\tutorial_sample_dataset.csv\"\n",
  29. "ds = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\\t')"
  30. ]
  31. },
  32. {
  33. "cell_type": "markdown",
  34. "metadata": {},
  35. "source": [
  36. "## step 2\n",
  37. "数据预处理\n",
  38. "1. 类型转换\n",
  39. "2. 切分验证集\n",
  40. "3. 构建词典"
  41. ]
  42. },
  43. {
  44. "cell_type": "code",
  45. "execution_count": 52,
  46. "metadata": {},
  47. "outputs": [],
  48. "source": [
  49. "# 将所有数字转为小写\n",
  50. "ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n",
  51. "# label转int\n",
  52. "ds.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)\n",
  53. "\n",
  54. "def split_sent(ins):\n",
  55. " return ins['raw_sentence'].split()\n",
  56. "ds.apply(split_sent, new_field_name='words', is_input=True)\n"
  57. ]
  58. },
  59. {
  60. "cell_type": "code",
  61. "execution_count": 60,
  62. "metadata": {
  63. "collapsed": false
  64. },
  65. "outputs": [
  66. {
  67. "name": "stdout",
  68. "output_type": "stream",
  69. "text": [
  70. "Train size: "
  71. ]
  72. },
  73. {
  74. "name": "stdout",
  75. "output_type": "stream",
  76. "text": [
  77. " "
  78. ]
  79. },
  80. {
  81. "name": "stdout",
  82. "output_type": "stream",
  83. "text": [
  84. "54"
  85. ]
  86. },
  87. {
  88. "name": "stdout",
  89. "output_type": "stream",
  90. "text": [
  91. "\n"
  92. ]
  93. },
  94. {
  95. "name": "stdout",
  96. "output_type": "stream",
  97. "text": [
  98. "Test size: "
  99. ]
  100. },
  101. {
  102. "name": "stdout",
  103. "output_type": "stream",
  104. "text": [
  105. " "
  106. ]
  107. },
  108. {
  109. "name": "stdout",
  110. "output_type": "stream",
  111. "text": [
  112. "23"
  113. ]
  114. },
  115. {
  116. "name": "stdout",
  117. "output_type": "stream",
  118. "text": [
  119. "\n"
  120. ]
  121. }
  122. ],
  123. "source": [
  124. "# 分割训练集/验证集\n",
  125. "train_data, dev_data = ds.split(0.3)\n",
  126. "print(\"Train size: \", len(train_data))\n",
  127. "print(\"Test size: \", len(dev_data))"
  128. ]
  129. },
  130. {
  131. "cell_type": "code",
  132. "execution_count": 61,
  133. "metadata": {},
  134. "outputs": [],
  135. "source": [
  136. "from fastNLP import Vocabulary\n",
  137. "vocab = Vocabulary(min_freq=2)\n",
  138. "train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n",
  139. "\n",
  140. "# index句子, Vocabulary.to_index(word)\n",
  141. "train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n",
  142. "dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n"
  143. ]
  144. },
  145. {
  146. "cell_type": "markdown",
  147. "metadata": {},
  148. "source": [
  149. "## step 3\n",
  150. " 定义模型"
  151. ]
  152. },
  153. {
  154. "cell_type": "code",
  155. "execution_count": 62,
  156. "metadata": {},
  157. "outputs": [],
  158. "source": [
  159. "from fastNLP.models import CNNText\n",
  160. "model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n"
  161. ]
  162. },
  163. {
  164. "cell_type": "markdown",
  165. "metadata": {},
  166. "source": [
  167. "## step 4\n",
  168. "开始训练"
  169. ]
  170. },
  171. {
  172. "cell_type": "code",
  173. "execution_count": 63,
  174. "metadata": {},
  175. "outputs": [
  176. {
  177. "name": "stdout",
  178. "output_type": "stream",
  179. "text": [
  180. "training epochs started 2018-12-07 14:03:41"
  181. ]
  182. },
  183. {
  184. "name": "stdout",
  185. "output_type": "stream",
  186. "text": [
  187. "\n"
  188. ]
  189. },
  190. {
  191. "data": {
  192. "text/plain": [
  193. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6), HTML(value='')), layout=Layout(display='i…"
  194. ]
  195. },
  196. "execution_count": 0,
  197. "metadata": {},
  198. "output_type": "execute_result"
  199. },
  200. {
  201. "name": "stdout",
  202. "output_type": "stream",
  203. "text": [
  204. "\r"
  205. ]
  206. },
  207. {
  208. "name": "stdout",
  209. "output_type": "stream",
  210. "text": [
  211. "Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.26087"
  212. ]
  213. },
  214. {
  215. "name": "stdout",
  216. "output_type": "stream",
  217. "text": [
  218. "\n"
  219. ]
  220. },
  221. {
  222. "name": "stdout",
  223. "output_type": "stream",
  224. "text": [
  225. "\r"
  226. ]
  227. },
  228. {
  229. "name": "stdout",
  230. "output_type": "stream",
  231. "text": [
  232. "Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.347826"
  233. ]
  234. },
  235. {
  236. "name": "stdout",
  237. "output_type": "stream",
  238. "text": [
  239. "\n"
  240. ]
  241. },
  242. {
  243. "name": "stdout",
  244. "output_type": "stream",
  245. "text": [
  246. "\r"
  247. ]
  248. },
  249. {
  250. "name": "stdout",
  251. "output_type": "stream",
  252. "text": [
  253. "Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.608696"
  254. ]
  255. },
  256. {
  257. "name": "stdout",
  258. "output_type": "stream",
  259. "text": [
  260. "\n"
  261. ]
  262. },
  263. {
  264. "name": "stdout",
  265. "output_type": "stream",
  266. "text": [
  267. "\r"
  268. ]
  269. },
  270. {
  271. "name": "stdout",
  272. "output_type": "stream",
  273. "text": [
  274. "Train finished!"
  275. ]
  276. },
  277. {
  278. "name": "stdout",
  279. "output_type": "stream",
  280. "text": [
  281. "\n"
  282. ]
  283. }
  284. ],
  285. "source": [
  286. "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
  287. "trainer = Trainer(model=model, \n",
  288. " train_data=train_data, \n",
  289. " dev_data=dev_data,\n",
  290. " loss=CrossEntropyLoss(),\n",
  291. " metrics=AccuracyMetric()\n",
  292. " )\n",
  293. "trainer.train()\n",
  294. "print('Train finished!')\n"
  295. ]
  296. },
  297. {
  298. "cell_type": "markdown",
  299. "metadata": {},
  300. "source": [
  301. "### 本教程结束。更多操作请参考进阶教程。"
  302. ]
  303. },
  304. {
  305. "cell_type": "code",
  306. "execution_count": null,
  307. "metadata": {},
  308. "outputs": [],
  309. "source": []
  310. }
  311. ],
  312. "metadata": {
  313. "kernelspec": {
  314. "display_name": "Python 2",
  315. "language": "python",
  316. "name": "python2"
  317. },
  318. "language_info": {
  319. "codemirror_mode": {
  320. "name": "ipython",
  321. "version": 2
  322. },
  323. "file_extension": ".py",
  324. "mimetype": "text/x-python",
  325. "name": "python",
  326. "nbconvert_exporter": "python",
  327. "pygments_lexer": "ipython2",
  328. "version": "2.7.6"
  329. }
  330. },
  331. "nbformat": 4,
  332. "nbformat_minor": 0
  333. }