Browse Source

change the NaN value to 'NaN' string for json format

tags/v1.1.0
YuhanShi53 5 years ago
parent
commit
ca3cb83b68
1 changed files with 8 additions and 3 deletions
  1. +8
    -3
      mindinsight/explainer/manager/explain_loader.py

+ 8
- 3
mindinsight/explainer/manager/explain_loader.py View File

@@ -14,6 +14,7 @@
# ============================================================================
"""ExplainLoader."""

import math
import os
import re
from collections import defaultdict
@@ -27,6 +28,7 @@ from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.utils.exceptions import ParamValueError, UnknownError

_NAN_CONSTANT = 'NaN'
_NUM_DIGITS = 6

_EXPLAIN_FIELD_NAMES = [
@@ -44,7 +46,10 @@ _SAMPLE_FIELD_NAMES = [

def _round(score):
"""Take round of a number to given precision."""
return round(score, _NUM_DIGITS)
try:
return round(score, _NUM_DIGITS)
except TypeError:
return score


class ExplainLoader:
@@ -405,7 +410,7 @@ class ExplainLoader:
metric_score = benchmark.total_score
label_score_event = benchmark.label_score

explainer_score[explainer][metric] = metric_score
explainer_score[explainer][metric] = _NAN_CONSTANT if math.isnan(metric_score) else metric_score
new_label_score_dict = ExplainLoader._score_event_to_dict(label_score_event, metric)
for label, scores_of_metric in new_label_score_dict.items():
if label not in label_score[explainer]:
@@ -551,5 +556,5 @@ class ExplainLoader:
"""Transfer metric scores per label to pre-defined structure."""
new_label_score_dict = defaultdict(dict)
for label_id, label_score in enumerate(label_score_event):
new_label_score_dict[label_id][metric] = label_score
new_label_score_dict[label_id][metric] = _NAN_CONSTANT if math.isnan(label_score) else label_score
return new_label_score_dict

Loading…
Cancel
Save