diff --git a/mindinsight/explainer/manager/explain_loader.py b/mindinsight/explainer/manager/explain_loader.py index 1c70c368..61ad80f3 100644 --- a/mindinsight/explainer/manager/explain_loader.py +++ b/mindinsight/explainer/manager/explain_loader.py @@ -14,6 +14,7 @@ # ============================================================================ """ExplainLoader.""" +import math import os import re from collections import defaultdict @@ -27,6 +28,7 @@ from mindinsight.datavisual.data_access.file_handler import FileHandler from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.utils.exceptions import ParamValueError, UnknownError +_NAN_CONSTANT = 'NaN' _NUM_DIGITS = 6 _EXPLAIN_FIELD_NAMES = [ @@ -44,7 +46,10 @@ _SAMPLE_FIELD_NAMES = [ def _round(score): """Take round of a number to given precision.""" - return round(score, _NUM_DIGITS) + try: + return round(score, _NUM_DIGITS) + except TypeError: + return score class ExplainLoader: @@ -405,7 +410,7 @@ class ExplainLoader: metric_score = benchmark.total_score label_score_event = benchmark.label_score - explainer_score[explainer][metric] = metric_score + explainer_score[explainer][metric] = _NAN_CONSTANT if math.isnan(metric_score) else metric_score new_label_score_dict = ExplainLoader._score_event_to_dict(label_score_event, metric) for label, scores_of_metric in new_label_score_dict.items(): if label not in label_score[explainer]: @@ -551,5 +556,5 @@ class ExplainLoader: """Transfer metric scores per label to pre-defined structure.""" new_label_score_dict = defaultdict(dict) for label_id, label_score in enumerate(label_score_event): - new_label_score_dict[label_id][metric] = label_score + new_label_score_dict[label_id][metric] = _NAN_CONSTANT if math.isnan(label_score) else label_score return new_label_score_dict