You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_runner.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Runner."""
  16. import os
  17. import re
  18. import traceback
  19. from time import time
  20. from typing import Tuple, List, Optional
  21. import numpy as np
  22. from PIL import Image
  23. from scipy.stats import beta
  24. import mindspore as ms
  25. import mindspore.dataset as ds
  26. from mindspore import log
  27. from mindspore.nn import Softmax, Cell
  28. from mindspore.nn.probability.toolbox import UncertaintyEvaluation
  29. from mindspore.ops.operations import ExpandDims
  30. from mindspore.train._utils import check_value_type
  31. from mindspore.train.summary._summary_adapter import _convert_image_format
  32. from mindspore.train.summary.summary_record import SummaryRecord
  33. from mindspore.train.summary_pb2 import Explain
  34. from .benchmark import Localization
  35. from .explanation import RISE
  36. from .benchmark._attribution.metric import AttributionMetric, LabelSensitiveMetric, LabelAgnosticMetric
  37. from .explanation._attribution.attribution import Attribution
  38. # datafile directory names
  39. _DATAFILE_DIRNAME_PREFIX = "_explain_"
  40. _ORIGINAL_IMAGE_DIRNAME = "origin_images"
  41. _HEATMAP_DIRNAME = "heatmap"
  42. # max. no. of sample per directory
  43. _SAMPLE_PER_DIR = 1000
  44. _EXPAND_DIMS = ExpandDims()
  45. _SEED = 58 # set a seed to fix the iterating order of the dataset
  46. def _normalize(img_np):
  47. """Normalize the numpy image to the range of [0, 1]. """
  48. max_ = img_np.max()
  49. min_ = img_np.min()
  50. normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
  51. return normed
  52. def _np_to_image(img_np, mode):
  53. """Convert numpy array to PIL image."""
  54. return Image.fromarray(np.uint8(img_np * 255), mode=mode)
  55. def _calc_prob_interval(volume, probs, prob_vars):
  56. """Compute the confidence interval of probability."""
  57. if not isinstance(probs, np.ndarray):
  58. probs = np.asarray(probs)
  59. if not isinstance(prob_vars, np.ndarray):
  60. prob_vars = np.asarray(prob_vars)
  61. one_minus_probs = 1 - probs
  62. alpha_coef = (np.square(probs) * one_minus_probs / prob_vars) - probs
  63. beta_coef = alpha_coef * one_minus_probs / probs
  64. intervals = beta.interval(volume, alpha_coef, beta_coef)
  65. # avoid invalid result due to extreme small value of prob_vars
  66. lows = []
  67. highs = []
  68. for i, low in enumerate(intervals[0]):
  69. high = intervals[1][i]
  70. if prob_vars[i] <= 0 or \
  71. not np.isfinite(low) or low > probs[i] or \
  72. not np.isfinite(high) or high < probs[i]:
  73. low = probs[i]
  74. high = probs[i]
  75. lows.append(low)
  76. highs.append(high)
  77. return lows, highs
  78. def _get_id_dirname(sample_id: int):
  79. """Get the name of parent directory of the image id."""
  80. return str(int(sample_id / _SAMPLE_PER_DIR) * _SAMPLE_PER_DIR)
  81. def _extract_timestamp(filename: str):
  82. """Extract timestamp from summary filename."""
  83. matched = re.search(r"summary\.(\d+)", filename)
  84. if matched:
  85. return int(matched.group(1))
  86. return None
  87. class ExplainRunner:
  88. """
  89. A high-level API for users to generate and store results of the explanation methods and the evaluation methods.
  90. After generating results with the explanation methods and the evaluation methods, the results will be written into
  91. a specified file with `mindspore.summary.SummaryRecord`. The stored content can be viewed using MindInsight.
  92. Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version
  93. will be deprecated and will not be supported in MindInsight of current version.
  94. Args:
  95. summary_dir (str, optional): The directory path to save the summary files which store the generated results.
  96. Default: "./"
  97. Examples:
  98. >>> from mindspore.explainer import ExplainRunner
  99. >>> # init a runner with a specified directory
  100. >>> summary_dir = "summary_dir"
  101. >>> runner = ExplainRunner(summary_dir)
  102. """
  103. def __init__(self, summary_dir: Optional[str] = "./"):
  104. check_value_type("summary_dir", summary_dir, str)
  105. self._summary_dir = summary_dir
  106. self._count = 0
  107. self._classes = None
  108. self._model = None
  109. self._uncertainty = None
  110. self._summary_timestamp = None
  111. def run(self,
  112. dataset: Tuple,
  113. explainers: List,
  114. benchmarkers: Optional[List] = None,
  115. uncertainty: Optional[UncertaintyEvaluation] = None,
  116. activation_fn: Optional[Cell] = Softmax()):
  117. """
  118. Genereates results and writes results into the summary files in `summary_dir` specified during the object
  119. initialization.
  120. Args:
  121. dataset (tuple): A tuple that contains `mindspore.dataset` object for iteration and its labels.
  122. - dataset[0]: A `mindspore.dataset` object to provide data to explain.
  123. - dataset[1]: A list of string that specifies the label names of the dataset.
  124. explainers (list[Explanation]): A list of explanation objects to generate attribution results. Explanation
  125. object is an instance initialized with the explanation methods in module
  126. `mindspore.explainer.explanation`.
  127. benchmarkers (list[Benchmark], optional): A list of benchmark objects to generate evaluation results.
  128. Default: None
  129. uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference
  130. uncertainty of samples.
  131. activation_fn (Cell, optional): The activation layer that transforms the output of the network to
  132. label probability distribution :math:`P(y|x)`. Default: Softmax().
  133. Examples:
  134. >>> from mindspore.explainer import ExplainRunner
  135. >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
  136. >>> from mindspore.nn import Sigmoid
  137. >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
  138. >>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10
  139. >>> dataset = get_dataset('/path/to/Cifar10_dataset')
  140. >>> classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'turck']
  141. >>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10
  142. >>> param_dict = load_checkpoint("checkpoint.ckpt")
  143. >>> net = resnet50(len(classes))
  144. >>> load_param_into_net(net, param_dict)
  145. >>> gbp = GuidedBackprop(net)
  146. >>> gradient = Gradient(net)
  147. >>> runner = ExplainRunner("./")
  148. >>> explainers = [gbp, gradient]
  149. >>> runner.run((dataset, classes), explainers, activation_fn=Sigmoid())
  150. """
  151. check_value_type("dataset", dataset, tuple)
  152. if len(dataset) != 2:
  153. raise ValueError("Argument `dataset` should be a tuple with length = 2.")
  154. dataset, classes = dataset
  155. if benchmarkers is None:
  156. benchmarkers = []
  157. self._verify_data_form(dataset, benchmarkers)
  158. self._classes = classes
  159. check_value_type("explainers", explainers, list)
  160. if not explainers:
  161. raise ValueError("Argument `explainers` must be a non-empty list")
  162. for exp in explainers:
  163. if not isinstance(exp, Attribution):
  164. raise TypeError("Argument `explainers` should be a list of objects of classes in "
  165. "`mindspore.explainer.explanation`.")
  166. if benchmarkers:
  167. check_value_type("benchmarkers", benchmarkers, list)
  168. for bench in benchmarkers:
  169. if not isinstance(bench, AttributionMetric):
  170. raise TypeError("Argument `benchmarkers` should be a list of objects of classes in explanation"
  171. "`mindspore.explainer.benchmark`.")
  172. check_value_type("activation_fn", activation_fn, Cell)
  173. self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn])
  174. next_element = next(dataset.create_tuple_iterator())
  175. inputs, _, _ = self._unpack_next_element(next_element)
  176. prop_test = self._model(inputs)
  177. check_value_type("output of model im explainer", prop_test, ms.Tensor)
  178. if prop_test.shape[1] != len(self._classes):
  179. raise ValueError("The dimension of model output does not match the length of dataset classes. Please "
  180. "check dataset classes or the black-box model in the explainer again.")
  181. if uncertainty is not None:
  182. check_value_type("uncertainty", uncertainty, UncertaintyEvaluation)
  183. prop_var_test = uncertainty.eval_epistemic_uncertainty(inputs)
  184. check_value_type("output of uncertainty", prop_var_test, np.ndarray)
  185. if prop_var_test.shape[1] != len(self._classes):
  186. raise ValueError("The dimension of uncertainty output does not match the length of dataset classes"
  187. "classes. Please check dataset classes or the black-box model in the explainer again.")
  188. self._uncertainty = uncertainty
  189. else:
  190. self._uncertainty = None
  191. with SummaryRecord(self._summary_dir) as summary:
  192. spacer = '{:120}\r'
  193. print("Start running and writing......")
  194. begin = time()
  195. print("Start writing metadata......")
  196. self._summary_timestamp = _extract_timestamp(summary.event_file_name)
  197. if self._summary_timestamp is None:
  198. raise RuntimeError("Cannot extract timestamp from summary filename!"
  199. " It should contains a timestamp of 10 digits.")
  200. explain = Explain()
  201. explain.metadata.label.extend(classes)
  202. exp_names = [exp.__class__.__name__ for exp in explainers]
  203. explain.metadata.explain_method.extend(exp_names)
  204. if benchmarkers:
  205. bench_names = [bench.__class__.__name__ for bench in benchmarkers]
  206. explain.metadata.benchmark_method.extend(bench_names)
  207. summary.add_value("explainer", "metadata", explain)
  208. summary.record(1)
  209. print("Finish writing metadata.")
  210. now = time()
  211. print("Start running and writing inference data.....")
  212. imageid_labels = self._run_inference(dataset, summary)
  213. print(spacer.format("Finish running and writing inference data. "
  214. "Time elapsed: {:.3f} s".format(time() - now)))
  215. if not benchmarkers:
  216. for exp in explainers:
  217. start = time()
  218. print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
  219. self._count = 0
  220. ds.config.set_seed(_SEED)
  221. for idx, next_element in enumerate(dataset):
  222. now = time()
  223. self._run_exp_step(next_element, exp, imageid_labels, summary)
  224. print(spacer.format("Finish writing {}-th explanation data for {}. Time elapsed: "
  225. "{:.3f} s".format(idx, exp.__class__.__name__, time() - now)), end='')
  226. print(spacer.format(
  227. "Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format(
  228. exp.__class__.__name__, time() - start)))
  229. else:
  230. for exp in explainers:
  231. explain = Explain()
  232. for bench in benchmarkers:
  233. bench.reset()
  234. print(f"Start running and writing explanation and "
  235. f"benchmark data for {exp.__class__.__name__}......")
  236. self._count = 0
  237. start = time()
  238. ds.config.set_seed(_SEED)
  239. for idx, next_element in enumerate(dataset):
  240. now = time()
  241. saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary)
  242. print(spacer.format(
  243. "Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format(
  244. idx, exp.__class__.__name__, time() - now)), end='')
  245. for bench in benchmarkers:
  246. now = time()
  247. self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
  248. print(spacer.format(
  249. "Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format(
  250. idx, bench.__class__.__name__, exp.__class__.__name__, time() - now)), end='')
  251. for bench in benchmarkers:
  252. benchmark = explain.benchmark.add()
  253. benchmark.explain_method = exp.__class__.__name__
  254. benchmark.benchmark_method = bench.__class__.__name__
  255. benchmark.total_score = bench.performance
  256. if isinstance(bench, LabelSensitiveMetric):
  257. benchmark.label_score.extend(bench.class_performances)
  258. print(spacer.format("Finish running and writing explanation and benchmark data for {}. "
  259. "Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start)))
  260. summary.add_value('explainer', 'benchmark', explain)
  261. summary.record(1)
  262. print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin))
  263. @staticmethod
  264. def _verify_data_form(dataset, benchmarkers):
  265. """
  266. Verify the validity of dataset.
  267. Args:
  268. dataset (`ds`): the user parsed dataset.
  269. benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers.
  270. """
  271. next_element = next(dataset.create_tuple_iterator())
  272. if len(next_element) not in [1, 2, 3]:
  273. raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
  274. " as columns.")
  275. if len(next_element) == 3:
  276. inputs, labels, bboxes = next_element
  277. if bboxes.shape[-1] != 4:
  278. raise ValueError("The third element of dataset should be bounding boxes with shape of "
  279. "[batch_size, num_ground_truth, 4].")
  280. else:
  281. if any(map(lambda benchmarker: isinstance(benchmarker, Localization), benchmarkers)):
  282. raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
  283. if len(next_element) == 2:
  284. inputs, labels = next_element
  285. if len(next_element) == 1:
  286. inputs = next_element[0]
  287. if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
  288. raise ValueError(
  289. "Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
  290. inputs.shape))
  291. if len(inputs.shape) == 3:
  292. log.warning(
  293. "Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
  294. " dimension as batch data.".format(inputs.shape))
  295. if len(next_element) > 1:
  296. if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
  297. raise ValueError(
  298. "Labels shape {} is unrecognizable: labels should not have more than two dimensions"
  299. " with length greater than 1.".format(labels.shape))
  300. def _transform_data(self, inputs, labels, bboxes, ifbbox):
  301. """
  302. Transform the data from one iteration of dataset to a unifying form for the follow-up operations.
  303. Args:
  304. inputs (Tensor): the image data
  305. labels (Tensor): the labels
  306. bboxes (Tensor): the boudnding boxes data
  307. ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t label
  308. id will be returned. If False, the returned bboxes is the the parsed bboxes.
  309. Returns:
  310. inputs (Tensor): the image data, unified to a 4D Tensor.
  311. labels (List[List[int]]): the ground truth labels.
  312. bboxes (Union[List[Dict], None, Tensor]): the bounding boxes
  313. """
  314. inputs = ms.Tensor(inputs, ms.float32)
  315. if len(inputs.shape) == 3:
  316. inputs = _EXPAND_DIMS(inputs, 0)
  317. if isinstance(labels, ms.Tensor):
  318. labels = ms.Tensor(labels, ms.int32)
  319. labels = _EXPAND_DIMS(labels, 0)
  320. if isinstance(bboxes, ms.Tensor):
  321. bboxes = ms.Tensor(bboxes, ms.int32)
  322. bboxes = _EXPAND_DIMS(bboxes, 0)
  323. input_len = len(inputs)
  324. if bboxes is not None and ifbbox:
  325. bboxes = ms.Tensor(bboxes, ms.int32)
  326. masks_lst = []
  327. labels = labels.asnumpy().reshape([input_len, -1])
  328. bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
  329. for idx, label in enumerate(labels):
  330. height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
  331. masks = {}
  332. for j, label_item in enumerate(label):
  333. target = int(label_item)
  334. if -1 < target < len(self._classes):
  335. if target not in masks:
  336. mask = np.zeros((1, 1, height, width))
  337. else:
  338. mask = masks[target]
  339. x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
  340. mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
  341. masks[target] = mask
  342. masks_lst.append(masks)
  343. bboxes = masks_lst
  344. labels = ms.Tensor(labels, ms.int32)
  345. if len(labels.shape) == 1:
  346. labels_lst = [[int(i)] for i in labels.asnumpy()]
  347. else:
  348. labels = labels.asnumpy().reshape([input_len, -1])
  349. labels_lst = []
  350. for item in labels:
  351. labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._classes))))
  352. labels = labels_lst
  353. return inputs, labels, bboxes
  354. def _unpack_next_element(self, next_element, ifbbox=False):
  355. """
  356. Unpack a single iteration of dataset.
  357. Args:
  358. next_element (Tuple): a single element iterated from dataset object.
  359. ifbbox (bool): whether to preprocess bboxes in self._transform_data.
  360. Returns:
  361. Tuple, a unified Tuple contains image_data, labels, and bounding boxes.
  362. """
  363. if len(next_element) == 3:
  364. inputs, labels, bboxes = next_element
  365. elif len(next_element) == 2:
  366. inputs, labels = next_element
  367. bboxes = None
  368. else:
  369. inputs = next_element[0]
  370. labels = [[] for _ in inputs]
  371. bboxes = None
  372. inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
  373. return inputs, labels, bboxes
  374. @staticmethod
  375. def _make_label_batch(labels):
  376. """
  377. Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max
  378. length of all the rows in labels.
  379. Args:
  380. labels (List[List]): the union labels of a data batch.
  381. Returns:
  382. 2D Tensor.
  383. """
  384. max_len = max([len(label) for label in labels])
  385. batch_labels = np.zeros((len(labels), max_len))
  386. for idx, _ in enumerate(batch_labels):
  387. length = len(labels[idx])
  388. batch_labels[idx, :length] = np.array(labels[idx])
  389. return ms.Tensor(batch_labels, ms.int32)
  390. def _run_inference(self, dataset, summary, threshold=0.5):
  391. """
  392. Run inference for the dataset and write the inference related data into summary.
  393. Args:
  394. dataset (`ds`): the parsed dataset
  395. summary (`SummaryRecord`): the summary object to store the data
  396. threshold (float): the threshold for prediction.
  397. Returns:
  398. imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels.
  399. """
  400. spacer = '{:120}\r'
  401. imageid_labels = {}
  402. ds.config.set_seed(_SEED)
  403. self._count = 0
  404. for j, next_element in enumerate(dataset):
  405. now = time()
  406. inputs, labels, _ = self._unpack_next_element(next_element)
  407. prob = self._model(inputs).asnumpy()
  408. if self._uncertainty is not None:
  409. prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
  410. prob_sd = np.sqrt(prob_var)
  411. else:
  412. prob_var = prob_sd = None
  413. for idx, inp in enumerate(inputs):
  414. gt_labels = labels[idx]
  415. gt_probs = [float(prob[idx][i]) for i in gt_labels]
  416. data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
  417. original_image = _np_to_image(_normalize(data_np), mode='RGB')
  418. original_image_path = self._save_original_image(self._count, original_image)
  419. predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]]
  420. predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
  421. has_uncertainty = False
  422. gt_prob_sds = gt_prob_itl95_lows = gt_prob_itl95_his = None
  423. predicted_prob_sds = predicted_prob_itl95_lows = predicted_prob_itl95_his = None
  424. if prob_var is not None:
  425. gt_prob_sds = [float(prob_sd[idx][i]) for i in gt_labels]
  426. predicted_prob_sds = [float(prob_sd[idx][i]) for i in predicted_labels]
  427. try:
  428. gt_prob_itl95_lows, gt_prob_itl95_his = \
  429. _calc_prob_interval(0.95, gt_probs, [float(prob_var[idx][i]) for i in gt_labels])
  430. predicted_prob_itl95_lows, predicted_prob_itl95_his = \
  431. _calc_prob_interval(0.95, predicted_probs, [float(prob_var[idx][i])
  432. for i in predicted_labels])
  433. has_uncertainty = True
  434. except ValueError:
  435. log.error(traceback.format_exc())
  436. log.error("Error on calculating uncertainty")
  437. union_labs = list(set(gt_labels + predicted_labels))
  438. imageid_labels[str(self._count)] = union_labs
  439. explain = Explain()
  440. explain.sample_id = self._count
  441. explain.image_path = original_image_path
  442. summary.add_value("explainer", "sample", explain)
  443. explain = Explain()
  444. explain.sample_id = self._count
  445. explain.ground_truth_label.extend(gt_labels)
  446. explain.inference.ground_truth_prob.extend(gt_probs)
  447. explain.inference.predicted_label.extend(predicted_labels)
  448. explain.inference.predicted_prob.extend(predicted_probs)
  449. if has_uncertainty:
  450. explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
  451. explain.inference.ground_truth_prob_itl95_low.extend(gt_prob_itl95_lows)
  452. explain.inference.ground_truth_prob_itl95_hi.extend(gt_prob_itl95_his)
  453. explain.inference.predicted_prob_sd.extend(predicted_prob_sds)
  454. explain.inference.predicted_prob_itl95_low.extend(predicted_prob_itl95_lows)
  455. explain.inference.predicted_prob_itl95_hi.extend(predicted_prob_itl95_his)
  456. summary.add_value("explainer", "inference", explain)
  457. summary.record(1)
  458. self._count += 1
  459. print(spacer.format("Finish running and writing {}-th batch inference data."
  460. " Time elapsed: {:.3f} s".format(j, time() - now)),
  461. end='')
  462. return imageid_labels
  463. def _run_exp_step(self, next_element, explainer, imageid_labels, summary):
  464. """
  465. Run the explanation for each step and write explanation results into summary.
  466. Args:
  467. next_element (Tuple): data of one step
  468. explainer (_Attribution): an Attribution object to generate saliency maps.
  469. imageid_labels (dict): a dict that maps the image_id and its union labels.
  470. summary (SummaryRecord): the summary object to store the data
  471. Returns:
  472. List of dict that maps label to its corresponding saliency map.
  473. """
  474. inputs, labels, _ = self._unpack_next_element(next_element)
  475. count = self._count
  476. unions = []
  477. for _ in range(len(labels)):
  478. unions_labels = imageid_labels[str(count)]
  479. unions.append(unions_labels)
  480. count += 1
  481. batch_unions = self._make_label_batch(unions)
  482. saliency_dict_lst = []
  483. if isinstance(explainer, RISE):
  484. batch_saliency_full = explainer(inputs, batch_unions)
  485. else:
  486. batch_saliency_full = []
  487. for i in range(len(batch_unions[0])):
  488. batch_saliency = explainer(inputs, batch_unions[:, i])
  489. batch_saliency_full.append(batch_saliency)
  490. concat = ms.ops.operations.Concat(1)
  491. batch_saliency_full = concat(tuple(batch_saliency_full))
  492. for idx, union in enumerate(unions):
  493. saliency_dict = {}
  494. explain = Explain()
  495. explain.sample_id = self._count
  496. for k, lab in enumerate(union):
  497. saliency = batch_saliency_full[idx:idx + 1, k:k + 1]
  498. saliency_dict[lab] = saliency
  499. saliency_np = _normalize(saliency.asnumpy().squeeze())
  500. saliency_image = _np_to_image(saliency_np, mode='L')
  501. heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab, self._count, saliency_image)
  502. explanation = explain.explanation.add()
  503. explanation.explain_method = explainer.__class__.__name__
  504. explanation.heatmap_path = heatmap_path
  505. explanation.label = lab
  506. summary.add_value("explainer", "explanation", explain)
  507. summary.record(1)
  508. self._count += 1
  509. saliency_dict_lst.append(saliency_dict)
  510. return saliency_dict_lst
  511. def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
  512. """
  513. Run the explanation and evaluation for each step and write explanation results into summary.
  514. Args:
  515. next_element (Tuple): Data of one step
  516. explainer (`_Attribution`): An Attribution object to generate saliency maps.
  517. """
  518. inputs, labels, _ = self._unpack_next_element(next_element)
  519. for idx, inp in enumerate(inputs):
  520. inp = _EXPAND_DIMS(inp, 0)
  521. if isinstance(benchmarker, LabelAgnosticMetric):
  522. res = benchmarker.evaluate(explainer, inp)
  523. res[np.isnan(res)] = 0.0
  524. benchmarker.aggregate(res)
  525. else:
  526. saliency_dict = saliency_dict_lst[idx]
  527. for label, saliency in saliency_dict.items():
  528. if isinstance(benchmarker, Localization):
  529. _, _, bboxes = self._unpack_next_element(next_element, True)
  530. if label in labels[idx]:
  531. res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
  532. saliency=saliency)
  533. res[np.isnan(res)] = 0.0
  534. benchmarker.aggregate(res, label)
  535. elif isinstance(benchmarker, LabelSensitiveMetric):
  536. res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
  537. res[np.isnan(res)] = 0.0
  538. benchmarker.aggregate(res, label)
  539. else:
  540. raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
  541. 'receive {}'.format(type(benchmarker)))
  542. def _save_original_image(self, sample_id: int, image):
  543. """Save an image to summary directory."""
  544. id_dirname = _get_id_dirname(sample_id)
  545. relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
  546. _ORIGINAL_IMAGE_DIRNAME,
  547. id_dirname)
  548. os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
  549. relative_path = os.path.join(relative_dir, f"{sample_id}.jpg")
  550. save_path = os.path.join(self._summary_dir, relative_path)
  551. with open(save_path, "wb") as file:
  552. image.save(file)
  553. return relative_path
  554. def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image):
  555. """Save heatmap image to summary directory."""
  556. id_dirname = _get_id_dirname(sample_id)
  557. relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
  558. _HEATMAP_DIRNAME,
  559. explain_method,
  560. id_dirname)
  561. os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
  562. relative_path = os.path.join(relative_dir, f"{sample_id}_{class_id}.jpg")
  563. save_path = os.path.join(self._summary_dir, relative_path)
  564. with open(save_path, "wb") as file:
  565. image.save(file, optimize=True)
  566. return relative_path