You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run.py 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. import os
  2. import sys
  3. sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
  4. from collections import defaultdict
  5. import math
  6. import torch
  7. from fastNLP.core.trainer import Trainer
  8. from fastNLP.core.instance import Instance
  9. from fastNLP.core.vocabulary import Vocabulary
  10. from fastNLP.core.dataset import DataSet
  11. from fastNLP.core.batch import Batch
  12. from fastNLP.core.sampler import SequentialSampler
  13. from fastNLP.core.field import TextField, SeqLabelField
  14. from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
  15. from fastNLP.core.tester import Tester
  16. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  17. from fastNLP.loader.model_loader import ModelLoader
  18. from fastNLP.loader.embed_loader import EmbedLoader
  19. from fastNLP.models.biaffine_parser import BiaffineParser
  20. from fastNLP.saver.model_saver import ModelSaver
  21. # not in the file's dir
  22. if len(os.path.dirname(__file__)) != 0:
  23. os.chdir(os.path.dirname(__file__))
  24. class MyDataLoader(object):
  25. def __init__(self, pickle_path):
  26. self.pickle_path = pickle_path
  27. def load(self, path, word_v=None, pos_v=None, headtag_v=None):
  28. datalist = []
  29. with open(path, 'r', encoding='utf-8') as f:
  30. sample = []
  31. for line in f:
  32. if line.startswith('\n'):
  33. datalist.append(sample)
  34. sample = []
  35. elif line.startswith('#'):
  36. continue
  37. else:
  38. sample.append(line.split('\t'))
  39. if len(sample) > 0:
  40. datalist.append(sample)
  41. ds = DataSet(name='conll')
  42. for sample in datalist:
  43. # print(sample)
  44. res = self.get_one(sample)
  45. if word_v is not None:
  46. word_v.update(res[0])
  47. pos_v.update(res[1])
  48. headtag_v.update(res[3])
  49. ds.append(Instance(word_seq=TextField(res[0], is_target=False),
  50. pos_seq=TextField(res[1], is_target=False),
  51. head_indices=SeqLabelField(res[2], is_target=True),
  52. head_labels=TextField(res[3], is_target=True),
  53. seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False)))
  54. return ds
  55. def get_one(self, sample):
  56. text = ['<root>']
  57. pos_tags = ['<root>']
  58. heads = [0]
  59. head_tags = ['root']
  60. for w in sample:
  61. t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
  62. if t3 == '_':
  63. continue
  64. text.append(t1)
  65. pos_tags.append(t2)
  66. heads.append(int(t3))
  67. head_tags.append(t4)
  68. return (text, pos_tags, heads, head_tags)
  69. def index_data(self, dataset, word_v, pos_v, tag_v):
  70. dataset.index_field('word_seq', word_v)
  71. dataset.index_field('pos_seq', pos_v)
  72. dataset.index_field('head_labels', tag_v)
  73. # datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT"
  74. datadir = "/home/yfshao/UD_English-EWT"
  75. cfgfile = './cfg.cfg'
  76. train_data_name = "en_ewt-ud-train.conllu"
  77. dev_data_name = "en_ewt-ud-dev.conllu"
  78. emb_file_name = '/home/yfshao/glove.6B.100d.txt'
  79. processed_datadir = './save'
  80. # Config Loader
  81. train_args = ConfigSection()
  82. test_args = ConfigSection()
  83. model_args = ConfigSection()
  84. optim_args = ConfigSection()
  85. ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args})
  86. # Data Loader
  87. def save_data(dirpath, **kwargs):
  88. import _pickle
  89. if not os.path.exists(dirpath):
  90. os.mkdir(dirpath)
  91. for name, data in kwargs.items():
  92. with open(os.path.join(dirpath, name+'.pkl'), 'wb') as f:
  93. _pickle.dump(data, f)
  94. def load_data(dirpath):
  95. import _pickle
  96. datas = {}
  97. for f_name in os.listdir(dirpath):
  98. if not f_name.endswith('.pkl'):
  99. continue
  100. name = f_name[:-4]
  101. with open(os.path.join(dirpath, f_name), 'rb') as f:
  102. datas[name] = _pickle.load(f)
  103. return datas
  104. class MyTester(object):
  105. def __init__(self, batch_size, use_cuda=False, **kwagrs):
  106. self.batch_size = batch_size
  107. self.use_cuda = use_cuda
  108. def test(self, model, dataset):
  109. self.model = model.cuda() if self.use_cuda else model
  110. self.model.eval()
  111. batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda)
  112. eval_res = defaultdict(list)
  113. i = 0
  114. for batch_x, batch_y in batchiter:
  115. with torch.no_grad():
  116. pred_y = self.model(**batch_x)
  117. eval_one = self.model.evaluate(**pred_y, **batch_y)
  118. i += self.batch_size
  119. for eval_name, tensor in eval_one.items():
  120. eval_res[eval_name].append(tensor)
  121. tmp = {}
  122. for eval_name, tensorlist in eval_res.items():
  123. tmp[eval_name] = torch.cat(tensorlist, dim=0)
  124. self.res = self.model.metrics(**tmp)
  125. def show_metrics(self):
  126. s = ""
  127. for name, val in self.res.items():
  128. s += '{}: {:.2f}\t'.format(name, val)
  129. return s
  130. loader = MyDataLoader('')
  131. try:
  132. data_dict = load_data(processed_datadir)
  133. word_v = data_dict['word_v']
  134. pos_v = data_dict['pos_v']
  135. tag_v = data_dict['tag_v']
  136. train_data = data_dict['train_data']
  137. dev_data = data_dict['dev_data']
  138. print('use saved pickles')
  139. except Exception as _:
  140. print('load raw data and preprocess')
  141. word_v = Vocabulary(need_default=True, min_freq=2)
  142. pos_v = Vocabulary(need_default=True)
  143. tag_v = Vocabulary(need_default=False)
  144. train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v)
  145. dev_data = loader.load(os.path.join(datadir, dev_data_name))
  146. save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data)
  147. loader.index_data(train_data, word_v, pos_v, tag_v)
  148. loader.index_data(dev_data, word_v, pos_v, tag_v)
  149. print(len(train_data))
  150. print(len(dev_data))
  151. ep = train_args['epochs']
  152. train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep
  153. model_args['word_vocab_size'] = len(word_v)
  154. model_args['pos_vocab_size'] = len(pos_v)
  155. model_args['num_label'] = len(tag_v)
  156. def train():
  157. # Trainer
  158. trainer = Trainer(**train_args.data)
  159. def _define_optim(obj):
  160. obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data)
  161. obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4))
  162. def _update(obj):
  163. obj._scheduler.step()
  164. obj._optimizer.step()
  165. trainer.define_optimizer = lambda: _define_optim(trainer)
  166. trainer.update = lambda: _update(trainer)
  167. trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth)
  168. trainer._create_validator = lambda x: MyTester(**test_args.data)
  169. # Model
  170. model = BiaffineParser(**model_args.data)
  171. # use pretrain embedding
  172. embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))
  173. model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False)
  174. model.word_embedding.padding_idx = word_v.padding_idx
  175. model.word_embedding.weight.data[word_v.padding_idx].fill_(0)
  176. model.pos_embedding.padding_idx = pos_v.padding_idx
  177. model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0)
  178. try:
  179. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  180. print('model parameter loaded!')
  181. except Exception as _:
  182. print("No saved model. Continue.")
  183. pass
  184. # Start training
  185. trainer.train(model, train_data, dev_data)
  186. print("Training finished!")
  187. # Saver
  188. saver = ModelSaver("./save/saved_model.pkl")
  189. saver.save_pytorch(model)
  190. print("Model saved!")
  191. def test():
  192. # Tester
  193. tester = MyTester(**test_args.data)
  194. # Model
  195. model = BiaffineParser(**model_args.data)
  196. try:
  197. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  198. print('model parameter loaded!')
  199. except Exception as _:
  200. print("No saved model. Abort test.")
  201. raise
  202. # Start training
  203. tester.test(model, dev_data)
  204. print(tester.show_metrics())
  205. print("Testing finished!")
  206. if __name__ == "__main__":
  207. import argparse
  208. parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
  209. parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer'])
  210. args = parser.parse_args()
  211. if args.mode == 'train':
  212. train()
  213. elif args.mode == 'test':
  214. test()
  215. elif args.mode == 'infer':
  216. infer()
  217. else:
  218. print('no mode specified for model!')
  219. parser.print_help()