You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

explain_job.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ExplainJob."""
  16. import os
  17. from datetime import datetime
  18. from typing import List, Iterable, Union
  19. from mindinsight.explainer.common.enums import PluginNameEnum
  20. from mindinsight.explainer.common.log import logger
  21. from mindinsight.explainer.manager.explain_parser import _ExplainParser
  22. from mindinsight.explainer.manager.event_parse import EventParser
  23. from mindinsight.datavisual.data_access.file_handler import FileHandler
  24. from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
  25. class ExplainJob:
  26. """ExplainJob which manage the record in the summary file."""
  27. def __init__(self,
  28. job_id: str,
  29. summary_dir: str,
  30. create_time: float,
  31. latest_update_time: float):
  32. self._job_id = job_id
  33. self._summary_dir = summary_dir
  34. self._parser = _ExplainParser(summary_dir)
  35. self._event_parser = EventParser(self)
  36. self._latest_update_time = latest_update_time
  37. self._create_time = create_time
  38. self._labels = []
  39. self._metrics = []
  40. self._explainers = []
  41. self._samples_info = {}
  42. self._labels_info = {}
  43. self._benchmark = {}
  44. self._overlay_dict = {}
  45. self._image_dict = {}
  46. @property
  47. def all_classes(self):
  48. """
  49. Return a list of label info
  50. Returns:
  51. class_objs (List[ClassObj]): a list of class_objects, each object
  52. contains:
  53. - id (int): label id
  54. - label (str): label name
  55. - sample_count (int): number of samples for each label
  56. """
  57. all_classes_return = []
  58. for label_id, label_info in self._labels_info.items():
  59. single_info = {'id': label_id,
  60. 'label': label_info['label'],
  61. 'sample_count': len(label_info['sample_ids'])}
  62. all_classes_return.append(single_info)
  63. return all_classes_return
  64. @property
  65. def explainers(self):
  66. """
  67. Return a list of explainer names
  68. Returns:
  69. list(str), explainer names
  70. """
  71. return self._explainers
  72. @property
  73. def explainer_scores(self):
  74. """Return evaluation results for every explainer."""
  75. return [score for score in self._benchmark.values()]
  76. @property
  77. def sample_count(self):
  78. """
  79. Return total number of samples in the job.
  80. Return:
  81. int, total number of samples
  82. """
  83. return len(self._samples_info)
  84. @property
  85. def train_id(self):
  86. """
  87. Return ID of explain job
  88. Returns:
  89. str, id of ExplainJob object
  90. """
  91. return self._job_id
  92. @property
  93. def metrics(self):
  94. """
  95. Return a list of metric names
  96. Returns:
  97. list(str), metric names
  98. """
  99. return self._metrics
  100. @property
  101. def min_confidence(self):
  102. """
  103. Return minimum confidence
  104. Returns:
  105. min_confidence (float):
  106. """
  107. return None
  108. @property
  109. def create_time(self):
  110. """
  111. Return the create time of summary file
  112. Returns:
  113. creation timestamp (float)
  114. """
  115. return self._create_time
  116. @property
  117. def labels(self):
  118. """Return the label contained in the job."""
  119. return self._labels
  120. @property
  121. def latest_update_time(self):
  122. """
  123. Return last modification time stamp of summary file.
  124. Returns:
  125. float, last_modification_time stamp
  126. """
  127. return self._latest_update_time
  128. @latest_update_time.setter
  129. def latest_update_time(self, new_time: Union[float, datetime]):
  130. """
  131. Update the latest_update_time timestamp manually.
  132. Args:
  133. new_time stamp (union[float, datetime]): updated time for the job
  134. """
  135. if isinstance(new_time, datetime):
  136. self._latest_update_time = new_time.timestamp()
  137. elif isinstance(new_time, str):
  138. self._latest_update_time = new_time
  139. else:
  140. raise TypeError('new_time should have type of str or datetime')
  141. @property
  142. def loader_id(self):
  143. """Return the job id."""
  144. return self._job_id
  145. @property
  146. def samples(self):
  147. """Return the information of all samples in the job."""
  148. return self._samples_info
  149. @staticmethod
  150. def get_create_time(file_path: str) -> float:
  151. """Return timestamp of create time of specific path."""
  152. create_time = os.stat(file_path).st_ctime
  153. return create_time
  154. @staticmethod
  155. def get_update_time(file_path: str) -> float:
  156. """Return timestamp of update time of specific path."""
  157. update_time = os.stat(file_path).st_mtime
  158. return update_time
  159. @staticmethod
  160. def _total_score_to_dict(total_scores: Iterable):
  161. """Transfer a list of benchmark score to a list of dict."""
  162. evaluation_info = []
  163. for total_score in total_scores:
  164. metric_result = {'metric': total_score.benchmark_method,
  165. 'score': total_score.score}
  166. evaluation_info.append(metric_result)
  167. return evaluation_info
  168. @staticmethod
  169. def _label_score_to_dict(label_scores: Iterable, labels: List[str]):
  170. """Transfer a list of benchmark score."""
  171. evaluation_info = [{'label': label, 'evaluations': []}
  172. for label in labels]
  173. for label_score in label_scores:
  174. metric = label_score.benchmark_method
  175. for i, score in enumerate(label_score.score):
  176. label_metric_score = dict()
  177. label_metric_score['metric'] = metric
  178. label_metric_score['score'] = score
  179. evaluation_info[i]['evaluations'].append(label_metric_score)
  180. return evaluation_info
  181. def _initialize_labels_info(self):
  182. """Initialize a dict for labels in the job."""
  183. if self._labels is None:
  184. logger.warning('No labels is provided in job %s', self._job_id)
  185. return
  186. for label_id, label in enumerate(self._labels):
  187. self._labels_info[label_id] = {'label': label,
  188. 'sample_ids': set()}
  189. def _explanation_to_dict(self, explanation, sample_id):
  190. """Transfer the explanation from event to dict storage."""
  191. explainer_name = explanation.explain_method
  192. explain_label = explanation.label
  193. saliency = explanation.heatmap
  194. saliency_id = '{}_{}_{}'.format(
  195. sample_id, explain_label, explainer_name)
  196. explain_info = {
  197. 'explainer': explainer_name,
  198. 'overlay': saliency_id,
  199. }
  200. self._overlay_dict[saliency_id] = saliency
  201. return explain_info
  202. def _image_container_to_dict(self, sample_data):
  203. """Transfer the image container to dict storage."""
  204. sample_id = sample_data.image_id
  205. sample_info = {
  206. 'id': sample_id,
  207. 'name': sample_id,
  208. 'labels': [self._labels_info[x]['label']
  209. for x in sample_data.ground_truth_label],
  210. 'inferences': []}
  211. self._image_dict[sample_id] = sample_data.image_data
  212. ground_truth_labels = list(sample_data.ground_truth_label)
  213. ground_truth_probs = list(sample_data.inference.ground_truth_prob)
  214. predicted_labels = list(sample_data.inference.predicted_label)
  215. predicted_probs = list(sample_data.inference.predicted_prob)
  216. inference_info = {}
  217. for label, prob in zip(
  218. ground_truth_labels + predicted_labels,
  219. ground_truth_probs + predicted_probs):
  220. inference_info[label] = {
  221. 'label': self._labels_info[label]['label'],
  222. 'confidence': prob,
  223. 'saliency_maps': []}
  224. if EventParser.is_attr_ready(sample_data, 'explanation'):
  225. for explanation in sample_data.explanation:
  226. explanation_dict = self._explanation_to_dict(
  227. explanation, sample_id)
  228. inference_info[explanation.label]['saliency_maps'].append(explanation_dict)
  229. sample_info['inferences'] = list(inference_info.values())
  230. return sample_info
  231. def _import_sample(self, sample):
  232. """Add sample object of given sample id."""
  233. for label_id in sample.ground_truth_label:
  234. self._labels_info[label_id]['sample_ids'].add(sample.image_id)
  235. sample_info = self._image_container_to_dict(sample)
  236. self._samples_info.update({sample_info['id']: sample_info})
  237. def retrieve_image(self, image_id: str):
  238. """
  239. Retrieve image data from the job given image_id.
  240. Return:
  241. string, image data in base64 byte
  242. """
  243. return self._image_dict.get(image_id, None)
  244. def retrieve_overlay(self, overlay_id: str):
  245. """
  246. Retrieve sample map from the job given overlay_id.
  247. Return:
  248. string, saliency_map data in base64 byte
  249. """
  250. return self._overlay_dict.get(overlay_id, None)
  251. def get_all_samples(self):
  252. """
  253. Return a list of sample information cachced in the explain job
  254. Returns:
  255. sample_list (List[SampleObj]): a list of sample objects, each object
  256. consists of:
  257. - id (int): sample id
  258. - name (str): basename of image
  259. - labels (list[str]): list of labels
  260. - inferences list[dict])
  261. """
  262. samples_in_list = list(self._samples_info.values())
  263. return samples_in_list
  264. def _is_metadata_empty(self):
  265. """Check whether metadata is loaded first."""
  266. if not self._explainers or not self._metrics or not self._labels:
  267. return True
  268. return False
  269. def _import_data_from_event(self, event):
  270. """Parse and import data from the event data."""
  271. tags = {
  272. 'image_id': PluginNameEnum.IMAGE_ID,
  273. 'benchmark': PluginNameEnum.BENCHMARK,
  274. 'metadata': PluginNameEnum.METADATA
  275. }
  276. if 'metadata' not in event and self._is_metadata_empty():
  277. raise ValueError('metadata is empty, should write metadata first'
  278. 'in the summary.')
  279. for tag in tags:
  280. if tag not in event:
  281. continue
  282. if tag == PluginNameEnum.IMAGE_ID.value:
  283. sample_event = event[tag]
  284. sample_data = self._event_parser.parse_sample(sample_event)
  285. if sample_data is not None:
  286. self._import_sample(sample_data)
  287. continue
  288. if tag == PluginNameEnum.BENCHMARK.value:
  289. benchmark_event = event[tag].benchmark
  290. benchmark = self._event_parser.parse_benchmark(benchmark_event)
  291. self._benchmark = benchmark
  292. elif tag == PluginNameEnum.METADATA.value:
  293. metadata_event = event[tag].metadata
  294. metadata = self._event_parser.parse_metadata(metadata_event)
  295. self._explainers, self._metrics, self._labels = metadata
  296. self._initialize_labels_info()
  297. def load(self):
  298. """
  299. Start loading data from parser.
  300. """
  301. valid_file_names = []
  302. for filename in FileHandler.list_dir(self._summary_dir):
  303. if FileHandler.is_file(
  304. FileHandler.join(self._summary_dir, filename)):
  305. valid_file_names.append(filename)
  306. if not valid_file_names:
  307. raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' % self._summary_dir)
  308. is_end = False
  309. while not is_end:
  310. is_clean, is_end, event = self._parser.parse_explain(valid_file_names)
  311. if is_clean:
  312. logger.info('Summary file in %s update, reload the clean the loaded data.', self._summary_dir)
  313. self._clean_job()
  314. if event:
  315. self._import_data_from_event(event)
  316. def _clean_job(self):
  317. """Clean the cached data in job."""
  318. self._latest_update_time = ExplainJob.get_update_time(self._summary_dir)
  319. self._create_time = ExplainJob.get_update_time(self._summary_dir)
  320. self._labels.clear()
  321. self._metrics.clear()
  322. self._explainers.clear()
  323. self._samples_info.clear()
  324. self._labels_info.clear()
  325. self._benchmark.clear()
  326. self._overlay_dict.clear()
  327. self._image_dict.clear()
  328. self._event_parser.clear()