You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluation.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Bert evaluation script.
  17. """
  18. import os
  19. import argparse
  20. import numpy as np
  21. import mindspore.common.dtype as mstype
  22. from mindspore import context
  23. from mindspore import log as logger
  24. from mindspore.common.tensor import Tensor
  25. import mindspore.dataset as de
  26. import mindspore.dataset.transforms.c_transforms as C
  27. from mindspore.train.model import Model
  28. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  29. from src.evaluation_config import cfg, bert_net_cfg
  30. from src.utils import BertNER, BertCLS
  31. from src.CRF import postprocess
  32. from src.cluener_evaluation import submit
  33. from src.finetune_config import tag_to_index
  34. class Accuracy():
  35. '''
  36. calculate accuracy
  37. '''
  38. def __init__(self):
  39. self.acc_num = 0
  40. self.total_num = 0
  41. def update(self, logits, labels):
  42. labels = labels.asnumpy()
  43. labels = np.reshape(labels, -1)
  44. logits = logits.asnumpy()
  45. logit_id = np.argmax(logits, axis=-1)
  46. self.acc_num += np.sum(labels == logit_id)
  47. self.total_num += len(labels)
  48. print("=========================accuracy is ", self.acc_num / self.total_num)
  49. class F1():
  50. '''
  51. calculate F1 score
  52. '''
  53. def __init__(self):
  54. self.TP = 0
  55. self.FP = 0
  56. self.FN = 0
  57. def update(self, logits, labels):
  58. '''
  59. update F1 score
  60. '''
  61. labels = labels.asnumpy()
  62. labels = np.reshape(labels, -1)
  63. if cfg.use_crf:
  64. backpointers, best_tag_id = logits
  65. best_path = postprocess(backpointers, best_tag_id)
  66. logit_id = []
  67. for ele in best_path:
  68. logit_id.extend(ele)
  69. else:
  70. logits = logits.asnumpy()
  71. logit_id = np.argmax(logits, axis=-1)
  72. logit_id = np.reshape(logit_id, -1)
  73. pos_eva = np.isin(logit_id, [i for i in range(1, cfg.num_labels)])
  74. pos_label = np.isin(labels, [i for i in range(1, cfg.num_labels)])
  75. self.TP += np.sum(pos_eva&pos_label)
  76. self.FP += np.sum(pos_eva&(~pos_label))
  77. self.FN += np.sum((~pos_eva)&pos_label)
  78. def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
  79. '''
  80. get dataset
  81. '''
  82. _ = distribute_file
  83. ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
  84. "segment_ids", "label_ids"])
  85. type_cast_op = C.TypeCast(mstype.int32)
  86. ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
  87. ds = ds.map(input_columns="input_mask", operations=type_cast_op)
  88. ds = ds.map(input_columns="input_ids", operations=type_cast_op)
  89. ds = ds.map(input_columns="label_ids", operations=type_cast_op)
  90. ds = ds.repeat(repeat_count)
  91. # apply shuffle operation
  92. buffer_size = 960
  93. ds = ds.shuffle(buffer_size=buffer_size)
  94. # apply batch operations
  95. ds = ds.batch(batch_size, drop_remainder=True)
  96. return ds
  97. def bert_predict(Evaluation):
  98. '''
  99. prediction function
  100. '''
  101. target = args_opt.device_target
  102. if target == "Ascend":
  103. devid = int(os.getenv('DEVICE_ID'))
  104. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
  105. elif target == "GPU":
  106. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  107. if bert_net_cfg.compute_type != mstype.float32:
  108. logger.warning('GPU only support fp32 temporarily, run with fp32.')
  109. bert_net_cfg.compute_type = mstype.float32
  110. else:
  111. raise Exception("Target error, GPU or Ascend is supported.")
  112. dataset = get_dataset(bert_net_cfg.batch_size, 1)
  113. if cfg.use_crf:
  114. net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True,
  115. tag_to_index=tag_to_index, dropout_prob=0.0)
  116. else:
  117. net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels)
  118. net_for_pretraining.set_train(False)
  119. param_dict = load_checkpoint(cfg.finetune_ckpt)
  120. load_param_into_net(net_for_pretraining, param_dict)
  121. model = Model(net_for_pretraining)
  122. return model, dataset
  123. def test_eval():
  124. '''
  125. evaluation function
  126. '''
  127. task_type = BertNER if cfg.task == "NER" else BertCLS
  128. model, dataset = bert_predict(task_type)
  129. if cfg.clue_benchmark:
  130. submit(model, cfg.data_file, bert_net_cfg.seq_length)
  131. else:
  132. callback = F1() if cfg.task == "NER" else Accuracy()
  133. columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
  134. for data in dataset.create_dict_iterator():
  135. input_data = []
  136. for i in columns_list:
  137. input_data.append(Tensor(data[i]))
  138. input_ids, input_mask, token_type_id, label_ids = input_data
  139. logits = model.predict(input_ids, input_mask, token_type_id, label_ids)
  140. callback.update(logits, label_ids)
  141. print("==============================================================")
  142. if cfg.task == "NER":
  143. print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
  144. print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
  145. print("F1 {:.6f} ".format(2*callback.TP / (2*callback.TP + callback.FP + callback.FN)))
  146. else:
  147. print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
  148. callback.acc_num / callback.total_num))
  149. print("==============================================================")
  150. parser = argparse.ArgumentParser(description='Bert eval')
  151. parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
  152. args_opt = parser.parse_args()
  153. if __name__ == "__main__":
  154. num_labels = cfg.num_labels
  155. test_eval()