Browse Source

!8740 Fix full batch error in field slice mode

From: @huangxinjing
Reviewed-by: @yao_yf,@stsuteng,@kisnwang
Signed-off-by: @stsuteng
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ae62d15bc0
1 changed files with 1 additions and 10 deletions
  1. +1
    -10
      model_zoo/official/recommend/wide_and_deep/src/metrics.py

+ 1
- 10
model_zoo/official/recommend/wide_and_deep/src/metrics.py View File

@@ -18,9 +18,7 @@ Area under cure metric
"""

from sklearn.metrics import roc_auc_score
from mindspore import context
from mindspore.nn.metrics import Metric
from mindspore.communication.management import get_rank, get_group_size

class AUCMetric(Metric):
"""
@@ -30,7 +28,6 @@ class AUCMetric(Metric):
def __init__(self):
super(AUCMetric, self).__init__()
self.clear()
self.full_batch = context.get_auto_parallel_context("full_batch")

def clear(self):
"""Clear the internal evaluation result."""
@@ -42,13 +39,7 @@ class AUCMetric(Metric):
all_predict = inputs[1].asnumpy().flatten().tolist() # predict
all_label = inputs[2].asnumpy().flatten().tolist() # label
self.pred_probs.extend(all_predict)
if self.full_batch:
rank_id = get_rank()
group_size = get_group_size()
gap = len(all_label) // group_size
self.true_labels.extend(all_label[rank_id*gap: (rank_id+1)*gap])
else:
self.true_labels.extend(all_label)
self.true_labels.extend(all_label)

def eval(self):
if len(self.true_labels) != len(self.pred_probs):


Loading…
Cancel
Save