You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

postprocess.py 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # less required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Evaluation for NAML"""
  16. import os
  17. import argparse
  18. import numpy as np
  19. from sklearn.metrics import roc_auc_score
  20. parser = argparse.ArgumentParser(description="")
  21. parser.add_argument("--result_path", type=str, default="", help="Device id")
  22. parser.add_argument("--label_path", type=str, default="", help="output file name.")
  23. args = parser.parse_args()
  24. def AUC(y_true, y_pred):
  25. return roc_auc_score(y_true, y_pred)
  26. def MRR(y_true, y_pred):
  27. index = np.argsort(y_pred)[::-1]
  28. y_true = np.take(y_true, index)
  29. score = y_true / (np.arange(len(y_true)) + 1)
  30. return np.sum(score) / np.sum(y_true)
  31. def DCG(y_true, y_pred, n):
  32. index = np.argsort(y_pred)[::-1]
  33. y_true = np.take(y_true, index[:n])
  34. score = (2 ** y_true - 1) / np.log2(np.arange(len(y_true)) + 2)
  35. return np.sum(score)
  36. def nDCG(y_true, y_pred, n):
  37. return DCG(y_true, y_pred, n) / DCG(y_true, y_true, n)
  38. class NAMLMetric:
  39. """
  40. Metric method
  41. """
  42. def __init__(self):
  43. super(NAMLMetric, self).__init__()
  44. self.AUC_list = []
  45. self.MRR_list = []
  46. self.nDCG5_list = []
  47. self.nDCG10_list = []
  48. def clear(self):
  49. """Clear the internal evaluation result."""
  50. self.AUC_list = []
  51. self.MRR_list = []
  52. self.nDCG5_list = []
  53. self.nDCG10_list = []
  54. def update(self, predict, y_true):
  55. predict = predict.flatten()
  56. y_true = y_true.flatten()
  57. # predict = np.interp(predict, (predict.min(), predict.max()), (0, 1))
  58. self.AUC_list.append(AUC(y_true, predict))
  59. self.MRR_list.append(MRR(y_true, predict))
  60. self.nDCG5_list.append(nDCG(y_true, predict, 5))
  61. self.nDCG10_list.append(nDCG(y_true, predict, 10))
  62. def eval(self):
  63. auc = np.mean(self.AUC_list)
  64. print('AUC:', auc)
  65. print('MRR:', np.mean(self.MRR_list))
  66. print('nDCG@5:', np.mean(self.nDCG5_list))
  67. print('nDCG@10:', np.mean(self.nDCG10_list))
  68. return auc
  69. def get_metric(result_path, label_path, metric):
  70. """get accuracy"""
  71. result_files = os.listdir(result_path)
  72. for file in result_files:
  73. result_file = os.path.join(result_path, file)
  74. pred = np.fromfile(result_file, dtype=np.float32)
  75. label_file = os.path.join(label_path, file)
  76. label = np.fromfile(label_file, dtype=np.int32)
  77. if np.nan in pred:
  78. continue
  79. metric.update(pred, label)
  80. auc = metric.eval()
  81. return auc
  82. if __name__ == "__main__":
  83. naml_metric = NAMLMetric()
  84. get_metric(args.result_path, args.label_path, naml_metric)