You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

localization.py 8.0 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Localization metrics."""
  16. import numpy as np
  17. from mindspore.train._utils import check_value_type
  18. from .metric import LabelSensitiveMetric
  19. from ..._operators import maximum, reshape, Tensor
  20. from ..._utils import format_tensor_to_ndarray
  21. def _get_max_position(saliency):
  22. """Get the position of the max pixel of the saliency map."""
  23. saliency = saliency.asnumpy()
  24. w = saliency.shape[3]
  25. saliency = np.reshape(saliency, (len(saliency), -1))
  26. max_arg = np.argmax(saliency, axis=1)
  27. return max_arg // w, max_arg - (max_arg // w) * w
  28. def _mask_out_saliency(saliency, threshold):
  29. """Keep the saliency map with value greater than threshold."""
  30. max_value = maximum(saliency)
  31. mask_out = saliency > (reshape(max_value, (len(saliency), -1, 1, 1)) * threshold)
  32. return mask_out
  33. class Localization(LabelSensitiveMetric):
  34. r"""
  35. Provides evaluation on the localization capability of XAI methods.
  36. Three specific metrics to obtain quantified results are supported: "PointingGame", and "IoSR"
  37. (Intersection over Salient Region).
  38. For metric "PointingGame", the localization capability is calculated as the ratio of data in which the max position
  39. of their saliency maps lies within the bounding boxes. Specifically, for a single datum, given the saliency map and
  40. its bounding box, if the max point of its saliency map lies within the bounding box, the evaluation result is 1
  41. otherwise 0.
  42. For metric "IoSR" (Intersection over Salient Region), the localization capability is calculated as the intersection
  43. of the bounding box and the salient region over the area of the salient region. The salient region is defined as
  44. the region whose value exceeds :math:`\theta * \max{saliency}`.
  45. Args:
  46. num_labels (int): Number of classes in the dataset.
  47. metric (str, optional): Specific metric to calculate localization capability.
  48. Options: "PointingGame", "IoSR". Default: "PointingGame".
  49. Raises:
  50. TypeError: Be raised for any argument type problem.
  51. Supported Platforms:
  52. ``Ascend`` ``GPU``
  53. """
  54. def __init__(self,
  55. num_labels,
  56. metric="PointingGame"
  57. ):
  58. super(Localization, self).__init__(num_labels)
  59. self._verify_metrics(metric)
  60. self._metric = metric
  61. # Arg for specific metric, for "PointingGame" it should be an integer indicating the tolerance
  62. # of "PointingGame", while for "IoSR" it should be a float number
  63. # indicating the threshold to choose salient region. Default: 25.
  64. if self._metric == "PointingGame":
  65. self._metric_arg = 15
  66. else:
  67. self._metric_arg = 0.5
  68. @staticmethod
  69. def _verify_metrics(metric):
  70. """Verify the user defined metric."""
  71. supports = ["PointingGame", "IoSR"]
  72. if metric not in supports:
  73. raise ValueError("Metric should be one of {}".format(supports))
  74. def evaluate(self, explainer, inputs, targets, saliency=None, mask=None):
  75. """
  76. Evaluate localization on a single data sample.
  77. Note:
  78. Currently only single sample (:math:`N=1`) at each call is supported.
  79. Args:
  80. explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`.
  81. inputs (Tensor): A data sample, a 4D tensor of shape :math:`(N, C, H, W)`.
  82. targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer.
  83. If `targets` is a 1D tensor, its length should be the same as `inputs`.
  84. saliency (Tensor, optional): The saliency map to be evaluated, a 4D tensor of shape :math:`(N, 1, H, W)`.
  85. If it is None, the parsed `explainer` will generate the saliency map with `inputs` and `targets` and
  86. continue the evaluation. Default: None.
  87. mask (Tensor, numpy.ndarray): Ground truth bounding box/masks for the inputs w.r.t targets, a 4D tensor
  88. or numpy.ndarray of shape :math:`(N, 1, H, W)`.
  89. Returns:
  90. numpy.ndarray, 1D array of shape :math:`(N,)`, result of localization evaluated on `explainer`.
  91. Raises:
  92. ValueError: Be raised for any argument value problem.
  93. Examples:
  94. >>> import numpy as np
  95. >>> import mindspore as ms
  96. >>> from mindspore.explainer.explanation import Gradient
  97. >>> from mindspore.explainer.benchmark import Localization
  98. >>> from mindspore import context
  99. >>>
  100. >>> context.set_context(mode=context.PYNATIVE_MODE)
  101. >>> num_labels = 10
  102. >>> localization = Localization(num_labels, "PointingGame")
  103. >>>
  104. >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
  105. >>> net = LeNet5(10, num_channel=3)
  106. >>> gradient = Gradient(net)
  107. >>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
  108. >>> masks = np.zeros([1, 1, 32, 32])
  109. >>> masks[:, :, 10: 20, 10: 20] = 1
  110. >>> targets = 5
  111. >>> # usage 1: input the explainer and the data to be explained,
  112. >>> # localization is a Localization instance
  113. >>> res = localization.evaluate(gradient, inputs, targets, mask=masks)
  114. >>> print(res.shape)
  115. (1,)
  116. >>> # usage 2: input the generated saliency map
  117. >>> saliency = gradient(inputs, targets)
  118. >>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks)
  119. >>> print(res.shape)
  120. (1,)
  121. """
  122. self._check_evaluate_param_with_mask(explainer, inputs, targets, saliency, mask)
  123. mask_np = format_tensor_to_ndarray(mask)[0]
  124. if saliency is None:
  125. saliency = explainer(inputs, targets)
  126. if self._metric == "PointingGame":
  127. point = _get_max_position(saliency)
  128. x, y = np.meshgrid(
  129. (np.arange(mask_np.shape[1]) - point[0]) ** 2,
  130. (np.arange(mask_np.shape[2]) - point[1]) ** 2)
  131. max_region = (x + y) < self._metric_arg ** 2
  132. # if max_region has overlap with mask_np return 1 otherwise 0.
  133. result = 1 if (mask_np.astype(bool) & max_region).any() else 0
  134. elif self._metric == "IoSR":
  135. mask_out_np = format_tensor_to_ndarray(_mask_out_saliency(saliency, self._metric_arg))
  136. overlap = np.sum(mask_np.astype(bool) & mask_out_np.astype(bool))
  137. saliency_area = np.sum(mask_out_np)
  138. result = overlap / saliency_area.clip(min=1e-10)
  139. return np.array([result], np.float)
  140. def _check_evaluate_param_with_mask(self, explainer, inputs, targets, saliency, mask):
  141. self._check_evaluate_param(explainer, inputs, targets, saliency)
  142. if len(inputs.shape) != 4:
  143. raise ValueError('Argument mask must be 4D Tensor')
  144. if mask is None:
  145. raise ValueError('To compute localization, mask must be provided.')
  146. check_value_type('mask', mask, (Tensor, np.ndarray))
  147. if len(mask.shape) != 4 or len(mask) != len(inputs):
  148. raise ValueError("The input mask must be 4-dimensional (1, 1, h, w) with same length of inputs.")