You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run.py 8.2 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. import os
  2. import sys
  3. sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
  4. import fastNLP
  5. import torch
  6. from fastNLP.core.trainer import Trainer
  7. from fastNLP.core.instance import Instance
  8. from fastNLP.api.pipeline import Pipeline
  9. from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss
  10. from fastNLP.core.vocabulary import Vocabulary
  11. from fastNLP.core.dataset import DataSet
  12. from fastNLP.core.tester import Tester
  13. from fastNLP.io.config_io import ConfigLoader, ConfigSection
  14. from fastNLP.io.model_io import ModelLoader
  15. from fastNLP.io.embed_loader import EmbedLoader
  16. from fastNLP.io.model_io import ModelSaver
  17. from reproduction.Biaffine_parser.util import ConllxDataLoader, MyDataloader
  18. from fastNLP.api.processor import *
  19. BOS = '<BOS>'
  20. EOS = '<EOS>'
  21. UNK = '<UNK>'
  22. NUM = '<NUM>'
  23. ENG = '<ENG>'
  24. # not in the file's dir
  25. if len(os.path.dirname(__file__)) != 0:
  26. os.chdir(os.path.dirname(__file__))
  27. def convert(data):
  28. dataset = DataSet()
  29. for sample in data:
  30. word_seq = [BOS] + sample[0]
  31. pos_seq = [BOS] + sample[1]
  32. heads = [0] + list(map(int, sample[2]))
  33. head_tags = [BOS] + sample[3]
  34. dataset.append(Instance(words=word_seq,
  35. pos=pos_seq,
  36. gold_heads=heads,
  37. arc_true=heads,
  38. tags=head_tags))
  39. return dataset
  40. def load(path):
  41. data = ConllxDataLoader().load(path)
  42. return convert(data)
  43. # datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT"
  44. # datadir = "/home/yfshao/UD_English-EWT"
  45. # train_data_name = "en_ewt-ud-train.conllu"
  46. # dev_data_name = "en_ewt-ud-dev.conllu"
  47. # emb_file_name = '/home/yfshao/glove.6B.100d.txt'
  48. # loader = ConlluDataLoader()
  49. # datadir = '/home/yfshao/workdir/parser-data/'
  50. # train_data_name = "train_ctb5.txt"
  51. # dev_data_name = "dev_ctb5.txt"
  52. # test_data_name = "test_ctb5.txt"
  53. datadir = "/home/yfshao/workdir/ctb7.0/"
  54. train_data_name = "train.conllx"
  55. dev_data_name = "dev.conllx"
  56. test_data_name = "test.conllx"
  57. # emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt"
  58. emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec"
  59. cfgfile = './cfg.cfg'
  60. processed_datadir = './save'
  61. # Config Loader
  62. train_args = ConfigSection()
  63. model_args = ConfigSection()
  64. optim_args = ConfigSection()
  65. ConfigLoader.load_config(cfgfile, {"train": train_args, "model": model_args, "optim": optim_args})
  66. print('trainre Args:', train_args.data)
  67. print('model Args:', model_args.data)
  68. print('optim_args', optim_args.data)
  69. # Pickle Loader
  70. def save_data(dirpath, **kwargs):
  71. import _pickle
  72. if not os.path.exists(dirpath):
  73. os.mkdir(dirpath)
  74. for name, data in kwargs.items():
  75. with open(os.path.join(dirpath, name+'.pkl'), 'wb') as f:
  76. _pickle.dump(data, f)
  77. def load_data(dirpath):
  78. import _pickle
  79. datas = {}
  80. for f_name in os.listdir(dirpath):
  81. if not f_name.endswith('.pkl'):
  82. continue
  83. name = f_name[:-4]
  84. with open(os.path.join(dirpath, f_name), 'rb') as f:
  85. datas[name] = _pickle.load(f)
  86. return datas
  87. def P2(data, field, length):
  88. ds = [ins for ins in data if len(ins[field]) >= length]
  89. data.clear()
  90. data.extend(ds)
  91. return ds
  92. def update_v(vocab, data, field):
  93. data.apply(lambda x: vocab.add_word_lst(x[field]), new_field_name=None)
  94. print('load raw data and preprocess')
  95. # use pretrain embedding
  96. word_v = Vocabulary()
  97. word_v.unknown_label = UNK
  98. pos_v = Vocabulary()
  99. tag_v = Vocabulary(unknown=None, padding=None)
  100. train_data = load(os.path.join(datadir, train_data_name))
  101. dev_data = load(os.path.join(datadir, dev_data_name))
  102. test_data = load(os.path.join(datadir, test_data_name))
  103. print(train_data[0])
  104. num_p = Num2TagProcessor('words', 'words')
  105. for ds in (train_data, dev_data, test_data):
  106. num_p(ds)
  107. update_v(word_v, train_data, 'words')
  108. update_v(pos_v, train_data, 'pos')
  109. update_v(tag_v, train_data, 'tags')
  110. print('vocab build success {}, {}, {}'.format(len(word_v), len(pos_v), len(tag_v)))
  111. # embed, _ = EmbedLoader.fast_load_embedding(model_args['word_emb_dim'], emb_file_name, word_v)
  112. # print(embed.size())
  113. # Model
  114. model_args['word_vocab_size'] = len(word_v)
  115. model_args['pos_vocab_size'] = len(pos_v)
  116. model_args['num_label'] = len(tag_v)
  117. model = BiaffineParser(**model_args.data)
  118. model.reset_parameters()
  119. word_idxp = IndexerProcessor(word_v, 'words', 'word_seq')
  120. pos_idxp = IndexerProcessor(pos_v, 'pos', 'pos_seq')
  121. tag_idxp = IndexerProcessor(tag_v, 'tags', 'label_true')
  122. seq_p = SeqLenProcessor('word_seq', 'seq_lens')
  123. set_input_p = SetInputProcessor('word_seq', 'pos_seq', 'seq_lens', flag=True)
  124. set_target_p = SetTargetProcessor('arc_true', 'label_true', 'seq_lens', flag=True)
  125. label_toword_p = Index2WordProcessor(vocab=tag_v, field_name='label_pred', new_added_field_name='label_pred_seq')
  126. for ds in (train_data, dev_data, test_data):
  127. word_idxp(ds)
  128. pos_idxp(ds)
  129. tag_idxp(ds)
  130. seq_p(ds)
  131. set_input_p(ds)
  132. set_target_p(ds)
  133. if train_args['use_golden_train']:
  134. train_data.set_input('gold_heads', flag=True)
  135. train_args.data.pop('use_golden_train')
  136. ignore_label = pos_v['punct']
  137. print(test_data[0])
  138. print('train len {}'.format(len(train_data)))
  139. print('dev len {}'.format(len(dev_data)))
  140. print('test len {}'.format(len(test_data)))
  141. def train(path):
  142. # test saving pipeline
  143. save_pipe(path)
  144. # Trainer
  145. trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data,
  146. loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS',
  147. **train_args.data,
  148. optimizer=fastNLP.Adam(**optim_args.data),
  149. save_path=path)
  150. # model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False)
  151. model.word_embedding.padding_idx = word_v.padding_idx
  152. model.word_embedding.weight.data[word_v.padding_idx].fill_(0)
  153. model.pos_embedding.padding_idx = pos_v.padding_idx
  154. model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0)
  155. # try:
  156. # ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  157. # print('model parameter loaded!')
  158. # except Exception as _:
  159. # print("No saved model. Continue.")
  160. # pass
  161. # Start training
  162. trainer.train()
  163. print("Training finished!")
  164. # save pipeline
  165. save_pipe(path)
  166. print('pipe saved')
  167. def save_pipe(path):
  168. pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p])
  169. pipe.add_processor(ModelProcessor(model=model, batch_size=32))
  170. pipe.add_processor(label_toword_p)
  171. torch.save(pipe, os.path.join(path, 'pipe.pkl'))
  172. def test(path):
  173. # Tester
  174. tester = Tester(**test_args.data)
  175. # Model
  176. model = BiaffineParser(**model_args.data)
  177. model.eval()
  178. try:
  179. ModelLoader.load_pytorch(model, path)
  180. print('model parameter loaded!')
  181. except Exception as _:
  182. print("No saved model. Abort test.")
  183. raise
  184. # Start training
  185. print("Testing Train data")
  186. tester.test(model, train_data)
  187. print("Testing Dev data")
  188. tester.test(model, dev_data)
  189. print("Testing Test data")
  190. tester.test(model, test_data)
  191. def build_pipe(parser_pipe_path):
  192. parser_pipe = torch.load(parser_pipe_path)
  193. if __name__ == "__main__":
  194. import argparse
  195. parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
  196. parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer', 'save'])
  197. parser.add_argument('--path', type=str, default='')
  198. # parser.add_argument('--dst', type=str, default='')
  199. args = parser.parse_args()
  200. if args.mode == 'train':
  201. train(args.path)
  202. elif args.mode == 'test':
  203. test(args.path)
  204. elif args.mode == 'infer':
  205. pass
  206. # elif args.mode == 'save':
  207. # print(f'save model from {args.path} to {args.dst}')
  208. # save_model(args.path, args.dst)
  209. # load_path = os.path.dirname(args.dst)
  210. # print(f'save pipeline in {load_path}')
  211. # build(load_path)
  212. else:
  213. print('no mode specified for model!')
  214. parser.print_help()