|
|
|
@@ -18,9 +18,7 @@ Area under cure metric |
|
|
|
""" |
|
|
|
|
|
|
|
from sklearn.metrics import roc_auc_score |
|
|
|
from mindspore import context |
|
|
|
from mindspore.nn.metrics import Metric |
|
|
|
from mindspore.communication.management import get_rank, get_group_size |
|
|
|
|
|
|
|
class AUCMetric(Metric): |
|
|
|
""" |
|
|
|
@@ -30,7 +28,6 @@ class AUCMetric(Metric): |
|
|
|
def __init__(self): |
|
|
|
super(AUCMetric, self).__init__() |
|
|
|
self.clear() |
|
|
|
self.full_batch = context.get_auto_parallel_context("full_batch") |
|
|
|
|
|
|
|
def clear(self): |
|
|
|
"""Clear the internal evaluation result.""" |
|
|
|
@@ -42,13 +39,7 @@ class AUCMetric(Metric): |
|
|
|
all_predict = inputs[1].asnumpy().flatten().tolist() # predict |
|
|
|
all_label = inputs[2].asnumpy().flatten().tolist() # label |
|
|
|
self.pred_probs.extend(all_predict) |
|
|
|
if self.full_batch: |
|
|
|
rank_id = get_rank() |
|
|
|
group_size = get_group_size() |
|
|
|
gap = len(all_label) // group_size |
|
|
|
self.true_labels.extend(all_label[rank_id*gap: (rank_id+1)*gap]) |
|
|
|
else: |
|
|
|
self.true_labels.extend(all_label) |
|
|
|
self.true_labels.extend(all_label) |
|
|
|
|
|
|
|
def eval(self): |
|
|
|
if len(self.true_labels) != len(self.pred_probs): |
|
|
|
|