From: @yuhanshi Reviewed-by: @wuxuejian,@wenkai_dist Signed-off-by: @wenkai_disttags/v1.1.0
| @@ -418,29 +418,24 @@ class ImageClassificationRunner: | |||
| inputs, labels, _ = self._unpack_next_element(next_element) | |||
| for idx, inp in enumerate(inputs): | |||
| inp = _EXPAND_DIMS(inp, 0) | |||
| saliency_dict = saliency_dict_lst[idx] | |||
| for label, saliency in saliency_dict.items(): | |||
| if isinstance(benchmarker, Localization): | |||
| _, _, bboxes = self._unpack_next_element(next_element, True) | |||
| if label in labels[idx]: | |||
| res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], | |||
| saliency=saliency) | |||
| if np.any(res == np.nan): | |||
| res = np.zeros_like(res) | |||
| if isinstance(benchmarker, LabelAgnosticMetric): | |||
| res = benchmarker.evaluate(explainer, inp) | |||
| benchmarker.aggregate(res) | |||
| else: | |||
| saliency_dict = saliency_dict_lst[idx] | |||
| for label, saliency in saliency_dict.items(): | |||
| if isinstance(benchmarker, Localization): | |||
| _, _, bboxes = self._unpack_next_element(next_element, True) | |||
| if label in labels[idx]: | |||
| res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], | |||
| saliency=saliency) | |||
| benchmarker.aggregate(res, label) | |||
| elif isinstance(benchmarker, LabelSensitiveMetric): | |||
| res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) | |||
| benchmarker.aggregate(res, label) | |||
| elif isinstance(benchmarker, LabelSensitiveMetric): | |||
| res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) | |||
| if np.any(res == np.nan): | |||
| res = np.zeros_like(res) | |||
| benchmarker.aggregate(res, label) | |||
| elif isinstance(benchmarker, LabelAgnosticMetric): | |||
| res = benchmarker.evaluate(explainer, inp) | |||
| if np.any(res == np.nan): | |||
| res = np.zeros_like(res) | |||
| benchmarker.aggregate(res) | |||
| else: | |||
| raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but' | |||
| 'receive {}'.format(type(benchmarker))) | |||
| else: | |||
| raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but' | |||
| 'receive {}'.format(type(benchmarker))) | |||
| def _verify_data(self): | |||
| """Verify dataset and labels.""" | |||
| @@ -382,8 +382,6 @@ class Faithfulness(LabelSensitiveMetric): | |||
| perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument | |||
| perturb_method = "Constant" # perturbation method, all the perturbed pixels will be set to constant | |||
| num_perturb_pixel_per_step = None # number of pixels for each perturbation step | |||
| num_perturb_steps = 100 # separate the perturbation progress in to 100 steps. | |||
| base_value = 0.0 # the pixel value set for the perturbed pixels | |||
| check_value_type("activation_fn", activation_fn, nn.Cell) | |||
| @@ -395,8 +393,6 @@ class Faithfulness(LabelSensitiveMetric): | |||
| self._faithfulness_helper = method( | |||
| perturb_percent=perturb_percent, | |||
| perturb_method=perturb_method, | |||
| perturb_pixel_per_step=num_perturb_pixel_per_step, | |||
| num_perturbations=num_perturb_steps, | |||
| base_value=base_value | |||
| ) | |||
| @@ -15,6 +15,7 @@ | |||
| """Base class for XAI metrics.""" | |||
| import copy | |||
| import math | |||
| from typing import Callable | |||
| import numpy as np | |||
| @@ -88,11 +89,12 @@ class LabelAgnosticMetric(AttributionMetric): | |||
| Return: | |||
| float, averaged result. If no result is aggregate in the global_results, 0.0 will be returned. | |||
| """ | |||
| if not self._global_results: | |||
| return 0.0 | |||
| results_sum = sum(self._global_results) | |||
| count = len(self._global_results) | |||
| return results_sum / count | |||
| result_sum, count = 0, 0 | |||
| for res in self._global_results: | |||
| if math.isfinite(res): | |||
| result_sum += res | |||
| count += 1 | |||
| return 0. if count == 0 else result_sum / count | |||
| def aggregate(self, result): | |||
| """Aggregate single evaluation result to global results.""" | |||
| @@ -100,7 +102,7 @@ class LabelAgnosticMetric(AttributionMetric): | |||
| self._global_results.append(result) | |||
| elif isinstance(result, (ms.Tensor, np.ndarray)): | |||
| result = format_tensor_to_ndarray(result) | |||
| self._global_results.append(float(result)) | |||
| self._global_results.extend([float(res) for res in result.reshape(-1)]) | |||
| else: | |||
| raise TypeError('result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result)) | |||
| @@ -130,10 +132,12 @@ class LabelSensitiveMetric(AttributionMetric): | |||
| @property | |||
| def num_labels(self): | |||
| """Number of labels used in evaluation.""" | |||
| return self._num_labels | |||
| @staticmethod | |||
| def _verify_params(num_labels): | |||
| """Checks whether num_labels is valid.""" | |||
| check_value_type("num_labels", num_labels, int) | |||
| if num_labels < 1: | |||
| raise ValueError("Argument num_labels must be parsed with a integer > 0.") | |||
| @@ -147,17 +151,19 @@ class LabelSensitiveMetric(AttributionMetric): | |||
| target_np = format_tensor_to_ndarray(targets) | |||
| if len(target_np) > 1: | |||
| raise ValueError("One result can not be aggreated to multiple targets.") | |||
| else: | |||
| result_np = format_tensor_to_ndarray(result) | |||
| elif isinstance(result, (ms.Tensor, np.ndarray)): | |||
| result_np = format_tensor_to_ndarray(result).reshape(-1) | |||
| if isinstance(targets, int): | |||
| for res in result_np: | |||
| self._global_results[targets].append(float(res)) | |||
| else: | |||
| target_np = format_tensor_to_ndarray(targets) | |||
| target_np = format_tensor_to_ndarray(targets).reshape(-1) | |||
| if len(target_np) != len(result_np): | |||
| raise ValueError("Length of result does not match with length of targets.") | |||
| for tar, res in zip(target_np, result_np): | |||
| self._global_results[int(tar)].append(float(res)) | |||
| else: | |||
| raise TypeError('Result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result)) | |||
| def reset(self): | |||
| """Resets global_result.""" | |||
| @@ -168,16 +174,18 @@ class LabelSensitiveMetric(AttributionMetric): | |||
| """ | |||
| Get the class performances by global result. | |||
| Returns: | |||
| (:class:`np.ndarray`): :attr:`num_labels`-dimensional vector | |||
| containing per-class performance. | |||
| (:class:`list`): a list of performances where each value is the average score of specific class. | |||
| """ | |||
| count = np.array( | |||
| [len(self._global_results[i]) for i in range(self._num_labels)]) | |||
| result_sum = np.array( | |||
| [sum(self._global_results[i]) for i in range(self._num_labels)]) | |||
| return result_sum / count.clip(min=1) | |||
| results_on_labels = [] | |||
| for label_id in range(self._num_labels): | |||
| sum_of_label, count_of_label = 0, 0 | |||
| for res in self._global_results[label_id]: | |||
| if math.isfinite(res): | |||
| sum_of_label += res | |||
| count_of_label += 1 | |||
| results_on_labels.append(0. if count_of_label == 0 else sum_of_label / count_of_label) | |||
| return results_on_labels | |||
| @property | |||
| def performance(self): | |||
| @@ -187,13 +195,13 @@ class LabelSensitiveMetric(AttributionMetric): | |||
| Returns: | |||
| (:class:`float`): mean performance. | |||
| """ | |||
| count = sum( | |||
| [len(self._global_results[i]) for i in range(self._num_labels)]) | |||
| result_sum = sum( | |||
| [sum(self._global_results[i]) for i in range(self._num_labels)]) | |||
| if count == 0: | |||
| return 0 | |||
| return result_sum / count | |||
| result_sum, count = 0, 0 | |||
| for label_id in range(self._num_labels): | |||
| for res in self._global_results[label_id]: | |||
| if math.isfinite(res): | |||
| result_sum += res | |||
| count += 1 | |||
| return 0. if count == 0 else result_sum / count | |||
| def get_results(self): | |||
| """Global result of the metric can be return""" | |||
| @@ -122,8 +122,8 @@ class Robustness(LabelSensitiveMetric): | |||
| perturbations.append(perturbation_on_single_sample) | |||
| perturbations = np.vstack(perturbations) | |||
| perturbations_saliency = explainer(ms.Tensor(perturbations, ms.float32), targets).asnumpy() | |||
| sensitivity = np.sum((perturbations_saliency - saliency_np) ** 2, | |||
| axis=tuple(range(1, len(saliency_np.shape)))) | |||
| sensitivity = np.sqrt(np.sum((perturbations_saliency - saliency_np) ** 2, | |||
| axis=tuple(range(1, len(saliency_np.shape))))) | |||
| sensitivities.append(sensitivity) | |||
| sensitivities = np.stack(sensitivities, axis=-1) | |||
| max_sensitivity = np.max(sensitivities, axis=1) / norm | |||
| @@ -89,7 +89,7 @@ class Ablation: | |||
| class AblationWithSaliency(Ablation): | |||
| """ | |||
| Perturbation generator to generate perturbations for a given array. | |||
| Perturbation generator to generate perturbations w.r.t a given saliency map. | |||
| Args: | |||
| perturb_percent (float): percentage of pixels to perturb | |||
| @@ -143,28 +143,20 @@ class AblationWithSaliency(Ablation): | |||
| """ | |||
| batch_size = saliency.shape[0] | |||
| expected_num_dim = len(saliency.shape) + 1 | |||
| has_channel = num_channels is not None | |||
| num_channels = 1 if num_channels is None else num_channels | |||
| if has_channel: | |||
| saliency = saliency.mean(axis=1) | |||
| saliency_rank = rank_pixels(saliency, descending=True) | |||
| num_pixels = reduce(lambda x, y: x * y, saliency.shape[1:]) | |||
| if self._pixel_per_step: | |||
| pixel_per_step = self._pixel_per_step | |||
| num_perturbations = math.floor(num_pixels * self._perturb_percent / self._pixel_per_step) | |||
| elif self._num_perturbations: | |||
| pixel_per_step = math.floor(num_pixels * self._perturb_percent / self._num_perturbations) | |||
| num_perturbations = self._num_perturbations | |||
| else: | |||
| raise ValueError("Must provide either pixel_per_step or num_perturbations.") | |||
| pixel_per_step, num_perturbations = self._check_and_format_perturb_param(num_pixels) | |||
| masks = np.zeros((batch_size, num_perturbations, num_channels, saliency_rank.shape[1], saliency_rank.shape[2]), | |||
| dtype=np.bool) | |||
| # If the perturbation is added accumulately, the factor should be 0 to preserve the low bound of indexing. | |||
| factor = 0 if self._is_accumulate else 1 | |||
| for i in range(batch_size): | |||
| @@ -176,7 +168,23 @@ class AblationWithSaliency(Ablation): | |||
| up_bound += pixel_per_step | |||
| masks = masks if has_channel else np.squeeze(masks, axis=2) | |||
| return masks | |||
| def _check_and_format_perturb_param(self, num_pixels): | |||
| """ | |||
| Check whether the self._pixel_per_step and self._num_perturbation is valid. If the parameters are unreasonable, | |||
| this function will try to reassign the parameters and raise ValueError when reassignment is failed. | |||
| """ | |||
| if self._pixel_per_step: | |||
| pixel_per_step = self._pixel_per_step | |||
| num_perturbations = math.floor(num_pixels * self._perturb_percent / self._pixel_per_step) | |||
| elif self._num_perturbations: | |||
| pixel_per_step = math.floor(num_pixels * self._perturb_percent / self._num_perturbations) | |||
| num_perturbations = self._num_perturbations | |||
| else: | |||
| # If neither pixel_per_step or num_perturbations is provided, num_perturbations is determined by the square | |||
| # root of product from the spatial size of saliency map. | |||
| num_perturbations = math.floor(np.sqrt(num_pixels)) | |||
| pixel_per_step = math.floor(num_pixels * self._perturb_percent / num_perturbations) | |||
| if len(masks.shape) == expected_num_dim: | |||
| return masks | |||
| raise ValueError(f'Invalid masks shape {len(masks.shape)}, expect {expected_num_dim}-dim.') | |||
| return pixel_per_step, num_perturbations | |||
| @@ -14,8 +14,9 @@ | |||
| # ============================================================================ | |||
| """Occlusion explainer.""" | |||
| from typing import Tuple | |||
| import numpy as np | |||
| from numpy.lib.stride_tricks import as_strided | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| @@ -25,24 +26,17 @@ from .replacement import Constant | |||
| from ...._utils import abs_max | |||
| def _generate_patches(array, window_size, stride): | |||
| """View as windows.""" | |||
| if not isinstance(array, np.ndarray): | |||
| raise TypeError("`array` must be a numpy ndarray") | |||
| arr_shape = np.array(array.shape) | |||
| window_size = np.array(window_size, dtype=arr_shape.dtype) | |||
| slices = tuple(slice(None, None, st) for st in stride) | |||
| window_strides = np.array(array.strides) | |||
| def _generate_patches(array, window_size: Tuple, strides: Tuple): | |||
| """Generate patches from image w.r.t given window_size and strides.""" | |||
| window_strides = array.strides | |||
| slices = tuple(slice(None, None, stride) for stride in strides) | |||
| indexing_strides = array[slices].strides | |||
| win_indices_shape = (((np.array(array.shape) - np.array(window_size)) // np.array(stride)) + 1) | |||
| new_shape = tuple(list(win_indices_shape) + list(window_size)) | |||
| strides = tuple(list(indexing_strides) + list(window_strides)) | |||
| win_indices_shape = (np.array(array.shape) - np.array(window_size)) // np.array(strides) + 1 | |||
| patches = as_strided(array, shape=new_shape, strides=strides) | |||
| patches_shape = tuple(win_indices_shape) + window_size | |||
| strides_in_memory = indexing_strides + window_strides | |||
| patches = np.lib.stride_tricks.as_strided(array, shape=patches_shape, strides=strides_in_memory, writeable=False) | |||
| patches = patches.reshape((-1,) + window_size) | |||
| return patches | |||
| @@ -159,7 +153,7 @@ class Occlusion(PerturbationAttribution): | |||
| total_dim = np.prod(inputs.shape[1:]).item() | |||
| template = np.arange(total_dim).reshape(inputs.shape[1:]) | |||
| indices = _generate_patches(template, window_size, strides) | |||
| num_perturbations = indices.reshape((-1,) + window_size).shape[0] | |||
| num_perturbations = indices.shape[0] | |||
| indices = indices.reshape(num_perturbations, -1) | |||
| mask = np.zeros((num_perturbations, total_dim), dtype=np.bool) | |||