You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

localization.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Localization metrics."""
  16. import numpy as np
  17. from mindspore.train._utils import check_value_type
  18. from .metric import AttributionMetric
  19. from ..._operators import maximum, reshape, Tensor
  20. from ..._utils import format_tensor_to_ndarray
  21. def _get_max_position(saliency):
  22. """Get the position of the max pixel of the saliency map."""
  23. saliency = saliency.asnumpy()
  24. w = saliency.shape[3]
  25. saliency = np.reshape(saliency, (len(saliency), -1))
  26. max_arg = np.argmax(saliency, axis=1)
  27. return max_arg // w, max_arg - (max_arg // w) * w
  28. def _mask_out_saliency(saliency, threshold):
  29. """Keep the saliency map with value greater than threshold."""
  30. max_value = maximum(saliency)
  31. mask_out = saliency > (reshape(max_value, (len(saliency), -1, 1, 1)) * threshold)
  32. return mask_out
  33. class Localization(AttributionMetric):
  34. """
  35. Provides evaluation on the localization capability of XAI methods.
  36. We support two metrics for the evaluation os localization capability: "PointingGame" and "IoSR".
  37. For metric "PointingGame", the localization capability is calculated as the ratio of data in which the max position
  38. of their saliency maps lies within the bounding boxes. Specifically, for a single datum, given the saliency map and
  39. its bounding box, if the max point of its saliency map lies within the bounding box, the evaluation result is 1
  40. otherwise 0.
  41. For metric "IoSR" (Intersection over Salient Region), the localization capability is calculated as the intersection
  42. of the bounding box and the salient region over the area of the salient region.
  43. Args:
  44. num_labels (int): number of classes in the dataset.
  45. metric (str): specific metric to calculate localization capability.
  46. Options: "PointingGame", "IoSR".
  47. Default: "PointingGame".
  48. Examples:
  49. >>> from mindspore.explainer.benchmark import Localization
  50. >>> num_labels = 100
  51. >>> localization = Localization(num_labels, "PointingGame")
  52. """
  53. def __init__(self,
  54. num_labels,
  55. metric="PointingGame"
  56. ):
  57. super(Localization, self).__init__(num_labels)
  58. self._verify_metrics(metric)
  59. self._metric = metric
  60. # Arg for specific metric, for "PointingGame" it should be an integer indicating the tolerance
  61. # of "PointingGame", while for "IoSR" it should be a float number
  62. # indicating the threshold to choose salient region. Default: 25.
  63. if self._metric == "PointingGame":
  64. self._metric_arg = 15
  65. else:
  66. self._metric_arg = 0.5
  67. @staticmethod
  68. def _verify_metrics(metric):
  69. """Verify the user defined metric."""
  70. supports = ["PointingGame", "IoSR"]
  71. if metric not in supports:
  72. raise ValueError("Metric should be one of {}".format(supports))
  73. def evaluate(self, explainer, inputs, targets, saliency=None, mask=None):
  74. """
  75. Evaluate localization on a single data sample.
  76. Args:
  77. explainer (Explanation): The explainer to be evaluated, see `mindspore/explainer/explanation`.
  78. inputs (Tensor): data sample. Currently only support single sample at each call.
  79. targets (int): target label to evaluate on.
  80. saliency (Tensor): A saliency tensor.
  81. mask (Union[Tensor, np.ndarray]): ground truth bounding box/masks for the inputs w.r.t targets.
  82. Returns:
  83. np.ndarray, result of localization evaluated on explainer
  84. Examples:
  85. >>> # init an explainer, the network should contain the output activation function.
  86. >>> gradient = Gradient(network)
  87. >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
  88. >>> masks = np.zeros(1, 1, 224, 224)
  89. >>> masks[:, :, 65: 100, 65: 100] = 1
  90. >>> targets = 5
  91. >>> # usage 1: input the explainer and the data to be explained,
  92. >>> # calculate the faithfulness with the specified metric
  93. >>> res = localization.evaluate(gradient, inputs, targets, mask=masks)
  94. >>> # usage 2: input the generated saliency map
  95. >>> saliency = gradient(inputs, targets)
  96. >>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks)
  97. """
  98. self._check_evaluate_param(explainer, inputs, targets, saliency)
  99. mask_np = format_tensor_to_ndarray(mask)[0]
  100. if saliency is None:
  101. saliency = explainer(inputs, targets)
  102. if self._metric == "PointingGame":
  103. point = _get_max_position(saliency)
  104. x, y = np.meshgrid(
  105. (np.arange(mask_np.shape[1]) - point[0]) ** 2,
  106. (np.arange(mask_np.shape[2]) - point[1]) ** 2)
  107. max_region = (x + y) < self._metric_arg ** 2
  108. # if max_region has overlap with mask_np return 1 otherwise 0.
  109. result = 1 if (mask_np.astype(bool) & max_region).any() else 0
  110. elif self._metric == "IoSR":
  111. mask_out = _mask_out_saliency(saliency, self._metric_arg)
  112. mask_out_np = format_tensor_to_ndarray(mask_out)
  113. overlap = np.sum(mask_np.astype(bool) & mask_out_np.astype(bool))
  114. saliency_area = np.sum(mask_out_np)
  115. result = overlap / saliency_area.clip(min=1e-10)
  116. return np.array([result], np.float)
  117. def _check_evaluate_param_with_mask(self, explainer, inputs, targets, saliency, mask):
  118. self._check_evaluate_param(explainer, inputs, targets, saliency)
  119. check_value_type('mask', mask, (Tensor, np.ndarray))
  120. if len(inputs.shape) != 4:
  121. raise ValueError('Argument mask must be 4D Tensor')