From: @yuhanshi Reviewed-by: @wuxuejian,@wenkai_dist Signed-off-by: @wenkai_disttags/v1.1.0
| @@ -418,29 +418,24 @@ class ImageClassificationRunner: | |||||
| inputs, labels, _ = self._unpack_next_element(next_element) | inputs, labels, _ = self._unpack_next_element(next_element) | ||||
| for idx, inp in enumerate(inputs): | for idx, inp in enumerate(inputs): | ||||
| inp = _EXPAND_DIMS(inp, 0) | inp = _EXPAND_DIMS(inp, 0) | ||||
| saliency_dict = saliency_dict_lst[idx] | |||||
| for label, saliency in saliency_dict.items(): | |||||
| if isinstance(benchmarker, Localization): | |||||
| _, _, bboxes = self._unpack_next_element(next_element, True) | |||||
| if label in labels[idx]: | |||||
| res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], | |||||
| saliency=saliency) | |||||
| if np.any(res == np.nan): | |||||
| res = np.zeros_like(res) | |||||
| if isinstance(benchmarker, LabelAgnosticMetric): | |||||
| res = benchmarker.evaluate(explainer, inp) | |||||
| benchmarker.aggregate(res) | |||||
| else: | |||||
| saliency_dict = saliency_dict_lst[idx] | |||||
| for label, saliency in saliency_dict.items(): | |||||
| if isinstance(benchmarker, Localization): | |||||
| _, _, bboxes = self._unpack_next_element(next_element, True) | |||||
| if label in labels[idx]: | |||||
| res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], | |||||
| saliency=saliency) | |||||
| benchmarker.aggregate(res, label) | |||||
| elif isinstance(benchmarker, LabelSensitiveMetric): | |||||
| res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) | |||||
| benchmarker.aggregate(res, label) | benchmarker.aggregate(res, label) | ||||
| elif isinstance(benchmarker, LabelSensitiveMetric): | |||||
| res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) | |||||
| if np.any(res == np.nan): | |||||
| res = np.zeros_like(res) | |||||
| benchmarker.aggregate(res, label) | |||||
| elif isinstance(benchmarker, LabelAgnosticMetric): | |||||
| res = benchmarker.evaluate(explainer, inp) | |||||
| if np.any(res == np.nan): | |||||
| res = np.zeros_like(res) | |||||
| benchmarker.aggregate(res) | |||||
| else: | |||||
| raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but' | |||||
| 'receive {}'.format(type(benchmarker))) | |||||
| else: | |||||
| raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but' | |||||
| 'receive {}'.format(type(benchmarker))) | |||||
| def _verify_data(self): | def _verify_data(self): | ||||
| """Verify dataset and labels.""" | """Verify dataset and labels.""" | ||||
| @@ -382,8 +382,6 @@ class Faithfulness(LabelSensitiveMetric): | |||||
| perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument | perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument | ||||
| perturb_method = "Constant" # perturbation method, all the perturbed pixels will be set to constant | perturb_method = "Constant" # perturbation method, all the perturbed pixels will be set to constant | ||||
| num_perturb_pixel_per_step = None # number of pixels for each perturbation step | |||||
| num_perturb_steps = 100 # separate the perturbation progress in to 100 steps. | |||||
| base_value = 0.0 # the pixel value set for the perturbed pixels | base_value = 0.0 # the pixel value set for the perturbed pixels | ||||
| check_value_type("activation_fn", activation_fn, nn.Cell) | check_value_type("activation_fn", activation_fn, nn.Cell) | ||||
| @@ -395,8 +393,6 @@ class Faithfulness(LabelSensitiveMetric): | |||||
| self._faithfulness_helper = method( | self._faithfulness_helper = method( | ||||
| perturb_percent=perturb_percent, | perturb_percent=perturb_percent, | ||||
| perturb_method=perturb_method, | perturb_method=perturb_method, | ||||
| perturb_pixel_per_step=num_perturb_pixel_per_step, | |||||
| num_perturbations=num_perturb_steps, | |||||
| base_value=base_value | base_value=base_value | ||||
| ) | ) | ||||
| @@ -15,6 +15,7 @@ | |||||
| """Base class for XAI metrics.""" | """Base class for XAI metrics.""" | ||||
| import copy | import copy | ||||
| import math | |||||
| from typing import Callable | from typing import Callable | ||||
| import numpy as np | import numpy as np | ||||
| @@ -88,11 +89,12 @@ class LabelAgnosticMetric(AttributionMetric): | |||||
| Return: | Return: | ||||
| float, averaged result. If no result is aggregate in the global_results, 0.0 will be returned. | float, averaged result. If no result is aggregate in the global_results, 0.0 will be returned. | ||||
| """ | """ | ||||
| if not self._global_results: | |||||
| return 0.0 | |||||
| results_sum = sum(self._global_results) | |||||
| count = len(self._global_results) | |||||
| return results_sum / count | |||||
| result_sum, count = 0, 0 | |||||
| for res in self._global_results: | |||||
| if math.isfinite(res): | |||||
| result_sum += res | |||||
| count += 1 | |||||
| return 0. if count == 0 else result_sum / count | |||||
| def aggregate(self, result): | def aggregate(self, result): | ||||
| """Aggregate single evaluation result to global results.""" | """Aggregate single evaluation result to global results.""" | ||||
| @@ -100,7 +102,7 @@ class LabelAgnosticMetric(AttributionMetric): | |||||
| self._global_results.append(result) | self._global_results.append(result) | ||||
| elif isinstance(result, (ms.Tensor, np.ndarray)): | elif isinstance(result, (ms.Tensor, np.ndarray)): | ||||
| result = format_tensor_to_ndarray(result) | result = format_tensor_to_ndarray(result) | ||||
| self._global_results.append(float(result)) | |||||
| self._global_results.extend([float(res) for res in result.reshape(-1)]) | |||||
| else: | else: | ||||
| raise TypeError('result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result)) | raise TypeError('result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result)) | ||||
| @@ -130,10 +132,12 @@ class LabelSensitiveMetric(AttributionMetric): | |||||
| @property | @property | ||||
| def num_labels(self): | def num_labels(self): | ||||
| """Number of labels used in evaluation.""" | |||||
| return self._num_labels | return self._num_labels | ||||
| @staticmethod | @staticmethod | ||||
| def _verify_params(num_labels): | def _verify_params(num_labels): | ||||
| """Checks whether num_labels is valid.""" | |||||
| check_value_type("num_labels", num_labels, int) | check_value_type("num_labels", num_labels, int) | ||||
| if num_labels < 1: | if num_labels < 1: | ||||
| raise ValueError("Argument num_labels must be parsed with a integer > 0.") | raise ValueError("Argument num_labels must be parsed with a integer > 0.") | ||||
| @@ -147,17 +151,19 @@ class LabelSensitiveMetric(AttributionMetric): | |||||
| target_np = format_tensor_to_ndarray(targets) | target_np = format_tensor_to_ndarray(targets) | ||||
| if len(target_np) > 1: | if len(target_np) > 1: | ||||
| raise ValueError("One result can not be aggreated to multiple targets.") | raise ValueError("One result can not be aggreated to multiple targets.") | ||||
| else: | |||||
| result_np = format_tensor_to_ndarray(result) | |||||
| elif isinstance(result, (ms.Tensor, np.ndarray)): | |||||
| result_np = format_tensor_to_ndarray(result).reshape(-1) | |||||
| if isinstance(targets, int): | if isinstance(targets, int): | ||||
| for res in result_np: | for res in result_np: | ||||
| self._global_results[targets].append(float(res)) | self._global_results[targets].append(float(res)) | ||||
| else: | else: | ||||
| target_np = format_tensor_to_ndarray(targets) | |||||
| target_np = format_tensor_to_ndarray(targets).reshape(-1) | |||||
| if len(target_np) != len(result_np): | if len(target_np) != len(result_np): | ||||
| raise ValueError("Length of result does not match with length of targets.") | raise ValueError("Length of result does not match with length of targets.") | ||||
| for tar, res in zip(target_np, result_np): | for tar, res in zip(target_np, result_np): | ||||
| self._global_results[int(tar)].append(float(res)) | self._global_results[int(tar)].append(float(res)) | ||||
| else: | |||||
| raise TypeError('Result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result)) | |||||
| def reset(self): | def reset(self): | ||||
| """Resets global_result.""" | """Resets global_result.""" | ||||
| @@ -168,16 +174,18 @@ class LabelSensitiveMetric(AttributionMetric): | |||||
| """ | """ | ||||
| Get the class performances by global result. | Get the class performances by global result. | ||||
| Returns: | Returns: | ||||
| (:class:`np.ndarray`): :attr:`num_labels`-dimensional vector | |||||
| containing per-class performance. | |||||
| (:class:`list`): a list of performances where each value is the average score of specific class. | |||||
| """ | """ | ||||
| count = np.array( | |||||
| [len(self._global_results[i]) for i in range(self._num_labels)]) | |||||
| result_sum = np.array( | |||||
| [sum(self._global_results[i]) for i in range(self._num_labels)]) | |||||
| return result_sum / count.clip(min=1) | |||||
| results_on_labels = [] | |||||
| for label_id in range(self._num_labels): | |||||
| sum_of_label, count_of_label = 0, 0 | |||||
| for res in self._global_results[label_id]: | |||||
| if math.isfinite(res): | |||||
| sum_of_label += res | |||||
| count_of_label += 1 | |||||
| results_on_labels.append(0. if count_of_label == 0 else sum_of_label / count_of_label) | |||||
| return results_on_labels | |||||
| @property | @property | ||||
| def performance(self): | def performance(self): | ||||
| @@ -187,13 +195,13 @@ class LabelSensitiveMetric(AttributionMetric): | |||||
| Returns: | Returns: | ||||
| (:class:`float`): mean performance. | (:class:`float`): mean performance. | ||||
| """ | """ | ||||
| count = sum( | |||||
| [len(self._global_results[i]) for i in range(self._num_labels)]) | |||||
| result_sum = sum( | |||||
| [sum(self._global_results[i]) for i in range(self._num_labels)]) | |||||
| if count == 0: | |||||
| return 0 | |||||
| return result_sum / count | |||||
| result_sum, count = 0, 0 | |||||
| for label_id in range(self._num_labels): | |||||
| for res in self._global_results[label_id]: | |||||
| if math.isfinite(res): | |||||
| result_sum += res | |||||
| count += 1 | |||||
| return 0. if count == 0 else result_sum / count | |||||
| def get_results(self): | def get_results(self): | ||||
| """Global result of the metric can be return""" | """Global result of the metric can be return""" | ||||
| @@ -122,8 +122,8 @@ class Robustness(LabelSensitiveMetric): | |||||
| perturbations.append(perturbation_on_single_sample) | perturbations.append(perturbation_on_single_sample) | ||||
| perturbations = np.vstack(perturbations) | perturbations = np.vstack(perturbations) | ||||
| perturbations_saliency = explainer(ms.Tensor(perturbations, ms.float32), targets).asnumpy() | perturbations_saliency = explainer(ms.Tensor(perturbations, ms.float32), targets).asnumpy() | ||||
| sensitivity = np.sum((perturbations_saliency - saliency_np) ** 2, | |||||
| axis=tuple(range(1, len(saliency_np.shape)))) | |||||
| sensitivity = np.sqrt(np.sum((perturbations_saliency - saliency_np) ** 2, | |||||
| axis=tuple(range(1, len(saliency_np.shape))))) | |||||
| sensitivities.append(sensitivity) | sensitivities.append(sensitivity) | ||||
| sensitivities = np.stack(sensitivities, axis=-1) | sensitivities = np.stack(sensitivities, axis=-1) | ||||
| max_sensitivity = np.max(sensitivities, axis=1) / norm | max_sensitivity = np.max(sensitivities, axis=1) / norm | ||||
| @@ -89,7 +89,7 @@ class Ablation: | |||||
| class AblationWithSaliency(Ablation): | class AblationWithSaliency(Ablation): | ||||
| """ | """ | ||||
| Perturbation generator to generate perturbations for a given array. | |||||
| Perturbation generator to generate perturbations w.r.t a given saliency map. | |||||
| Args: | Args: | ||||
| perturb_percent (float): percentage of pixels to perturb | perturb_percent (float): percentage of pixels to perturb | ||||
| @@ -143,28 +143,20 @@ class AblationWithSaliency(Ablation): | |||||
| """ | """ | ||||
| batch_size = saliency.shape[0] | batch_size = saliency.shape[0] | ||||
| expected_num_dim = len(saliency.shape) + 1 | |||||
| has_channel = num_channels is not None | has_channel = num_channels is not None | ||||
| num_channels = 1 if num_channels is None else num_channels | num_channels = 1 if num_channels is None else num_channels | ||||
| if has_channel: | if has_channel: | ||||
| saliency = saliency.mean(axis=1) | saliency = saliency.mean(axis=1) | ||||
| saliency_rank = rank_pixels(saliency, descending=True) | saliency_rank = rank_pixels(saliency, descending=True) | ||||
| num_pixels = reduce(lambda x, y: x * y, saliency.shape[1:]) | num_pixels = reduce(lambda x, y: x * y, saliency.shape[1:]) | ||||
| if self._pixel_per_step: | |||||
| pixel_per_step = self._pixel_per_step | |||||
| num_perturbations = math.floor(num_pixels * self._perturb_percent / self._pixel_per_step) | |||||
| elif self._num_perturbations: | |||||
| pixel_per_step = math.floor(num_pixels * self._perturb_percent / self._num_perturbations) | |||||
| num_perturbations = self._num_perturbations | |||||
| else: | |||||
| raise ValueError("Must provide either pixel_per_step or num_perturbations.") | |||||
| pixel_per_step, num_perturbations = self._check_and_format_perturb_param(num_pixels) | |||||
| masks = np.zeros((batch_size, num_perturbations, num_channels, saliency_rank.shape[1], saliency_rank.shape[2]), | masks = np.zeros((batch_size, num_perturbations, num_channels, saliency_rank.shape[1], saliency_rank.shape[2]), | ||||
| dtype=np.bool) | dtype=np.bool) | ||||
| # If the perturbation is added accumulately, the factor should be 0 to preserve the low bound of indexing. | |||||
| factor = 0 if self._is_accumulate else 1 | factor = 0 if self._is_accumulate else 1 | ||||
| for i in range(batch_size): | for i in range(batch_size): | ||||
| @@ -176,7 +168,23 @@ class AblationWithSaliency(Ablation): | |||||
| up_bound += pixel_per_step | up_bound += pixel_per_step | ||||
| masks = masks if has_channel else np.squeeze(masks, axis=2) | masks = masks if has_channel else np.squeeze(masks, axis=2) | ||||
| return masks | |||||
| def _check_and_format_perturb_param(self, num_pixels): | |||||
| """ | |||||
| Check whether the self._pixel_per_step and self._num_perturbation is valid. If the parameters are unreasonable, | |||||
| this function will try to reassign the parameters and raise ValueError when reassignment is failed. | |||||
| """ | |||||
| if self._pixel_per_step: | |||||
| pixel_per_step = self._pixel_per_step | |||||
| num_perturbations = math.floor(num_pixels * self._perturb_percent / self._pixel_per_step) | |||||
| elif self._num_perturbations: | |||||
| pixel_per_step = math.floor(num_pixels * self._perturb_percent / self._num_perturbations) | |||||
| num_perturbations = self._num_perturbations | |||||
| else: | |||||
| # If neither pixel_per_step or num_perturbations is provided, num_perturbations is determined by the square | |||||
| # root of product from the spatial size of saliency map. | |||||
| num_perturbations = math.floor(np.sqrt(num_pixels)) | |||||
| pixel_per_step = math.floor(num_pixels * self._perturb_percent / num_perturbations) | |||||
| if len(masks.shape) == expected_num_dim: | |||||
| return masks | |||||
| raise ValueError(f'Invalid masks shape {len(masks.shape)}, expect {expected_num_dim}-dim.') | |||||
| return pixel_per_step, num_perturbations | |||||
| @@ -14,8 +14,9 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Occlusion explainer.""" | """Occlusion explainer.""" | ||||
| from typing import Tuple | |||||
| import numpy as np | import numpy as np | ||||
| from numpy.lib.stride_tricks import as_strided | |||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| @@ -25,24 +26,17 @@ from .replacement import Constant | |||||
| from ...._utils import abs_max | from ...._utils import abs_max | ||||
| def _generate_patches(array, window_size, stride): | |||||
| """View as windows.""" | |||||
| if not isinstance(array, np.ndarray): | |||||
| raise TypeError("`array` must be a numpy ndarray") | |||||
| arr_shape = np.array(array.shape) | |||||
| window_size = np.array(window_size, dtype=arr_shape.dtype) | |||||
| slices = tuple(slice(None, None, st) for st in stride) | |||||
| window_strides = np.array(array.strides) | |||||
| def _generate_patches(array, window_size: Tuple, strides: Tuple): | |||||
| """Generate patches from image w.r.t given window_size and strides.""" | |||||
| window_strides = array.strides | |||||
| slices = tuple(slice(None, None, stride) for stride in strides) | |||||
| indexing_strides = array[slices].strides | indexing_strides = array[slices].strides | ||||
| win_indices_shape = (((np.array(array.shape) - np.array(window_size)) // np.array(stride)) + 1) | |||||
| new_shape = tuple(list(win_indices_shape) + list(window_size)) | |||||
| strides = tuple(list(indexing_strides) + list(window_strides)) | |||||
| win_indices_shape = (np.array(array.shape) - np.array(window_size)) // np.array(strides) + 1 | |||||
| patches = as_strided(array, shape=new_shape, strides=strides) | |||||
| patches_shape = tuple(win_indices_shape) + window_size | |||||
| strides_in_memory = indexing_strides + window_strides | |||||
| patches = np.lib.stride_tricks.as_strided(array, shape=patches_shape, strides=strides_in_memory, writeable=False) | |||||
| patches = patches.reshape((-1,) + window_size) | |||||
| return patches | return patches | ||||
| @@ -159,7 +153,7 @@ class Occlusion(PerturbationAttribution): | |||||
| total_dim = np.prod(inputs.shape[1:]).item() | total_dim = np.prod(inputs.shape[1:]).item() | ||||
| template = np.arange(total_dim).reshape(inputs.shape[1:]) | template = np.arange(total_dim).reshape(inputs.shape[1:]) | ||||
| indices = _generate_patches(template, window_size, strides) | indices = _generate_patches(template, window_size, strides) | ||||
| num_perturbations = indices.reshape((-1,) + window_size).shape[0] | |||||
| num_perturbations = indices.shape[0] | |||||
| indices = indices.reshape(num_perturbations, -1) | indices = indices.reshape(num_perturbations, -1) | ||||
| mask = np.zeros((num_perturbations, total_dim), dtype=np.bool) | mask = np.zeros((num_perturbations, total_dim), dtype=np.bool) | ||||