You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Utils for MindExplain"""
  16. __all__ = [
  17. 'ForwardProbe',
  18. 'abs_max',
  19. 'calc_auc',
  20. 'calc_correlation',
  21. 'format_tensor_to_ndarray',
  22. 'generate_one_hot',
  23. 'rank_pixels',
  24. 'resize',
  25. 'retrieve_layer_by_name',
  26. 'retrieve_layer',
  27. 'unify_inputs',
  28. 'unify_targets'
  29. ]
  30. from typing import Tuple, Union
  31. import numpy as np
  32. from PIL import Image
  33. import mindspore as ms
  34. import mindspore.nn as nn
  35. import mindspore.ops.operations as op
  36. _Array = np.ndarray
  37. _Module = nn.Cell
  38. _Tensor = ms.Tensor
  39. def abs_max(gradients):
  40. """
  41. Transform gradients to saliency through abs then take max along channels.
  42. Args:
  43. gradients (_Tensor): Gradients which will be transformed to saliency map.
  44. Returns:
  45. _Tensor, saliency map integrated from gradients.
  46. """
  47. gradients = op.Abs()(gradients)
  48. saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1)
  49. return saliency
  50. def generate_one_hot(indices, depth):
  51. r"""
  52. Simple wrap of OneHot operation, the on_value an off_value are fixed to 1.0
  53. and 0.0.
  54. """
  55. on_value = ms.Tensor(1.0, ms.float32)
  56. off_value = ms.Tensor(0.0, ms.float32)
  57. weights = op.OneHot()(indices, depth, on_value, off_value)
  58. return weights
  59. def unify_inputs(inputs) -> tuple:
  60. """Unify inputs of explainer."""
  61. if isinstance(inputs, tuple):
  62. return inputs
  63. if isinstance(inputs, ms.Tensor):
  64. inputs = (inputs,)
  65. elif isinstance(inputs, np.ndarray):
  66. inputs = (ms.Tensor(inputs),)
  67. else:
  68. raise TypeError(
  69. 'inputs must be one of [tuple, ms.Tensor or np.ndarray], '
  70. 'but get {}'.format(type(inputs)))
  71. return inputs
  72. def unify_targets(targets) -> ms.Tensor:
  73. """Unify targets labels of explainer."""
  74. if isinstance(targets, ms.Tensor):
  75. return targets
  76. if isinstance(targets, list):
  77. targets = ms.Tensor(targets, dtype=ms.int32)
  78. if isinstance(targets, int):
  79. targets = ms.Tensor([targets], dtype=ms.int32)
  80. else:
  81. raise TypeError(
  82. 'targets must be one of [int, list or ms.Tensor], '
  83. 'but get {}'.format(type(targets)))
  84. return targets
  85. def retrieve_layer_by_name(model: _Module, layer_name: str):
  86. """
  87. Retrieve the layer in the model by the given layer_name.
  88. Args:
  89. model (_Module): model which contains the target layer
  90. layer_name (str): name of target layer
  91. Return:
  92. - target_layer (_Module)
  93. Raise:
  94. ValueError: if module with given layer_name is not found in the model,
  95. raise ValueError.
  96. """
  97. if not isinstance(layer_name, str):
  98. raise TypeError('layer_name should be type of str, but receive {}.'
  99. .format(type(layer_name)))
  100. if not layer_name:
  101. return model
  102. target_layer = None
  103. for name, cell in model.cells_and_names():
  104. if name == layer_name:
  105. target_layer = cell
  106. return target_layer
  107. if target_layer is None:
  108. raise ValueError(
  109. 'Cannot match {}, please provide target layer'
  110. 'in the given model.'.format(layer_name))
  111. return None
  112. def retrieve_layer(model: _Module, target_layer: Union[str, _Module] = ''):
  113. """
  114. Retrieve the layer in the model.
  115. 'target' can be either a layer name or a Cell object. Given the layer name,
  116. the method will search thourgh the model and return the matched layer. If a
  117. Cell object is provided, it will check whether the given layer exists
  118. in the model. If target layer is not found in the model, ValueError will
  119. be raised.
  120. Args:
  121. model (_Module): the model to retrieve the target layer
  122. target_layer (Union[str, _Module]): target layer to retrieve. Can be
  123. either string (layer name) or the Cell object. If '' is provided,
  124. the input model will be returned.
  125. Return:
  126. target layer (_Module)
  127. """
  128. if isinstance(target_layer, str):
  129. target_layer = retrieve_layer_by_name(model, target_layer)
  130. return target_layer
  131. if isinstance(target_layer, _Module):
  132. for _, cell in model.cells_and_names():
  133. if target_layer is cell:
  134. return target_layer
  135. raise ValueError(
  136. 'Model not contain cell {}, fail to probe.'.format(target_layer)
  137. )
  138. raise TypeError('layer_name must have type of str or ms.nn.Cell,'
  139. 'but receive {}'.format(type(target_layer)))
  140. class ForwardProbe:
  141. """
  142. Probe to capture output of specific layer in a given model.
  143. Args:
  144. target_layer (_Module): name of target layer or just provide the
  145. target layer.
  146. """
  147. def __init__(self, target_layer: _Module):
  148. self._target_layer = target_layer
  149. self._original_construct = self._target_layer.construct
  150. self._intermediate_tensor = None
  151. @property
  152. def value(self):
  153. return self._intermediate_tensor
  154. def __enter__(self):
  155. self._target_layer.construct = self._new_construct
  156. return self
  157. def __exit__(self, *_):
  158. self._target_layer.construct = self._original_construct
  159. self._intermediate_tensor = None
  160. return False
  161. def _new_construct(self, *inputs):
  162. outputs = self._original_construct(*inputs)
  163. self._intermediate_tensor = outputs
  164. return outputs
  165. def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray:
  166. """Unify `mindspore.Tensor` and `np.ndarray` to `np.ndarray`. """
  167. if isinstance(x, ms.Tensor):
  168. x = x.asnumpy()
  169. if not isinstance(x, np.ndarray):
  170. raise TypeError('input should be one of [ms.Tensor or np.ndarray],'
  171. ' but receive {}'.format(type(x)))
  172. return x
  173. def calc_correlation(x: Union[ms.Tensor, np.ndarray],
  174. y: Union[ms.Tensor, np.ndarray]) -> float:
  175. """Calculate Pearson correlation coefficient between two vectors."""
  176. x = format_tensor_to_ndarray(x)
  177. y = format_tensor_to_ndarray(y)
  178. if len(x.shape) > 1 or len(y.shape) > 1:
  179. raise ValueError('"calc_correlation" only support 1-dim vectors currently, but get shape {} and {}.'
  180. .format(len(x.shape), len(y.shape)))
  181. if np.all(x == 0) or np.all(y == 0):
  182. return np.float(0)
  183. faithfulness = -np.corrcoef(x, y)[0, 1]
  184. return faithfulness
  185. def calc_auc(x: _Array) -> _Array:
  186. """Calculate the Aera under Curve."""
  187. # take mean for multiple patches if the model is fully convolutional model
  188. if len(x.shape) == 4:
  189. x = np.mean(np.mean(x, axis=2), axis=3)
  190. auc = (x.sum() - x[0] - x[-1]) / len(x)
  191. return auc
  192. def rank_pixels(inputs: _Array, descending: bool = True) -> _Array:
  193. """
  194. Generate rank order fo every pixel in an 2D array.
  195. The rank order start from 0 to (num_pixel-1). If descending is True, the
  196. rank order will generate in a descending order, otherwise in ascending
  197. order.
  198. Example:
  199. x = np.array([[4., 3., 1.], [5., 9., 1.]])
  200. rank_pixels(x, descending=True)
  201. >> np.array([[2, 3, 4], [1, 0, 5]])
  202. rank_pixels(x, descending=False)
  203. >> np.array([[3, 2, 0], [4, 5, 1]])
  204. """
  205. if len(inputs.shape) < 2 or len(inputs.shape) > 3:
  206. raise ValueError('Only support 2D or 3D inputs currently.')
  207. batch_size = inputs.shape[0]
  208. flatten_saliency = inputs.reshape(batch_size, -1)
  209. factor = -1 if descending else 1
  210. sorted_arg = np.argsort(factor * flatten_saliency, axis=1)
  211. flatten_rank = np.zeros_like(sorted_arg)
  212. arange = np.arange(flatten_saliency.shape[1])
  213. for i in range(batch_size):
  214. flatten_rank[i][sorted_arg[i]] = arange
  215. rank_map = flatten_rank.reshape(inputs.shape)
  216. return rank_map
  217. def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor:
  218. """
  219. Resize the intermediate layer _attribution to the same size as inputs.
  220. Args:
  221. inputs (ms.Tensor): the input tensor to be resized
  222. size (tupleint]): the targeted size resize to
  223. mode (str): the resize mode. Options: 'nearest_neighbor', 'bilinear'
  224. Returns:
  225. outputs (ms.Tensor): the resized tensor.
  226. Raises:
  227. ValueError: the resize mode is not in ['nearest_neighbor',
  228. 'bilinear'].
  229. """
  230. h, w = size
  231. if mode == 'nearest_neighbor':
  232. resize_nn = op.ResizeNearestNeighbor((h, w))
  233. outputs = resize_nn(inputs)
  234. elif mode == 'bilinear':
  235. inputs_np = inputs.asnumpy()
  236. inputs_np = np.transpose(inputs_np, [0, 2, 3, 1])
  237. array_lst = []
  238. for inp in inputs_np:
  239. array = (np.repeat(inp, 3, axis=2) * 255).astype(np.uint8)
  240. image = Image.fromarray(array)
  241. image = image.resize(size, resample=Image.BILINEAR)
  242. array = np.asarray(image).astype(np.float32) / 255
  243. array_lst.append(array[:, :, 0:1])
  244. resized_np = np.transpose(array_lst, [0, 3, 1, 2])
  245. outputs = ms.Tensor(resized_np, inputs.dtype)
  246. else:
  247. raise ValueError('Unsupported resize mode {}'.format(mode))
  248. return outputs