You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

event_parse.py 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """EventParser for summary event."""
  16. from collections import namedtuple
  17. from typing import Dict, Iterable, List, Optional, Tuple
  18. from mindinsight.explainer.common.enums import PluginNameEnum
  19. from mindinsight.explainer.common.log import logger
  20. from mindinsight.utils.exceptions import UnknownError
  21. _IMAGE_DATA_TAGS = {
  22. 'image_data': PluginNameEnum.IMAGE_DATA.value,
  23. 'ground_truth_label': PluginNameEnum.GROUND_TRUTH_LABEL.value,
  24. 'inference': PluginNameEnum.INFERENCE.value,
  25. 'explanation': PluginNameEnum.EXPLANATION.value
  26. }
  27. class EventParser:
  28. """Parser for event data."""
  29. def __init__(self, job):
  30. self._job = job
  31. self._sample_pool = {}
  32. def clear(self):
  33. """Clear the loaded data."""
  34. self._sample_pool.clear()
  35. def parse_metadata(self, metadata) -> Tuple[List, List, List]:
  36. """Parse the metadata event."""
  37. explainers = list(metadata.explain_method)
  38. metrics = list(metadata.benchmark_method)
  39. labels = list(metadata.label)
  40. return explainers, metrics, labels
  41. def parse_benchmark(self, benchmark) -> Dict:
  42. """Parse the benchmark event."""
  43. imported_benchmark = {}
  44. for explainer_result in benchmark:
  45. explainer = explainer_result.explain_method
  46. total_score = explainer_result.total_score
  47. label_score = explainer_result.label_score
  48. explainer_benchmark = {
  49. 'explainer': explainer,
  50. 'evaluations': EventParser._total_score_to_dict(total_score),
  51. 'class_scores': EventParser._label_score_to_dict(label_score, self._job.labels)
  52. }
  53. imported_benchmark[explainer] = explainer_benchmark
  54. return imported_benchmark
  55. def parse_sample(self, sample: namedtuple) -> Optional[namedtuple]:
  56. """Parse the sample event."""
  57. sample_id = sample.image_id
  58. if sample_id not in self._sample_pool:
  59. self._sample_pool[sample_id] = sample
  60. return None
  61. for tag in _IMAGE_DATA_TAGS:
  62. try:
  63. if tag == PluginNameEnum.INFERENCE.value:
  64. self._parse_inference(sample, sample_id)
  65. elif tag == PluginNameEnum.EXPLANATION.value:
  66. self._parse_explanation(sample, sample_id)
  67. else:
  68. self._parse_sample_info(sample, sample_id, tag)
  69. except UnknownError as ex:
  70. logger.warning("Parse %s data failed within image related data,"
  71. " detail: %r", tag, str(ex))
  72. continue
  73. if EventParser._is_sample_data_complete(self._sample_pool[sample_id]):
  74. return self._sample_pool.pop(sample_id)
  75. if EventParser._is_ready_for_display(self._sample_pool[sample_id]):
  76. return self._sample_pool[sample_id]
  77. return None
  78. def _parse_inference(self, event, sample_id):
  79. """Parse the inference event."""
  80. self._sample_pool[sample_id].inference.ground_truth_prob.extend(
  81. event.inference.ground_truth_prob)
  82. self._sample_pool[sample_id].inference.predicted_label.extend(
  83. event.inference.predicted_label)
  84. self._sample_pool[sample_id].inference.predicted_prob.extend(
  85. event.inference.predicted_prob)
  86. def _parse_explanation(self, event, sample_id):
  87. """Parse the explanation event."""
  88. if event.explanation:
  89. for explanation_item in event.explanation:
  90. new_explanation = self._sample_pool[sample_id].explanation.add()
  91. new_explanation.explain_method = explanation_item.explain_method
  92. new_explanation.label = explanation_item.label
  93. new_explanation.heatmap = explanation_item.heatmap
  94. def _parse_sample_info(self, event, sample_id, tag):
  95. """Parse the event containing image info."""
  96. if not getattr(self._sample_pool[sample_id], tag):
  97. setattr(self._sample_pool[sample_id], tag, getattr(event, tag))
  98. @staticmethod
  99. def _total_score_to_dict(total_scores: Iterable):
  100. """Transfer a list of benchmark score to a list of dict."""
  101. evaluation_info = []
  102. for total_score in total_scores:
  103. metric_result = {
  104. 'metric': total_score.benchmark_method,
  105. 'score': total_score.score}
  106. evaluation_info.append(metric_result)
  107. return evaluation_info
  108. @staticmethod
  109. def _label_score_to_dict(label_scores: Iterable, labels: List[str]):
  110. """Transfer a list of benchmark score."""
  111. evaluation_info = [{'label': label, 'evaluations': []}
  112. for label in labels]
  113. for label_score in label_scores:
  114. metric = label_score.benchmark_method
  115. for i, score in enumerate(label_score.score):
  116. label_metric_score = {
  117. 'metric': metric,
  118. 'score': score}
  119. evaluation_info[i]['evaluations'].append(label_metric_score)
  120. return evaluation_info
  121. @staticmethod
  122. def _is_sample_data_complete(image_container: namedtuple) -> bool:
  123. """Check whether sample data completely loaded."""
  124. required_attrs = ['image_id', 'image_data', 'ground_truth_label', 'inference', 'explanation']
  125. for attr in required_attrs:
  126. if not EventParser.is_attr_ready(image_container, attr):
  127. return False
  128. return True
  129. @staticmethod
  130. def _is_ready_for_display(image_container: namedtuple) -> bool:
  131. """
  132. Check whether the image_container is ready for frontend display.
  133. Args:
  134. image_container (nametuple): container consists of sample data
  135. Return:
  136. bool: whether the image_container if ready for display
  137. """
  138. required_attrs = ['image_id', 'image_data', 'ground_truth_label', 'inference']
  139. for attr in required_attrs:
  140. if not EventParser.is_attr_ready(image_container, attr):
  141. return False
  142. return True
  143. @staticmethod
  144. def is_attr_ready(image_container: namedtuple, attr: str) -> bool:
  145. """
  146. Check whether the given attribute is ready in image_container.
  147. Args:
  148. image_container (nametuple): container consist of sample data
  149. attr (str): attribute to check
  150. Returns:
  151. bool, whether the attr is ready
  152. """
  153. if getattr(image_container, attr, False):
  154. return True
  155. return False