You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infer.py 3.0 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. import sys
  3. sys.path.extend(['/home/yfshao/workdir/dev_fastnlp'])
  4. from fastNLP.api.processor import *
  5. from fastNLP.models.biaffine_parser import BiaffineParser
  6. from fastNLP.io.config_io import ConfigSection, ConfigLoader
  7. import _pickle as pickle
  8. import torch
  9. def _load(path):
  10. with open(path, 'rb') as f:
  11. obj = pickle.load(f)
  12. return obj
  13. def _load_all(src):
  14. model_path = src
  15. src = os.path.dirname(src)
  16. word_v = _load(src+'/word_v.pkl')
  17. pos_v = _load(src+'/pos_v.pkl')
  18. tag_v = _load(src+'/tag_v.pkl')
  19. pos_pp = torch.load(src+'/pos_pp.pkl')['pipeline']
  20. model_args = ConfigSection()
  21. ConfigLoader.load_config('cfg.cfg', {'model': model_args})
  22. model_args['word_vocab_size'] = len(word_v)
  23. model_args['pos_vocab_size'] = len(pos_v)
  24. model_args['num_label'] = len(tag_v)
  25. model = BiaffineParser(**model_args.data)
  26. model.load_state_dict(torch.load(model_path))
  27. return {
  28. 'word_v': word_v,
  29. 'pos_v': pos_v,
  30. 'tag_v': tag_v,
  31. 'model': model,
  32. 'pos_pp':pos_pp,
  33. }
  34. def build(load_path, save_path):
  35. BOS = '<BOS>'
  36. NUM = '<NUM>'
  37. _dict = _load_all(load_path)
  38. word_vocab = _dict['word_v']
  39. pos_vocab = _dict['pos_v']
  40. tag_vocab = _dict['tag_v']
  41. pos_pp = _dict['pos_pp']
  42. model = _dict['model']
  43. print('load model from {}'.format(load_path))
  44. word_seq = 'raw_word_seq'
  45. pos_seq = 'raw_pos_seq'
  46. # build pipeline
  47. # input
  48. pipe = pos_pp
  49. pipe.pipeline.pop(-1)
  50. pipe.add_processor(Num2TagProcessor(NUM, 'word_list', word_seq))
  51. pipe.add_processor(PreAppendProcessor(BOS, word_seq))
  52. pipe.add_processor(PreAppendProcessor(BOS, 'pos_list', pos_seq))
  53. pipe.add_processor(IndexerProcessor(word_vocab, word_seq, 'word_seq'))
  54. pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, 'pos_seq'))
  55. pipe.add_processor(SeqLenProcessor('word_seq', 'word_seq_origin_len'))
  56. pipe.add_processor(SetTensorProcessor({'word_seq':True, 'pos_seq':True, 'word_seq_origin_len':True}, default=False))
  57. pipe.add_processor(ModelProcessor(model, 'word_seq_origin_len'))
  58. pipe.add_processor(SliceProcessor(1, None, None, 'head_pred', 'heads'))
  59. pipe.add_processor(SliceProcessor(1, None, None, 'label_pred', 'label_pred'))
  60. pipe.add_processor(Index2WordProcessor(tag_vocab, 'label_pred', 'labels'))
  61. if not os.path.exists(save_path):
  62. os.makedirs(save_path)
  63. with open(save_path+'/pipeline.pkl', 'wb') as f:
  64. torch.save({'pipeline': pipe}, f)
  65. print('save pipeline in {}'.format(save_path))
  66. import argparse
  67. parser = argparse.ArgumentParser(description='build pipeline for parser.')
  68. parser.add_argument('--src', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/save')
  69. parser.add_argument('--dst', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/pipe')
  70. args = parser.parse_args()
  71. build(args.src, args.dst)