You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

metrics.py 1.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Area under cure metric
  17. """
  18. from mindspore.nn.metrics import Metric
  19. from sklearn.metrics import roc_auc_score
  20. class AUCMetric(Metric):
  21. """
  22. Area under cure metric
  23. """
  24. def __init__(self):
  25. super(AUCMetric, self).__init__()
  26. self.clear()
  27. def clear(self):
  28. """Clear the internal evaluation result."""
  29. self.true_labels = []
  30. self.pred_probs = []
  31. def update(self, *inputs): # inputs
  32. all_predict = inputs[1].asnumpy() # predict
  33. all_label = inputs[2].asnumpy() # label
  34. self.true_labels.extend(all_label.flatten().tolist())
  35. self.pred_probs.extend(all_predict.flatten().tolist())
  36. def eval(self):
  37. if len(self.true_labels) != len(self.pred_probs):
  38. raise RuntimeError(
  39. 'true_labels.size is not equal to pred_probs.size()')
  40. auc = roc_auc_score(self.true_labels, self.pred_probs)
  41. print("====" * 20 + " auc_metric end")
  42. print("====" * 20 + " auc: {}".format(auc))
  43. return auc