import torch.nn as nn # from pathes import * from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner,load_weibo_ner from fastNLP.embeddings import StaticEmbedding from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1 from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward import torch.optim as optim import argparse import torch import sys from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ from fastNLP import Tester import fitlog from fastNLP.core.callback import FitlogCallback from utils import set_seed import os from fastNLP import LRScheduler from torch.optim.lr_scheduler import LambdaLR parser = argparse.ArgumentParser() parser.add_argument('--device',default='cuda:4') parser.add_argument('--debug',default=False) parser.add_argument('--norm_embed',default=True) parser.add_argument('--batch',default=10) parser.add_argument('--test_batch',default=1024) parser.add_argument('--optim',default='sgd',help='adam|sgd') parser.add_argument('--lr',default=0.045) parser.add_argument('--model',default='lattice',help='lattice|lstm') parser.add_argument('--skip_before_head',default=False)#in paper it's false parser.add_argument('--hidden',default=100) parser.add_argument('--momentum',default=0) parser.add_argument('--bi',default=True) parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra') parser.add_argument('--use_bigram',default=True) parser.add_argument('--embed_dropout',default=0.5) parser.add_argument('--output_dropout',default=0.5) parser.add_argument('--epoch',default=100) parser.add_argument('--seed',default=100) args = parser.parse_args() set_seed(args.seed) fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)] if args.model == 'lattice': fit_msg_list.append(str(args.skip_before_head)) fit_msg = ' '.join(fit_msg_list) fitlog.commit(__file__,fit_msg=fit_msg) fitlog.add_hyper(args) device = torch.device(args.device) for k,v in args.__dict__.items(): print(k,v) refresh_data = False from pathes import * # ontonote4ner_cn_path = 0 # yangjie_rich_pretrain_unigram_path = 0 # yangjie_rich_pretrain_bigram_path = 0 # resume_ner_path = 0 # weibo_ner_path = 0 if args.dataset == 'ontonote': datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, _refresh=refresh_data,index_token=False, ) elif args.dataset == 'resume': datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, _refresh=refresh_data,index_token=False, ) elif args.dataset == 'weibo': datasets,vocabs,embeddings = load_weibo_ner(weibo_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, _refresh=refresh_data,index_token=False, ) if args.dataset == 'ontonote': args.batch = 10 args.lr = 0.045 elif args.dataset == 'resume': args.batch = 1 args.lr = 0.015 elif args.dataset == 'weibo': args.embed_dropout = 0.1 args.output_dropout = 0.1 w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, _refresh=refresh_data) cache_name = os.path.join('cache',args.dataset+'_lattice') datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path, _refresh=refresh_data,_cache_fp=cache_name) print(datasets['train'][0]) print('vocab info:') for k,v in vocabs.items(): print('{}:{}'.format(k,len(v))) for k,v in datasets.items(): if args.model == 'lattice': v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source') if args.skip_before_head: v.set_padder('skips_l2r_word',LatticeLexiconPadder()) v.set_padder('skips_l2r_source',LatticeLexiconPadder()) v.set_padder('skips_r2l_word',LatticeLexiconPadder()) v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True)) else: v.set_padder('skips_l2r_word',LatticeLexiconPadder()) v.set_padder('skips_r2l_word', LatticeLexiconPadder()) v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1)) v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1)) if args.bi: v.set_input('chars','bigrams','seq_len', 'skips_l2r_word','skips_l2r_source','lexicon_count', 'skips_r2l_word', 'skips_r2l_source','lexicon_count_back', 'target', use_1st_ins_infer_dim_type=True) else: v.set_input('chars','bigrams','seq_len', 'skips_l2r_word','skips_l2r_source','lexicon_count', 'target', use_1st_ins_infer_dim_type=True) v.set_target('target','seq_len') v['target'].set_pad_val(0) elif args.model == 'lstm': v.set_ignore_type('skips_l2r_word','skips_l2r_source') v.set_padder('skips_l2r_word',LatticeLexiconPadder()) v.set_padder('skips_l2r_source',LatticeLexiconPadder()) v.set_input('chars','bigrams','seq_len','target', use_1st_ins_infer_dim_type=True) v.set_target('target','seq_len') v['target'].set_pad_val(0) print(datasets['dev']['skips_l2r_word'][100]) if args.model =='lattice': model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'], hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, skip_batch_first=True,bidirectional=args.bi,debug=args.debug, skip_before_head=args.skip_before_head,use_bigram=args.use_bigram ) elif args.model == 'lstm': model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'], hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, bidirectional=args.bi, embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, use_bigram=args.use_bigram) loss = LossInForward() f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso') f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj') acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len') metrics = [f1_metric,f1_metric_yj,acc_metric] if args.optim == 'adam': optimizer = optim.Adam(model.parameters(),lr=args.lr) elif args.optim == 'sgd': optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum) callbacks = [ FitlogCallback({'test':datasets['test'],'train':datasets['train']}), LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.03)**ep)) ] trainer = Trainer(datasets['train'],model, optimizer=optimizer, loss=loss, metrics=metrics, dev_data=datasets['dev'], device=device, batch_size=args.batch, n_epochs=args.epoch, dev_batch_size=args.test_batch, callbacks=callbacks) trainer.train()