You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ner.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import _pickle
  2. import os
  3. import numpy as np
  4. import torch
  5. from fastNLP.core.preprocess import SeqLabelPreprocess
  6. from fastNLP.core.tester import SeqLabelTester
  7. from fastNLP.core.trainer import SeqLabelTrainer
  8. from fastNLP.models.sequence_modeling import AdvSeqLabel
  9. class MyNERTrainer(SeqLabelTrainer):
  10. def __init__(self, train_args):
  11. super(MyNERTrainer, self).__init__(train_args)
  12. self.scheduler = None
  13. def define_optimizer(self):
  14. """
  15. override
  16. :return:
  17. """
  18. self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.001)
  19. self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=3000, gamma=0.5)
  20. def update(self):
  21. """
  22. override
  23. :return:
  24. """
  25. self.optimizer.step()
  26. self.scheduler.step()
  27. def _create_validator(self, valid_args):
  28. return MyNERTester(valid_args)
  29. def best_eval_result(self, validator):
  30. accuracy = validator.metrics()
  31. if accuracy > self.best_accuracy:
  32. self.best_accuracy = accuracy
  33. return True
  34. else:
  35. return False
  36. class MyNERTester(SeqLabelTester):
  37. def __init__(self, test_args):
  38. super(MyNERTester, self).__init__(test_args)
  39. def _evaluate(self, prediction, batch_y, seq_len):
  40. """
  41. :param prediction: [batch_size, seq_len, num_classes]
  42. :param batch_y: [batch_size, seq_len]
  43. :param seq_len: [batch_size]
  44. :return:
  45. """
  46. summ = 0
  47. correct = 0
  48. _, indices = torch.max(prediction, 2)
  49. for p, y, l in zip(indices, batch_y, seq_len):
  50. summ += l
  51. correct += np.sum(p[:l].cpu().numpy() == y[:l].cpu().numpy())
  52. return float(correct / summ)
  53. def evaluate(self, predict, truth):
  54. return self._evaluate(predict, truth, self.seq_len)
  55. def metrics(self):
  56. return np.mean(self.eval_history)
  57. def show_matrices(self):
  58. return "dev accuracy={:.2f}".format(float(self.metrics()))
  59. def embedding_process(emb_file, word_dict, emb_dim, emb_pkl):
  60. if os.path.exists(emb_pkl):
  61. with open(emb_pkl, "rb") as f:
  62. embedding_np = _pickle.load(f)
  63. return embedding_np
  64. with open(emb_file, "r", encoding="utf-8") as f:
  65. embedding_np = np.random.uniform(-1, 1, size=(len(word_dict), emb_dim))
  66. for line in f:
  67. line = line.strip().split()
  68. if len(line) != emb_dim + 1:
  69. continue
  70. if line[0] in word_dict:
  71. embedding_np[word_dict[line[0]]] = [float(i) for i in line[1:]]
  72. with open(emb_pkl, "wb") as f:
  73. _pickle.dump(embedding_np, f)
  74. return embedding_np
  75. def data_load(data_file):
  76. with open(data_file, "r", encoding="utf-8") as f:
  77. all_data = []
  78. sent = []
  79. label = []
  80. for line in f:
  81. line = line.strip().split()
  82. if not len(line) <= 1:
  83. sent.append(line[0])
  84. label.append(line[1])
  85. else:
  86. all_data.append([sent, label])
  87. sent = []
  88. label = []
  89. return all_data
  90. data_path = "data_for_tests/people.txt"
  91. pick_path = "data_for_tests/"
  92. emb_path = "data_for_tests/emb50.txt"
  93. save_path = "data_for_tests/"
  94. if __name__ == "__main__":
  95. data = data_load(data_path)
  96. preprocess = SeqLabelPreprocess()
  97. data_train, data_dev = preprocess.run(data, pickle_path=pick_path, train_dev_split=0.3)
  98. # emb = embedding_process(emb_path, p.word2index, 50, os.path.join(pick_path, "embedding.pkl"))
  99. emb = None
  100. args = {"epochs": 20,
  101. "batch_size": 1,
  102. "pickle_path": pick_path,
  103. "validate": True,
  104. "save_best_dev": True,
  105. "model_saved_path": save_path,
  106. "use_cuda": True,
  107. "vocab_size": preprocess.vocab_size,
  108. "num_classes": preprocess.num_classes,
  109. "word_emb_dim": 50,
  110. "rnn_hidden_units": 100
  111. }
  112. # emb = torch.Tensor(emb).float().cuda()
  113. networks = AdvSeqLabel(args, emb)
  114. trainer = MyNERTrainer(args)
  115. trainer.train(networks, data_train, data_dev)
  116. print("Training finished!")

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等