|
|
|
@@ -301,9 +301,10 @@ class ExplainLoader: |
|
|
|
'id': sample_id, |
|
|
|
'name': str(sample_id), |
|
|
|
'image': sample_info['image'], |
|
|
|
'labels': sample_info['ground_truth_label'], |
|
|
|
'labels': [self._metadata['labels'][i] for i in sample_info['ground_truth_label']], |
|
|
|
} |
|
|
|
|
|
|
|
# Check whether the sample has valid label-prob pairs. |
|
|
|
if not ExplainLoader._is_inference_valid(sample_info): |
|
|
|
continue |
|
|
|
|
|
|
|
@@ -348,7 +349,7 @@ class ExplainLoader: |
|
|
|
elif tag == ExplainFieldsEnum.SAMPLE_ID.value: |
|
|
|
self._import_sample_from_event(event) |
|
|
|
else: |
|
|
|
logger.info('Unknown ExplainField: %s', tag) |
|
|
|
logger.info('Unknown ExplainField: %s.', tag) |
|
|
|
|
|
|
|
def _is_metadata_empty(self): |
|
|
|
"""Check whether metadata is completely loaded first.""" |
|
|
|
@@ -455,13 +456,11 @@ class ExplainLoader: |
|
|
|
|
|
|
|
for tag in _SAMPLE_FIELD_NAMES: |
|
|
|
try: |
|
|
|
if ExplainLoader._is_attr_empty(sample, tag.value): |
|
|
|
continue |
|
|
|
if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL: |
|
|
|
self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label)) |
|
|
|
elif tag == ExplainFieldsEnum.INFERENCE: |
|
|
|
self._import_inference_from_event(sample, sample_id) |
|
|
|
elif tag == ExplainFieldsEnum.EXPLANATION: |
|
|
|
else: |
|
|
|
self._import_explanation_from_event(sample, sample_id) |
|
|
|
except UnknownError as ex: |
|
|
|
logger.warning("Parse %s data failed within image related data, detail: %r", tag, str(ex)) |
|
|
|
@@ -518,59 +517,35 @@ class ExplainLoader: |
|
|
|
Returns: |
|
|
|
list[str], filename list. |
|
|
|
""" |
|
|
|
return list(filter(lambda filename: (re.search(r'summary\.\d+', filename) and filename.endswith("_explain")), |
|
|
|
filenames)) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _is_attr_empty(event, attr_name) -> bool: |
|
|
|
if not getattr(event, attr_name): |
|
|
|
return True |
|
|
|
for item in getattr(event, attr_name): |
|
|
|
if not isinstance(item, list) or item: |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _is_ground_truth_label_valid(sample_id: str, sample_info: Dict) -> bool: |
|
|
|
if len(sample_info['ground_truth_label']) != len(sample_info['ground_truth_prob']): |
|
|
|
logger.info('length of ground_truth_prob does not match the length of ground_truth_label' |
|
|
|
'length of ground_turth_label is: %s but length of ground_truth_prob is: %s.' |
|
|
|
'sample_id is : %s.', |
|
|
|
len(sample_info['ground_truth_label']), len(sample_info['ground_truth_prob']), sample_id) |
|
|
|
return False |
|
|
|
return True |
|
|
|
return list(filter( |
|
|
|
lambda filename: (re.search(r'summary\.\d+', filename) and filename.endswith("_explain")), filenames)) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _is_inference_valid(sample): |
|
|
|
""" |
|
|
|
Check whether the inference data is empty or have the same length. |
|
|
|
|
|
|
|
If the probs have different length with the labels, it can be confusing when assigning each prob to label. |
|
|
|
'is_inference_valid' return True only when the data size of match to each other. Note that prob data could be |
|
|
|
If probs have different length with the labels, it can be confusing when assigning each prob to label. |
|
|
|
'_is_inference_valid' returns True only when the data size of match to each other. Note that prob data could be |
|
|
|
empty, so empty prob will pass the check. |
|
|
|
""" |
|
|
|
ground_truth_len = len(sample['ground_truth_label']) |
|
|
|
for name in ['ground_truth_prob', 'ground_truth_prob_sd', |
|
|
|
'ground_truth_prob_itl95_low', 'ground_truth_prob_itl95_hi']: |
|
|
|
if sample[name] and len(sample[name]) != ground_truth_len: |
|
|
|
logger.info('Length of %s not match the ground_truth_label. Length of ground_truth_label: %d,' |
|
|
|
'length of %s: %d', name, ground_truth_len, name, len(sample[name])) |
|
|
|
return False |
|
|
|
|
|
|
|
predicted_len = len(sample['predicted_label']) |
|
|
|
for name in ['predicted_prob', 'predicted_prob_sd', |
|
|
|
'predicted_prob_itl95_low', 'predicted_prob_itl95_hi']: |
|
|
|
if sample[name] and len(sample[name]) != predicted_len: |
|
|
|
logger.info('Length of %s not match the predicted_labels. Length of predicted_label: %d,' |
|
|
|
'length of %s: %d', name, predicted_len, name, len(sample[name])) |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _is_predicted_label_valid(sample_id: str, sample_info: Dict) -> bool: |
|
|
|
if len(sample_info['predicted_label']) != len(sample_info['predicted_prob']): |
|
|
|
logger.info('length of predicted_probs does not match the length of predicted_labels' |
|
|
|
'length of predicted_probs: %s but receive length of predicted_label: %s, sample_id: %s.', |
|
|
|
len(sample_info['predicted_prob']), len(sample_info['predicted_label']), sample_id) |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _score_event_to_dict(label_score_event, metric) -> Dict: |
|
|
|
"""Transfer metric scores per label to pre-defined structure.""" |
|
|
|
|