From: @yuhanshi Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -13,19 +13,6 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Packaged operations based on MindSpore.""" | |||
| from typing import List, Tuple, Union, Callable | |||
| import numpy as np | |||
| import mindspore | |||
| import mindspore.ops.operations as op | |||
| from mindspore import nn | |||
| _Axis = Union[int, Tuple[int, ...], List[int]] | |||
| _Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]] | |||
| _Number = Union[int, float, np.int, np.float] | |||
| _Shape = Union[int, Tuple[int, ...]] | |||
| Tensor = mindspore.Tensor | |||
| __all__ = [ | |||
| 'absolute', | |||
| @@ -41,6 +28,7 @@ __all__ = [ | |||
| 'mean', | |||
| 'mul', | |||
| 'sort', | |||
| 'sqrt', | |||
| 'squeeze', | |||
| 'tile', | |||
| 'reshape', | |||
| @@ -51,6 +39,20 @@ __all__ = [ | |||
| 'summation' | |||
| ] | |||
| from typing import List, Tuple, Union, Callable | |||
| import numpy as np | |||
| import mindspore | |||
| from mindspore import nn | |||
| import mindspore.ops.operations as op | |||
| _Axis = Union[int, Tuple[int, ...], List[int]] | |||
| _Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]] | |||
| _Number = Union[int, float, np.int, np.float] | |||
| _Shape = Union[int, Tuple[int, ...]] | |||
| Tensor = mindspore.Tensor | |||
| def absolute(inputs: Tensor) -> Tensor: | |||
| """Get the absolute value of a tensor value.""" | |||
| @@ -33,11 +33,10 @@ from mindspore.train._utils import check_value_type | |||
| from mindspore.train.summary._summary_adapter import _convert_image_format | |||
| from mindspore.train.summary.summary_record import SummaryRecord | |||
| from mindspore.train.summary_pb2 import Explain | |||
| from .benchmark import Localization | |||
| from .benchmark._attribution.metric import AttributionMetric | |||
| from .explanation import RISE | |||
| from .explanation._attribution._attribution import Attribution | |||
| from .benchmark._attribution.metric import AttributionMetric, LabelSensitiveMetric, LabelAgnosticMetric | |||
| from .explanation._attribution.attribution import Attribution | |||
| # datafile directory names | |||
| _DATAFILE_DIRNAME_PREFIX = "_explain_" | |||
| @@ -293,7 +292,8 @@ class ExplainRunner: | |||
| benchmark.benchmark_method = bench.__class__.__name__ | |||
| benchmark.total_score = bench.performance | |||
| benchmark.label_score.extend(bench.class_performances) | |||
| if isinstance(bench, LabelSensitiveMetric): | |||
| benchmark.label_score.extend(bench.class_performances) | |||
| print(spacer.format("Finish running and writing explanation and benchmark data for {}. " | |||
| "Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start))) | |||
| @@ -603,7 +603,6 @@ class ExplainRunner: | |||
| Args: | |||
| next_element (Tuple): Data of one step | |||
| explainer (`_Attribution`): An Attribution object to generate saliency maps. | |||
| imageid_labels (dict): A dict that maps the image_id and its union labels. | |||
| """ | |||
| inputs, labels, _ = self._unpack_next_element(next_element) | |||
| for idx, inp in enumerate(inputs): | |||
| @@ -615,10 +614,22 @@ class ExplainRunner: | |||
| if label in labels[idx]: | |||
| res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], | |||
| saliency=saliency) | |||
| if np.any(res == np.nan): | |||
| res = np.zeros_like(res) | |||
| benchmarker.aggregate(res, label) | |||
| else: | |||
| elif isinstance(benchmarker, LabelSensitiveMetric): | |||
| res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) | |||
| if np.any(res == np.nan): | |||
| res = np.zeros_like(res) | |||
| benchmarker.aggregate(res, label) | |||
| elif isinstance(benchmarker, LabelAgnosticMetric): | |||
| res = benchmarker.evaluate(explainer, inp) | |||
| if np.any(res == np.nan): | |||
| res = np.zeros_like(res) | |||
| benchmarker.aggregate(res) | |||
| else: | |||
| raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but' | |||
| 'receive {}'.format(type(benchmarker))) | |||
| def _save_original_image(self, sample_id: int, image): | |||
| """Save an image to summary directory.""" | |||
| @@ -16,6 +16,7 @@ | |||
| __all__ = [ | |||
| 'ForwardProbe', | |||
| 'abs_max', | |||
| 'calc_auc', | |||
| 'calc_correlation', | |||
| 'format_tensor_to_ndarray', | |||
| @@ -29,7 +30,6 @@ __all__ = [ | |||
| ] | |||
| from typing import Tuple, Union | |||
| import math | |||
| import numpy as np | |||
| from PIL import Image | |||
| @@ -43,6 +43,21 @@ _Module = nn.Cell | |||
| _Tensor = ms.Tensor | |||
| def abs_max(gradients): | |||
| """ | |||
| Transform gradients to saliency through abs then take max along channels. | |||
| Args: | |||
| gradients (_Tensor): Gradients which will be transformed to saliency map. | |||
| Returns: | |||
| _Tensor, saliency map integrated from gradients. | |||
| """ | |||
| gradients = op.Abs()(gradients) | |||
| saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1) | |||
| return saliency | |||
| def generate_one_hot(indices, depth): | |||
| r""" | |||
| Simple wrap of OneHot operation, the on_value an off_value are fixed to 1.0 | |||
| @@ -96,7 +111,7 @@ def retrieve_layer_by_name(model: _Module, layer_name: str): | |||
| - target_layer (_Module) | |||
| Raise: | |||
| ValueError: is module with given layer_name is not found in the model, | |||
| ValueError: if module with given layer_name is not found in the model, | |||
| raise ValueError. | |||
| """ | |||
| @@ -201,23 +216,28 @@ def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray: | |||
| def calc_correlation(x: Union[ms.Tensor, np.ndarray], | |||
| y: Union[ms.Tensor, np.ndarray]) -> float: | |||
| """Calculate Pearson correlation coefficient between two arrays. """ | |||
| """Calculate Pearson correlation coefficient between two vectors.""" | |||
| x = format_tensor_to_ndarray(x) | |||
| y = format_tensor_to_ndarray(y) | |||
| faithfulness = -np.corrcoef(x, y)[0, 1] | |||
| if math.isnan(faithfulness): | |||
| if len(x.shape) > 1 or len(y.shape) > 1: | |||
| raise ValueError('"calc_correlation" only support 1-dim vectors currently, but get shape {} and {}.' | |||
| .format(len(x.shape), len(y.shape))) | |||
| if np.all(x == 0) or np.all(y == 0): | |||
| return np.float(0) | |||
| faithfulness = -np.corrcoef(x, y)[0, 1] | |||
| return faithfulness | |||
| def calc_auc(x: _Array) -> float: | |||
| def calc_auc(x: _Array) -> _Array: | |||
| """Calculate the Aera under Curve.""" | |||
| # take mean for multiple patches if the model is fully convolutional model | |||
| if len(x.shape) == 4: | |||
| x = np.mean(np.mean(x, axis=2), axis=3) | |||
| auc = (x.sum() - x[0] - x[-1]) / len(x) | |||
| return float(auc) | |||
| return auc | |||
| def rank_pixels(inputs: _Array, descending: bool = True) -> _Array: | |||
| @@ -235,13 +255,17 @@ def rank_pixels(inputs: _Array, descending: bool = True) -> _Array: | |||
| rank_pixels(x, descending=False) | |||
| >> np.array([[3, 2, 0], [4, 5, 1]]) | |||
| """ | |||
| if len(inputs.shape) != 2: | |||
| raise ValueError('Only support 2D array currently') | |||
| flatten_saliency = inputs.reshape(-1) | |||
| if len(inputs.shape) < 2 or len(inputs.shape) > 3: | |||
| raise ValueError('Only support 2D or 3D inputs currently.') | |||
| batch_size = inputs.shape[0] | |||
| flatten_saliency = inputs.reshape(batch_size, -1) | |||
| factor = -1 if descending else 1 | |||
| sorted_arg = np.argsort(factor * flatten_saliency, axis=0) | |||
| sorted_arg = np.argsort(factor * flatten_saliency, axis=1) | |||
| flatten_rank = np.zeros_like(sorted_arg) | |||
| flatten_rank[sorted_arg] = np.arange(0, sorted_arg.shape[0]) | |||
| arange = np.arange(flatten_saliency.shape[1]) | |||
| for i in range(batch_size): | |||
| flatten_rank[i][sorted_arg[i]] = arange | |||
| rank_map = flatten_rank.reshape(inputs.shape) | |||
| return rank_map | |||
| @@ -14,10 +14,14 @@ | |||
| # ============================================================================ | |||
| """Predefined XAI metrics.""" | |||
| from ._attribution.class_sensitivity import ClassSensitivity | |||
| from ._attribution.faithfulness import Faithfulness | |||
| from ._attribution.localization import Localization | |||
| from ._attribution.robustness import Robustness | |||
| __all__ = [ | |||
| "ClassSensitivity", | |||
| "Faithfulness", | |||
| "Localization" | |||
| "Localization", | |||
| "Robustness" | |||
| ] | |||
| @@ -13,11 +13,3 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Predefined XAI metrics""" | |||
| from .faithfulness import Faithfulness | |||
| from .localization import Localization | |||
| __all__ = [ | |||
| "Faithfulness", | |||
| "Localization" | |||
| ] | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Class Sensitivity.""" | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from .metric import LabelAgnosticMetric | |||
| from ... import _operators as ops | |||
| from ...explanation._attribution.attribution import Attribution | |||
| from ..._utils import calc_correlation | |||
| class ClassSensitivity(LabelAgnosticMetric): | |||
| r""" | |||
| Class sensitivity metric used to evaluate attribution-based explanations. | |||
| Reasonable atrribution-based explainers are expected to generate distinct saliency maps for different labels, | |||
| especially for labels of highest confidence and low confidence. Class sensitivity evaluates the explainer through | |||
| computing the correlation between saliency maps of highest-confidence and lowest-confidence labels. Explainer with | |||
| better class sensitivity will receive lower correlation score. To make the evaluation results intuitive, the | |||
| returned score will take negative on correlation and normalize. | |||
| """ | |||
| def evaluate(self, explainer: Attribution, inputs: Tensor) -> np.ndarray: | |||
| """ | |||
| Evaluate class sensitivity on a single data sample. | |||
| Args: | |||
| explainer (Attribution): The explainer to be evaluated, see `mindspore.explainer.explanation`. | |||
| inputs (Tensor): A data sample, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| Returns: | |||
| numpy.ndarray, 1D array of shape :math:`(N,)`, result of class sensitivity evaluated on `explainer`. | |||
| Examples: | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import Gradient | |||
| >>> gradient = Gradient() | |||
| >>> x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> class_sensitivity = ClassSensitivity() | |||
| >>> res = class_sensitivity.evaluate(gradient, x) | |||
| """ | |||
| self._check_evaluate_param(explainer, inputs) | |||
| outputs = explainer.model(inputs) | |||
| max_confidence_label = ops.argmax(outputs) | |||
| min_confidence_label = ops.argmin(outputs) | |||
| max_confidence_saliency = explainer(inputs, max_confidence_label).asnumpy() | |||
| min_confidence_saliency = explainer(inputs, min_confidence_label).asnumpy() | |||
| correlations = [] | |||
| for i in range(inputs.shape[0]): | |||
| correlation = calc_correlation(max_confidence_saliency[i].reshape(-1), | |||
| min_confidence_saliency[i].reshape(-1)) | |||
| normalized_correlation = (-correlation + 1) / 2 | |||
| correlations.append(normalized_correlation) | |||
| return np.array(correlations, np.float) | |||
| @@ -12,21 +12,19 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Faithfulness""" | |||
| import math | |||
| from typing import Callable, Optional, Union, Tuple | |||
| """Faithfulness.""" | |||
| from typing import Callable, Optional, Union | |||
| import numpy as np | |||
| from scipy.ndimage.filters import gaussian_filter | |||
| from mindspore import log | |||
| import mindspore as ms | |||
| from mindspore.train._utils import check_value_type | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as op | |||
| from .metric import AttributionMetric | |||
| from ..._utils import calc_correlation, calc_auc, format_tensor_to_ndarray, rank_pixels | |||
| from ...explanation._attribution._attribution import Attribution as _Attribution | |||
| from .metric import LabelSensitiveMetric | |||
| from ..._utils import calc_auc, format_tensor_to_ndarray | |||
| from ...explanation._attribution import Attribution as _Attribution | |||
| from ...explanation._attribution._perturbation.replacement import Constant, GaussianBlur | |||
| from ...explanation._attribution._perturbation.ablation import AblationWithSaliency | |||
| _Array = np.ndarray | |||
| _Explainer = Union[_Attribution, Callable] | |||
| @@ -36,189 +34,19 @@ _Module = nn.Cell | |||
| def _calc_feature_importance(saliency: _Array, masks: _Array) -> _Array: | |||
| """Calculate feature important w.r.t given masks.""" | |||
| feature_importance = [] | |||
| num_perturbations = masks.shape[0] | |||
| for i in range(num_perturbations): | |||
| patch_feature_importance = saliency[masks[i]].sum() / masks[i].sum() | |||
| feature_importance.append(patch_feature_importance) | |||
| feature_importance = np.array(feature_importance, dtype=np.float32) | |||
| if saliency.shape[1] < masks.shape[2]: | |||
| saliency = np.repeat(saliency, repeats=masks.shape[2], axis=1) | |||
| batch_size = masks.shape[0] | |||
| num_perturbations = masks.shape[1] | |||
| saliency = np.repeat(saliency, repeats=num_perturbations, axis=0) | |||
| saliency = saliency.reshape([batch_size, num_perturbations, -1]) | |||
| masks = masks.reshape([batch_size, num_perturbations, -1]) | |||
| feature_importance = saliency * masks | |||
| feature_importance = feature_importance.sum(-1) / masks.sum(-1) | |||
| return feature_importance | |||
| class _BaseReplacement: | |||
| """ | |||
| Base class of generator for generating different replacement for perturbations. | |||
| Args: | |||
| kwargs: Optional args for generating replacement. Derived class need to | |||
| add necessary arg names and default value to '_necessary_args'. | |||
| If the argument has no default value, the value should be set to | |||
| 'EMPTY' to mark the required args. Initializing an object will | |||
| check the given kwargs w.r.t '_necessary_args'. | |||
| Raise: | |||
| ValueError: Raise when provided kwargs not contain necessary arg names with 'EMPTY' mark. | |||
| """ | |||
| _necessary_args = {} | |||
| def __init__(self, **kwargs): | |||
| self._replace_args = self._necessary_args.copy() | |||
| for key, value in self._replace_args.items(): | |||
| if key in kwargs.keys(): | |||
| self._replace_args[key] = kwargs[key] | |||
| elif key not in kwargs.keys() and value == 'EMPTY': | |||
| raise ValueError(f"Missing keyword arg {key} for {self.__class__.__name__}.") | |||
| __call__: Callable | |||
| """ | |||
| Generate replacement for perturbations. Derived class should overwrite this | |||
| function to generate different replacement for perturbing. | |||
| Args: | |||
| inputs (_Array): Array to be perturb. | |||
| Returns: | |||
| - replacement (_Array): Array to provide alternative pixels for every | |||
| position in the given | |||
| inputs. The returned array should have same shape as inputs. | |||
| """ | |||
| class Constant(_BaseReplacement): | |||
| """ Generator to provide constant-value replacement for perturbations """ | |||
| _necessary_args = {'base_value': 'EMPTY'} | |||
| def __call__(self, inputs: _Array) -> _Array: | |||
| replacement = np.ones_like(inputs, dtype=np.float32) | |||
| replacement *= self._replace_args['base_value'] | |||
| return replacement | |||
| class GaussianBlur(_BaseReplacement): | |||
| """ Generator to provided gaussian blurred inputs for perturbation. """ | |||
| _necessary_args = {'sigma': 0.7} | |||
| def __call__(self, inputs: _Array) -> _Array: | |||
| sigma = self._replace_args['sigma'] | |||
| replacement = gaussian_filter(inputs, sigma=sigma) | |||
| return replacement | |||
| class Perturb: | |||
| """ | |||
| Perturbation generator to generate perturbations for a given array. | |||
| Args: | |||
| perturb_percent (float): percentage of pixels to perturb | |||
| perturb_mode (str): specify perturbing mode, through deleting or | |||
| inserting pixels. Current support: ['Deletion', 'Insertion']. | |||
| is_accumulate (bool): whether to accumulate the former perturbations to | |||
| the later perturbations. | |||
| perturb_pixel_per_step (int, optional): number of pixel to perturb | |||
| for each perturbation. If perturb_pixel_per_step is None, actual | |||
| perturb_pixel_per_step will be calculate by: | |||
| num_image_pixel * perturb_percent / num_perturb_steps. | |||
| Default: None | |||
| num_perturbations (int, optional): number of perturbations. If | |||
| num_perturbations if None, it will be calculated by: | |||
| num_image_pixel * perturb_percent / perturb_pixel_per_step. | |||
| Default: None | |||
| """ | |||
| def __init__(self, | |||
| perturb_percent: float, | |||
| perturb_mode: str, | |||
| is_accumulate: bool, | |||
| perturb_pixel_per_step: Optional[int] = None, | |||
| num_perturbations: Optional[int] = None): | |||
| self._perturb_percent = perturb_percent | |||
| self._perturb_mode = perturb_mode | |||
| self._pixel_per_step = perturb_pixel_per_step | |||
| self._num_perturbations = num_perturbations | |||
| self._is_accumulate = is_accumulate | |||
| @staticmethod | |||
| def _assign(x: _Array, y: _Array, masks: _Array): | |||
| """Assign values to perturb pixels on perturbations.""" | |||
| check_value_type("masks dtype", masks.dtype, type(np.dtype(bool))) | |||
| for i in range(x.shape[0]): | |||
| x[i][:, masks[i]] = y[:, masks[i]] | |||
| def _generate_mask(self, saliency_rank: _Array) -> _Array: | |||
| """Generate mask for perturbations based on given saliency ranks.""" | |||
| if len(saliency_rank.shape) != 2: | |||
| raise ValueError(f'The param "saliency_rank" should be 2-dim, but receive {len(saliency_rank.shape)}.') | |||
| num_pixels = saliency_rank.shape[0] * saliency_rank.shape[1] | |||
| if self._pixel_per_step: | |||
| pixel_per_step = self._pixel_per_step | |||
| num_perturbations = math.floor( | |||
| num_pixels * self._perturb_percent / self._pixel_per_step) | |||
| elif self._num_perturbations: | |||
| pixel_per_step = math.floor( | |||
| num_pixels * self._perturb_percent / self._num_perturbations) | |||
| num_perturbations = self._num_perturbations | |||
| else: | |||
| raise ValueError("Must provide either pixel_per_step or num_perturbations.") | |||
| masks = np.zeros( | |||
| (num_perturbations, saliency_rank.shape[0], saliency_rank.shape[1]), | |||
| dtype=np.bool) | |||
| low_bound = 0 | |||
| up_bound = low_bound + pixel_per_step | |||
| factor = 0 if self._is_accumulate else 1 | |||
| for i in range(num_perturbations): | |||
| masks[i, ((saliency_rank >= low_bound) | |||
| & (saliency_rank < up_bound))] = True | |||
| low_bound = up_bound * factor | |||
| up_bound += pixel_per_step | |||
| if len(masks.shape) == 3: | |||
| return masks | |||
| raise ValueError(f'Invalid masks shape {len(masks.shape)}, expect 3-dim.') | |||
| def __call__(self, | |||
| inputs: _Array, | |||
| saliency: _Array, | |||
| reference: _Array, | |||
| return_mask: bool = False, | |||
| ) -> Union[_Array, Tuple[_Array, ...]]: | |||
| """ | |||
| Generate perturbations of given array. | |||
| Args: | |||
| inputs (_Array): input array to perturb | |||
| saliency (_Array): saliency map | |||
| return_mask (bool): whether return the mask for generating | |||
| the perturbation. The mask can be used to calculate | |||
| average feature importance of pixels perturbed at each step. | |||
| Return: | |||
| perturbations (_Array) | |||
| masks (_Array): return when return_mask is set to True. | |||
| """ | |||
| if not np.array_equal(inputs.shape, reference.shape): | |||
| raise ValueError('reference must have the same shape as inputs.') | |||
| saliency_rank = rank_pixels(saliency, descending=True) | |||
| masks = self._generate_mask(saliency_rank) | |||
| num_perturbations = masks.shape[0] | |||
| if self._perturb_mode == 'Insertion': | |||
| inputs, reference = reference, inputs | |||
| perturbations = np.tile( | |||
| inputs, (num_perturbations, *[1] * len(inputs.shape))) | |||
| Perturb._assign(perturbations, reference, masks) | |||
| if return_mask: | |||
| return perturbations, masks | |||
| return perturbations | |||
| class _FaithfulnessHelper: | |||
| """Base class for faithfulness calculator.""" | |||
| _support = [Constant, GaussianBlur] | |||
| @@ -240,27 +68,15 @@ class _FaithfulnessHelper: | |||
| raise ValueError( | |||
| 'The param "perturb_method" should be one of {}.'.format([x.__name__ for x in self._support])) | |||
| self._perturb = Perturb(perturb_percent=perturb_percent, | |||
| perturb_mode=perturb_mode, | |||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||
| num_perturbations=num_perturbations, | |||
| is_accumulate=is_accumulate) | |||
| self._ablation = AblationWithSaliency(perturb_mode=perturb_mode, | |||
| perturb_percent=perturb_percent, | |||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||
| num_perturbations=num_perturbations, | |||
| is_accumulate=is_accumulate) | |||
| calc_faithfulness: Callable | |||
| """ | |||
| Method used to calculate faithfulness for given inputs, target label, | |||
| saliency. Derive class should implement this method. | |||
| Args: | |||
| inputs (_Array): sample to calculate faithfulness score | |||
| model (_Module): model to explanation | |||
| targets (_Label): label to explanation on. | |||
| saliency (_Array): Saliency map of given inputs and targets from the | |||
| explainer. | |||
| Return: | |||
| - faithfulness (float): faithfulness score | |||
| """ | |||
| def calc_faithfulness(self, inputs, model, targets, saliency): | |||
| """Calc faithfulness.""" | |||
| raise NotImplementedError | |||
| class NaiveFaithfulness(_FaithfulnessHelper): | |||
| @@ -304,14 +120,13 @@ class NaiveFaithfulness(_FaithfulnessHelper): | |||
| perturb_pixel_per_step: Optional[int] = None, | |||
| num_perturbations: Optional[int] = None, | |||
| **kwargs): | |||
| super(NaiveFaithfulness, self).__init__( | |||
| perturb_percent=perturb_percent, | |||
| perturb_mode='Deletion', | |||
| perturb_method=perturb_method, | |||
| is_accumulate=is_accumulate, | |||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||
| num_perturbations=num_perturbations, | |||
| **kwargs) | |||
| super().__init__(perturb_percent=perturb_percent, | |||
| perturb_mode='Deletion', | |||
| perturb_method=perturb_method, | |||
| is_accumulate=is_accumulate, | |||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||
| num_perturbations=num_perturbations, | |||
| **kwargs) | |||
| def calc_faithfulness(self, | |||
| inputs: _Array, | |||
| @@ -336,16 +151,21 @@ class NaiveFaithfulness(_FaithfulnessHelper): | |||
| log.warning("The saliency map is zero everywhere. The correlation will be set to zero.") | |||
| correlation = 0 | |||
| return np.array([correlation], np.float) | |||
| batch_size = inputs.shape[0] | |||
| reference = self._get_reference(inputs) | |||
| perturbations, masks = self._perturb( | |||
| inputs, saliency, reference, return_mask=True) | |||
| masks = self._ablation.generate_mask(saliency, inputs.shape[1]) | |||
| perturbations = self._ablation(inputs, reference, masks) | |||
| feature_importance = _calc_feature_importance(saliency, masks) | |||
| perturbations = perturbations.reshape(-1, *perturbations.shape[2:]) | |||
| perturbations = ms.Tensor(perturbations, dtype=ms.float32) | |||
| predictions = model(perturbations).asnumpy()[:, targets] | |||
| predictions = model(perturbations)[:, targets].asnumpy() | |||
| predictions = predictions.reshape(*feature_importance.shape) | |||
| faithfulness = calc_correlation(feature_importance, predictions) | |||
| return np.array([faithfulness], np.float) | |||
| faithfulness = -np.corrcoef(feature_importance, predictions) | |||
| faithfulness = np.diag(faithfulness[:batch_size, batch_size:]) | |||
| return faithfulness | |||
| class DeletionAUC(_FaithfulnessHelper): | |||
| @@ -385,20 +205,19 @@ class DeletionAUC(_FaithfulnessHelper): | |||
| perturb_pixel_per_step: Optional[int] = None, | |||
| num_perturbations: Optional[int] = None, | |||
| **kwargs): | |||
| super(DeletionAUC, self).__init__( | |||
| perturb_percent=perturb_percent, | |||
| perturb_mode='Deletion', | |||
| perturb_method=perturb_method, | |||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||
| num_perturbations=num_perturbations, | |||
| is_accumulate=True, | |||
| **kwargs) | |||
| super().__init__(perturb_percent=perturb_percent, | |||
| perturb_mode='Deletion', | |||
| perturb_method=perturb_method, | |||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||
| num_perturbations=num_perturbations, | |||
| is_accumulate=True, | |||
| **kwargs) | |||
| def calc_faithfulness(self, | |||
| inputs: _Array, | |||
| model: _Module, | |||
| targets: _Label, | |||
| saliency: _Array) -> np.ndarray: | |||
| saliency: _Array) -> _Array: | |||
| """ | |||
| Calculate faithfulness through deletion AUC. | |||
| @@ -414,14 +233,17 @@ class DeletionAUC(_FaithfulnessHelper): | |||
| """ | |||
| reference = self._get_reference(inputs) | |||
| perturbations = self._perturb(inputs, saliency, reference) | |||
| masks = self._ablation.generate_mask(saliency, inputs.shape[1]) | |||
| perturbations = self._ablation(inputs, reference, masks) | |||
| perturbations = perturbations.reshape(-1, *perturbations.shape[2:]) | |||
| perturbations = ms.Tensor(perturbations, dtype=ms.float32) | |||
| predictions = model(perturbations).asnumpy()[:, targets] | |||
| input_tensor = op.ExpandDims()(ms.Tensor(inputs, ms.float32), 0) | |||
| predictions = predictions.reshape((inputs.shape[0], -1)) | |||
| input_tensor = ms.Tensor(inputs, ms.float32) | |||
| original_output = model(input_tensor).asnumpy()[:, targets] | |||
| auc = calc_auc(original_output - predictions) | |||
| return np.array([1 - auc]) | |||
| auc = calc_auc(original_output.squeeze() - predictions.squeeze()) | |||
| return np.array([1 - auc], np.float) | |||
| class InsertionAUC(_FaithfulnessHelper): | |||
| @@ -462,20 +284,19 @@ class InsertionAUC(_FaithfulnessHelper): | |||
| perturb_pixel_per_step: Optional[int] = None, | |||
| num_perturbations: Optional[int] = None, | |||
| **kwargs): | |||
| super(InsertionAUC, self).__init__( | |||
| perturb_percent=perturb_percent, | |||
| perturb_mode='Insertion', | |||
| perturb_method=perturb_method, | |||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||
| num_perturbations=num_perturbations, | |||
| is_accumulate=True, | |||
| **kwargs) | |||
| super().__init__(perturb_percent=perturb_percent, | |||
| perturb_mode='Insertion', | |||
| perturb_method=perturb_method, | |||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||
| num_perturbations=num_perturbations, | |||
| is_accumulate=True, | |||
| **kwargs) | |||
| def calc_faithfulness(self, | |||
| inputs: _Array, | |||
| model: _Module, | |||
| targets: _Label, | |||
| saliency: _Array) -> np.ndarray: | |||
| saliency: _Array) -> _Array: | |||
| """ | |||
| Calculate faithfulness through insertion AUC. | |||
| @@ -491,17 +312,21 @@ class InsertionAUC(_FaithfulnessHelper): | |||
| """ | |||
| reference = self._get_reference(inputs) | |||
| perturbations = self._perturb(inputs, saliency, reference) | |||
| masks = self._ablation.generate_mask(saliency, inputs.shape[1]) | |||
| perturbations = self._ablation(inputs, reference, masks) | |||
| perturbations = perturbations.reshape(-1, *perturbations.shape[2:]) | |||
| perturbations = ms.Tensor(perturbations, dtype=ms.float32) | |||
| predictions = model(perturbations).asnumpy()[:, targets] | |||
| base_tensor = op.ExpandDims()(ms.Tensor(reference, ms.float32), 0) | |||
| predictions = predictions.reshape((inputs.shape[0], -1)) | |||
| base_tensor = ms.Tensor(reference, ms.float32) | |||
| base_outputs = model(base_tensor).asnumpy()[:, targets] | |||
| auc = calc_auc(predictions - base_outputs) | |||
| return np.array([auc]) | |||
| auc = calc_auc(predictions.squeeze() - base_outputs.squeeze()) | |||
| return np.array([auc], np.float) | |||
| class Faithfulness(AttributionMetric): | |||
| class Faithfulness(LabelSensitiveMetric): | |||
| """ | |||
| Provides evaluation on faithfulness on XAI explanations. | |||
| @@ -604,10 +429,6 @@ class Faithfulness(AttributionMetric): | |||
| inputs = format_tensor_to_ndarray(inputs) | |||
| saliency = format_tensor_to_ndarray(saliency) | |||
| inputs = inputs.squeeze(axis=0) | |||
| saliency = saliency.squeeze() | |||
| if len(saliency.shape) != 2: | |||
| raise ValueError('Squeezed saliency map is expected to 2D, but receive {}.'.format(len(saliency.shape))) | |||
| model = nn.SequentialCell([explainer.model, self._activation_fn]) | |||
| faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=model, | |||
| targets=targets, saliency=saliency) | |||
| @@ -16,7 +16,7 @@ | |||
| import numpy as np | |||
| from mindspore.train._utils import check_value_type | |||
| from .metric import AttributionMetric | |||
| from .metric import LabelSensitiveMetric | |||
| from ..._operators import maximum, reshape, Tensor | |||
| from ..._utils import format_tensor_to_ndarray | |||
| @@ -37,7 +37,7 @@ def _mask_out_saliency(saliency, threshold): | |||
| return mask_out | |||
| class Localization(AttributionMetric): | |||
| class Localization(LabelSensitiveMetric): | |||
| r""" | |||
| Provides evaluation on the localization capability of XAI methods. | |||
| @@ -13,12 +13,20 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Base class for XAI metrics.""" | |||
| import copy | |||
| from typing import Callable | |||
| import numpy as np | |||
| import mindspore as ms | |||
| from mindspore import log as logger | |||
| from mindspore.train._utils import check_value_type | |||
| from ..._operators import Tensor | |||
| from ..._utils import format_tensor_to_ndarray | |||
| from ...explanation._attribution._attribution import Attribution | |||
| from ...explanation._attribution.attribution import Attribution | |||
| _Explainer = Attribution | |||
| def verify_argument(inputs, arg_name): | |||
| @@ -46,8 +54,77 @@ def verify_targets(targets, num_labels): | |||
| class AttributionMetric: | |||
| """Super class of XAI metric class used in classification scenarios.""" | |||
| def __init__(self, num_labels=None): | |||
| self._verify_params(num_labels) | |||
| def __init__(self): | |||
| self._explainer = None | |||
| evaluate: Callable | |||
| """ | |||
| This method evaluates the explainer on the given attribution and returns the evaluation results. | |||
| Derived class should implement this method according to specific algorithms of the metric. | |||
| """ | |||
| def _record_explainer(self, explainer: _Explainer): | |||
| """Record the explainer in current evaluation.""" | |||
| if self._explainer is None: | |||
| self._explainer = explainer | |||
| elif self._explainer is not explainer: | |||
| logger.info('Provided explainer is not the same as previously evaluted one. Please reset the evaluated ' | |||
| 'results. Previous explainer: %s, current explainer: %s', self._explainer, explainer) | |||
| self._explainer = explainer | |||
| class LabelAgnosticMetric(AttributionMetric): | |||
| """Super class add functions for label-agnostic metric.""" | |||
| def __init__(self): | |||
| super().__init__() | |||
| self._global_results = [] | |||
| @property | |||
| def performance(self) -> float: | |||
| """ | |||
| Return the average evaluation result. | |||
| Return: | |||
| float, averaged result. If no result is aggregate in the global_results, 0.0 will be returned. | |||
| """ | |||
| if not self._global_results: | |||
| return 0.0 | |||
| results_sum = sum(self._global_results) | |||
| count = len(self._global_results) | |||
| return results_sum / count | |||
| def aggregate(self, result): | |||
| """Aggregate single evaluation result to global results.""" | |||
| if isinstance(result, float): | |||
| self._global_results.append(result) | |||
| elif isinstance(result, (ms.Tensor, np.ndarray)): | |||
| result = format_tensor_to_ndarray(result) | |||
| self._global_results.append(float(result)) | |||
| else: | |||
| raise TypeError('result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result)) | |||
| def get_results(self): | |||
| """Return the gloabl results.""" | |||
| return self._global_results.copy() | |||
| def reset(self): | |||
| """Reset global results.""" | |||
| self._global_results.clear() | |||
| def _check_evaluate_param(self, explainer, inputs): | |||
| """Check the evaluate parameters.""" | |||
| check_value_type('explainer', explainer, Attribution) | |||
| self._record_explainer(explainer) | |||
| verify_argument(inputs, 'inputs') | |||
| class LabelSensitiveMetric(AttributionMetric): | |||
| """Super class add functions for label-sensitive metrics.""" | |||
| def __init__(self, num_labels: int): | |||
| super().__init__() | |||
| LabelSensitiveMetric._verify_params(num_labels) | |||
| self._num_labels = num_labels | |||
| self._global_results = {i: [] for i in range(num_labels)} | |||
| @@ -57,10 +134,6 @@ class AttributionMetric: | |||
| if num_labels < 1: | |||
| raise ValueError("Argument num_labels must be parsed with a integer > 0.") | |||
| def evaluate(self, explainer, inputs, targets, saliency=None): | |||
| """This function evaluates on a single sample and return the result.""" | |||
| raise NotImplementedError | |||
| def aggregate(self, result, targets): | |||
| """Aggregates single result to global_results.""" | |||
| if isinstance(result, float): | |||
| @@ -120,11 +193,12 @@ class AttributionMetric: | |||
| def get_results(self): | |||
| """Global result of the metric can be return""" | |||
| return self._global_results | |||
| return copy.deepcopy(self._global_results) | |||
| def _check_evaluate_param(self, explainer, inputs, targets, saliency): | |||
| """Check the evaluate parameters.""" | |||
| check_value_type('explainer', explainer, Attribution) | |||
| self._record_explainer(explainer) | |||
| verify_argument(inputs, 'inputs') | |||
| output = explainer.model(inputs) | |||
| check_value_type("output of explainer model", output, Tensor) | |||
| @@ -0,0 +1,134 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Robustness.""" | |||
| from typing import Optional, Union | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import log | |||
| from .metric import LabelSensitiveMetric | |||
| from ...explanation._attribution import Attribution | |||
| from ...explanation._attribution._perturbation.replacement import RandomPerturb | |||
| _Array = np.ndarray | |||
| _Label = Union[ms.Tensor, int] | |||
| class Robustness(LabelSensitiveMetric): | |||
| """ | |||
| Robustness perturbs the inputs by adding random noise and choose the maximum sensitivity as evaluation score from | |||
| the perturbations. | |||
| Args: | |||
| num_labels (int): Number of classes in the dataset. | |||
| Examples: | |||
| >>> from mindspore.explainer.benchmark import Robustness | |||
| >>> num_labels = 100 | |||
| >>> robustness = Robustness(num_labels) | |||
| """ | |||
| def __init__(self, num_labels: int, activation_fn=nn.Softmax()): | |||
| super().__init__(num_labels) | |||
| self._perturb = RandomPerturb() | |||
| self._num_perturbations = 100 # number of perturbations used in evaluation | |||
| self._threshold = 0.1 # threshold to generate perturbation | |||
| self._activation_fn = activation_fn | |||
| def evaluate(self, | |||
| explainer: Attribution, | |||
| inputs: Tensor, | |||
| targets: _Label, | |||
| saliency: Optional[Tensor] = None | |||
| ) -> _Array: | |||
| """ | |||
| Evaluate robustness on single sample. | |||
| Note: | |||
| Currently only single sample (:math:`N=1`) at each call is supported. | |||
| Args: | |||
| explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`. | |||
| inputs (Tensor): A data sample, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer. | |||
| If `targets` is a 1D tensor, its length should be the same as `inputs`. | |||
| saliency (Tensor, optional): The saliency map to be evaluated, a 4D tensor of shape :math:`(N, 1, H, W)`. | |||
| If it is None, the parsed `explainer` will generate the saliency map with `inputs` and `targets` and | |||
| continue the evaluation. Default: None. | |||
| Returns: | |||
| numpy.ndarray, 1D array of shape :math:`(N,)`, result of localization evaluated on `explainer`. | |||
| Raises: | |||
| ValueError: If batch_size is larger than 1. | |||
| Examples: | |||
| >>> # init an explainer, the network should contain the output activation function. | |||
| >>> from mindspore.explainer.explanation import Gradient | |||
| >>> from mindspore.explainer.benchmark import Robustness | |||
| >>> gradient = Gradient(network) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> target_label = 5 | |||
| >>> robustness = Robustness(num_labels=10) | |||
| >>> res = robustness.evaluate(gradient, input_x, target_label) | |||
| """ | |||
| self._check_evaluate_param(explainer, inputs, targets, saliency) | |||
| if inputs.shape[0] > 1: | |||
| raise ValueError('Robustness only support a sample each time, but receive {}'.format(inputs.shape[0])) | |||
| inputs_np = inputs.asnumpy() | |||
| if isinstance(targets, int): | |||
| targets = ms.Tensor(targets, ms.int32) | |||
| if saliency is None: | |||
| saliency = explainer(inputs, targets) | |||
| saliency_np = saliency.asnumpy() | |||
| norm = np.sqrt(np.sum(np.square(saliency_np), axis=tuple(range(1, len(saliency_np.shape))))) | |||
| if norm == 0: | |||
| log.warning('Get saliency norm equals 0, robustness return NaN for zero-norm saliency currently.') | |||
| return np.array([np.nan]) | |||
| perturbations = [] | |||
| for sample in inputs_np: | |||
| sample = np.expand_dims(sample, axis=0) | |||
| perturbations_per_input = [] | |||
| for _ in range(self._num_perturbations): | |||
| perturbation = self._perturb(sample) | |||
| perturbations_per_input.append(perturbation) | |||
| perturbations_per_input = np.vstack(perturbations_per_input) | |||
| perturbations.append(perturbations_per_input) | |||
| perturbations = np.stack(perturbations, axis=0) | |||
| perturbations = np.reshape(perturbations, (-1,) + inputs_np.shape[1:]) | |||
| perturbations = ms.Tensor(perturbations, ms.float32) | |||
| repeated_targets = np.repeat(targets.asnumpy(), repeats=self._num_perturbations, axis=0) | |||
| repeated_targets = ms.Tensor(repeated_targets, ms.int32) | |||
| saliency_of_perturbations = explainer(perturbations, repeated_targets) | |||
| perturbations_saliency = saliency_of_perturbations.asnumpy() | |||
| repeated_saliency = np.repeat(saliency_np, repeats=self._num_perturbations, axis=0) | |||
| sensitivities = np.sum((repeated_saliency - perturbations_saliency) ** 2, | |||
| axis=tuple(range(1, len(repeated_saliency.shape)))) | |||
| max_sensitivity = np.max(sensitivities.reshape((norm.shape[0], -1)), axis=1) / norm | |||
| robustness_res = 1 / np.exp(max_sensitivity) | |||
| return robustness_res | |||
| @@ -14,9 +14,10 @@ | |||
| # ============================================================================ | |||
| """Predefined Attribution explainers.""" | |||
| from ._attribution._backprop.gradcam import GradCAM | |||
| from ._attribution._backprop.gradient import Gradient | |||
| from ._attribution._backprop.gradcam import GradCAM | |||
| from ._attribution._backprop.modified_relu import Deconvolution, GuidedBackprop | |||
| from ._attribution._perturbation.occlusion import Occlusion | |||
| from ._attribution._perturbation.rise import RISE | |||
| __all__ = [ | |||
| @@ -24,5 +25,6 @@ __all__ = [ | |||
| 'Deconvolution', | |||
| 'GuidedBackprop', | |||
| 'GradCAM', | |||
| 'RISE' | |||
| 'Occlusion', | |||
| 'RISE', | |||
| ] | |||
| @@ -13,15 +13,9 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Predefined Attribution explainers.""" | |||
| from ._backprop.gradcam import GradCAM | |||
| from ._backprop.gradient import Gradient | |||
| from ._backprop.modified_relu import Deconvolution, GuidedBackprop | |||
| from ._perturbation.rise import RISE | |||
| from .attribution import Attribution | |||
| __all__ = [ | |||
| 'Gradient', | |||
| 'Deconvolution', | |||
| 'GuidedBackprop', | |||
| 'GradCAM', | |||
| 'RISE' | |||
| 'Attribution' | |||
| ] | |||
| @@ -13,12 +13,3 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Backprop-base _attribution explainer.""" | |||
| from .gradient import Gradient | |||
| from .gradcam import GradCAM | |||
| from .modified_relu import Deconvolution, GuidedBackprop | |||
| __all__ = ['Gradient', | |||
| 'GradCAM', | |||
| 'Deconvolution', | |||
| 'GuidedBackprop'] | |||
| @@ -22,7 +22,6 @@ from .intermediate_layer import IntermediateLayerAttribution | |||
| from ...._utils import ForwardProbe, retrieve_layer, unify_inputs, unify_targets | |||
| def _gradcam_aggregation(attributions): | |||
| """ | |||
| Aggregate the gradient and activation to get the final _attribution. | |||
| @@ -76,10 +75,7 @@ class GradCAM(IntermediateLayerAttribution): | |||
| >>> gradcam = GradCAM(net, layer=layer_name) | |||
| """ | |||
| def __init__( | |||
| self, | |||
| network, | |||
| layer=""): | |||
| def __init__(self, network, layer=""): | |||
| super(GradCAM, self).__init__(network, layer) | |||
| self._saliency_cell = retrieve_layer(self._backward_model, target_layer=layer) | |||
| @@ -16,12 +16,11 @@ | |||
| from copy import deepcopy | |||
| from mindspore import nn | |||
| from mindspore.ops import operations as op | |||
| from mindspore.train._utils import check_value_type | |||
| from ...._operators import reshape, sqrt, Tensor | |||
| from .._attribution import Attribution | |||
| from ..attribution import Attribution | |||
| from .backprop_utils import compute_gradients | |||
| from ...._utils import unify_inputs, unify_targets | |||
| from ...._utils import abs_max, unify_inputs, unify_targets | |||
| def _get_hook(bntype, cache): | |||
| @@ -41,16 +40,6 @@ def _get_hook(bntype, cache): | |||
| return reset_gradient | |||
| def _abs_max(gradients): | |||
| """ | |||
| Transform gradients to saliency through abs then take max along | |||
| channels. | |||
| """ | |||
| gradients = op.Abs()(gradients) | |||
| saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1) | |||
| return saliency | |||
| class Gradient(Attribution): | |||
| r""" | |||
| Provides Gradient explanation method. | |||
| @@ -85,8 +74,7 @@ class Gradient(Attribution): | |||
| self._backward_model.set_grad(False) | |||
| self._hook_bn() | |||
| self._grad_op = compute_gradients | |||
| self._aggregation_fn = _abs_max | |||
| self._aggregation_fn = abs_max | |||
| def __call__(self, inputs, targets): | |||
| """ | |||
| @@ -13,7 +13,3 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ Perturbation-based _attribution explainer. """ | |||
| from .rise import RISE | |||
| __all__ = ['RISE'] | |||
| @@ -0,0 +1,182 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Modules to ablate images.""" | |||
| __all__ = [ | |||
| 'Ablation', | |||
| 'AblationWithSaliency', | |||
| ] | |||
| import math | |||
| from functools import reduce | |||
| from typing import Optional, Union | |||
| import numpy as np | |||
| from .replacement import Constant | |||
| from ...._utils import rank_pixels | |||
| class Ablation: | |||
| """Base class to ablate image based on given replacement.""" | |||
| def __init__(self, perturb_mode: str): | |||
| self._perturb_mode = perturb_mode | |||
| def __call__(self, | |||
| inputs: np.array, | |||
| reference: Union[np.array, float], | |||
| masks: np.array | |||
| ) -> np.array: | |||
| """ | |||
| Generate perturbations of given array. | |||
| Args: | |||
| inputs (np.ndarray): Input array to perturb. The first dim of inputs is assumed to be the batch size, i.e., | |||
| number of samples. | |||
| reference (np.ndarray or float): Array of values to replace the elements in the original inputs. The shape | |||
| of reference must math the inputs. If scalar is provided, the perturbed elements will be assigned the | |||
| given value.. | |||
| masks (np.ndarray): Several boolean array to mark the perturbed positions. True marks the pixels to be | |||
| perturbed, otherwise the pixels will be kept. The shape of masks is assumed to be | |||
| [batch_size, num_perturbations, inputs_shape[1:]]. | |||
| Return: | |||
| perturbations (np.ndarray) | |||
| """ | |||
| if isinstance(reference, float): | |||
| reference = Constant(base_value=reference)(inputs) | |||
| if not np.array_equal(inputs.shape, reference.shape): | |||
| raise ValueError('reference must have the same shape as inputs.') | |||
| num_perturbations = masks.shape[1] | |||
| if self._perturb_mode == 'Insertion': | |||
| inputs, reference = reference, inputs | |||
| perturbations = np.repeat(inputs[:, None, :], num_perturbations, 1) | |||
| reference = np.repeat(reference[:, None, :], num_perturbations, 1) | |||
| Ablation._assign(perturbations, reference, masks) | |||
| return perturbations | |||
| @staticmethod | |||
| def _assign(original_array: np.ndarray, replacement: np.ndarray, masks: np.ndarray): | |||
| """Assign values to perturb pixels on perturbations.""" | |||
| if masks.dtype != bool: | |||
| raise TypeError('The param "masks" should be an array of bool, but receive {}'.format(masks.dtype)) | |||
| if not np.array_equal(original_array.shape, masks.shape): | |||
| raise ValueError('masks must have the shape {} same as [batch_size, num_perturbations, inputs.shape[1:],' | |||
| 'but receive {}.'.format(original_array.shape, masks.shape)) | |||
| original_array[masks] = replacement[masks] | |||
| class AblationWithSaliency(Ablation): | |||
| """ | |||
| Perturbation generator to generate perturbations for a given array. | |||
| Args: | |||
| perturb_percent (float): percentage of pixels to perturb | |||
| perturb_mode (str): specify perturbing mode, through deleting or | |||
| inserting pixels. Current support: ['Deletion', 'Insertion']. | |||
| is_accumulate (bool): whether to accumulate the former perturbations to | |||
| the later perturbations. | |||
| perturb_pixel_per_step (int, optional): number of pixel to perturb | |||
| for each perturbation. If perturb_pixel_per_step is None, actual | |||
| perturb_pixel_per_step will be calculate by: | |||
| num_image_pixel * perturb_percent / num_perturb_steps. | |||
| Default: None | |||
| num_perturbations (int, optional): number of perturbations. If | |||
| num_perturbations if None, it will be calculated by: | |||
| num_image_pixel * perturb_percent / perturb_pixel_per_step. | |||
| Default: None | |||
| """ | |||
| def __init__(self, | |||
| perturb_mode: str, | |||
| perturb_percent: float = 1.0, | |||
| is_accumulate: bool = False, | |||
| perturb_pixel_per_step: Optional[int] = None, | |||
| num_perturbations: Optional[int] = None): | |||
| super().__init__(perturb_mode) | |||
| self._perturb_percent = perturb_percent | |||
| self._perturb_mode = perturb_mode | |||
| self._pixel_per_step = perturb_pixel_per_step | |||
| self._num_perturbations = num_perturbations | |||
| self._is_accumulate = is_accumulate | |||
| def generate_mask(self, | |||
| saliency: np.ndarray, | |||
| num_channels: Optional[int] = None | |||
| ) -> np.ndarray: | |||
| """ | |||
| Generate mask for perturbations based on given saliency ranks. | |||
| Args: | |||
| saliency (np.ndarray): Perturbing masks will be generated based on the given saliency map. The shape of | |||
| saliency is expected to be: [batch_size, optional(num_channels), *spatial_size]. If multi-channel | |||
| saliency is provided, an averaged saliency will be taken to calculate pixel order in spatial dimension. | |||
| num_channels (optional[int]): Number of channels of the input data. In order to match the shape of inputs, | |||
| num_channels should be provided when input data have channels dimension, even if num_channel. If None is | |||
| provided, the inputs is assumed to be no-channel data, and the generated mask will have no channel | |||
| dimension. Default: None. | |||
| Return: | |||
| mask (np.ndarray): boolen mask for generate perturbations. | |||
| """ | |||
| batch_size = saliency.shape[0] | |||
| expected_num_dim = len(saliency.shape) + 1 | |||
| has_channel = num_channels is not None | |||
| num_channels = 1 if num_channels is None else num_channels | |||
| if has_channel: | |||
| saliency = saliency.mean(axis=1) | |||
| saliency_rank = rank_pixels(saliency, descending=True) | |||
| num_pixels = reduce(lambda x, y: x * y, saliency.shape[1:]) | |||
| if self._pixel_per_step: | |||
| pixel_per_step = self._pixel_per_step | |||
| num_perturbations = math.floor(num_pixels * self._perturb_percent / self._pixel_per_step) | |||
| elif self._num_perturbations: | |||
| pixel_per_step = math.floor(num_pixels * self._perturb_percent / self._num_perturbations) | |||
| num_perturbations = self._num_perturbations | |||
| else: | |||
| raise ValueError("Must provide either pixel_per_step or num_perturbations.") | |||
| masks = np.zeros((batch_size, num_perturbations, num_channels, saliency_rank.shape[1], saliency_rank.shape[2]), | |||
| dtype=np.bool) | |||
| factor = 0 if self._is_accumulate else 1 | |||
| for i in range(batch_size): | |||
| low_bound = 0 | |||
| up_bound = low_bound + pixel_per_step | |||
| for j in range(num_perturbations): | |||
| masks[i, j, :, ((saliency_rank[i] >= low_bound) & (saliency_rank[i] < up_bound))] = True | |||
| low_bound = up_bound + factor | |||
| up_bound += pixel_per_step | |||
| masks = masks if has_channel else np.squeeze(masks, axis=2) | |||
| if len(masks.shape) == expected_num_dim: | |||
| return masks | |||
| raise ValueError(f'Invalid masks shape {len(masks.shape)}, expect {expected_num_dim}-dim.') | |||
| @@ -0,0 +1,166 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Occlusion explainer.""" | |||
| import math | |||
| from typing import Tuple, Union | |||
| import numpy as np | |||
| from numpy.lib.stride_tricks import as_strided | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.nn import Cell | |||
| from .ablation import Ablation | |||
| from .perturbation import PerturbationAttribution | |||
| from .replacement import Constant | |||
| from ...._utils import abs_max | |||
| _Array = np.ndarray | |||
| _Label = Union[int, Tensor] | |||
| def _generate_patches(array, window_size, stride): | |||
| """View as windows.""" | |||
| if not isinstance(array, np.ndarray): | |||
| raise TypeError("`array` must be a numpy ndarray") | |||
| arr_shape = np.array(array.shape) | |||
| window_size = np.array(window_size, dtype=arr_shape.dtype) | |||
| slices = tuple(slice(None, None, st) for st in stride) | |||
| window_strides = np.array(array.strides) | |||
| indexing_strides = array[slices].strides | |||
| win_indices_shape = (((np.array(array.shape) - np.array(window_size)) // np.array(stride)) + 1) | |||
| new_shape = tuple(list(win_indices_shape) + list(window_size)) | |||
| strides = tuple(list(indexing_strides) + list(window_strides)) | |||
| patches = as_strided(array, shape=new_shape, strides=strides) | |||
| return patches | |||
| class Occlusion(PerturbationAttribution): | |||
| r""" | |||
| Occlusion uses a sliding window to replace the pixels with a reference value (e.g. constant value), and computes | |||
| the output difference w.r.t the original output. The output difference caused by perturbed pixels are assigned as | |||
| feature importance to those pixels. For pixels involved in multiple sliding windows, the feature importance is the | |||
| averaged differences from multiple sliding windows. | |||
| For more details, please refer to the original paper via: `<https://arxiv.org/abs/1311.2901>`_. | |||
| Args: | |||
| network (Cell): Specify the black-box model to be explained. | |||
| Inputs: | |||
| inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer. | |||
| If it is a 1D tensor, its length should be the same as `inputs`. | |||
| Outputs: | |||
| Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. | |||
| Example: | |||
| >>> from mindspore.explainer.explanation import Occlusion | |||
| >>> net = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> occlusion = Occlusion(net) | |||
| >>> x = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32) | |||
| >>> label = 1 | |||
| >>> saliency = occlusion(x, label) | |||
| """ | |||
| def __init__(self, network: Cell, activation_fn: Cell = nn.Softmax()): | |||
| super().__init__(network, activation_fn) | |||
| self._ablation = Ablation(perturb_mode='Deletion') | |||
| self._aggregation_fn = abs_max | |||
| self._get_replacement = Constant(base_value=0.0) | |||
| self._num_sample_per_dim = 32 # specify the number of perturbations each dimension. | |||
| self._num_per_eval = 32 # number of perturbations each evaluation step. | |||
| def __call__(self, inputs: Tensor, targets: _Label) -> Tensor: | |||
| """Call function for 'Occlusion'.""" | |||
| self._verify_data(inputs, targets) | |||
| inputs = inputs.asnumpy() | |||
| targets = targets.asnumpy() if isinstance(targets, Tensor) else np.array([targets] * inputs.shape[0], np.int) | |||
| # If spatial size of input data is smaller than self._num_sample_per_dim, window_size and strides will set to | |||
| # `(C, 3, 3)` and `(C, 1, 1)` separately. | |||
| window_size = tuple( | |||
| [inputs.shape[1]] | |||
| + [x % self._num_sample_per_dim if x > self._num_sample_per_dim else 3 for x in inputs.shape[2:]]) | |||
| strides = tuple( | |||
| [inputs.shape[1]] | |||
| + [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 1 for x in inputs.shape[2:]]) | |||
| model = nn.SequentialCell([self._model, self._activation_fn]) | |||
| original_outputs = model(Tensor(inputs, ms.float32)).asnumpy()[np.arange(len(targets)), targets] | |||
| total_attribution = np.zeros_like(inputs) | |||
| weights = np.ones_like(inputs) | |||
| masks = Occlusion._generate_masks(inputs, window_size, strides) | |||
| num_perturbations = masks.shape[1] | |||
| original_outputs_repeat = np.repeat(original_outputs, repeats=num_perturbations, axis=0) | |||
| reference = self._get_replacement(inputs) | |||
| occluded_inputs = self._ablation(inputs, reference, masks) | |||
| targets_repeat = np.repeat(targets, repeats=num_perturbations, axis=0) | |||
| occluded_inputs = occluded_inputs.reshape((-1, *inputs.shape[1:])) | |||
| if occluded_inputs.shape[0] > self._num_per_eval: | |||
| cal_time = math.ceil(occluded_inputs.shape[0] / self._num_per_eval) | |||
| occluded_outputs = [] | |||
| for i in range(cal_time): | |||
| occluded_input = occluded_inputs[i*self._num_per_eval | |||
| :min((i+1) * self._num_per_eval, occluded_inputs.shape[0])] | |||
| target = targets_repeat[i*self._num_per_eval | |||
| :min((i+1) * self._num_per_eval, occluded_inputs.shape[0])] | |||
| occluded_output = model(Tensor(occluded_input)).asnumpy()[np.arange(target.shape[0]), target] | |||
| occluded_outputs.append(occluded_output) | |||
| occluded_outputs = np.concatenate(occluded_outputs) | |||
| else: | |||
| occluded_outputs = model(Tensor(occluded_inputs)).asnumpy()[np.arange(len(targets_repeat)), targets_repeat] | |||
| outputs_diff = original_outputs_repeat - occluded_outputs | |||
| outputs_diff = outputs_diff.reshape(inputs.shape[0], -1) | |||
| total_attribution += ( | |||
| outputs_diff.reshape(outputs_diff.shape + (1,) * (len(masks.shape) - 2)) * masks).sum(axis=1).clip(1e-6) | |||
| weights += masks.sum(axis=1) | |||
| attribution = self._aggregation_fn(ms.Tensor(total_attribution / weights)) | |||
| return attribution | |||
| @staticmethod | |||
| def _generate_masks(inputs: Tensor, window_size: Tuple[int, ...], strides: Tuple[int, ...]) -> _Array: | |||
| """Generate masks to perturb contiguous regions.""" | |||
| total_dim = np.prod(inputs.shape[1:]).item() | |||
| template = np.arange(total_dim).reshape(inputs.shape[1:]) | |||
| indices = _generate_patches(template, window_size, strides) | |||
| num_perturbations = indices.reshape((-1,) + window_size).shape[0] | |||
| indices = indices.reshape(num_perturbations, -1) | |||
| mask = np.zeros((num_perturbations, total_dim), dtype=np.bool) | |||
| for i in range(num_perturbations): | |||
| mask[i, indices[i]] = True | |||
| mask = mask.reshape((num_perturbations,) + inputs.shape[1:]) | |||
| masks = np.tile(mask, reps=(inputs.shape[0],) + (1,) * len(mask.shape)) | |||
| return masks | |||
| @@ -18,7 +18,7 @@ | |||
| from mindspore.train._utils import check_value_type | |||
| from mindspore.nn import Cell | |||
| from .._attribution import Attribution | |||
| from ..attribution import Attribution | |||
| from ...._operators import softmax | |||
| @@ -0,0 +1,85 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Modules to generate perturbations.""" | |||
| import numpy as np | |||
| from scipy.ndimage.filters import gaussian_filter | |||
| _Array = np.ndarray | |||
| __all__ = [ | |||
| 'BaseReplacement', | |||
| 'Constant', | |||
| 'GaussianBlur', | |||
| 'RandomPerturb', | |||
| ] | |||
| class BaseReplacement: | |||
| """ | |||
| Base class of generator for generating different replacement for perturbations. | |||
| Args: | |||
| kwargs: Optional args for generating replacement. Derived class need to | |||
| add necessary arg names and default value to '_necessary_args'. | |||
| If the argument has no default value, the value should be set to | |||
| 'EMPTY' to mark the required args. Initializing an object will | |||
| check the given kwargs w.r.t '_necessary_args'. | |||
| Raise: | |||
| ValueError: Raise when provided kwargs not contain necessary arg names with 'EMPTY' mark. | |||
| """ | |||
| _necessary_args = {} | |||
| def __init__(self, **kwargs): | |||
| self._replace_args = self._necessary_args.copy() | |||
| for key, value in self._replace_args.items(): | |||
| if key in kwargs.keys(): | |||
| self._replace_args[key] = kwargs[key] | |||
| elif key not in kwargs.keys() and value == 'EMPTY': | |||
| raise ValueError(f"Missing keyword arg {key} for {self.__class__.__name__}.") | |||
| def __call__(self, inputs): | |||
| raise NotImplementedError() | |||
| class Constant(BaseReplacement): | |||
| """Generator to provide constant-value replacement for perturbations.""" | |||
| _necessary_args = {'base_value': 'EMPTY'} | |||
| def __call__(self, inputs: _Array) -> _Array: | |||
| replacement = np.ones_like(inputs, dtype=np.float32) | |||
| replacement *= self._replace_args['base_value'] | |||
| return replacement | |||
| class GaussianBlur(BaseReplacement): | |||
| """Generator to provided gaussian blurred inputs for perturbation""" | |||
| _necessary_args = {'sigma': 0.7} | |||
| def __call__(self, inputs: _Array) -> _Array: | |||
| sigma = self._replace_args['sigma'] | |||
| replacement = gaussian_filter(inputs, sigma=sigma) | |||
| return replacement | |||
| class RandomPerturb(BaseReplacement): | |||
| """Generator to provide replacement by randomly adding noise.""" | |||
| _necessary_args = {'radius': 0.2} | |||
| def __call__(self, inputs: _Array) -> _Array: | |||
| radius = self._replace_args['radius'] | |||
| outputs = inputs + (2 * np.random.rand(*inputs.shape) - 1) * radius | |||
| return outputs | |||
| @@ -64,6 +64,9 @@ class RISE(PerturbationAttribution): | |||
| activation_fn=nn.Softmax(), | |||
| perturbation_per_eval=32): | |||
| super(RISE, self).__init__(network, activation_fn) | |||
| check_value_type('perturbation_per-eval', perturbation_per_eval, int) | |||
| if perturbation_per_eval <= 0: | |||
| raise ValueError('perturbation_per_eval should be postive integer.') | |||
| self._perturbation_per_eval = perturbation_per_eval | |||
| self._num_masks = 6000 # number of masks to be sampled | |||
| @@ -156,12 +159,11 @@ class RISE(PerturbationAttribution): | |||
| targets = self._unify_targets(inputs, targets) | |||
| attr_classes = [] | |||
| for idx, target in enumerate(targets): | |||
| dtype = inputs.dtype | |||
| attr_np_idx = attr_np[idx] | |||
| attr_idx = attr_np_idx[target] | |||
| attr_classes.append(attr_idx) | |||
| return op.Tensor(attr_classes, dtype=dtype) | |||
| return op.Tensor(attr_classes, dtype=inputs.dtype) | |||
| @staticmethod | |||
| def _verify_data(inputs, targets): | |||
| @@ -183,7 +185,7 @@ class RISE(PerturbationAttribution): | |||
| def _unify_targets(inputs, targets): | |||
| """To unify targets to be 2D numpy.ndarray.""" | |||
| if isinstance(targets, int): | |||
| return np.array([[targets] for i in inputs]).astype(np.int) | |||
| return np.array([[targets] for _ in inputs]).astype(np.int) | |||
| if isinstance(targets, Tensor): | |||
| if not targets.shape: | |||
| return np.array([[targets.asnumpy()] for _ in inputs]).astype(np.int) | |||
| @@ -16,8 +16,10 @@ | |||
| from typing import Callable | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore.train._utils import check_value_type | |||
| from mindspore.nn import Cell | |||
| class Attribution: | |||
| """ | |||
| @@ -26,15 +28,20 @@ class Attribution: | |||
| The explainers which explanation through attributing the relevance scores should inherit this class. | |||
| Args: | |||
| network (Cell): The black-box model to explain. | |||
| network (nn.Cell): The black-box model to explanation. | |||
| """ | |||
| def __init__(self, network): | |||
| check_value_type("network", network, Cell) | |||
| check_value_type("network", network, nn.Cell) | |||
| self._model = network | |||
| self._model.set_train(False) | |||
| self._model.set_grad(False) | |||
| @staticmethod | |||
| def _verify_model(model): | |||
| """Verify the input `network` for __init__ function.""" | |||
| if not isinstance(model, nn.Cell): | |||
| raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.") | |||
| __call__: Callable | |||
| """ | |||
| @@ -51,4 +58,17 @@ class Attribution: | |||
| @property | |||
| def model(self): | |||
| """Return the model.""" | |||
| return self._model | |||
| @staticmethod | |||
| def _verify_data(inputs, targets): | |||
| """Verify the validity of the parsed inputs.""" | |||
| check_value_type('inputs', inputs, ms.Tensor) | |||
| if len(inputs.shape) != 4: | |||
| raise ValueError('Argument inputs must be 4D Tensor') | |||
| check_value_type('targets', targets, (ms.Tensor, int)) | |||
| if isinstance(targets, ms.Tensor): | |||
| if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != len(inputs)): | |||
| raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, ' | |||
| 'it should have the same length as inputs.') | |||