| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ExplainLoader.""" | """ExplainLoader.""" | ||||
| import math | |||||
| import os | import os | ||||
| import re | import re | ||||
| from collections import defaultdict | from collections import defaultdict | ||||
| @@ -27,6 +28,7 @@ from mindinsight.datavisual.data_access.file_handler import FileHandler | |||||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | ||||
| from mindinsight.utils.exceptions import ParamValueError, UnknownError | from mindinsight.utils.exceptions import ParamValueError, UnknownError | ||||
| _NAN_CONSTANT = 'NaN' | |||||
| _NUM_DIGITS = 6 | _NUM_DIGITS = 6 | ||||
| _EXPLAIN_FIELD_NAMES = [ | _EXPLAIN_FIELD_NAMES = [ | ||||
| @@ -44,7 +46,10 @@ _SAMPLE_FIELD_NAMES = [ | |||||
| def _round(score): | def _round(score): | ||||
| """Take round of a number to given precision.""" | """Take round of a number to given precision.""" | ||||
| return round(score, _NUM_DIGITS) | |||||
| try: | |||||
| return round(score, _NUM_DIGITS) | |||||
| except TypeError: | |||||
| return score | |||||
| class ExplainLoader: | class ExplainLoader: | ||||
| @@ -405,7 +410,7 @@ class ExplainLoader: | |||||
| metric_score = benchmark.total_score | metric_score = benchmark.total_score | ||||
| label_score_event = benchmark.label_score | label_score_event = benchmark.label_score | ||||
| explainer_score[explainer][metric] = metric_score | |||||
| explainer_score[explainer][metric] = _NAN_CONSTANT if math.isnan(metric_score) else metric_score | |||||
| new_label_score_dict = ExplainLoader._score_event_to_dict(label_score_event, metric) | new_label_score_dict = ExplainLoader._score_event_to_dict(label_score_event, metric) | ||||
| for label, scores_of_metric in new_label_score_dict.items(): | for label, scores_of_metric in new_label_score_dict.items(): | ||||
| if label not in label_score[explainer]: | if label not in label_score[explainer]: | ||||
| @@ -551,5 +556,5 @@ class ExplainLoader: | |||||
| """Transfer metric scores per label to pre-defined structure.""" | """Transfer metric scores per label to pre-defined structure.""" | ||||
| new_label_score_dict = defaultdict(dict) | new_label_score_dict = defaultdict(dict) | ||||
| for label_id, label_score in enumerate(label_score_event): | for label_id, label_score in enumerate(label_score_event): | ||||
| new_label_score_dict[label_id][metric] = label_score | |||||
| new_label_score_dict[label_id][metric] = _NAN_CONSTANT if math.isnan(label_score) else label_score | |||||
| return new_label_score_dict | return new_label_score_dict | ||||