You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run.py 8.2 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. import os
  2. import sys
  3. sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
  4. import fastNLP
  5. from fastNLP.core.trainer import Trainer
  6. from fastNLP.core.instance import Instance
  7. from fastNLP.api.pipeline import Pipeline
  8. from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss
  9. from fastNLP.core.tester import Tester
  10. from fastNLP.io.config_io import ConfigLoader, ConfigSection
  11. from fastNLP.io.model_io import ModelLoader
  12. from fastNLP.io.dataset_loader import ConllxDataLoader
  13. from fastNLP.api.processor import *
  14. from fastNLP.io.embed_loader import EmbedLoader
  15. from fastNLP.core.callback import Callback
  16. BOS = '<BOS>'
  17. EOS = '<EOS>'
  18. UNK = '<UNK>'
  19. PAD = '<PAD>'
  20. NUM = '<NUM>'
  21. ENG = '<ENG>'
  22. # not in the file's dir
  23. if len(os.path.dirname(__file__)) != 0:
  24. os.chdir(os.path.dirname(__file__))
  25. def convert(data):
  26. dataset = DataSet()
  27. for sample in data:
  28. word_seq = [BOS] + sample['words']
  29. pos_seq = [BOS] + sample['pos_tags']
  30. heads = [0] + sample['heads']
  31. head_tags = [BOS] + sample['labels']
  32. dataset.append(Instance(raw_words=word_seq,
  33. pos=pos_seq,
  34. gold_heads=heads,
  35. arc_true=heads,
  36. tags=head_tags))
  37. return dataset
  38. def load(path):
  39. data = ConllxDataLoader().load(path)
  40. return convert(data)
  41. datadir = "/remote-home/yfshao/workdir/ctb9.0/"
  42. train_data_name = "train.conllx"
  43. dev_data_name = "dev.conllx"
  44. test_data_name = "test.conllx"
  45. emb_file_name = "/remote-home/yfshao/workdir/word_vector/cc.zh.300.vec"
  46. cfgfile = './cfg.cfg'
  47. processed_datadir = './save'
  48. # Config Loader
  49. train_args = ConfigSection()
  50. model_args = ConfigSection()
  51. optim_args = ConfigSection()
  52. ConfigLoader.load_config(cfgfile, {"train": train_args, "model": model_args, "optim": optim_args})
  53. print('trainre Args:', train_args.data)
  54. print('model Args:', model_args.data)
  55. print('optim_args', optim_args.data)
  56. # Pickle Loader
  57. def save_data(dirpath, **kwargs):
  58. import _pickle
  59. if not os.path.exists(dirpath):
  60. os.mkdir(dirpath)
  61. for name, data in kwargs.items():
  62. with open(os.path.join(dirpath, name+'.pkl'), 'wb') as f:
  63. _pickle.dump(data, f)
  64. def load_data(dirpath):
  65. import _pickle
  66. datas = {}
  67. for f_name in os.listdir(dirpath):
  68. if not f_name.endswith('.pkl'):
  69. continue
  70. name = f_name[:-4]
  71. with open(os.path.join(dirpath, f_name), 'rb') as f:
  72. datas[name] = _pickle.load(f)
  73. return datas
  74. def P2(data, field, length):
  75. ds = [ins for ins in data if len(ins[field]) >= length]
  76. data.clear()
  77. data.extend(ds)
  78. return ds
  79. def update_v(vocab, data, field):
  80. data.apply(lambda x: vocab.add_word_lst(x[field]), new_field_name=None)
  81. # use pretrain embedding
  82. word_v = Vocabulary(unknown=UNK, padding=PAD)
  83. pos_v = Vocabulary(unknown=None, padding=PAD)
  84. tag_v = Vocabulary(unknown=None, padding=None)
  85. train_data = load(os.path.join(datadir, train_data_name))
  86. dev_data = load(os.path.join(datadir, dev_data_name))
  87. test_data = load(os.path.join(datadir, test_data_name))
  88. print('load raw data and preprocess')
  89. num_p = Num2TagProcessor(tag=NUM, field_name='raw_words', new_added_field_name='words')
  90. for ds in (train_data, dev_data, test_data):
  91. num_p(ds)
  92. update_v(word_v, train_data, 'words')
  93. update_v(pos_v, train_data, 'pos')
  94. update_v(tag_v, train_data, 'tags')
  95. print('vocab build success {}, {}, {}'.format(len(word_v), len(pos_v), len(tag_v)))
  96. # Model
  97. model_args['word_vocab_size'] = len(word_v)
  98. model_args['pos_vocab_size'] = len(pos_v)
  99. model_args['num_label'] = len(tag_v)
  100. model = BiaffineParser(**model_args.data)
  101. print(model)
  102. word_idxp = IndexerProcessor(word_v, 'words', 'word_seq')
  103. pos_idxp = IndexerProcessor(pos_v, 'pos', 'pos_seq')
  104. tag_idxp = IndexerProcessor(tag_v, 'tags', 'label_true')
  105. seq_p = SeqLenProcessor('word_seq', 'seq_lens')
  106. set_input_p = SetInputProcessor('word_seq', 'pos_seq', 'seq_lens', flag=True)
  107. set_target_p = SetTargetProcessor('arc_true', 'label_true', 'seq_lens', flag=True)
  108. label_toword_p = Index2WordProcessor(vocab=tag_v, field_name='label_pred', new_added_field_name='label_pred_seq')
  109. for ds in (train_data, dev_data, test_data):
  110. word_idxp(ds)
  111. pos_idxp(ds)
  112. tag_idxp(ds)
  113. seq_p(ds)
  114. set_input_p(ds)
  115. set_target_p(ds)
  116. if train_args['use_golden_train']:
  117. train_data.set_input('gold_heads', flag=True)
  118. train_args.data.pop('use_golden_train')
  119. print(test_data[0])
  120. print('train len {}'.format(len(train_data)))
  121. print('dev len {}'.format(len(dev_data)))
  122. print('test len {}'.format(len(test_data)))
  123. def train(path):
  124. # test saving pipeline
  125. save_pipe(path)
  126. embed = EmbedLoader.fast_load_embedding(model_args['word_emb_dim'], emb_file_name, word_v)
  127. embed = torch.tensor(embed, dtype=torch.float32)
  128. # embed = EmbedLoader.fast_load_embedding(emb_dim=model_args['word_emb_dim'], emb_file=emb_file_name, vocab=word_v)
  129. # embed = torch.tensor(embed, dtype=torch.float32)
  130. # model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=True)
  131. model.word_embedding.padding_idx = word_v.padding_idx
  132. model.word_embedding.weight.data[word_v.padding_idx].fill_(0)
  133. model.pos_embedding.padding_idx = pos_v.padding_idx
  134. model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0)
  135. class MyCallback(Callback):
  136. def on_step_end(self, optimizer):
  137. step = self.trainer.step
  138. # learning rate decay
  139. if step > 0 and step % 1000 == 0:
  140. for pg in optimizer.param_groups:
  141. pg['lr'] *= 0.93
  142. print('decay lr to {}'.format([pg['lr'] for pg in optimizer.param_groups]))
  143. if step == 3000:
  144. # start training embedding
  145. print('start training embedding at {}'.format(step))
  146. model = self.trainer.model
  147. for m in model.modules():
  148. if isinstance(m, torch.nn.Embedding):
  149. m.weight.requires_grad = True
  150. # Trainer
  151. trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data,
  152. loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS',
  153. **train_args.data,
  154. optimizer=fastNLP.Adam(**optim_args.data),
  155. save_path=path,
  156. callbacks=[MyCallback()])
  157. # Start training
  158. try:
  159. trainer.train()
  160. print("Training finished!")
  161. finally:
  162. # save pipeline
  163. save_pipe(path)
  164. print('pipe saved')
  165. def save_pipe(path):
  166. pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p])
  167. pipe.add_processor(ModelProcessor(model=model, batch_size=32))
  168. pipe.add_processor(label_toword_p)
  169. os.makedirs(path, exist_ok=True)
  170. torch.save({'pipeline': pipe,
  171. 'names':['num word_idx pos_idx seq set_input model tag_to_word'.split()],
  172. }, os.path.join(path, 'pipe.pkl'))
  173. def test(path):
  174. # Tester
  175. tester = Tester(**test_args.data)
  176. # Model
  177. model = BiaffineParser(**model_args.data)
  178. model.eval()
  179. try:
  180. ModelLoader.load_pytorch(model, path)
  181. print('model parameter loaded!')
  182. except Exception as _:
  183. print("No saved model. Abort test.")
  184. raise
  185. # Start training
  186. print("Testing Train data")
  187. tester.test(model, train_data)
  188. print("Testing Dev data")
  189. tester.test(model, dev_data)
  190. print("Testing Test data")
  191. tester.test(model, test_data)
  192. if __name__ == "__main__":
  193. import argparse
  194. parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
  195. parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer'])
  196. parser.add_argument('--path', type=str, default='')
  197. # parser.add_argument('--dst', type=str, default='')
  198. args = parser.parse_args()
  199. if args.mode == 'train':
  200. train(args.path)
  201. elif args.mode == 'test':
  202. test(args.path)
  203. elif args.mode == 'infer':
  204. pass
  205. else:
  206. print('no mode specified for model!')
  207. parser.print_help()