|
|
|
@@ -14,6 +14,7 @@ |
|
|
|
# ============================================================================ |
|
|
|
"""ExplainLoader.""" |
|
|
|
|
|
|
|
import math |
|
|
|
import os |
|
|
|
import re |
|
|
|
from collections import defaultdict |
|
|
|
@@ -27,6 +28,7 @@ from mindinsight.datavisual.data_access.file_handler import FileHandler |
|
|
|
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError |
|
|
|
from mindinsight.utils.exceptions import ParamValueError, UnknownError |
|
|
|
|
|
|
|
_NAN_CONSTANT = 'NaN' |
|
|
|
_NUM_DIGITS = 6 |
|
|
|
|
|
|
|
_EXPLAIN_FIELD_NAMES = [ |
|
|
|
@@ -44,7 +46,10 @@ _SAMPLE_FIELD_NAMES = [ |
|
|
|
|
|
|
|
def _round(score): |
|
|
|
"""Take round of a number to given precision.""" |
|
|
|
return round(score, _NUM_DIGITS) |
|
|
|
try: |
|
|
|
return round(score, _NUM_DIGITS) |
|
|
|
except TypeError: |
|
|
|
return score |
|
|
|
|
|
|
|
|
|
|
|
class ExplainLoader: |
|
|
|
@@ -405,7 +410,7 @@ class ExplainLoader: |
|
|
|
metric_score = benchmark.total_score |
|
|
|
label_score_event = benchmark.label_score |
|
|
|
|
|
|
|
explainer_score[explainer][metric] = metric_score |
|
|
|
explainer_score[explainer][metric] = _NAN_CONSTANT if math.isnan(metric_score) else metric_score |
|
|
|
new_label_score_dict = ExplainLoader._score_event_to_dict(label_score_event, metric) |
|
|
|
for label, scores_of_metric in new_label_score_dict.items(): |
|
|
|
if label not in label_score[explainer]: |
|
|
|
@@ -551,5 +556,5 @@ class ExplainLoader: |
|
|
|
"""Transfer metric scores per label to pre-defined structure.""" |
|
|
|
new_label_score_dict = defaultdict(dict) |
|
|
|
for label_id, label_score in enumerate(label_score_event): |
|
|
|
new_label_score_dict[label_id][metric] = label_score |
|
|
|
new_label_score_dict[label_id][metric] = _NAN_CONSTANT if math.isnan(label_score) else label_score |
|
|
|
return new_label_score_dict |