You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run.py 12 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago

  1. import os
  2. import sys
  3. sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
  4. import torch
  5. import re
  6. from fastNLP.core.trainer import Trainer
  7. from fastNLP.core.metrics import Evaluator
  8. from fastNLP.core.instance import Instance
  9. from fastNLP.core.vocabulary import Vocabulary
  10. from fastNLP.core.dataset import DataSet
  11. from fastNLP.core.field import TextField, SeqLabelField
  12. from fastNLP.core.tester import Tester
  13. from fastNLP.io.config_io import ConfigLoader, ConfigSection
  14. from fastNLP.io.model_io import ModelLoader, ModelSaver
  15. from fastNLP.io.embed_loader import EmbedLoader
  16. from fastNLP.models.biaffine_parser import BiaffineParser
  17. BOS = '<BOS>'
  18. EOS = '<EOS>'
  19. UNK = '<OOV>'
  20. NUM = '<NUM>'
  21. ENG = '<ENG>'
  22. # not in the file's dir
  23. if len(os.path.dirname(__file__)) != 0:
  24. os.chdir(os.path.dirname(__file__))
  25. class ConlluDataLoader(object):
  26. def load(self, path):
  27. datalist = []
  28. with open(path, 'r', encoding='utf-8') as f:
  29. sample = []
  30. for line in f:
  31. if line.startswith('\n'):
  32. datalist.append(sample)
  33. sample = []
  34. elif line.startswith('#'):
  35. continue
  36. else:
  37. sample.append(line.split('\t'))
  38. if len(sample) > 0:
  39. datalist.append(sample)
  40. ds = DataSet(name='conll')
  41. for sample in datalist:
  42. # print(sample)
  43. res = self.get_one(sample)
  44. ds.append(Instance(word_seq=TextField(res[0], is_target=False),
  45. pos_seq=TextField(res[1], is_target=False),
  46. head_indices=SeqLabelField(res[2], is_target=True),
  47. head_labels=TextField(res[3], is_target=True)))
  48. return ds
  49. def get_one(self, sample):
  50. text = []
  51. pos_tags = []
  52. heads = []
  53. head_tags = []
  54. for w in sample:
  55. t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
  56. if t3 == '_':
  57. continue
  58. text.append(t1)
  59. pos_tags.append(t2)
  60. heads.append(int(t3))
  61. head_tags.append(t4)
  62. return (text, pos_tags, heads, head_tags)
  63. class CTBDataLoader(object):
  64. def load(self, data_path):
  65. with open(data_path, "r", encoding="utf-8") as f:
  66. lines = f.readlines()
  67. data = self.parse(lines)
  68. return self.convert(data)
  69. def parse(self, lines):
  70. """
  71. [
  72. [word], [pos], [head_index], [head_tag]
  73. ]
  74. """
  75. sample = []
  76. data = []
  77. for i, line in enumerate(lines):
  78. line = line.strip()
  79. if len(line) == 0 or i+1 == len(lines):
  80. data.append(list(map(list, zip(*sample))))
  81. sample = []
  82. else:
  83. sample.append(line.split())
  84. return data
  85. def convert(self, data):
  86. dataset = DataSet()
  87. for sample in data:
  88. word_seq = [BOS] + sample[0] + [EOS]
  89. pos_seq = [BOS] + sample[1] + [EOS]
  90. heads = [0] + list(map(int, sample[2])) + [0]
  91. head_tags = [BOS] + sample[3] + [EOS]
  92. dataset.append(Instance(word_seq=TextField(word_seq, is_target=False),
  93. pos_seq=TextField(pos_seq, is_target=False),
  94. gold_heads=SeqLabelField(heads, is_target=False),
  95. head_indices=SeqLabelField(heads, is_target=True),
  96. head_labels=TextField(head_tags, is_target=True)))
  97. return dataset
  98. # datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT"
  99. # datadir = "/home/yfshao/UD_English-EWT"
  100. # train_data_name = "en_ewt-ud-train.conllu"
  101. # dev_data_name = "en_ewt-ud-dev.conllu"
  102. # emb_file_name = '/home/yfshao/glove.6B.100d.txt'
  103. # loader = ConlluDataLoader()
  104. datadir = '/home/yfshao/workdir/parser-data/'
  105. train_data_name = "train_ctb5.txt"
  106. dev_data_name = "dev_ctb5.txt"
  107. test_data_name = "test_ctb5.txt"
  108. emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt"
  109. # emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec"
  110. loader = CTBDataLoader()
  111. cfgfile = './cfg.cfg'
  112. processed_datadir = './save'
  113. # Config Loader
  114. train_args = ConfigSection()
  115. test_args = ConfigSection()
  116. model_args = ConfigSection()
  117. optim_args = ConfigSection()
  118. ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args})
  119. print('trainre Args:', train_args.data)
  120. print('test Args:', test_args.data)
  121. print('optim Args:', optim_args.data)
  122. # Pickle Loader
  123. def save_data(dirpath, **kwargs):
  124. import _pickle
  125. if not os.path.exists(dirpath):
  126. os.mkdir(dirpath)
  127. for name, data in kwargs.items():
  128. with open(os.path.join(dirpath, name+'.pkl'), 'wb') as f:
  129. _pickle.dump(data, f)
  130. def load_data(dirpath):
  131. import _pickle
  132. datas = {}
  133. for f_name in os.listdir(dirpath):
  134. if not f_name.endswith('.pkl'):
  135. continue
  136. name = f_name[:-4]
  137. with open(os.path.join(dirpath, f_name), 'rb') as f:
  138. datas[name] = _pickle.load(f)
  139. return datas
  140. def P2(data, field, length):
  141. ds = [ins for ins in data if ins[field].get_length() >= length]
  142. data.clear()
  143. data.extend(ds)
  144. return ds
  145. def P1(data, field):
  146. def reeng(w):
  147. return w if w == BOS or w == EOS or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else ENG
  148. def renum(w):
  149. return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else NUM
  150. for ins in data:
  151. ori = ins[field].contents()
  152. s = list(map(renum, map(reeng, ori)))
  153. if s != ori:
  154. # print(ori)
  155. # print(s)
  156. # print()
  157. ins[field] = ins[field].new(s)
  158. return data
  159. class ParserEvaluator(Evaluator):
  160. def __init__(self, ignore_label):
  161. super(ParserEvaluator, self).__init__()
  162. self.ignore = ignore_label
  163. def __call__(self, predict_list, truth_list):
  164. head_all, label_all, total_all = 0, 0, 0
  165. for pred, truth in zip(predict_list, truth_list):
  166. head, label, total = self.evaluate(**pred, **truth)
  167. head_all += head
  168. label_all += label
  169. total_all += total
  170. return {'UAS': head_all*1.0 / total_all, 'LAS': label_all*1.0 / total_all}
  171. def evaluate(self, head_pred, label_pred, head_indices, head_labels, seq_mask, **_):
  172. """
  173. Evaluate the performance of prediction.
  174. :return : performance results.
  175. head_pred_corrct: number of correct predicted heads.
  176. label_pred_correct: number of correct predicted labels.
  177. total_tokens: number of predicted tokens
  178. """
  179. seq_mask *= (head_labels != self.ignore).long()
  180. head_pred_correct = (head_pred == head_indices).long() * seq_mask
  181. _, label_preds = torch.max(label_pred, dim=2)
  182. label_pred_correct = (label_preds == head_labels).long() * head_pred_correct
  183. return head_pred_correct.sum().item(), label_pred_correct.sum().item(), seq_mask.sum().item()
  184. try:
  185. data_dict = load_data(processed_datadir)
  186. word_v = data_dict['word_v']
  187. pos_v = data_dict['pos_v']
  188. tag_v = data_dict['tag_v']
  189. train_data = data_dict['train_data']
  190. dev_data = data_dict['dev_data']
  191. test_data = data_dict['test_data']
  192. print('use saved pickles')
  193. except Exception as _:
  194. print('load raw data and preprocess')
  195. # use pretrain embedding
  196. word_v = Vocabulary(need_default=True, min_freq=2)
  197. word_v.unknown_label = UNK
  198. pos_v = Vocabulary(need_default=True)
  199. tag_v = Vocabulary(need_default=False)
  200. train_data = loader.load(os.path.join(datadir, train_data_name))
  201. dev_data = loader.load(os.path.join(datadir, dev_data_name))
  202. test_data = loader.load(os.path.join(datadir, test_data_name))
  203. train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v)
  204. datasets = (train_data, dev_data, test_data)
  205. save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data)
  206. embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))
  207. print(len(word_v))
  208. print(embed.size())
  209. # Model
  210. model_args['word_vocab_size'] = len(word_v)
  211. model_args['pos_vocab_size'] = len(pos_v)
  212. model_args['num_label'] = len(tag_v)
  213. model = BiaffineParser(**model_args.data)
  214. model.reset_parameters()
  215. datasets = (train_data, dev_data, test_data)
  216. for ds in datasets:
  217. ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
  218. ds.set_origin_len('word_seq')
  219. if train_args['use_golden_train']:
  220. train_data.set_target(gold_heads=False)
  221. else:
  222. train_data.set_target(gold_heads=None)
  223. train_args.data.pop('use_golden_train')
  224. ignore_label = pos_v['P']
  225. print(test_data[0])
  226. print(len(train_data))
  227. print(len(dev_data))
  228. print(len(test_data))
  229. def train(path):
  230. # Trainer
  231. trainer = Trainer(**train_args.data)
  232. def _define_optim(obj):
  233. lr = optim_args.data['lr']
  234. embed_params = set(obj._model.word_embedding.parameters())
  235. decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters())
  236. params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params]
  237. obj._optimizer = torch.optim.Adam([
  238. {'params': list(embed_params), 'lr':lr*0.1},
  239. {'params': list(decay_params), **optim_args.data},
  240. {'params': params}
  241. ], lr=lr, betas=(0.9, 0.9))
  242. obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05))
  243. def _update(obj):
  244. # torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0)
  245. obj._scheduler.step()
  246. obj._optimizer.step()
  247. trainer.define_optimizer = lambda: _define_optim(trainer)
  248. trainer.update = lambda: _update(trainer)
  249. trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)))
  250. model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False)
  251. model.word_embedding.padding_idx = word_v.padding_idx
  252. model.word_embedding.weight.data[word_v.padding_idx].fill_(0)
  253. model.pos_embedding.padding_idx = pos_v.padding_idx
  254. model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0)
  255. # try:
  256. # ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  257. # print('model parameter loaded!')
  258. # except Exception as _:
  259. # print("No saved model. Continue.")
  260. # pass
  261. # Start training
  262. trainer.train(model, train_data, dev_data)
  263. print("Training finished!")
  264. # Saver
  265. saver = ModelSaver("./save/saved_model.pkl")
  266. saver.save_pytorch(model)
  267. print("Model saved!")
  268. def test(path):
  269. # Tester
  270. tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))
  271. # Model
  272. model = BiaffineParser(**model_args.data)
  273. model.eval()
  274. try:
  275. ModelLoader.load_pytorch(model, path)
  276. print('model parameter loaded!')
  277. except Exception as _:
  278. print("No saved model. Abort test.")
  279. raise
  280. # Start training
  281. print("Testing Train data")
  282. tester.test(model, train_data)
  283. print("Testing Dev data")
  284. tester.test(model, dev_data)
  285. print("Testing Test data")
  286. tester.test(model, test_data)
  287. if __name__ == "__main__":
  288. import argparse
  289. parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
  290. parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer'])
  291. parser.add_argument('--path', type=str, default='')
  292. args = parser.parse_args()
  293. if args.mode == 'train':
  294. train(args.path)
  295. elif args.mode == 'test':
  296. test(args.path)
  297. elif args.mode == 'infer':
  298. pass
  299. else:
  300. print('no mode specified for model!')
  301. parser.print_help()