Browse Source

change the NaN value to 'NaN' string for json format

tags/v1.1.0
YuhanShi53 5 years ago
parent
commit
ca3cb83b68
1 changed files with 8 additions and 3 deletions
  1. +8
    -3
      mindinsight/explainer/manager/explain_loader.py

+ 8
- 3
mindinsight/explainer/manager/explain_loader.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""ExplainLoader.""" """ExplainLoader."""


import math
import os import os
import re import re
from collections import defaultdict from collections import defaultdict
@@ -27,6 +28,7 @@ from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.utils.exceptions import ParamValueError, UnknownError from mindinsight.utils.exceptions import ParamValueError, UnknownError


_NAN_CONSTANT = 'NaN'
_NUM_DIGITS = 6 _NUM_DIGITS = 6


_EXPLAIN_FIELD_NAMES = [ _EXPLAIN_FIELD_NAMES = [
@@ -44,7 +46,10 @@ _SAMPLE_FIELD_NAMES = [


def _round(score): def _round(score):
"""Take round of a number to given precision.""" """Take round of a number to given precision."""
return round(score, _NUM_DIGITS)
try:
return round(score, _NUM_DIGITS)
except TypeError:
return score




class ExplainLoader: class ExplainLoader:
@@ -405,7 +410,7 @@ class ExplainLoader:
metric_score = benchmark.total_score metric_score = benchmark.total_score
label_score_event = benchmark.label_score label_score_event = benchmark.label_score


explainer_score[explainer][metric] = metric_score
explainer_score[explainer][metric] = _NAN_CONSTANT if math.isnan(metric_score) else metric_score
new_label_score_dict = ExplainLoader._score_event_to_dict(label_score_event, metric) new_label_score_dict = ExplainLoader._score_event_to_dict(label_score_event, metric)
for label, scores_of_metric in new_label_score_dict.items(): for label, scores_of_metric in new_label_score_dict.items():
if label not in label_score[explainer]: if label not in label_score[explainer]:
@@ -551,5 +556,5 @@ class ExplainLoader:
"""Transfer metric scores per label to pre-defined structure.""" """Transfer metric scores per label to pre-defined structure."""
new_label_score_dict = defaultdict(dict) new_label_score_dict = defaultdict(dict)
for label_id, label_score in enumerate(label_score_event): for label_id, label_score in enumerate(label_score_event):
new_label_score_dict[label_id][metric] = label_score
new_label_score_dict[label_id][metric] = _NAN_CONSTANT if math.isnan(label_score) else label_score
return new_label_score_dict return new_label_score_dict

Loading…
Cancel
Save