You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_image_classification_runner.py 46 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Image Classification Runner."""
  16. import os
  17. import re
  18. import json
  19. from time import time
  20. import numpy as np
  21. from scipy.stats import beta
  22. from PIL import Image
  23. import mindspore as ms
  24. from mindspore import context
  25. from mindspore import log
  26. import mindspore.dataset as ds
  27. from mindspore.dataset import Dataset
  28. from mindspore.nn import Cell, SequentialCell
  29. from mindspore.ops.operations import ExpandDims
  30. from mindspore.train._utils import check_value_type
  31. from mindspore.train.summary._summary_adapter import _convert_image_format
  32. from mindspore.train.summary.summary_record import SummaryRecord
  33. from mindspore.train.summary_pb2 import Explain
  34. from mindspore.nn.probability.toolbox.uncertainty_evaluation import UncertaintyEvaluation
  35. from mindspore.explainer.benchmark import Localization
  36. from mindspore.explainer.benchmark._attribution.metric import AttributionMetric
  37. from mindspore.explainer.benchmark._attribution.metric import LabelSensitiveMetric
  38. from mindspore.explainer.benchmark._attribution.metric import LabelAgnosticMetric
  39. from mindspore.explainer.explanation import RISE
  40. from mindspore.explainer.explanation._attribution.attribution import Attribution
  41. from mindspore.explainer.explanation._counterfactual import hierarchical_occlusion as hoc
  42. _EXPAND_DIMS = ExpandDims()
  43. def _normalize(img_np):
  44. """Normalize the numpy image to the range of [0, 1]. """
  45. max_ = img_np.max()
  46. min_ = img_np.min()
  47. normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
  48. return normed
  49. def _np_to_image(img_np, mode):
  50. """Convert numpy array to PIL image."""
  51. return Image.fromarray(np.uint8(img_np * 255), mode=mode)
  52. class ImageClassificationRunner:
  53. """
  54. A high-level API for users to generate and store results of the explanation methods and the evaluation methods.
  55. Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version
  56. will be deprecated and will not be supported in MindInsight of current version.
  57. Args:
  58. summary_dir (str): The directory path to save the summary files which store the generated results.
  59. data (tuple[Dataset, list[str]]): Tuple of dataset and the corresponding class label list. The dataset
  60. should provides [images], [images, labels] or [images, labels, bboxes] as columns. The label list must
  61. share the exact same length and order of the network outputs.
  62. network (Cell): The network(with logit outputs) to be explained.
  63. activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For
  64. single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks,
  65. `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as
  66. when combining this function with network, the final output is the probability of the input.
  67. Examples:
  68. >>> from mindspore.explainer import ImageClassificationRunner
  69. >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
  70. >>> from mindspore.explainer.benchmark import Faithfulness
  71. >>> from mindspore.nn import Softmax
  72. >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
  73. >>>
  74. >>> # The detail of AlexNet is shown in model_zoo.official.cv.alexnet.src.alexnet.py
  75. >>> net = AlexNet(10)
  76. >>> # Load the checkpoint
  77. >>> param_dict = load_checkpoint("/path/to/checkpoint")
  78. >>> load_param_into_net(net, param_dict)
  79. []
  80. >>>
  81. >>> # Prepare the dataset for explaining and evaluation.
  82. >>> # The detail of create_dataset_cifar10 method is shown in model_zoo.official.cv.alexnet.src.dataset.py
  83. >>>
  84. >>> dataset = create_dataset_cifar10("/path/to/cifar/dataset", 1)
  85. >>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  86. >>>
  87. >>> activation_fn = Softmax()
  88. >>> gbp = GuidedBackprop(net)
  89. >>> gradient = Gradient(net)
  90. >>> explainers = [gbp, gradient]
  91. >>> faithfulness = Faithfulness(len(labels), activation_fn, "NaiveFaithfulness")
  92. >>> benchmarkers = [faithfulness]
  93. >>>
  94. >>> runner = ImageClassificationRunner("./summary_dir", (dataset, labels), net, activation_fn)
  95. >>> runner.register_saliency(explainers=explainers, benchmarkers=benchmarkers)
  96. >>> runner.run()
  97. """
  98. # datafile directory names
  99. _DATAFILE_DIRNAME_PREFIX = "_explain_"
  100. _ORIGINAL_IMAGE_DIRNAME = "origin_images"
  101. _HEATMAP_DIRNAME = "heatmap"
  102. # specfial filenames
  103. _MANIFEST_FILENAME = "manifest.json"
  104. # max. no. of sample per directory
  105. _SAMPLE_PER_DIR = 1000
  106. # seed for fixing the iterating order of the dataset
  107. _DATASET_SEED = 58
  108. # printing spacer
  109. _SPACER = "{:120}\r"
  110. # datafile directory's permission
  111. _DIR_MODE = 0o750
  112. # datafile's permission
  113. _FILE_MODE = 0o600
  114. def __init__(self,
  115. summary_dir,
  116. data,
  117. network,
  118. activation_fn):
  119. check_value_type("data", data, tuple)
  120. if len(data) != 2:
  121. raise ValueError("Argument data is not a tuple with 2 elements")
  122. check_value_type("data[0]", data[0], Dataset)
  123. check_value_type("data[1]", data[1], list)
  124. if not all(isinstance(ele, str) for ele in data[1]):
  125. raise ValueError("Argument data[1] is not list of str.")
  126. check_value_type("summary_dir", summary_dir, str)
  127. check_value_type("network", network, Cell)
  128. check_value_type("activation_fn", activation_fn, Cell)
  129. self._summary_dir = summary_dir
  130. self._dataset = data[0]
  131. self._labels = data[1]
  132. self._network = network
  133. self._explainers = None
  134. self._benchmarkers = None
  135. self._uncertainty = None
  136. self._hoc_searcher = None
  137. self._summary_timestamp = None
  138. self._sample_index = -1
  139. self._full_network = SequentialCell([self._network, activation_fn])
  140. self._full_network.set_train(False)
  141. self._manifest = None
  142. self._verify_data_n_settings(check_data_n_network=True,
  143. check_environment=True)
  144. def register_saliency(self,
  145. explainers,
  146. benchmarkers=None):
  147. """
  148. Register saliency explanation instances.
  149. Note:
  150. This function can not be invoked more than once on each runner.
  151. Args:
  152. explainers (list[Attribution]): The explainers to be evaluated,
  153. see `mindspore.explainer.explanation`. All explainers' class must be distinct and their network
  154. must be the exact same instance of the runner's network.
  155. benchmarkers (list[AttributionMetric], optional): The benchmarkers for scoring the explainers,
  156. see `mindspore.explainer.benchmark`. All benchmarkers' class must be distinct.
  157. Raises:
  158. ValueError: Be raised for any data or settings' value problem.
  159. TypeError: Be raised for any data or settings' type problem.
  160. RuntimeError: Be raised if this function was invoked before.
  161. """
  162. check_value_type("explainers", explainers, list)
  163. if not all(isinstance(ele, Attribution) for ele in explainers):
  164. raise TypeError("Argument explainers is not list of mindspore.explainer.explanation .")
  165. if not explainers:
  166. raise ValueError("Argument explainers is empty.")
  167. if benchmarkers is not None:
  168. check_value_type("benchmarkers", benchmarkers, list)
  169. if not all(isinstance(ele, AttributionMetric) for ele in benchmarkers):
  170. raise TypeError("Argument benchmarkers is not list of mindspore.explainer.benchmark .")
  171. if self._explainers is not None:
  172. raise RuntimeError("Function register_saliency() was invoked already.")
  173. self._explainers = explainers
  174. self._benchmarkers = benchmarkers
  175. try:
  176. self._verify_data_n_settings(check_saliency=True, check_environment=True)
  177. except (ValueError, TypeError):
  178. self._explainers = None
  179. self._benchmarkers = None
  180. raise
  181. def register_hierarchical_occlusion(self):
  182. """
  183. Register hierarchical occlusion instances.
  184. Notes:
  185. Input images are required to be in 3 channels formats and the length of side short must be equals to or
  186. greater than 56 pixels. This function can not be invoked more than once on each runner.
  187. Raises:
  188. ValueError: Be raised for any data or settings' value problem.
  189. RuntimeError: Be raised if the function was called already.
  190. """
  191. if self._hoc_searcher is not None:
  192. raise RuntimeError("Function register_hierarchical_occlusion() was invoked already.")
  193. self._hoc_searcher = hoc.Searcher(self._full_network)
  194. try:
  195. self._verify_data_n_settings(check_hoc=True, check_environment=True)
  196. except ValueError:
  197. self._hoc_searcher = None
  198. raise
  199. def register_uncertainty(self):
  200. """
  201. Register uncertainty instance to compute the epistemic uncertainty base on the Bayes' theorem.
  202. Note:
  203. Please refer to the documentation of mindspore.nn.probability.toolbox.uncertainty_evaluation for the
  204. details. The actual output is standard deviation of the classification predictions and the corresponding
  205. 95% confidence intervals. Users have to invoke register_saliency() as well for the uncertainty results are
  206. going to be shown on the saliency map page in MindInsight. This function can not be invoked more then once
  207. on each runner.
  208. Raises:
  209. RuntimeError: Be raised if the function was called already.
  210. """
  211. if self._uncertainty is not None:
  212. raise RuntimeError("Function register_uncertainty() was invoked already.")
  213. self._uncertainty = UncertaintyEvaluation(model=self._full_network,
  214. train_dataset=None,
  215. task_type='classification',
  216. num_classes=len(self._labels))
  217. def run(self):
  218. """
  219. Run the explain job and save the result as a summary in summary_dir.
  220. Note:
  221. User should call register_saliency() once before running this function.
  222. Raises:
  223. ValueError: Be raised for any data or settings' value problem.
  224. TypeError: Be raised for any data or settings' type problem.
  225. RuntimeError: Be raised for any runtime problem.
  226. """
  227. self._verify_data_n_settings(check_all=True)
  228. self._manifest = {"saliency_map": False,
  229. "benchmark": False,
  230. "uncertainty": False,
  231. "hierarchical_occlusion": False}
  232. with SummaryRecord(self._summary_dir, raise_exception=True) as summary:
  233. print("Start running and writing......")
  234. begin = time()
  235. self._summary_timestamp = self._extract_timestamp(summary.event_file_name)
  236. if self._summary_timestamp is None:
  237. raise RuntimeError("Cannot extract timestamp from summary filename!"
  238. " It should contains a timestamp after 'summary.' .")
  239. self._save_metadata(summary)
  240. imageid_labels = self._run_inference(summary)
  241. sample_count = self._sample_index
  242. if self._is_saliency_registered:
  243. self._run_saliency(summary, imageid_labels)
  244. if not self._manifest["saliency_map"]:
  245. raise RuntimeError(
  246. f"No saliency map was generated in {sample_count} samples. "
  247. f"Please make sure the dataset, labels, activation function and network are properly trained "
  248. f"and configured.")
  249. if self._is_hoc_registered and not self._manifest["hierarchical_occlusion"]:
  250. raise RuntimeError(
  251. f"No Hierarchical Occlusion result was found in {sample_count} samples. "
  252. f"Please make sure the dataset, labels, activation function and network are properly trained "
  253. f"and configured.")
  254. self._save_manifest()
  255. print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin))
  256. @property
  257. def _is_hoc_registered(self):
  258. """Check if HOC module is registered."""
  259. return self._hoc_searcher is not None
  260. @property
  261. def _is_saliency_registered(self):
  262. """Check if saliency module is registered."""
  263. return bool(self._explainers)
  264. @property
  265. def _is_uncertainty_registered(self):
  266. """Check if uncertainty module is registered."""
  267. return self._uncertainty is not None
  268. def _save_metadata(self, summary):
  269. """Save metadata of the explain job to summary."""
  270. print("Start writing metadata......")
  271. explain = Explain()
  272. explain.metadata.label.extend(self._labels)
  273. if self._is_saliency_registered:
  274. exp_names = [exp.__class__.__name__ for exp in self._explainers]
  275. explain.metadata.explain_method.extend(exp_names)
  276. if self._benchmarkers is not None:
  277. bench_names = [bench.__class__.__name__ for bench in self._benchmarkers]
  278. explain.metadata.benchmark_method.extend(bench_names)
  279. summary.add_value("explainer", "metadata", explain)
  280. summary.record(1)
  281. print("Finish writing metadata.")
  282. def _run_inference(self, summary, threshold=0.5):
  283. """
  284. Run inference for the dataset and write the inference related data into summary.
  285. Args:
  286. summary (SummaryRecord): The summary object to store the data.
  287. threshold (float): The threshold for prediction.
  288. Returns:
  289. dict, The map of sample d to the union of its ground truth and predicted labels.
  290. """
  291. has_uncertainty_rec = False
  292. sample_id_labels = {}
  293. self._sample_index = 0
  294. ds.config.set_seed(self._DATASET_SEED)
  295. for j, next_element in enumerate(self._dataset):
  296. now = time()
  297. inputs, labels, _ = self._unpack_next_element(next_element)
  298. prob = self._full_network(inputs).asnumpy()
  299. if self._uncertainty is not None:
  300. prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
  301. else:
  302. prob_var = None
  303. for idx, inp in enumerate(inputs):
  304. gt_labels = labels[idx]
  305. gt_probs = [float(prob[idx][i]) for i in gt_labels]
  306. if prob_var is not None:
  307. gt_prob_vars = [float(prob_var[idx][i]) for i in gt_labels]
  308. gt_itl_lows, gt_itl_his, gt_prob_sds = \
  309. self._calc_beta_intervals(gt_probs, gt_prob_vars)
  310. data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
  311. original_image = _np_to_image(_normalize(data_np), mode='RGB')
  312. original_image_path = self._save_original_image(self._sample_index, original_image)
  313. predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]]
  314. predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
  315. if prob_var is not None:
  316. predicted_prob_vars = [float(prob_var[idx][i]) for i in predicted_labels]
  317. predicted_itl_lows, predicted_itl_his, predicted_prob_sds = \
  318. self._calc_beta_intervals(predicted_probs, predicted_prob_vars)
  319. union_labs = list(set(gt_labels + predicted_labels))
  320. sample_id_labels[str(self._sample_index)] = union_labs
  321. explain = Explain()
  322. explain.sample_id = self._sample_index
  323. explain.image_path = original_image_path
  324. summary.add_value("explainer", "sample", explain)
  325. explain = Explain()
  326. explain.sample_id = self._sample_index
  327. explain.ground_truth_label.extend(gt_labels)
  328. explain.inference.ground_truth_prob.extend(gt_probs)
  329. explain.inference.predicted_label.extend(predicted_labels)
  330. explain.inference.predicted_prob.extend(predicted_probs)
  331. if prob_var is not None:
  332. explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
  333. explain.inference.ground_truth_prob_itl95_low.extend(gt_itl_lows)
  334. explain.inference.ground_truth_prob_itl95_hi.extend(gt_itl_his)
  335. explain.inference.predicted_prob_sd.extend(predicted_prob_sds)
  336. explain.inference.predicted_prob_itl95_low.extend(predicted_itl_lows)
  337. explain.inference.predicted_prob_itl95_hi.extend(predicted_itl_his)
  338. has_uncertainty_rec = True
  339. summary.add_value("explainer", "inference", explain)
  340. summary.record(1)
  341. if self._is_hoc_registered:
  342. self._run_hoc(summary, self._sample_index, inputs[idx], prob[idx])
  343. self._sample_index += 1
  344. self._spaced_print("Finish running and writing {}-th batch inference data."
  345. " Time elapsed: {:.3f} s".format(j, time() - now))
  346. if has_uncertainty_rec:
  347. self._manifest["uncertainty"] = True
  348. return sample_id_labels
  349. def _run_saliency(self, summary, sample_id_labels):
  350. """Run the saliency explanations."""
  351. if self._benchmarkers is None or not self._benchmarkers:
  352. for exp in self._explainers:
  353. start = time()
  354. print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
  355. self._sample_index = 0
  356. ds.config.set_seed(self._DATASET_SEED)
  357. for idx, next_element in enumerate(self._dataset):
  358. now = time()
  359. self._spaced_print("Start running {}-th explanation data for {}......".format(
  360. idx, exp.__class__.__name__))
  361. self._run_exp_step(next_element, exp, sample_id_labels, summary)
  362. self._spaced_print("Finish writing {}-th explanation data for {}. Time elapsed: "
  363. "{:.3f} s".format(idx, exp.__class__.__name__, time() - now))
  364. self._spaced_print(
  365. "Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format(
  366. exp.__class__.__name__, time() - start))
  367. else:
  368. for exp in self._explainers:
  369. explain = Explain()
  370. for bench in self._benchmarkers:
  371. bench.reset()
  372. print(f"Start running and writing explanation and "
  373. f"benchmark data for {exp.__class__.__name__}......")
  374. self._sample_index = 0
  375. start = time()
  376. ds.config.set_seed(self._DATASET_SEED)
  377. for idx, next_element in enumerate(self._dataset):
  378. now = time()
  379. self._spaced_print("Start running {}-th explanation data for {}......".format(
  380. idx, exp.__class__.__name__))
  381. saliency_dict_lst = self._run_exp_step(next_element, exp, sample_id_labels, summary)
  382. self._spaced_print(
  383. "Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format(
  384. idx, exp.__class__.__name__, time() - now))
  385. for bench in self._benchmarkers:
  386. now = time()
  387. self._spaced_print(
  388. "Start running {}-th batch {} data for {}......".format(
  389. idx, bench.__class__.__name__, exp.__class__.__name__))
  390. self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
  391. self._spaced_print(
  392. "Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format(
  393. idx, bench.__class__.__name__, exp.__class__.__name__, time() - now))
  394. for bench in self._benchmarkers:
  395. benchmark = explain.benchmark.add()
  396. benchmark.explain_method = exp.__class__.__name__
  397. benchmark.benchmark_method = bench.__class__.__name__
  398. benchmark.total_score = bench.performance
  399. if isinstance(bench, LabelSensitiveMetric):
  400. benchmark.label_score.extend(bench.class_performances)
  401. self._spaced_print("Finish running and writing explanation and benchmark data for {}. "
  402. "Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start))
  403. summary.add_value('explainer', 'benchmark', explain)
  404. summary.record(1)
  405. def _run_hoc(self, summary, sample_id, sample_input, prob):
  406. """
  407. Run HOC search for a sample image, and then save the result to summary.
  408. Args:
  409. summary (SummaryRecord): The summary object to store the data.
  410. sample_id (int): The sample ID.
  411. sample_input (Union[Tensor, np.ndarray]): Sample image tensor in CHW or NCWH(N=1).
  412. prob (Union[Tensor, np.ndarray]): List of sample's classification prediction output, HOC will run for
  413. labels with prediction output strictly larger then HOC searcher's threshold(0.5 by default).
  414. """
  415. if isinstance(sample_input, ms.Tensor):
  416. sample_input = sample_input.asnumpy()
  417. if len(sample_input.shape) == 3:
  418. sample_input = np.expand_dims(sample_input, axis=0)
  419. has_rec = False
  420. explain = Explain()
  421. explain.sample_id = sample_id
  422. str_mask = hoc.auto_str_mask(sample_input)
  423. compiled_mask = None
  424. for label_idx, label_prob in enumerate(prob):
  425. if label_prob > self._hoc_searcher.threshold:
  426. if compiled_mask is None:
  427. compiled_mask = hoc.compile_mask(str_mask, sample_input)
  428. try:
  429. edit_tree, layer_outputs = self._hoc_searcher.search(sample_input, label_idx, compiled_mask)
  430. except hoc.NoValidResultError:
  431. log.warning(f"No Hierarchical Occlusion result was found in sample#{sample_id} "
  432. f"label:{self._labels[label_idx]}, skipped.")
  433. continue
  434. has_rec = True
  435. hoc_rec = explain.hoc.add()
  436. hoc_rec.label = label_idx
  437. hoc_rec.mask = str_mask
  438. layer_count = edit_tree.max_layer + 1
  439. for layer in range(layer_count):
  440. steps = edit_tree.get_layer_or_leaf_steps(layer)
  441. layer_output = layer_outputs[layer]
  442. hoc_layer = hoc_rec.layer.add()
  443. hoc_layer.prob = layer_output
  444. for step in steps:
  445. hoc_layer.box.extend(list(step.box))
  446. if has_rec:
  447. summary.add_value("explainer", "hoc", explain)
  448. summary.record(1)
  449. self._manifest['hierarchical_occlusion'] = True
  450. def _run_exp_step(self, next_element, explainer, sample_id_labels, summary):
  451. """
  452. Run the explanation for each step and write explanation results into summary.
  453. Args:
  454. next_element (Tuple): Data of one step
  455. explainer (_Attribution): An Attribution object to generate saliency maps.
  456. sample_id_labels (dict): A dict that maps the sample id and its union labels.
  457. summary (SummaryRecord): The summary object to store the data.
  458. Returns:
  459. list, List of dict that maps label to its corresponding saliency map.
  460. """
  461. has_saliency_rec = False
  462. inputs, labels, _ = self._unpack_next_element(next_element)
  463. sample_index = self._sample_index
  464. unions = []
  465. for _ in range(len(labels)):
  466. unions_labels = sample_id_labels[str(sample_index)]
  467. unions.append(unions_labels)
  468. sample_index += 1
  469. batch_unions = self._make_label_batch(unions)
  470. saliency_dict_lst = []
  471. if isinstance(explainer, RISE):
  472. batch_saliency_full = explainer(inputs, batch_unions)
  473. else:
  474. batch_saliency_full = []
  475. for i in range(len(batch_unions[0])):
  476. batch_saliency = explainer(inputs, batch_unions[:, i])
  477. batch_saliency_full.append(batch_saliency)
  478. concat = ms.ops.operations.Concat(1)
  479. batch_saliency_full = concat(tuple(batch_saliency_full))
  480. for idx, union in enumerate(unions):
  481. saliency_dict = {}
  482. explain = Explain()
  483. explain.sample_id = self._sample_index
  484. for k, lab in enumerate(union):
  485. saliency = batch_saliency_full[idx:idx + 1, k:k + 1]
  486. saliency_dict[lab] = saliency
  487. saliency_np = _normalize(saliency.asnumpy().squeeze())
  488. saliency_image = _np_to_image(saliency_np, mode='L')
  489. heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab, self._sample_index, saliency_image)
  490. explanation = explain.explanation.add()
  491. explanation.explain_method = explainer.__class__.__name__
  492. explanation.heatmap_path = heatmap_path
  493. explanation.label = lab
  494. has_saliency_rec = True
  495. summary.add_value("explainer", "explanation", explain)
  496. summary.record(1)
  497. self._sample_index += 1
  498. saliency_dict_lst.append(saliency_dict)
  499. if has_saliency_rec:
  500. self._manifest['saliency_map'] = True
  501. return saliency_dict_lst
  502. def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
  503. """Run the explanation and evaluation for each step and write explanation results into summary."""
  504. inputs, labels, _ = self._unpack_next_element(next_element)
  505. for idx, inp in enumerate(inputs):
  506. inp = _EXPAND_DIMS(inp, 0)
  507. if isinstance(benchmarker, LabelAgnosticMetric):
  508. res = benchmarker.evaluate(explainer, inp)
  509. benchmarker.aggregate(res)
  510. else:
  511. saliency_dict = saliency_dict_lst[idx]
  512. for label, saliency in saliency_dict.items():
  513. if isinstance(benchmarker, Localization):
  514. _, _, bboxes = self._unpack_next_element(next_element, True)
  515. if label in labels[idx]:
  516. res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
  517. saliency=saliency)
  518. benchmarker.aggregate(res, label)
  519. elif isinstance(benchmarker, LabelSensitiveMetric):
  520. res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
  521. benchmarker.aggregate(res, label)
  522. else:
  523. raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
  524. 'receive {}'.format(type(benchmarker)))
  525. self._manifest['benchmark'] = True
  526. @staticmethod
  527. def _calc_beta_intervals(means, variances, prob=0.95):
  528. """Calculate confidence interval of beta distributions."""
  529. if not isinstance(means, np.ndarray):
  530. means = np.array(means)
  531. if not isinstance(variances, np.ndarray):
  532. variances = np.array(variances)
  533. with np.errstate(divide='ignore'):
  534. coef_a = ((means ** 2) * (1 - means) / variances) - means
  535. coef_b = (coef_a * (1 - means)) / means
  536. itl_lows, itl_his = beta.interval(prob, coef_a, coef_b)
  537. sds = np.sqrt(variances)
  538. for i in range(itl_lows.shape[0]):
  539. if not np.isfinite(sds[i]) or not np.isfinite(itl_lows[i]) or not np.isfinite(itl_his[i]):
  540. itl_lows[i] = means[i]
  541. itl_his[i] = means[i]
  542. sds[i] = 0
  543. return itl_lows, itl_his, sds
  544. def _verify_labels(self):
  545. """Verify labels."""
  546. label_set = set()
  547. if not self._labels:
  548. raise ValueError(f"The label list provided is empty.")
  549. for i, label in enumerate(self._labels):
  550. if label.strip() == "":
  551. raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is "
  552. f"no empty label.")
  553. if label in label_set:
  554. raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.")
  555. label_set.add(label)
  556. def _verify_ds_sample(self, sample):
  557. """Verify a dataset sample."""
  558. if len(sample) not in [1, 2, 3]:
  559. raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
  560. " as columns.")
  561. if len(sample) == 3:
  562. inputs, labels, bboxes = sample
  563. if bboxes.shape[-1] != 4:
  564. raise ValueError("The third element of dataset should be bounding boxes with shape of "
  565. "[batch_size, num_ground_truth, 4].")
  566. else:
  567. if self._benchmarkers is not None:
  568. if any([isinstance(bench, Localization) for bench in self._benchmarkers]):
  569. raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
  570. if len(sample) == 2:
  571. inputs, labels = sample
  572. if len(sample) == 1:
  573. inputs = sample[0]
  574. if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
  575. raise ValueError(
  576. "Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
  577. inputs.shape))
  578. if len(inputs.shape) == 3:
  579. log.warning(
  580. "Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
  581. " dimension as batch data.".format(inputs.shape))
  582. if len(sample) > 1:
  583. if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
  584. raise ValueError(
  585. "Labels shape {} is unrecognizable: outputs should not have more than two dimensions"
  586. " with length greater than 1.".format(labels.shape))
  587. if self._is_hoc_registered:
  588. if inputs.shape[-3] != 3:
  589. raise ValueError(
  590. "Hierarchical occlusion is registered, images must be in 3 channels format, but "
  591. "{} channel(s) is(are) encountered.".format(inputs.shape[-3]))
  592. short_side = min(inputs.shape[-2:])
  593. if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN:
  594. raise ValueError(
  595. "Hierarchical occlusion is registered, images' short side must be equals to or greater then "
  596. "{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side))
  597. def _verify_data(self):
  598. """Verify dataset and labels."""
  599. self._verify_labels()
  600. try:
  601. sample = next(self._dataset.create_tuple_iterator())
  602. except StopIteration:
  603. raise ValueError("The dataset provided is empty.")
  604. self._verify_ds_sample(sample)
  605. def _verify_network(self):
  606. """Verify the network."""
  607. next_element = next(self._dataset.create_tuple_iterator())
  608. inputs, _, _ = self._unpack_next_element(next_element)
  609. prop_test = self._full_network(inputs)
  610. check_value_type("output of network in explainer", prop_test, ms.Tensor)
  611. if prop_test.shape[1] != len(self._labels):
  612. raise ValueError("The dimension of network output does not match the no. of classes. Please "
  613. "check labels or the network in the explainer again.")
  614. def _verify_saliency(self):
  615. """Verify the saliency settings."""
  616. if self._explainers:
  617. explainer_classes = []
  618. for explainer in self._explainers:
  619. if explainer.__class__ in explainer_classes:
  620. raise ValueError(f"Repeated {explainer.__class__.__name__} explainer! "
  621. "Please make sure all explainers' class is distinct.")
  622. if explainer.network is not self._network:
  623. raise ValueError(f"The network of {explainer.__class__.__name__} explainer is different "
  624. "instance from network of runner. Please make sure they are the same "
  625. "instance.")
  626. explainer_classes.append(explainer.__class__)
  627. if self._benchmarkers:
  628. benchmarker_classes = []
  629. for benchmarker in self._benchmarkers:
  630. if benchmarker.__class__ in benchmarker_classes:
  631. raise ValueError(f"Repeated {benchmarker.__class__.__name__} benchmarker! "
  632. "Please make sure all benchmarkers' class is distinct.")
  633. if isinstance(benchmarker, LabelSensitiveMetric) and benchmarker.num_labels != len(self._labels):
  634. raise ValueError(f"The num_labels of {benchmarker.__class__.__name__} benchmarker is different "
  635. "from no. of labels of runner. Please make them are the same.")
  636. benchmarker_classes.append(benchmarker.__class__)
  637. def _verify_data_n_settings(self,
  638. check_all=False,
  639. check_registration=False,
  640. check_data_n_network=False,
  641. check_saliency=False,
  642. check_hoc=False,
  643. check_environment=False):
  644. """
  645. Verify the validity of dataset and other settings.
  646. Args:
  647. check_all (bool): Set it True for checking everything.
  648. check_registration (bool): Set it True for checking registrations, check if it is enough to invoke run().
  649. check_data_n_network (bool): Set it True for checking data and network.
  650. check_saliency (bool): Set it True for checking saliency related settings.
  651. check_hoc (bool): Set it True for checking HOC related settings.
  652. check_environment (bool): Set it True for checking environment conditions.
  653. Raises:
  654. ValueError: Be raised for any data or settings' value problem.
  655. TypeError: Be raised for any data or settings' type problem.
  656. RuntimeError: Be raised for any runtime problem.
  657. """
  658. if check_all:
  659. check_registration = True
  660. check_data_n_network = True
  661. check_saliency = True
  662. check_hoc = True
  663. check_environment = True
  664. if check_environment:
  665. device_target = context.get_context('device_target')
  666. if device_target not in ("Ascend", "GPU"):
  667. raise RuntimeError(f"Unsupported device_target: '{device_target}', "
  668. f"only 'Ascend' or 'GPU' is supported. "
  669. f"Please call context.set_context(device_target='Ascend') or "
  670. f"context.set_context(device_target='GPU').")
  671. if check_environment or check_saliency:
  672. if self._is_saliency_registered:
  673. mode = context.get_context('mode')
  674. if mode != context.PYNATIVE_MODE:
  675. raise RuntimeError("Context mode: GRAPH_MODE is not supported, "
  676. "please call context.set_context(mode=context.PYNATIVE_MODE).")
  677. if check_registration:
  678. if self._is_uncertainty_registered and not self._is_saliency_registered:
  679. raise ValueError("Function register_uncertainty() is called but register_saliency() is not.")
  680. if not self._is_saliency_registered and not self._is_hoc_registered:
  681. raise ValueError("No explanation module was registered, user should at least call register_saliency() "
  682. "or register_hierarchical_occlusion() once with proper arguments.")
  683. if check_data_n_network or check_saliency or check_hoc:
  684. self._verify_data()
  685. if check_data_n_network:
  686. self._verify_network()
  687. if check_saliency:
  688. self._verify_saliency()
  689. def _transform_data(self, inputs, labels, bboxes, ifbbox):
  690. """
  691. Transform the data from one iteration of dataset to a unifying form for the follow-up operations.
  692. Args:
  693. inputs (Tensor): the image data
  694. labels (Tensor): the labels
  695. bboxes (Tensor): the boudnding boxes data
  696. ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t
  697. label id will be returned. If False, the returned bboxes is the the parsed bboxes.
  698. Returns:
  699. inputs (Tensor): the image data, unified to a 4D Tensor.
  700. labels (list[list[int]]): the ground truth labels.
  701. bboxes (Union[list[dict], None, Tensor]): the bounding boxes
  702. """
  703. inputs = ms.Tensor(inputs, ms.float32)
  704. if len(inputs.shape) == 3:
  705. inputs = _EXPAND_DIMS(inputs, 0)
  706. if isinstance(labels, ms.Tensor):
  707. labels = ms.Tensor(labels, ms.int32)
  708. labels = _EXPAND_DIMS(labels, 0)
  709. if isinstance(bboxes, ms.Tensor):
  710. bboxes = ms.Tensor(bboxes, ms.int32)
  711. bboxes = _EXPAND_DIMS(bboxes, 0)
  712. input_len = len(inputs)
  713. if bboxes is not None and ifbbox:
  714. bboxes = ms.Tensor(bboxes, ms.int32)
  715. masks_lst = []
  716. labels = labels.asnumpy().reshape([input_len, -1])
  717. bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
  718. for idx, label in enumerate(labels):
  719. height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
  720. masks = {}
  721. for j, label_item in enumerate(label):
  722. target = int(label_item)
  723. if -1 < target < len(self._labels):
  724. if target not in masks:
  725. mask = np.zeros((1, 1, height, width))
  726. else:
  727. mask = masks[target]
  728. x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
  729. mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
  730. masks[target] = mask
  731. masks_lst.append(masks)
  732. bboxes = masks_lst
  733. labels = ms.Tensor(labels, ms.int32)
  734. if len(labels.shape) == 1:
  735. labels_lst = [[int(i)] for i in labels.asnumpy()]
  736. else:
  737. labels = labels.asnumpy().reshape([input_len, -1])
  738. labels_lst = []
  739. for item in labels:
  740. labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._labels))))
  741. labels = labels_lst
  742. return inputs, labels, bboxes
  743. def _unpack_next_element(self, next_element, ifbbox=False):
  744. """
  745. Unpack a single iteration of dataset.
  746. Args:
  747. next_element (Tuple): a single element iterated from dataset object.
  748. ifbbox (bool): whether to preprocess bboxes in self._transform_data.
  749. Returns:
  750. tuple, a unified Tuple contains image_data, labels, and bounding boxes.
  751. """
  752. if len(next_element) == 3:
  753. inputs, labels, bboxes = next_element
  754. elif len(next_element) == 2:
  755. inputs, labels = next_element
  756. bboxes = None
  757. else:
  758. inputs = next_element[0]
  759. labels = [[] for _ in inputs]
  760. bboxes = None
  761. inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
  762. return inputs, labels, bboxes
  763. @staticmethod
  764. def _make_label_batch(labels):
  765. """
  766. Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max
  767. length of all the rows in labels.
  768. Args:
  769. labels (List[List]): the union labels of a data batch.
  770. Returns:
  771. 2D Tensor.
  772. """
  773. max_len = max([len(label) for label in labels])
  774. batch_labels = np.zeros((len(labels), max_len))
  775. for idx, _ in enumerate(batch_labels):
  776. length = len(labels[idx])
  777. batch_labels[idx, :length] = np.array(labels[idx])
  778. return ms.Tensor(batch_labels, ms.int32)
  779. def _save_manifest(self):
  780. """Save manifest.json underneath datafile directory."""
  781. if self._manifest is None:
  782. raise RuntimeError("Manifest not yet be initialized.")
  783. path_tokens = [self._summary_dir,
  784. self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp)]
  785. abs_dir_path = self._create_subdir(*path_tokens)
  786. save_path = os.path.join(abs_dir_path, self._MANIFEST_FILENAME)
  787. with open(save_path, 'w') as file:
  788. json.dump(self._manifest, file, indent=4)
  789. os.chmod(save_path, self._FILE_MODE)
  790. def _save_original_image(self, sample_id, image):
  791. """Save an image to summary directory."""
  792. id_dirname = self._get_sample_dirname(sample_id)
  793. path_tokens = [self._summary_dir,
  794. self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
  795. self._ORIGINAL_IMAGE_DIRNAME,
  796. id_dirname]
  797. abs_dir_path = self._create_subdir(*path_tokens)
  798. filename = f"{sample_id}.jpg"
  799. save_path = os.path.join(abs_dir_path, filename)
  800. image.save(save_path)
  801. os.chmod(save_path, self._FILE_MODE)
  802. return os.path.join(*path_tokens[1:], filename)
  803. def _save_heatmap(self, explain_method, class_id, sample_id, image):
  804. """Save heatmap image to summary directory."""
  805. id_dirname = self._get_sample_dirname(sample_id)
  806. path_tokens = [self._summary_dir,
  807. self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
  808. self._HEATMAP_DIRNAME,
  809. explain_method,
  810. id_dirname]
  811. abs_dir_path = self._create_subdir(*path_tokens)
  812. filename = f"{sample_id}_{class_id}.jpg"
  813. save_path = os.path.join(abs_dir_path, filename)
  814. image.save(save_path, optimize=True)
  815. os.chmod(save_path, self._FILE_MODE)
  816. return os.path.join(*path_tokens[1:], filename)
  817. def _create_subdir(self, *args):
  818. """Recursively create subdirectories."""
  819. abs_path = None
  820. for token in args:
  821. if abs_path is None:
  822. abs_path = os.path.realpath(token)
  823. else:
  824. abs_path = os.path.join(abs_path, token)
  825. # os.makedirs() don't set intermediate dir permission properly, we mkdir() one by one
  826. try:
  827. os.mkdir(abs_path, mode=self._DIR_MODE)
  828. # In some platform, mode may be ignored in os.mkdir(), we have to chmod() again to make sure
  829. os.chmod(abs_path, mode=self._DIR_MODE)
  830. except FileExistsError:
  831. pass
  832. return abs_path
  833. @classmethod
  834. def _get_sample_dirname(cls, sample_id):
  835. """Get the name of parent directory of the image id."""
  836. return str(int(sample_id / cls._SAMPLE_PER_DIR) * cls._SAMPLE_PER_DIR)
  837. @staticmethod
  838. def _extract_timestamp(filename):
  839. """Extract timestamp from summary filename."""
  840. matched = re.search(r"summary\.(\d+)", filename)
  841. if matched:
  842. return int(matched.group(1))
  843. return None
  844. @classmethod
  845. def _spaced_print(cls, message):
  846. """Spaced message printing."""
  847. # workaround to print logs starting new line in case line width mismatch.
  848. print(cls._SPACER.format(message))