You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_fastNLP.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # encoding: utf-8
  2. import os
  3. from fastNLP.core.preprocess import save_pickle
  4. from fastNLP.core.vocabulary import Vocabulary
  5. from fastNLP.fastnlp import FastNLP
  6. from fastNLP.fastnlp import interpret_word_seg_results, interpret_cws_pos_results
  7. from fastNLP.models.cnn_text_classification import CNNText
  8. from fastNLP.models.sequence_modeling import AdvSeqLabel
  9. from fastNLP.saver.model_saver import ModelSaver
  10. PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/fastNLP/reproduction/chinese_word_segment/save/"
  11. PATH_TO_POS_TAG_PICKLE_FILES = "/home/zyfeng/data/crf_seg/"
  12. PATH_TO_TEXT_CLASSIFICATION_PICKLE_FILES = "/home/zyfeng/data/text_classify/"
  13. DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
  14. DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
  15. DEFAULT_RESERVED_LABEL = ['<reserved-2>',
  16. '<reserved-3>',
  17. '<reserved-4>'] # dict index = 2~4
  18. DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
  19. DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
  20. DEFAULT_RESERVED_LABEL[2]: 4}
  21. def word_seg(model_dir, config, section):
  22. nlp = FastNLP(model_dir=model_dir)
  23. nlp.load("cws_basic_model", config_file=config, section_name=section)
  24. text = ["这是最好的基于深度学习的中文分词系统。",
  25. "大王叫我来巡山。",
  26. "我党多年来致力于改善人民生活水平。"]
  27. results = nlp.run(text)
  28. print(results)
  29. for example in results:
  30. words, labels = [], []
  31. for res in example:
  32. words.append(res[0])
  33. labels.append(res[1])
  34. print(interpret_word_seg_results(words, labels))
  35. def mock_cws():
  36. os.makedirs("mock", exist_ok=True)
  37. text = ["这是最好的基于深度学习的中文分词系统。",
  38. "大王叫我来巡山。",
  39. "我党多年来致力于改善人民生活水平。"]
  40. word2id = Vocabulary()
  41. word_list = [ch for ch in "".join(text)]
  42. word2id.update(word_list)
  43. save_pickle(word2id, "./mock/", "word2id.pkl")
  44. class2id = Vocabulary(need_default=False)
  45. label_list = ['B', 'M', 'E', 'S']
  46. class2id.update(label_list)
  47. save_pickle(class2id, "./mock/", "class2id.pkl")
  48. model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)}
  49. config_file = """
  50. [test_section]
  51. vocab_size = {}
  52. word_emb_dim = 50
  53. rnn_hidden_units = 50
  54. num_classes = {}
  55. """.format(len(word2id), len(class2id))
  56. with open("mock/test.cfg", "w", encoding="utf-8") as f:
  57. f.write(config_file)
  58. model = AdvSeqLabel(model_args)
  59. ModelSaver("mock/cws_basic_model_v_0.pkl").save_pytorch(model)
  60. def test_word_seg():
  61. # fake the model and pickles
  62. print("start mocking")
  63. mock_cws()
  64. # run the inference codes
  65. print("start testing")
  66. word_seg("./mock/", "test.cfg", "test_section")
  67. # clean up environments
  68. print("clean up")
  69. os.system("rm -rf mock")
  70. def pos_tag(model_dir, config, section):
  71. nlp = FastNLP(model_dir=model_dir)
  72. nlp.load("pos_tag_model", config_file=config, section_name=section)
  73. text = ["这是最好的基于深度学习的中文分词系统。",
  74. "大王叫我来巡山。",
  75. "我党多年来致力于改善人民生活水平。"]
  76. results = nlp.run(text)
  77. for example in results:
  78. words, labels = [], []
  79. for res in example:
  80. words.append(res[0])
  81. labels.append(res[1])
  82. try:
  83. print(interpret_cws_pos_results(words, labels))
  84. except RuntimeError:
  85. print("inconsistent pos tags. this is for test only.")
  86. def mock_pos_tag():
  87. os.makedirs("mock", exist_ok=True)
  88. text = ["这是最好的基于深度学习的中文分词系统。",
  89. "大王叫我来巡山。",
  90. "我党多年来致力于改善人民生活水平。"]
  91. vocab = Vocabulary()
  92. word_list = [ch for ch in "".join(text)]
  93. vocab.update(word_list)
  94. save_pickle(vocab, "./mock/", "word2id.pkl")
  95. idx2label = Vocabulary(need_default=False)
  96. label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv']
  97. idx2label.update(label_list)
  98. save_pickle(idx2label, "./mock/", "class2id.pkl")
  99. model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
  100. config_file = """
  101. [test_section]
  102. vocab_size = {}
  103. word_emb_dim = 50
  104. rnn_hidden_units = 50
  105. num_classes = {}
  106. """.format(len(vocab), len(idx2label))
  107. with open("mock/test.cfg", "w", encoding="utf-8") as f:
  108. f.write(config_file)
  109. model = AdvSeqLabel(model_args)
  110. ModelSaver("mock/pos_tag_model_v_0.pkl").save_pytorch(model)
  111. def test_pos_tag():
  112. mock_pos_tag()
  113. pos_tag("./mock/", "test.cfg", "test_section")
  114. os.system("rm -rf mock")
  115. def text_classify(model_dir, config, section):
  116. nlp = FastNLP(model_dir=model_dir)
  117. nlp.load("text_classify_model", config_file=config, section_name=section)
  118. text = [
  119. "世界物联网大会明日在京召开龙头股启动在即",
  120. "乌鲁木齐市新增一处城市中心旅游目的地",
  121. "朱元璋的大明朝真的源于明教吗?——告诉你一个真实的“明教”"]
  122. results = nlp.run(text)
  123. print(results)
  124. def mock_text_classify():
  125. os.makedirs("mock", exist_ok=True)
  126. text = ["世界物联网大会明日在京召开龙头股启动在即",
  127. "乌鲁木齐市新增一处城市中心旅游目的地",
  128. "朱元璋的大明朝真的源于明教吗?——告诉你一个真实的“明教”"
  129. ]
  130. vocab = Vocabulary()
  131. word_list = [ch for ch in "".join(text)]
  132. vocab.update(word_list)
  133. save_pickle(vocab, "./mock/", "word2id.pkl")
  134. idx2label = Vocabulary(need_default=False)
  135. label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F']
  136. idx2label.update(label_list)
  137. save_pickle(idx2label, "./mock/", "class2id.pkl")
  138. model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
  139. config_file = """
  140. [test_section]
  141. vocab_size = {}
  142. word_emb_dim = 50
  143. rnn_hidden_units = 50
  144. num_classes = {}
  145. """.format(len(vocab), len(idx2label))
  146. with open("mock/test.cfg", "w", encoding="utf-8") as f:
  147. f.write(config_file)
  148. model = CNNText(model_args)
  149. ModelSaver("mock/text_class_model_v0.pkl").save_pytorch(model)
  150. def test_text_classify():
  151. mock_text_classify()
  152. text_classify("./mock/", "test.cfg", "test_section")
  153. os.system("rm -rf mock")
  154. def test_word_seg_interpret():
  155. foo = [[('这', 'S'), ('是', 'S'), ('最', 'S'), ('好', 'S'), ('的', 'S'), ('基', 'B'), ('于', 'E'), ('深', 'B'), ('度', 'E'),
  156. ('学', 'B'), ('习', 'E'), ('的', 'S'), ('中', 'B'), ('文', 'E'), ('分', 'B'), ('词', 'E'), ('系', 'B'), ('统', 'E'),
  157. ('。', 'S')]]
  158. chars = [x[0] for x in foo[0]]
  159. labels = [x[1] for x in foo[0]]
  160. print(interpret_word_seg_results(chars, labels))
  161. def test_interpret_cws_pos_results():
  162. foo = [
  163. [('这', 'S-r'), ('是', 'S-v'), ('最', 'S-d'), ('好', 'S-a'), ('的', 'S-u'), ('基', 'B-p'), ('于', 'E-p'), ('深', 'B-d'),
  164. ('度', 'E-d'), ('学', 'B-v'), ('习', 'E-v'), ('的', 'S-u'), ('中', 'B-nz'), ('文', 'E-nz'), ('分', 'B-vn'),
  165. ('词', 'E-vn'), ('系', 'B-n'), ('统', 'E-n'), ('。', 'S-w')]
  166. ]
  167. chars = [x[0] for x in foo[0]]
  168. labels = [x[1] for x in foo[0]]
  169. print(interpret_cws_pos_results(chars, labels))
  170. if __name__ == "__main__":
  171. test_word_seg()
  172. test_pos_tag()
  173. test_text_classify()
  174. test_word_seg_interpret()
  175. test_interpret_cws_pos_results()