You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_context.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from fastNLP.api.pipeline import Pipeline
  2. from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor
  3. from fastNLP.api.processor import SeqLenProcessor
  4. from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegProcessor
  5. from reproduction.chinese_word_segment.process.cws_processor import CWSBMESTagProcessor
  6. from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor
  7. from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor
  8. from reproduction.chinese_word_segment.process.cws_processor import InputTargetProcessor
  9. from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader
  10. from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF
  11. from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1
  12. ds_name = 'msr'
  13. tr_filename = '/home/hyan/ctb3/train.conllx'
  14. dev_filename = '/home/hyan/ctb3/dev.conllx'
  15. reader = ConllCWSReader()
  16. tr_dataset = reader.load(tr_filename, cut_long_sent=True)
  17. dev_dataset = reader.load(dev_filename)
  18. print("Train {}. Dev: {}".format(len(tr_dataset), len(dev_dataset)))
  19. # 1. 准备processor
  20. fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence')
  21. char_proc = CWSCharSegProcessor('raw_sentence', 'chars_lst')
  22. tag_proc = CWSBMESTagProcessor('raw_sentence', 'target')
  23. bigram_proc = Pre2Post2BigramProcessor('chars_lst', 'bigrams_lst')
  24. char_vocab_proc = VocabIndexerProcessor('chars_lst', new_added_filed_name='chars')
  25. bigram_vocab_proc = VocabIndexerProcessor('bigrams_lst', new_added_filed_name='bigrams', min_freq=4)
  26. seq_len_proc = SeqLenProcessor('chars')
  27. input_target_proc = InputTargetProcessor(input_fields=['chars', 'bigrams', 'seq_lens', "target"],
  28. target_fields=['target', 'seq_lens'])
  29. # 2. 使用processor
  30. fs2hs_proc(tr_dataset)
  31. char_proc(tr_dataset)
  32. tag_proc(tr_dataset)
  33. bigram_proc(tr_dataset)
  34. char_vocab_proc(tr_dataset)
  35. bigram_vocab_proc(tr_dataset)
  36. seq_len_proc(tr_dataset)
  37. # 2.1 处理dev_dataset
  38. fs2hs_proc(dev_dataset)
  39. char_proc(dev_dataset)
  40. tag_proc(dev_dataset)
  41. bigram_proc(dev_dataset)
  42. char_vocab_proc(dev_dataset)
  43. bigram_vocab_proc(dev_dataset)
  44. seq_len_proc(dev_dataset)
  45. input_target_proc(tr_dataset)
  46. input_target_proc(dev_dataset)
  47. print("Finish preparing data.")
  48. # 3. 得到数据集可以用于训练了
  49. # TODO pretrain的embedding是怎么解决的?
  50. import torch
  51. from torch import optim
  52. tag_size = tag_proc.tag_size
  53. cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100,
  54. bigram_vocab_num=bigram_vocab_proc.get_vocab_size(),
  55. bigram_embed_dim=100, num_bigram_per_char=8,
  56. hidden_size=200, bidirectional=True, embed_drop_p=0.2,
  57. num_layers=1, tag_size=tag_size)
  58. cws_model.cuda()
  59. num_epochs = 5
  60. optimizer = optim.Adagrad(cws_model.parameters(), lr=0.005)
  61. from fastNLP.core.trainer import Trainer
  62. from fastNLP.core.sampler import BucketSampler
  63. from fastNLP.core.metrics import BMESF1PreRecMetric
  64. metric = BMESF1PreRecMetric(target='tags')
  65. trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=num_epochs,
  66. batch_size=32, print_every=50, validate_every=-1, dev_data=dev_dataset, save_path=None,
  67. optimizer=optimizer, check_code_level=0, metric_key='f', sampler=BucketSampler(), use_tqdm=True)
  68. trainer.train()
  69. # 4. 组装需要存下的内容
  70. pp = Pipeline()
  71. pp.add_processor(fs2hs_proc)
  72. # pp.add_processor(sp_proc)
  73. pp.add_processor(char_proc)
  74. pp.add_processor(tag_proc)
  75. pp.add_processor(bigram_proc)
  76. pp.add_processor(char_vocab_proc)
  77. pp.add_processor(bigram_vocab_proc)
  78. pp.add_processor(seq_len_proc)
  79. pp.add_processor(input_target_proc)
  80. # te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name)
  81. te_filename = '/home/hyan/ctb3/test.conllx'
  82. te_dataset = reader.load(te_filename)
  83. pp(te_dataset)
  84. from fastNLP.core.tester import Tester
  85. tester = Tester(data=te_dataset, model=cws_model, metrics=metric, batch_size=64, use_cuda=False,
  86. verbose=1)
  87. tester.test()
  88. #
  89. # batch_size = 64
  90. # te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False)
  91. # pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher, type='bmes')
  92. # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100,
  93. # pre * 100,
  94. # rec * 100))
  95. # TODO 这里貌似需要区分test pipeline与infer pipeline
  96. test_context_dict = {'pipeline': pp,
  97. 'model': cws_model}
  98. # torch.save(test_context_dict, 'models/test_context_crf.pkl')
  99. # 5. dev的pp
  100. # 4. 组装需要存下的内容
  101. from fastNLP.api.processor import ModelProcessor
  102. from reproduction.chinese_word_segment.process.cws_processor import BMES2OutputProcessor
  103. model_proc = ModelProcessor(cws_model)
  104. output_proc = BMES2OutputProcessor()
  105. pp = Pipeline()
  106. pp.add_processor(fs2hs_proc)
  107. # pp.add_processor(sp_proc)
  108. pp.add_processor(char_proc)
  109. pp.add_processor(bigram_proc)
  110. pp.add_processor(char_vocab_proc)
  111. pp.add_processor(bigram_vocab_proc)
  112. pp.add_processor(seq_len_proc)
  113. pp.add_processor(model_proc)
  114. pp.add_processor(output_proc)
  115. # TODO 这里貌似需要区分test pipeline与infer pipeline
  116. infer_context_dict = {'pipeline': pp}
  117. # torch.save(infer_context_dict, 'models/cws_crf.pkl')
  118. # TODO 还需要考虑如何替换回原文的问题?
  119. # 1. 不需要将特殊tag替换
  120. # 2. 需要将特殊tag替换回去