You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluation.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Bert evaluation script.
  17. """
  18. import os
  19. import numpy as np
  20. import mindspore.common.dtype as mstype
  21. from mindspore import context
  22. from mindspore.common.tensor import Tensor
  23. import mindspore.dataset as de
  24. import mindspore.dataset.transforms.c_transforms as C
  25. from mindspore.train.model import Model
  26. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  27. from src.evaluation_config import cfg, bert_net_cfg
  28. from src.utils import BertNER, BertCLS
  29. from src.CRF import postprocess
  30. from src.cluener_evaluation import submit
  31. from src.finetune_config import tag_to_index
  32. class Accuracy():
  33. '''
  34. calculate accuracy
  35. '''
  36. def __init__(self):
  37. self.acc_num = 0
  38. self.total_num = 0
  39. def update(self, logits, labels):
  40. labels = labels.asnumpy()
  41. labels = np.reshape(labels, -1)
  42. logits = logits.asnumpy()
  43. logit_id = np.argmax(logits, axis=-1)
  44. self.acc_num += np.sum(labels == logit_id)
  45. self.total_num += len(labels)
  46. print("=========================accuracy is ", self.acc_num / self.total_num)
  47. class F1():
  48. '''
  49. calculate F1 score
  50. '''
  51. def __init__(self):
  52. self.TP = 0
  53. self.FP = 0
  54. self.FN = 0
  55. def update(self, logits, labels):
  56. '''
  57. update F1 score
  58. '''
  59. labels = labels.asnumpy()
  60. labels = np.reshape(labels, -1)
  61. if cfg.use_crf:
  62. backpointers, best_tag_id = logits
  63. best_path = postprocess(backpointers, best_tag_id)
  64. logit_id = []
  65. for ele in best_path:
  66. logit_id.extend(ele)
  67. else:
  68. logits = logits.asnumpy()
  69. logit_id = np.argmax(logits, axis=-1)
  70. logit_id = np.reshape(logit_id, -1)
  71. pos_eva = np.isin(logit_id, [i for i in range(1, cfg.num_labels)])
  72. pos_label = np.isin(labels, [i for i in range(1, cfg.num_labels)])
  73. self.TP += np.sum(pos_eva&pos_label)
  74. self.FP += np.sum(pos_eva&(~pos_label))
  75. self.FN += np.sum((~pos_eva)&pos_label)
  76. def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
  77. '''
  78. get dataset
  79. '''
  80. _ = distribute_file
  81. ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
  82. "segment_ids", "label_ids"])
  83. type_cast_op = C.TypeCast(mstype.int32)
  84. ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
  85. ds = ds.map(input_columns="input_mask", operations=type_cast_op)
  86. ds = ds.map(input_columns="input_ids", operations=type_cast_op)
  87. ds = ds.map(input_columns="label_ids", operations=type_cast_op)
  88. ds = ds.repeat(repeat_count)
  89. # apply shuffle operation
  90. buffer_size = 960
  91. ds = ds.shuffle(buffer_size=buffer_size)
  92. # apply batch operations
  93. ds = ds.batch(batch_size, drop_remainder=True)
  94. return ds
  95. def bert_predict(Evaluation):
  96. '''
  97. prediction function
  98. '''
  99. devid = int(os.getenv('DEVICE_ID'))
  100. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
  101. dataset = get_dataset(bert_net_cfg.batch_size, 1)
  102. if cfg.use_crf:
  103. net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True,
  104. tag_to_index=tag_to_index, dropout_prob=0.0)
  105. else:
  106. net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels)
  107. net_for_pretraining.set_train(False)
  108. param_dict = load_checkpoint(cfg.finetune_ckpt)
  109. load_param_into_net(net_for_pretraining, param_dict)
  110. model = Model(net_for_pretraining)
  111. return model, dataset
  112. def test_eval():
  113. '''
  114. evaluation function
  115. '''
  116. task_type = BertNER if cfg.task == "NER" else BertCLS
  117. model, dataset = bert_predict(task_type)
  118. if cfg.clue_benchmark:
  119. submit(model, cfg.data_file, bert_net_cfg.seq_length)
  120. else:
  121. callback = F1() if cfg.task == "NER" else Accuracy()
  122. columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
  123. for data in dataset.create_dict_iterator():
  124. input_data = []
  125. for i in columns_list:
  126. input_data.append(Tensor(data[i]))
  127. input_ids, input_mask, token_type_id, label_ids = input_data
  128. logits = model.predict(input_ids, input_mask, token_type_id, label_ids)
  129. callback.update(logits, label_ids)
  130. print("==============================================================")
  131. if cfg.task == "NER":
  132. print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
  133. print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
  134. print("F1 {:.6f} ".format(2*callback.TP / (2*callback.TP + callback.FP + callback.FN)))
  135. else:
  136. print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
  137. callback.acc_num / callback.total_num))
  138. print("==============================================================")
  139. if __name__ == "__main__":
  140. num_labels = cfg.num_labels
  141. test_eval()