You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 9.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Utils for MindExplain"""
  16. __all__ = [
  17. 'ForwardProbe',
  18. 'abs_max',
  19. 'calc_auc',
  20. 'calc_correlation',
  21. 'format_tensor_to_ndarray',
  22. 'generate_one_hot',
  23. 'rank_pixels',
  24. 'resize',
  25. 'retrieve_layer_by_name',
  26. 'retrieve_layer',
  27. 'unify_inputs',
  28. 'unify_targets'
  29. ]
  30. from typing import Tuple, Union
  31. import numpy as np
  32. from PIL import Image
  33. import mindspore as ms
  34. import mindspore.nn as nn
  35. import mindspore.ops.operations as op
  36. _Array = np.ndarray
  37. _Module = nn.Cell
  38. _Tensor = ms.Tensor
  39. def abs_max(gradients):
  40. """
  41. Transform gradients to saliency through abs then take max along channels.
  42. Args:
  43. gradients (_Tensor): Gradients which will be transformed to saliency map.
  44. Returns:
  45. _Tensor, saliency map integrated from gradients.
  46. """
  47. gradients = op.Abs()(gradients)
  48. saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1)
  49. return saliency
  50. def generate_one_hot(indices, depth):
  51. r"""
  52. Simple wrap of OneHot operation, the on_value an off_value are fixed to 1.0
  53. and 0.0.
  54. """
  55. on_value = ms.Tensor(1.0, ms.float32)
  56. off_value = ms.Tensor(0.0, ms.float32)
  57. weights = op.OneHot()(indices, depth, on_value, off_value)
  58. return weights
  59. def unify_inputs(inputs) -> tuple:
  60. """Unify inputs of explainer."""
  61. if isinstance(inputs, tuple):
  62. return inputs
  63. if isinstance(inputs, ms.Tensor):
  64. inputs = (inputs,)
  65. elif isinstance(inputs, np.ndarray):
  66. inputs = (ms.Tensor(inputs),)
  67. else:
  68. raise TypeError(
  69. 'inputs must be one of [tuple, ms.Tensor or np.ndarray], '
  70. 'but get {}'.format(type(inputs)))
  71. return inputs
  72. def unify_targets(targets) -> ms.Tensor:
  73. """Unify targets labels of explainer."""
  74. if isinstance(targets, ms.Tensor):
  75. return targets
  76. if isinstance(targets, list):
  77. targets = ms.Tensor(targets, dtype=ms.int32)
  78. if isinstance(targets, int):
  79. targets = ms.Tensor([targets], dtype=ms.int32)
  80. else:
  81. raise TypeError(
  82. 'targets must be one of [int, list or ms.Tensor], '
  83. 'but get {}'.format(type(targets)))
  84. return targets
  85. def retrieve_layer_by_name(model: _Module, layer_name: str):
  86. """
  87. Retrieve the layer in the model by the given layer_name.
  88. Args:
  89. model (Cell): Model which contains the target layer.
  90. layer_name (str): Name of target layer.
  91. Returns:
  92. Cell, the target layer.
  93. Raises:
  94. ValueError: If module with given layer_name is not found in the model.
  95. """
  96. if not isinstance(layer_name, str):
  97. raise TypeError('layer_name should be type of str, but receive {}.'
  98. .format(type(layer_name)))
  99. if not layer_name:
  100. return model
  101. target_layer = None
  102. for name, cell in model.cells_and_names():
  103. if name == layer_name:
  104. target_layer = cell
  105. return target_layer
  106. if target_layer is None:
  107. raise ValueError(
  108. 'Cannot match {}, please provide target layer'
  109. 'in the given model.'.format(layer_name))
  110. return None
  111. def retrieve_layer(model: _Module, target_layer: Union[str, _Module] = ''):
  112. """
  113. Retrieve the layer in the model.
  114. 'target' can be either a layer name or a Cell object. Given the layer name,
  115. the method will search thourgh the model and return the matched layer. If a
  116. Cell object is provided, it will check whether the given layer exists
  117. in the model. If target layer is not found in the model, ValueError will
  118. be raised.
  119. Args:
  120. model (Cell): Model which contains the target layer.
  121. target_layer (str, Cell): Name of target layer or the target layer instance.
  122. Returns:
  123. Cell, the target layer.
  124. Raises:
  125. ValueError: If module with given layer_name is not found in the model.
  126. """
  127. if isinstance(target_layer, str):
  128. target_layer = retrieve_layer_by_name(model, target_layer)
  129. return target_layer
  130. if isinstance(target_layer, _Module):
  131. for _, cell in model.cells_and_names():
  132. if target_layer is cell:
  133. return target_layer
  134. raise ValueError(
  135. 'Model not contain cell {}, fail to probe.'.format(target_layer)
  136. )
  137. raise TypeError('layer_name must have type of str or ms.nn.Cell,'
  138. 'but receive {}'.format(type(target_layer)))
  139. class ForwardProbe:
  140. """
  141. Probe to capture output of specific layer in a given model.
  142. Args:
  143. target_layer (str, Cell): Name of target layer or the target layer instance.
  144. """
  145. def __init__(self, target_layer: _Module):
  146. self._target_layer = target_layer
  147. self._original_construct = self._target_layer.construct
  148. self._intermediate_tensor = None
  149. @property
  150. def value(self):
  151. return self._intermediate_tensor
  152. def __enter__(self):
  153. self._target_layer.construct = self._new_construct
  154. return self
  155. def __exit__(self, *_):
  156. self._target_layer.construct = self._original_construct
  157. self._intermediate_tensor = None
  158. return False
  159. def _new_construct(self, *inputs):
  160. outputs = self._original_construct(*inputs)
  161. self._intermediate_tensor = outputs
  162. return outputs
  163. def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray:
  164. """Unify Tensor and numpy.array to numpy.array."""
  165. if isinstance(x, ms.Tensor):
  166. x = x.asnumpy()
  167. if not isinstance(x, np.ndarray):
  168. raise TypeError('input should be one of [ms.Tensor or np.ndarray],'
  169. ' but receive {}'.format(type(x)))
  170. return x
  171. def calc_correlation(x: Union[ms.Tensor, np.ndarray],
  172. y: Union[ms.Tensor, np.ndarray]) -> float:
  173. """Calculate Pearson correlation coefficient between two vectors."""
  174. x = format_tensor_to_ndarray(x)
  175. y = format_tensor_to_ndarray(y)
  176. if len(x.shape) > 1 or len(y.shape) > 1:
  177. raise ValueError('"calc_correlation" only support 1-dim vectors currently, but get shape {} and {}.'
  178. .format(len(x.shape), len(y.shape)))
  179. if np.all(x == 0) or np.all(y == 0):
  180. return np.float(0)
  181. faithfulness = np.corrcoef(x, y)[0, 1]
  182. return faithfulness
  183. def calc_auc(x: _Array) -> _Array:
  184. """Calculate the Area under Curve."""
  185. # take mean for multiple patches if the model is fully convolutional model
  186. if len(x.shape) == 4:
  187. x = np.mean(np.mean(x, axis=2), axis=3)
  188. auc = (x.sum() - x[0] - x[-1]) / len(x)
  189. return auc
  190. def rank_pixels(inputs: _Array, descending: bool = True) -> _Array:
  191. """
  192. Generate rank order for every pixel in an 2D array.
  193. The rank order start from 0 to (num_pixel-1). If descending is True, the
  194. rank order will generate in a descending order, otherwise in ascending
  195. order.
  196. """
  197. if len(inputs.shape) < 2 or len(inputs.shape) > 3:
  198. raise ValueError('Only support 2D or 3D inputs currently.')
  199. batch_size = inputs.shape[0]
  200. flatten_saliency = inputs.reshape(batch_size, -1)
  201. factor = -1 if descending else 1
  202. sorted_arg = np.argsort(factor * flatten_saliency, axis=1)
  203. flatten_rank = np.zeros_like(sorted_arg)
  204. arange = np.arange(flatten_saliency.shape[1])
  205. for i in range(batch_size):
  206. flatten_rank[i][sorted_arg[i]] = arange
  207. rank_map = flatten_rank.reshape(inputs.shape)
  208. return rank_map
  209. def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor:
  210. """
  211. Resize the intermediate layer _attribution to the same size as inputs.
  212. Args:
  213. inputs (Tensor): The input tensor to be resized.
  214. size (tuple[int]): The targeted size resize to.
  215. mode (str): The resize mode. Options: 'nearest_neighbor', 'bilinear'.
  216. Returns:
  217. Tensor, the resized tensor.
  218. Raises:
  219. ValueError: the resize mode is not in ['nearest_neighbor', 'bilinear'].
  220. """
  221. h, w = size
  222. if mode == 'nearest_neighbor':
  223. resize_nn = op.ResizeNearestNeighbor((h, w))
  224. outputs = resize_nn(inputs)
  225. elif mode == 'bilinear':
  226. inputs_np = inputs.asnumpy()
  227. inputs_np = np.transpose(inputs_np, [0, 2, 3, 1])
  228. array_lst = []
  229. for inp in inputs_np:
  230. array = (np.repeat(inp, 3, axis=2) * 255).astype(np.uint8)
  231. image = Image.fromarray(array)
  232. image = image.resize(size, resample=Image.BILINEAR)
  233. array = np.asarray(image).astype(np.float32) / 255
  234. array_lst.append(array[:, :, 0:1])
  235. resized_np = np.transpose(array_lst, [0, 3, 1, 2])
  236. outputs = ms.Tensor(resized_np, inputs.dtype)
  237. else:
  238. raise ValueError('Unsupported resize mode {}.'.format(mode))
  239. return outputs