You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluation.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Bert evaluation script.
  17. """
  18. import os
  19. import argparse
  20. import math
  21. import numpy as np
  22. import mindspore.common.dtype as mstype
  23. from mindspore import context
  24. from mindspore import log as logger
  25. from mindspore.common.tensor import Tensor
  26. import mindspore.dataset as de
  27. import mindspore.dataset.transforms.c_transforms as C
  28. from mindspore.train.model import Model
  29. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  30. from src.evaluation_config import cfg, bert_net_cfg
  31. from src.utils import BertNER, BertCLS, BertReg
  32. from src.CRF import postprocess
  33. from src.cluener_evaluation import submit
  34. from src.finetune_config import tag_to_index
  35. class Accuracy():
  36. """
  37. calculate accuracy
  38. """
  39. def __init__(self):
  40. self.acc_num = 0
  41. self.total_num = 0
  42. def update(self, logits, labels):
  43. """
  44. Update accuracy
  45. """
  46. labels = labels.asnumpy()
  47. labels = np.reshape(labels, -1)
  48. logits = logits.asnumpy()
  49. logit_id = np.argmax(logits, axis=-1)
  50. self.acc_num += np.sum(labels == logit_id)
  51. self.total_num += len(labels)
  52. print("=========================accuracy is ", self.acc_num / self.total_num)
  53. class F1():
  54. """
  55. calculate F1 score
  56. """
  57. def __init__(self):
  58. self.TP = 0
  59. self.FP = 0
  60. self.FN = 0
  61. def update(self, logits, labels):
  62. """
  63. update F1 score
  64. """
  65. labels = labels.asnumpy()
  66. labels = np.reshape(labels, -1)
  67. if cfg.use_crf:
  68. backpointers, best_tag_id = logits
  69. best_path = postprocess(backpointers, best_tag_id)
  70. logit_id = []
  71. for ele in best_path:
  72. logit_id.extend(ele)
  73. else:
  74. logits = logits.asnumpy()
  75. logit_id = np.argmax(logits, axis=-1)
  76. logit_id = np.reshape(logit_id, -1)
  77. pos_eva = np.isin(logit_id, [i for i in range(1, cfg.num_labels)])
  78. pos_label = np.isin(labels, [i for i in range(1, cfg.num_labels)])
  79. self.TP += np.sum(pos_eva&pos_label)
  80. self.FP += np.sum(pos_eva&(~pos_label))
  81. self.FN += np.sum((~pos_eva)&pos_label)
  82. class MCC():
  83. """
  84. Calculate Matthews Correlation Coefficient.
  85. """
  86. def __init__(self):
  87. self.TP = 0
  88. self.FP = 0
  89. self.FN = 0
  90. self.TN = 0
  91. def update(self, logits, labels):
  92. """
  93. Update MCC score
  94. """
  95. labels = labels.asnumpy()
  96. labels = np.reshape(labels, -1)
  97. labels = labels.astype(np.bool)
  98. logits = logits.asnumpy()
  99. logit_id = np.argmax(logits, axis=-1)
  100. logit_id = np.reshape(logit_id, -1)
  101. logit_id = logit_id.astype(np.bool)
  102. ornot = logit_id ^ labels
  103. self.TP += (~ornot & labels).sum()
  104. self.FP += (ornot & ~labels).sum()
  105. self.FN += (ornot & labels).sum()
  106. self.TN += (~ornot & ~labels).sum()
  107. class Spearman_Correlation():
  108. """
  109. calculate Spearman Correlation coefficient
  110. """
  111. def __init__(self):
  112. self.label = []
  113. self.logit = []
  114. def update(self, logits, labels):
  115. """
  116. Update Spearman Correlation
  117. """
  118. labels = labels.asnumpy()
  119. labels = np.reshape(labels, -1)
  120. logits = logits.asnumpy()
  121. logits = np.reshape(logits, -1)
  122. self.label.append(labels)
  123. self.logit.append(logits)
  124. def cal(self):
  125. """
  126. Calculate Spearman Correlation
  127. """
  128. label = np.concatenate(self.label)
  129. logit = np.concatenate(self.logit)
  130. sort_label = label.argsort()[::-1]
  131. sort_logit = logit.argsort()[::-1]
  132. n = len(label)
  133. d_acc = 0
  134. for i in range(n):
  135. d = np.where(sort_label == i)[0] - np.where(sort_logit == i)[0]
  136. d_acc += d**2
  137. ps = 1 - 6*d_acc/n/(n**2-1)
  138. return ps
  139. def get_dataset(batch_size=1, repeat_count=1, distribute_file=''):
  140. """
  141. get dataset
  142. """
  143. _ = distribute_file
  144. ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask",
  145. "segment_ids", "label_ids"])
  146. type_cast_op = C.TypeCast(mstype.int32)
  147. ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
  148. ds = ds.map(input_columns="input_mask", operations=type_cast_op)
  149. ds = ds.map(input_columns="input_ids", operations=type_cast_op)
  150. if cfg.task == "Regression":
  151. type_cast_op_float = C.TypeCast(mstype.float32)
  152. ds = ds.map(input_columns="label_ids", operations=type_cast_op_float)
  153. else:
  154. ds = ds.map(input_columns="label_ids", operations=type_cast_op)
  155. ds = ds.repeat(repeat_count)
  156. # apply shuffle operation
  157. buffer_size = 960
  158. ds = ds.shuffle(buffer_size=buffer_size)
  159. # apply batch operations
  160. ds = ds.batch(batch_size, drop_remainder=True)
  161. return ds
  162. def bert_predict(Evaluation):
  163. """
  164. prediction function
  165. """
  166. target = args_opt.device_target
  167. if target == "Ascend":
  168. devid = int(os.getenv('DEVICE_ID'))
  169. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
  170. elif target == "GPU":
  171. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  172. if bert_net_cfg.compute_type != mstype.float32:
  173. logger.warning('GPU only support fp32 temporarily, run with fp32.')
  174. bert_net_cfg.compute_type = mstype.float32
  175. else:
  176. raise Exception("Target error, GPU or Ascend is supported.")
  177. dataset = get_dataset(bert_net_cfg.batch_size, 1)
  178. if cfg.use_crf:
  179. net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True,
  180. tag_to_index=tag_to_index, dropout_prob=0.0)
  181. else:
  182. net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels)
  183. net_for_pretraining.set_train(False)
  184. param_dict = load_checkpoint(cfg.finetune_ckpt)
  185. load_param_into_net(net_for_pretraining, param_dict)
  186. model = Model(net_for_pretraining)
  187. return model, dataset
  188. def test_eval():
  189. """
  190. evaluation function
  191. """
  192. if cfg.task == "SeqLabeling":
  193. task_type = BertNER
  194. elif cfg.task == "Regression":
  195. task_type = BertReg
  196. elif cfg.task == "Classification":
  197. task_type = BertCLS
  198. elif cfg.task == "COLA":
  199. task_type = BertCLS
  200. else:
  201. raise ValueError("Task not supported.")
  202. model, dataset = bert_predict(task_type)
  203. if cfg.clue_benchmark:
  204. submit(model, cfg.data_file, bert_net_cfg.seq_length)
  205. else:
  206. if cfg.task == "SeqLabeling":
  207. callback = F1()
  208. elif cfg.task == "COLA":
  209. callback = MCC()
  210. elif cfg.task == "Regression":
  211. callback = Spearman_Correlation()
  212. else:
  213. callback = Accuracy()
  214. columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
  215. for data in dataset.create_dict_iterator():
  216. input_data = []
  217. for i in columns_list:
  218. input_data.append(Tensor(data[i]))
  219. input_ids, input_mask, token_type_id, label_ids = input_data
  220. logits = model.predict(input_ids, input_mask, token_type_id, label_ids)
  221. callback.update(logits, label_ids)
  222. print("==============================================================")
  223. if cfg.task == "SeqLabeling":
  224. print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP)))
  225. print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN)))
  226. print("F1 {:.6f} ".format(2*callback.TP / (2*callback.TP + callback.FP + callback.FN)))
  227. elif cfg.task == "COLA":
  228. TP = callback.TP
  229. TN = callback.TN
  230. FP = callback.FP
  231. FN = callback.FN
  232. mcc = (TP*TN-FP*FN)/math.sqrt((TP+FP)*(TP+FN)*(TN+FP)*(TN+FN))
  233. print("MCC: {:.6f}".format(mcc))
  234. elif cfg.task == "Regression":
  235. print("Spearman Correlation is {:.6f}".format(callback.cal()[0]))
  236. else:
  237. print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num,
  238. callback.acc_num / callback.total_num))
  239. print("==============================================================")
  240. parser = argparse.ArgumentParser(description='Bert eval')
  241. parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
  242. args_opt = parser.parse_args()
  243. if __name__ == "__main__":
  244. num_labels = cfg.num_labels
  245. test_eval()