You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import torch.nn as nn
  2. # from pathes import *
  3. from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner,load_weibo_ner
  4. from fastNLP.embeddings import StaticEmbedding
  5. from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1
  6. from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward
  7. import torch.optim as optim
  8. import argparse
  9. import torch
  10. import sys
  11. from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ
  12. from fastNLP import Tester
  13. import fitlog
  14. from fastNLP.core.callback import FitlogCallback
  15. from utils import set_seed
  16. import os
  17. from fastNLP import LRScheduler
  18. from torch.optim.lr_scheduler import LambdaLR
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument('--device',default='cuda:4')
  21. parser.add_argument('--debug',default=False)
  22. parser.add_argument('--norm_embed',default=True)
  23. parser.add_argument('--batch',default=10)
  24. parser.add_argument('--test_batch',default=1024)
  25. parser.add_argument('--optim',default='sgd',help='adam|sgd')
  26. parser.add_argument('--lr',default=0.045)
  27. parser.add_argument('--model',default='lattice',help='lattice|lstm')
  28. parser.add_argument('--skip_before_head',default=False)#in paper it's false
  29. parser.add_argument('--hidden',default=100)
  30. parser.add_argument('--momentum',default=0)
  31. parser.add_argument('--bi',default=True)
  32. parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra')
  33. parser.add_argument('--use_bigram',default=True)
  34. parser.add_argument('--embed_dropout',default=0.5)
  35. parser.add_argument('--output_dropout',default=0.5)
  36. parser.add_argument('--epoch',default=100)
  37. parser.add_argument('--seed',default=100)
  38. args = parser.parse_args()
  39. set_seed(args.seed)
  40. fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)]
  41. if args.model == 'lattice':
  42. fit_msg_list.append(str(args.skip_before_head))
  43. fit_msg = ' '.join(fit_msg_list)
  44. fitlog.commit(__file__,fit_msg=fit_msg)
  45. fitlog.add_hyper(args)
  46. device = torch.device(args.device)
  47. for k,v in args.__dict__.items():
  48. print(k,v)
  49. refresh_data = False
  50. from pathes import *
  51. # ontonote4ner_cn_path = 0
  52. # yangjie_rich_pretrain_unigram_path = 0
  53. # yangjie_rich_pretrain_bigram_path = 0
  54. # resume_ner_path = 0
  55. # weibo_ner_path = 0
  56. if args.dataset == 'ontonote':
  57. datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
  58. _refresh=refresh_data,index_token=False,
  59. )
  60. elif args.dataset == 'resume':
  61. datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
  62. _refresh=refresh_data,index_token=False,
  63. )
  64. elif args.dataset == 'weibo':
  65. datasets,vocabs,embeddings = load_weibo_ner(weibo_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
  66. _refresh=refresh_data,index_token=False,
  67. )
  68. if args.dataset == 'ontonote':
  69. args.batch = 10
  70. args.lr = 0.045
  71. elif args.dataset == 'resume':
  72. args.batch = 1
  73. args.lr = 0.015
  74. elif args.dataset == 'weibo':
  75. args.embed_dropout = 0.1
  76. args.output_dropout = 0.1
  77. w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path,
  78. _refresh=refresh_data)
  79. cache_name = os.path.join('cache',args.dataset+'_lattice')
  80. datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path,
  81. _refresh=refresh_data,_cache_fp=cache_name)
  82. print(datasets['train'][0])
  83. print('vocab info:')
  84. for k,v in vocabs.items():
  85. print('{}:{}'.format(k,len(v)))
  86. for k,v in datasets.items():
  87. if args.model == 'lattice':
  88. v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source')
  89. if args.skip_before_head:
  90. v.set_padder('skips_l2r_word',LatticeLexiconPadder())
  91. v.set_padder('skips_l2r_source',LatticeLexiconPadder())
  92. v.set_padder('skips_r2l_word',LatticeLexiconPadder())
  93. v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True))
  94. else:
  95. v.set_padder('skips_l2r_word',LatticeLexiconPadder())
  96. v.set_padder('skips_r2l_word', LatticeLexiconPadder())
  97. v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1))
  98. v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1))
  99. if args.bi:
  100. v.set_input('chars','bigrams','seq_len',
  101. 'skips_l2r_word','skips_l2r_source','lexicon_count',
  102. 'skips_r2l_word', 'skips_r2l_source','lexicon_count_back',
  103. 'target',
  104. use_1st_ins_infer_dim_type=True)
  105. else:
  106. v.set_input('chars','bigrams','seq_len',
  107. 'skips_l2r_word','skips_l2r_source','lexicon_count',
  108. 'target',
  109. use_1st_ins_infer_dim_type=True)
  110. v.set_target('target','seq_len')
  111. v['target'].set_pad_val(0)
  112. elif args.model == 'lstm':
  113. v.set_ignore_type('skips_l2r_word','skips_l2r_source')
  114. v.set_padder('skips_l2r_word',LatticeLexiconPadder())
  115. v.set_padder('skips_l2r_source',LatticeLexiconPadder())
  116. v.set_input('chars','bigrams','seq_len','target',
  117. use_1st_ins_infer_dim_type=True)
  118. v.set_target('target','seq_len')
  119. v['target'].set_pad_val(0)
  120. print(datasets['dev']['skips_l2r_word'][100])
  121. if args.model =='lattice':
  122. model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'],
  123. hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device,
  124. embed_dropout=args.embed_dropout,output_dropout=args.output_dropout,
  125. skip_batch_first=True,bidirectional=args.bi,debug=args.debug,
  126. skip_before_head=args.skip_before_head,use_bigram=args.use_bigram
  127. )
  128. elif args.model == 'lstm':
  129. model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'],
  130. hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device,
  131. bidirectional=args.bi,
  132. embed_dropout=args.embed_dropout,output_dropout=args.output_dropout,
  133. use_bigram=args.use_bigram)
  134. loss = LossInForward()
  135. f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso')
  136. f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj')
  137. acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len')
  138. metrics = [f1_metric,f1_metric_yj,acc_metric]
  139. if args.optim == 'adam':
  140. optimizer = optim.Adam(model.parameters(),lr=args.lr)
  141. elif args.optim == 'sgd':
  142. optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum)
  143. callbacks = [
  144. FitlogCallback({'test':datasets['test'],'train':datasets['train']}),
  145. LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.03)**ep))
  146. ]
  147. trainer = Trainer(datasets['train'],model,
  148. optimizer=optimizer,
  149. loss=loss,
  150. metrics=metrics,
  151. dev_data=datasets['dev'],
  152. device=device,
  153. batch_size=args.batch,
  154. n_epochs=args.epoch,
  155. dev_batch_size=args.test_batch,
  156. callbacks=callbacks)
  157. trainer.train()