You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_runner.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Runner."""
  16. from time import time
  17. from typing import Tuple, List, Optional
  18. import numpy as np
  19. from mindspore.train.summary_pb2 import Explain
  20. import mindspore as ms
  21. import mindspore.dataset as ds
  22. from mindspore import log
  23. from mindspore.ops.operations import ExpandDims
  24. from mindspore.train.summary._summary_adapter import _convert_image_format, _make_image
  25. from mindspore.train.summary.summary_record import SummaryRecord
  26. from .benchmark import Localization
  27. from .benchmark._attribution.metric import AttributionMetric
  28. from .explanation._attribution._attribution import Attribution
  29. _EXPAND_DIMS = ExpandDims()
  30. _CMAP_0 = np.reshape(np.array([55, 25, 86, 255]), [1, 1, 4]) / 255
  31. _CMAP_1 = np.reshape(np.array([255, 255, 0, 255]), [1, 1, 4]) / 255
  32. def _normalize(img_np):
  33. """Normalize the image in the numpy array to be in [0, 255]. """
  34. max_ = img_np.max()
  35. min_ = img_np.min()
  36. normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
  37. return (normed * 255).astype(np.uint8)
  38. def _make_rgba(saliency):
  39. """Make rgba image for saliency map."""
  40. saliency = saliency.asnumpy().squeeze()
  41. saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min()).clip(1e-10)
  42. rgba = np.empty((saliency.shape[0], saliency.shape[1], 4))
  43. rgba[:, :, :] = np.expand_dims(saliency, 2)
  44. rgba = rgba * _CMAP_1 + (1 - rgba) * _CMAP_0
  45. rgba[:, :, -1] = saliency * 1
  46. return rgba
  47. class ExplainRunner:
  48. """
  49. High-level API for users to generate results with the explanation methods and the evaluation methods.
  50. After generating results with the explanation methods and the evaluation methods, the results will be written into
  51. a specified file with 'mindspore.summary.SummaryRecord'. The stored content can be viewed using MindInsight.
  52. Args:
  53. summary_dir (str): The directory path to save the summary files which store the generated results.
  54. Default: "./"
  55. Examples:
  56. >>> # init a runner with a specified directory
  57. >>> summary_dir = "summary_dir"
  58. >>> runner = ExplainRunner(summary_dir)
  59. """
  60. def __init__(self, summary_dir: Optional[str] = "./"):
  61. self._summary_dir = summary_dir
  62. self._count = 0
  63. self._classes = None
  64. self._model = None
  65. def run(self,
  66. dataset: Tuple,
  67. explainers: List,
  68. benchmarkers: Optional[List] = None):
  69. """
  70. Genereate results and write results into the summary files in `self.summary_dir`.
  71. Args:
  72. dataset (tuple): A tuple that contains `mindspore.dataset` object for iteration and its labels.
  73. - dataset[0], a `mindspore.dataset` object to provide data to explain.
  74. - dataset[1], a list of string that specifies the label names of the dataset.
  75. explainers (list): A list of explanation objects to generate _attribution results.
  76. benchmarkers (list): A list of benchmark objects to generate evaluation results. Default: None
  77. Examples:
  78. >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
  79. >>> # obtain dataset object
  80. >>> dataset = get_dataset()
  81. >>> classes = ["cat", "dog", ...]
  82. >>> # load checkpoint to a network, e.g. resnet50
  83. >>> param_dict = load_checkpoint("checkpoint.ckpt")
  84. >>> net = resnet50(len(classes))
  85. >>> load_parama_into_net(net, param_dict)
  86. >>> # bind net with its output activation
  87. >>> model = nn.SequentialCell([net, nn.Sigmoid()])
  88. >>> gbp = GuidedBackprop(model)
  89. >>> gradient = Gradient(model)
  90. >>> runner = ExplainRunner("./")
  91. >>> explainers = [gbp, gradient]
  92. >>> runner.run((dataset, classes), explainers)
  93. """
  94. if not isinstance(dataset, tuple):
  95. raise TypeError("Argument `dataset` must be a tuple.")
  96. if len(dataset) != 2:
  97. raise ValueError("Argument `dataset` should be a tuple with length = 2.")
  98. dataset, classes = dataset
  99. self._verify_data_form(dataset, benchmarkers)
  100. self._classes = classes
  101. if explainers is None or not explainers:
  102. raise ValueError("Argument `explainers` can neither be None nor empty.")
  103. for exp in explainers:
  104. if not isinstance(exp, Attribution) or not isinstance(explainers, list):
  105. raise TypeError("Argument explainers should be a list of objects of classes in "
  106. "`mindspore.explainer.explanation._attribution`.")
  107. if benchmarkers is not None:
  108. for bench in benchmarkers:
  109. if not isinstance(bench, AttributionMetric) or not isinstance(explainers, list):
  110. raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation"
  111. "`mindspore.explainer.benchmark._attribution`.")
  112. self._model = explainers[0].model
  113. with SummaryRecord(self._summary_dir) as summary:
  114. print("Start running and writing......")
  115. begin = time()
  116. print("Start writing metadata.")
  117. explain = Explain()
  118. explain.metadata.label.extend(classes)
  119. exp_names = [exp.__class__.__name__ for exp in explainers]
  120. explain.metadata.explain_method.extend(exp_names)
  121. if benchmarkers is not None:
  122. bench_names = [bench.__class__.__name__ for bench in benchmarkers]
  123. explain.metadata.benchmark_method.extend(bench_names)
  124. summary.add_value("explainer", "metadata", explain)
  125. summary.record(1)
  126. print("Finish writing metadata.")
  127. now = time()
  128. print("Start running and writing inference data......")
  129. imageid_labels = self._run_inference(dataset, summary)
  130. print("Finish running and writing inference data. Time elapsed: {}s".format(time() - now))
  131. if benchmarkers is None:
  132. for exp in explainers:
  133. start = time()
  134. print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
  135. self._count = 0
  136. ds.config.set_seed(58)
  137. for idx, next_element in enumerate(dataset):
  138. now = time()
  139. self._run_exp_step(next_element, exp, imageid_labels, summary)
  140. print("Finish writing {}-th explanation data. Time elapsed: {}".format(
  141. idx, time() - now))
  142. print("Finish running and writing explanation data for {}. Time elapsed: {}".format(
  143. exp.__class__.__name__, time() - start))
  144. else:
  145. for exp in explainers:
  146. explain = Explain()
  147. for bench in benchmarkers:
  148. bench.reset()
  149. print(f"Start running and writing explanation and benchmark data for {exp.__class__.__name__}.")
  150. self._count = 0
  151. start = time()
  152. ds.config.set_seed(58)
  153. for idx, next_element in enumerate(dataset):
  154. now = time()
  155. saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary)
  156. print("Finish writing {}-th batch explanation data. Time elapsed: {}s".format(
  157. idx, time() - now))
  158. for bench in benchmarkers:
  159. now = time()
  160. self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
  161. print("Finish running {}-th batch benchmark data for {}. Time elapsed: {}s".format(
  162. idx, bench.__class__.__name__, time() - now))
  163. for bench in benchmarkers:
  164. benchmark = explain.benchmark.add()
  165. benchmark.explain_method = exp.__class__.__name__
  166. benchmark.benchmark_method = bench.__class__.__name__
  167. benchmark.total_score = bench.performance
  168. benchmark.label_score.extend(bench.class_performances)
  169. print("Finish running and writing explanation and benchmark data for {}. "
  170. "Time elapsed: {}s".format(exp.__class__.__name__, time() - start))
  171. summary.add_value('explainer', 'benchmark', explain)
  172. summary.record(1)
  173. print("Finish running and writing. Total time elapsed: {}s".format(time() - begin))
  174. @staticmethod
  175. def _verify_data_form(dataset, benchmarkers):
  176. """
  177. Verify the validity of dataset.
  178. Args:
  179. dataset (`ds`): the user parsed dataset.
  180. benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers.
  181. """
  182. next_element = dataset.create_tuple_iterator().get_next()
  183. if len(next_element) not in [1, 2, 3]:
  184. raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
  185. " as columns.")
  186. if len(next_element) == 3:
  187. inputs, labels, bboxes = next_element
  188. if bboxes.shape[-1] != 4:
  189. raise ValueError("The third element of dataset should be bounding boxes with shape of "
  190. "[batch_size, num_ground_truth, 4].")
  191. else:
  192. if True in [isinstance(bench, Localization) for bench in benchmarkers]:
  193. raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
  194. if len(next_element) == 2:
  195. inputs, labels = next_element
  196. if len(next_element) == 1:
  197. inputs = next_element[0]
  198. if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
  199. raise ValueError(
  200. "Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
  201. inputs.shape))
  202. if len(inputs.shape) == 3:
  203. log.warning(
  204. "Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
  205. " dimension as batch data.".format(inputs.shape))
  206. if len(next_element) > 1:
  207. if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
  208. raise ValueError(
  209. "Labels shape {} is unrecognizable: labels should not have more than two dimensions"
  210. " with length greater than 1.".format(labels.shape))
  211. def _transform_data(self, inputs, labels, bboxes, ifbbox):
  212. """
  213. Transform the data from one iteration of dataset to a unifying form for the follow-up operations.
  214. Args:
  215. inputs (Tensor): the image data
  216. labels (Tensor): the labels
  217. bboxes (Tensor): the boudnding boxes data
  218. ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t label
  219. id will be returned. If False, the returned bboxes is the the parsed bboxes.
  220. Returns:
  221. inputs (Tensor): the image data, unified to a 4D Tensor.
  222. labels (List[List[int]]): the ground truth labels.
  223. bboxes (Union[List[Dict], None, Tensor]): the bounding boxes
  224. """
  225. inputs = ms.Tensor(inputs, ms.float32)
  226. if len(inputs.shape) == 3:
  227. inputs = _EXPAND_DIMS(inputs, 0)
  228. if isinstance(labels, ms.Tensor):
  229. labels = ms.Tensor(labels, ms.int32)
  230. labels = _EXPAND_DIMS(labels, 0)
  231. if isinstance(bboxes, ms.Tensor):
  232. bboxes = ms.Tensor(bboxes, ms.int32)
  233. bboxes = _EXPAND_DIMS(bboxes, 0)
  234. input_len = len(inputs)
  235. if bboxes is not None and ifbbox:
  236. bboxes = ms.Tensor(bboxes, ms.int32)
  237. masks_lst = []
  238. labels = labels.asnumpy().reshape([input_len, -1])
  239. bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
  240. for idx, label in enumerate(labels):
  241. height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
  242. masks = {}
  243. for j, label_item in enumerate(label):
  244. target = int(label_item)
  245. if -1 < target < len(self._classes):
  246. if target not in masks:
  247. mask = np.zeros((1, 1, height, width))
  248. else:
  249. mask = masks[target]
  250. x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
  251. mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
  252. masks[target] = mask
  253. masks_lst.append(masks)
  254. bboxes = masks_lst
  255. labels = ms.Tensor(labels, ms.int32)
  256. if len(labels.shape) == 1:
  257. labels_lst = [[int(i)] for i in labels.asnumpy()]
  258. else:
  259. labels = labels.asnumpy().reshape([input_len, -1])
  260. labels_lst = []
  261. for item in labels:
  262. labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._classes))))
  263. labels = labels_lst
  264. return inputs, labels, bboxes
  265. def _unpack_next_element(self, next_element, ifbbox=False):
  266. """
  267. Unpack a single iteration of dataset.
  268. Args:
  269. next_element (Tuple): a single element iterated from dataset object.
  270. ifbbox (bool): whether to preprocess bboxes in self._transform_data.
  271. Returns:
  272. Tuple, a unified Tuple contains image_data, labels, and bounding boxes.
  273. """
  274. if len(next_element) == 3:
  275. inputs, labels, bboxes = next_element
  276. elif len(next_element) == 2:
  277. inputs, labels = next_element
  278. bboxes = None
  279. else:
  280. inputs = next_element[0]
  281. labels = [[] for x in inputs]
  282. bboxes = None
  283. inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
  284. return inputs, labels, bboxes
  285. @staticmethod
  286. def _make_label_batch(labels):
  287. """
  288. Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max
  289. length of all the rows in labels.
  290. Args:
  291. labels (List[List]): the union labels of a data batch.
  292. Returns:
  293. 2D Tensor.
  294. """
  295. max_len = max([len(l) for l in labels])
  296. batch_labels = np.zeros((len(labels), max_len))
  297. for idx, _ in enumerate(batch_labels):
  298. length = len(labels[idx])
  299. batch_labels[idx, :length] = np.array(labels[idx])
  300. return ms.Tensor(batch_labels, ms.int32)
  301. def _run_inference(self, dataset, summary, threshod=0.5):
  302. """
  303. Run inference for the dataset and write the inference related data into summary.
  304. Args:
  305. dataset (`ds`): the parsed dataset
  306. summary (`SummaryRecord`): the summary object to store the data
  307. threshold (float): the threshold for prediction.
  308. Returns:
  309. imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels.
  310. """
  311. imageid_labels = {}
  312. ds.config.set_seed(58)
  313. self._count = 0
  314. for j, next_element in enumerate(dataset):
  315. now = time()
  316. inputs, labels, _ = self._unpack_next_element(next_element)
  317. prob = self._model(inputs).asnumpy()
  318. for idx, inp in enumerate(inputs):
  319. gt_labels = labels[idx]
  320. gt_probs = [float(prob[idx][i]) for i in gt_labels]
  321. data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
  322. _, _, _, image_string = _make_image(_normalize(data_np))
  323. predicted_labels = [int(i) for i in (prob[idx] > threshod).nonzero()[0]]
  324. predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
  325. union_labs = list(set(gt_labels + predicted_labels))
  326. imageid_labels[str(self._count)] = union_labs
  327. explain = Explain()
  328. explain.image_id = str(self._count)
  329. explain.image_data = image_string
  330. summary.add_value("explainer", "image", explain)
  331. explain = Explain()
  332. explain.image_id = str(self._count)
  333. explain.ground_truth_label.extend(gt_labels)
  334. explain.inference.ground_truth_prob.extend(gt_probs)
  335. explain.inference.predicted_label.extend(predicted_labels)
  336. explain.inference.predicted_prob.extend(predicted_probs)
  337. summary.add_value("explainer", "inference", explain)
  338. summary.record(1)
  339. self._count += 1
  340. print("Finish running and writing {}-th batch inference data. Time elapsed: {}s".format(j, time() - now))
  341. return imageid_labels
  342. def _run_exp_step(self, next_element, explainer, imageid_labels, summary):
  343. """
  344. Run the explanation for each step and write explanation results into summary.
  345. Args:
  346. next_element (Tuple): data of one step
  347. explainer (_Attribution): an Attribution object to generate saliency maps.
  348. imageid_labels (dict): a dict that maps the image_id and its union labels.
  349. summary (SummaryRecord): the summary object to store the data
  350. Returns:
  351. List of dict that maps label to its corresponding saliency map.
  352. """
  353. inputs, labels, _ = self._unpack_next_element(next_element)
  354. count = self._count
  355. unions = []
  356. for _ in range(len(labels)):
  357. unions_labels = imageid_labels[str(count)]
  358. unions.append(unions_labels)
  359. count += 1
  360. batch_unions = self._make_label_batch(unions)
  361. saliency_dict_lst = []
  362. batch_saliency_full = []
  363. for i in range(len(batch_unions[0])):
  364. batch_saliency = explainer(inputs, batch_unions[:, i])
  365. batch_saliency_full.append(batch_saliency)
  366. for idx, union in enumerate(unions):
  367. saliency_dict = {}
  368. explain = Explain()
  369. explain.image_id = str(self._count)
  370. for k, lab in enumerate(union):
  371. saliency = batch_saliency_full[k][idx:idx + 1]
  372. saliency_dict[lab] = saliency
  373. saliency_np = _make_rgba(saliency)
  374. _, _, _, saliency_string = _make_image(_normalize(saliency_np))
  375. explanation = explain.explanation.add()
  376. explanation.explain_method = explainer.__class__.__name__
  377. explanation.label = lab
  378. explanation.heatmap = saliency_string
  379. summary.add_value("explainer", "explanation", explain)
  380. summary.record(1)
  381. self._count += 1
  382. saliency_dict_lst.append(saliency_dict)
  383. return saliency_dict_lst
  384. def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
  385. """
  386. Run the explanation and evaluation for each step and write explanation results into summary.
  387. Args:
  388. next_element (Tuple): Data of one step
  389. explainer (`_Attribution`): An Attribution object to generate saliency maps.
  390. imageid_labels (dict): A dict that maps the image_id and its union labels.
  391. """
  392. inputs, labels, _ = self._unpack_next_element(next_element)
  393. for idx, inp in enumerate(inputs):
  394. inp = _EXPAND_DIMS(inp, 0)
  395. saliency_dict = saliency_dict_lst[idx]
  396. for label, saliency in saliency_dict.items():
  397. if isinstance(benchmarker, Localization):
  398. _, _, bboxes = self._unpack_next_element(next_element, True)
  399. if label in labels[idx]:
  400. res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
  401. saliency=saliency)
  402. benchmarker.aggregate(res, label)
  403. else:
  404. res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
  405. benchmarker.aggregate(res, label)