Browse Source

Fix label saving format in ExplainLoader and remove useless methods

tags/v1.1.0
YuhanShi53 5 years ago
parent
commit
5535995070
2 changed files with 12 additions and 38 deletions
  1. +12
    -37
      mindinsight/explainer/manager/explain_loader.py
  2. +0
    -1
      mindinsight/explainer/manager/explain_manager.py

+ 12
- 37
mindinsight/explainer/manager/explain_loader.py View File

@@ -301,9 +301,10 @@ class ExplainLoader:
'id': sample_id,
'name': str(sample_id),
'image': sample_info['image'],
'labels': sample_info['ground_truth_label'],
'labels': [self._metadata['labels'][i] for i in sample_info['ground_truth_label']],
}

# Check whether the sample has valid label-prob pairs.
if not ExplainLoader._is_inference_valid(sample_info):
continue

@@ -348,7 +349,7 @@ class ExplainLoader:
elif tag == ExplainFieldsEnum.SAMPLE_ID.value:
self._import_sample_from_event(event)
else:
logger.info('Unknown ExplainField: %s', tag)
logger.info('Unknown ExplainField: %s.', tag)

def _is_metadata_empty(self):
"""Check whether metadata is completely loaded first."""
@@ -455,13 +456,11 @@ class ExplainLoader:

for tag in _SAMPLE_FIELD_NAMES:
try:
if ExplainLoader._is_attr_empty(sample, tag.value):
continue
if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL:
self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label))
elif tag == ExplainFieldsEnum.INFERENCE:
self._import_inference_from_event(sample, sample_id)
elif tag == ExplainFieldsEnum.EXPLANATION:
else:
self._import_explanation_from_event(sample, sample_id)
except UnknownError as ex:
logger.warning("Parse %s data failed within image related data, detail: %r", tag, str(ex))
@@ -518,59 +517,35 @@ class ExplainLoader:
Returns:
list[str], filename list.
"""
return list(filter(lambda filename: (re.search(r'summary\.\d+', filename) and filename.endswith("_explain")),
filenames))

@staticmethod
def _is_attr_empty(event, attr_name) -> bool:
if not getattr(event, attr_name):
return True
for item in getattr(event, attr_name):
if not isinstance(item, list) or item:
return False
return True

@staticmethod
def _is_ground_truth_label_valid(sample_id: str, sample_info: Dict) -> bool:
if len(sample_info['ground_truth_label']) != len(sample_info['ground_truth_prob']):
logger.info('length of ground_truth_prob does not match the length of ground_truth_label'
'length of ground_turth_label is: %s but length of ground_truth_prob is: %s.'
'sample_id is : %s.',
len(sample_info['ground_truth_label']), len(sample_info['ground_truth_prob']), sample_id)
return False
return True
return list(filter(
lambda filename: (re.search(r'summary\.\d+', filename) and filename.endswith("_explain")), filenames))

@staticmethod
def _is_inference_valid(sample):
"""
Check whether the inference data is empty or have the same length.

If the probs have different length with the labels, it can be confusing when assigning each prob to label.
'is_inference_valid' return True only when the data size of match to each other. Note that prob data could be
If probs have different length with the labels, it can be confusing when assigning each prob to label.
'_is_inference_valid' returns True only when the data size of match to each other. Note that prob data could be
empty, so empty prob will pass the check.
"""
ground_truth_len = len(sample['ground_truth_label'])
for name in ['ground_truth_prob', 'ground_truth_prob_sd',
'ground_truth_prob_itl95_low', 'ground_truth_prob_itl95_hi']:
if sample[name] and len(sample[name]) != ground_truth_len:
logger.info('Length of %s not match the ground_truth_label. Length of ground_truth_label: %d,'
'length of %s: %d', name, ground_truth_len, name, len(sample[name]))
return False

predicted_len = len(sample['predicted_label'])
for name in ['predicted_prob', 'predicted_prob_sd',
'predicted_prob_itl95_low', 'predicted_prob_itl95_hi']:
if sample[name] and len(sample[name]) != predicted_len:
logger.info('Length of %s not match the predicted_labels. Length of predicted_label: %d,'
'length of %s: %d', name, predicted_len, name, len(sample[name]))
return False
return True

@staticmethod
def _is_predicted_label_valid(sample_id: str, sample_info: Dict) -> bool:
if len(sample_info['predicted_label']) != len(sample_info['predicted_prob']):
logger.info('length of predicted_probs does not match the length of predicted_labels'
'length of predicted_probs: %s but receive length of predicted_label: %s, sample_id: %s.',
len(sample_info['predicted_prob']), len(sample_info['predicted_label']), sample_id)
return False
return True

@staticmethod
def _score_event_to_dict(label_score_event, metric) -> Dict:
"""Transfer metric scores per label to pre-defined structure."""


+ 0
- 1
mindinsight/explainer/manager/explain_manager.py View File

@@ -133,7 +133,6 @@ class ExplainManager:
return
time.sleep(repeat_interval)
except UnknownError as ex:
logger.exception(ex)
logger.error('Unexpected error happens when loading data. Loading status: %s, loading pool size: %d'
'Detail: %s', self._loading_status, len(self._loader_pool), str(ex))



Loading…
Cancel
Save