You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ner_decode.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import _pickle
  2. import os
  3. import torch
  4. from fastNLP.core.predictor import SeqLabelInfer
  5. from fastNLP.core.trainer import SeqLabelTrainer
  6. from fastNLP.loader.model_loader import ModelLoader
  7. from fastNLP.models.sequence_modeling import AdvSeqLabel
  8. class Decode(SeqLabelTrainer):
  9. def __init__(self, args):
  10. super(Decode, self).__init__(args)
  11. def decoder(self, network, sents, model_path):
  12. self.model = network
  13. self.model.load_state_dict(torch.load(model_path))
  14. out_put = []
  15. self.mode(network, test=True)
  16. for batch_x in sents:
  17. prediction = self.data_forward(self.model, batch_x)
  18. seq_tag = self.model.prediction(prediction, batch_x[1])
  19. out_put.append(list(seq_tag)[0])
  20. return out_put
  21. def process_sent(sents, word2id):
  22. sents_num = []
  23. for s in sents:
  24. sent_num = []
  25. for c in s:
  26. if c in word2id:
  27. sent_num.append(word2id[c])
  28. else:
  29. sent_num.append(word2id["<unk>"])
  30. sents_num.append(([sent_num], [len(sent_num)])) # batch_size is 1
  31. return sents_num
  32. def process_tag(sents, tags, id2class):
  33. Tags = []
  34. for ttt in tags:
  35. Tags.append([id2class[t] for t in ttt])
  36. Segs = []
  37. PosNers = []
  38. for sent, tag in zip(sents, tags):
  39. word__ = []
  40. lll__ = []
  41. for c, t in zip(sent, tag):
  42. t = id2class[t]
  43. l = t.split("-")
  44. split_ = l[0]
  45. pn = l[1]
  46. if split_ == "S":
  47. word__.append(c)
  48. lll__.append(pn)
  49. word_1 = ""
  50. elif split_ == "E":
  51. word_1 += c
  52. word__.append(word_1)
  53. lll__.append(pn)
  54. word_1 = ""
  55. elif split_ == "B":
  56. word_1 = ""
  57. word_1 += c
  58. else:
  59. word_1 += c
  60. Segs.append(word__)
  61. PosNers.append(lll__)
  62. return Segs, PosNers
  63. pickle_path = "data_for_tests/"
  64. model_path = "data_for_tests/model_best_dev.pkl"
  65. if __name__ == "__main__":
  66. with open(os.path.join(pickle_path, "id2word.pkl"), "rb") as f:
  67. id2word = _pickle.load(f)
  68. with open(os.path.join(pickle_path, "word2id.pkl"), "rb") as f:
  69. word2id = _pickle.load(f)
  70. with open(os.path.join(pickle_path, "id2class.pkl"), "rb") as f:
  71. id2class = _pickle.load(f)
  72. sent = ["中共中央总书记、国家主席江泽民",
  73. "逆向处理输入序列并返回逆序后的序列"] # here is input
  74. args = {"epochs": 1,
  75. "batch_size": 1,
  76. "pickle_path": "data_for_tests/",
  77. "validate": True,
  78. "save_best_dev": True,
  79. "model_saved_path": "data_for_tests/",
  80. "use_cuda": False,
  81. "vocab_size": len(word2id),
  82. "num_classes": len(id2class),
  83. "word_emb_dim": 50,
  84. "rnn_hidden_units": 100,
  85. }
  86. """
  87. network = AdvSeqLabel(args, None)
  88. decoder_ = Decode(args)
  89. tags_num = decoder_.decoder(network, process_sent(sent, word2id), model_path=model_path)
  90. output_seg, output_pn = process_tag(sent, tags_num, id2class) # here is output
  91. print(output_seg)
  92. print(output_pn)
  93. """
  94. # Define the same model
  95. model = AdvSeqLabel(args, None)
  96. # Dump trained parameters into the model
  97. ModelLoader.load_pytorch(model, "./data_for_tests/model_best_dev.pkl")
  98. print("model loaded!")
  99. # Inference interface
  100. infer = SeqLabelInfer(pickle_path)
  101. sent = [[ch for ch in s] for s in sent]
  102. results = infer.predict(model, sent)
  103. for res in results:
  104. print(res)
  105. print("Inference finished!")

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等