You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Utils for MindExplain"""
  16. __all__ = [
  17. 'ForwardProbe',
  18. 'calc_auc',
  19. 'calc_correlation',
  20. 'format_tensor_to_ndarray',
  21. 'generate_one_hot',
  22. 'rank_pixels',
  23. 'resize',
  24. 'retrieve_layer_by_name',
  25. 'retrieve_layer',
  26. 'unify_inputs',
  27. 'unify_targets'
  28. ]
  29. from typing import Tuple, Union
  30. import numpy as np
  31. from PIL import Image
  32. import mindspore as ms
  33. import mindspore.nn as nn
  34. import mindspore.ops.operations as op
  35. _Array = np.ndarray
  36. _Module = nn.Cell
  37. _Tensor = ms.Tensor
  38. def generate_one_hot(indices, depth):
  39. r"""
  40. Simple wrap of OneHot operation, the on_value an off_value are fixed to 1.0
  41. and 0.0.
  42. """
  43. on_value = ms.Tensor(1.0, ms.float32)
  44. off_value = ms.Tensor(0.0, ms.float32)
  45. weights = op.OneHot()(indices, depth, on_value, off_value)
  46. return weights
  47. def unify_inputs(inputs) -> tuple:
  48. """Unify inputs of explainer."""
  49. if isinstance(inputs, tuple):
  50. return inputs
  51. if isinstance(inputs, ms.Tensor):
  52. inputs = (inputs,)
  53. elif isinstance(inputs, np.ndarray):
  54. inputs = (ms.Tensor(inputs),)
  55. else:
  56. raise TypeError(
  57. 'inputs must be one of [tuple, ms.Tensor or np.ndarray], '
  58. 'but get {}'.format(type(inputs)))
  59. return inputs
  60. def unify_targets(targets) -> ms.Tensor:
  61. """Unify targets labels of explainer."""
  62. if isinstance(targets, ms.Tensor):
  63. return targets
  64. if isinstance(targets, list):
  65. targets = ms.Tensor(targets, dtype=ms.int32)
  66. if isinstance(targets, int):
  67. targets = ms.Tensor([targets], dtype=ms.int32)
  68. else:
  69. raise TypeError(
  70. 'targets must be one of [int, list or ms.Tensor], '
  71. 'but get {}'.format(type(targets)))
  72. return targets
  73. def retrieve_layer_by_name(model: _Module, layer_name: str):
  74. """
  75. Retrieve the layer in the model by the given layer_name.
  76. Args:
  77. model (_Module): model which contains the target layer
  78. layer_name (str): name of target layer
  79. Return:
  80. - target_layer (_Module)
  81. Raise:
  82. ValueError: is module with given layer_name is not found in the model,
  83. raise ValueError.
  84. """
  85. if not isinstance(layer_name, str):
  86. raise TypeError('layer_name should be type of str, but receive {}.'
  87. .format(type(layer_name)))
  88. if not layer_name:
  89. return model
  90. target_layer = None
  91. for name, cell in model.cells_and_names():
  92. if name == layer_name:
  93. target_layer = cell
  94. return target_layer
  95. if target_layer is None:
  96. raise ValueError(
  97. 'Cannot match {}, please provide target layer'
  98. 'in the given model.'.format(layer_name))
  99. return None
  100. def retrieve_layer(model: _Module, target_layer: Union[str, _Module] = ''):
  101. """
  102. Retrieve the layer in the model.
  103. 'target' can be either a layer name or a Cell object. Given the layer name,
  104. the method will search thourgh the model and return the matched layer. If a
  105. Cell object is provided, it will check whether the given layer exists
  106. in the model. If target layer is not found in the model, ValueError will
  107. be raised.
  108. Args:
  109. model (_Module): the model to retrieve the target layer
  110. target_layer (Union[str, _Module]): target layer to retrieve. Can be
  111. either string (layer name) or the Cell object. If '' is provided,
  112. the input model will be returned.
  113. Return:
  114. target layer (_Module)
  115. """
  116. if isinstance(target_layer, str):
  117. target_layer = retrieve_layer_by_name(model, target_layer)
  118. return target_layer
  119. if isinstance(target_layer, _Module):
  120. for _, cell in model.cells_and_names():
  121. if target_layer is cell:
  122. return target_layer
  123. raise ValueError(
  124. 'Model not contain cell {}, fail to probe.'.format(target_layer)
  125. )
  126. raise TypeError('layer_name must have type of str or ms.nn.Cell,'
  127. 'but receive {}'.format(type(target_layer)))
  128. class ForwardProbe:
  129. """
  130. Probe to capture output of specific layer in a given model.
  131. Args:
  132. target_layer (_Module): name of target layer or just provide the
  133. target layer.
  134. """
  135. def __init__(self, target_layer: _Module):
  136. self._target_layer = target_layer
  137. self._original_construct = self._target_layer.construct
  138. self._intermediate_tensor = None
  139. @property
  140. def value(self):
  141. return self._intermediate_tensor
  142. def __enter__(self):
  143. self._target_layer.construct = self._new_construct
  144. return self
  145. def __exit__(self, *_):
  146. self._target_layer.construct = self._original_construct
  147. self._intermediate_tensor = None
  148. return False
  149. def _new_construct(self, *inputs):
  150. outputs = self._original_construct(*inputs)
  151. self._intermediate_tensor = outputs
  152. return outputs
  153. def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray:
  154. """Unify `mindspore.Tensor` and `np.ndarray` to `np.ndarray`. """
  155. if isinstance(x, ms.Tensor):
  156. x = x.asnumpy()
  157. if not isinstance(x, np.ndarray):
  158. raise TypeError('input should be one of [ms.Tensor or np.ndarray],'
  159. ' but receive {}'.format(type(x)))
  160. return x
  161. def calc_correlation(x: Union[ms.Tensor, np.ndarray],
  162. y: Union[ms.Tensor, np.ndarray]) -> float:
  163. """Calculate Pearson correlation coefficient between two arrays. """
  164. x = format_tensor_to_ndarray(x)
  165. y = format_tensor_to_ndarray(y)
  166. faithfulness = -np.corrcoef(x, y)[0, 1]
  167. return faithfulness
  168. def calc_auc(x: _Array) -> float:
  169. """Calculate the Aera under Curve."""
  170. # take mean for multiple patches if the model is fully convolutional model
  171. if len(x.shape) == 4:
  172. x = np.mean(np.mean(x, axis=2), axis=3)
  173. auc = (x.sum() - x[0] - x[-1]) / len(x)
  174. return float(auc)
  175. def rank_pixels(inputs: _Array, descending: bool = True) -> _Array:
  176. """
  177. Generate rank order fo every pixel in an 2D array.
  178. The rank order start from 0 to (num_pixel-1). If descending is True, the
  179. rank order will generate in a descending order, otherwise in ascending
  180. order.
  181. Example:
  182. x = np.array([[4., 3., 1.], [5., 9., 1.]])
  183. rank_pixels(x, descending=True)
  184. >> np.array([[2, 3, 4], [1, 0, 5]])
  185. rank_pixels(x, descending=False)
  186. >> np.array([[3, 2, 0], [4, 5, 1]])
  187. """
  188. if len(inputs.shape) != 2:
  189. raise ValueError('Only support 2D array currently')
  190. flatten_saliency = inputs.reshape(-1)
  191. factor = -1 if descending else 1
  192. sorted_arg = np.argsort(factor * flatten_saliency, axis=0)
  193. flatten_rank = np.zeros_like(sorted_arg)
  194. flatten_rank[sorted_arg] = np.arange(0, sorted_arg.shape[0])
  195. rank_map = flatten_rank.reshape(inputs.shape)
  196. return rank_map
  197. def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor:
  198. """
  199. Resize the intermediate layer _attribution to the same size as inputs.
  200. Args:
  201. inputs (ms.Tensor): the input tensor to be resized
  202. size (tupleint]): the targeted size resize to
  203. mode (str): the resize mode. Options: 'nearest_neighbor', 'bilinear'
  204. Returns:
  205. outputs (ms.Tensor): the resized tensor.
  206. Raises:
  207. ValueError: the resize mode is not in ['nearest_neighbor',
  208. 'bilinear'].
  209. """
  210. h, w = size
  211. if mode == 'nearest_neighbor':
  212. resize_nn = op.ResizeNearestNeighbor((h, w))
  213. outputs = resize_nn(inputs)
  214. elif mode == 'bilinear':
  215. inputs_np = inputs.asnumpy()
  216. inputs_np = np.transpose(inputs_np, [0, 2, 3, 1])
  217. array_lst = []
  218. for inp in inputs_np:
  219. array = (np.repeat(inp, 3, axis=2) * 255).astype(np.uint8)
  220. image = Image.fromarray(array)
  221. image = image.resize(size, resample=Image.BILINEAR)
  222. array = np.asarray(image).astype(np.float32) / 255
  223. array_lst.append(array[:, :, 0:1])
  224. resized_np = np.transpose(array_lst, [0, 3, 1, 2])
  225. outputs = ms.Tensor(resized_np, inputs.dtype)
  226. else:
  227. raise ValueError('Unsupported resize mode {}'.format(mode))
  228. return outputs