您最多选择25个标签 标签必须以中文、字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Base class for XAI metrics."""
  16. import numpy as np
  17. from mindspore.train._utils import check_value_type
  18. from ..._operators import Tensor
  19. from ..._utils import format_tensor_to_ndarray
  20. from ...explanation._attribution._attribution import Attribution
  21. def verify_argument(inputs, arg_name):
  22. """Verify the validity of the parsed arguments."""
  23. check_value_type(arg_name, inputs, Tensor)
  24. if len(inputs.shape) != 4:
  25. raise ValueError('Argument {} must be a 4D Tensor.'.format(arg_name))
  26. if len(inputs) > 1:
  27. raise ValueError('Support single data evaluation only, but got {}.'.format(len(inputs)))
  28. def verify_targets(targets, num_labels):
  29. """Verify the validity of the parsed targets."""
  30. check_value_type('targets', targets, (int, Tensor))
  31. if isinstance(targets, Tensor):
  32. if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != 1):
  33. raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, '
  34. 'it should have the length = 1 as we only support single evaluation now.')
  35. targets = int(targets.asnumpy()[0]) if len(targets.shape) == 1 else int(targets.asnumpy())
  36. if targets > num_labels - 1 or targets < 0:
  37. raise ValueError('Parsed targets exceed the label range.')
  38. class AttributionMetric:
  39. """Super class of XAI metric class used in classification scenarios."""
  40. def __init__(self, num_labels=None):
  41. self._num_labels = num_labels
  42. self._global_results = {i: [] for i in range(num_labels)}
  43. def evaluate(self, explainer, inputs, targets, saliency=None):
  44. """This function evaluates on a single sample and return the result."""
  45. raise NotImplementedError
  46. def aggregate(self, result, targets):
  47. """Aggregates single result to global_results."""
  48. if isinstance(result, float):
  49. if isinstance(targets, int):
  50. self._global_results[targets].append(result)
  51. else:
  52. target_np = format_tensor_to_ndarray(targets)
  53. if len(target_np) > 1:
  54. raise ValueError("One result can not be aggreated to multiple targets.")
  55. else:
  56. result_np = format_tensor_to_ndarray(result)
  57. if isinstance(targets, int):
  58. for res in result_np:
  59. self._global_results[targets].append(float(res))
  60. else:
  61. target_np = format_tensor_to_ndarray(targets)
  62. if len(target_np) != len(result_np):
  63. raise ValueError("Length of result does not match with length of targets.")
  64. for tar, res in zip(target_np, result_np):
  65. self._global_results[int(tar)].append(float(res))
  66. def reset(self):
  67. """Resets global_result."""
  68. self._global_results = {i: [] for i in range(self._num_labels)}
  69. @property
  70. def class_performances(self):
  71. """
  72. Get the class performances by global result.
  73. Returns:
  74. (:class:`np.ndarray`): :attr:`num_labels`-dimensional vector
  75. containing per-class performance.
  76. """
  77. count = np.array(
  78. [len(self._global_results[i]) for i in range(self._num_labels)])
  79. result_sum = np.array(
  80. [sum(self._global_results[i]) for i in range(self._num_labels)])
  81. return result_sum / count.clip(min=1)
  82. @property
  83. def performance(self):
  84. """
  85. Get the performance by global result.
  86. Returns:
  87. (:class:`float`): mean performance.
  88. """
  89. count = sum(
  90. [len(self._global_results[i]) for i in range(self._num_labels)])
  91. result_sum = sum(
  92. [sum(self._global_results[i]) for i in range(self._num_labels)])
  93. if count == 0:
  94. return 0
  95. return result_sum / count
  96. def get_results(self):
  97. """Global result of the metric can be return"""
  98. return self._global_results
  99. def _check_evaluate_param(self, explainer, inputs, targets, saliency):
  100. """Check the evaluate parameters."""
  101. check_value_type('explainer', explainer, Attribution)
  102. verify_argument(inputs, 'inputs')
  103. verify_targets(targets, self._num_labels)
  104. check_value_type('saliency', saliency, (Tensor, type(None)))