From: @yuhanshi Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -22,15 +22,14 @@ from ..._utils import calc_correlation | |||
| class ClassSensitivity(LabelAgnosticMetric): | |||
| r""" | |||
| """ | |||
| Class sensitivity metric used to evaluate attribution-based explanations. | |||
| Reasonable atrribution-based explainers are expected to generate distinct saliency maps for different labels, | |||
| especially for labels of highest confidence and low confidence. Class sensitivity evaluates the explainer through | |||
| especially for labels of highest confidence and low confidence. ClassSensitivity evaluates the explainer through | |||
| computing the correlation between saliency maps of highest-confidence and lowest-confidence labels. Explainer with | |||
| better class sensitivity will receive lower correlation score. To make the evaluation results intuitive, the | |||
| returned score will take negative on correlation and normalize. | |||
| """ | |||
| def evaluate(self, explainer, inputs): | |||
| @@ -46,12 +45,18 @@ class ClassSensitivity(LabelAgnosticMetric): | |||
| Examples: | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.benchmark import ClassSensitivity | |||
| >>> from mindspore.explainer.explanation import Gradient | |||
| >>> model = resnet(10) | |||
| >>> gradient = Gradient(model) | |||
| >>> x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. | |||
| >>> network = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(network, param_dict) | |||
| >>> # prepare your explainer to be evaluated, e.g., Gradient. | |||
| >>> gradient = Gradient(network) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> class_sensitivity = ClassSensitivity() | |||
| >>> res = class_sensitivity.evaluate(gradient, x) | |||
| >>> res = class_sensitivity.evaluate(gradient, input_x) | |||
| """ | |||
| self._check_evaluate_param(explainer, inputs) | |||
| @@ -32,6 +32,7 @@ class Robustness(LabelSensitiveMetric): | |||
| num_labels (int): Number of classes in the dataset. | |||
| Examples: | |||
| >>> # Initialize a Robustness benchmarker passing num_labels of the dataset. | |||
| >>> from mindspore.explainer.benchmark import Robustness | |||
| >>> num_labels = 100 | |||
| >>> robustness = Robustness(num_labels) | |||
| @@ -41,7 +42,7 @@ class Robustness(LabelSensitiveMetric): | |||
| super().__init__(num_labels) | |||
| self._perturb = RandomPerturb() | |||
| self._num_perturbations = 100 # number of perturbations used in evaluation | |||
| self._num_perturbations = 10 # number of perturbations used in evaluation | |||
| self._threshold = 0.1 # threshold to generate perturbation | |||
| self._activation_fn = activation_fn | |||
| @@ -68,12 +69,17 @@ class Robustness(LabelSensitiveMetric): | |||
| ValueError: If batch_size is larger than 1. | |||
| Examples: | |||
| >>> # init an explainer, the network should contain the output activation function. | |||
| >>> from mindspore.explainer.explanation import Gradient | |||
| >>> from mindspore.explainer.benchmark import Robustness | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. | |||
| >>> network = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(network, param_dict) | |||
| >>> # prepare your explainer to be evaluated, e.g., Gradient. | |||
| >>> gradient = Gradient(network) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> target_label = 5 | |||
| >>> target_label = ms.Tensor([0], ms.int32) | |||
| >>> robustness = Robustness(num_labels=10) | |||
| >>> res = robustness.evaluate(gradient, input_x, target_label) | |||
| """ | |||
| @@ -84,39 +90,48 @@ class Robustness(LabelSensitiveMetric): | |||
| inputs_np = inputs.asnumpy() | |||
| if isinstance(targets, int): | |||
| targets = ms.Tensor(targets, ms.int32) | |||
| targets = ms.Tensor([targets], ms.int32) | |||
| if saliency is None: | |||
| saliency = explainer(inputs, targets) | |||
| saliency_np = saliency.asnumpy() | |||
| norm = np.sqrt(np.sum(np.square(saliency_np), axis=tuple(range(1, len(saliency_np.shape))))) | |||
| if norm == 0: | |||
| if (norm == 0).any(): | |||
| log.warning('Get saliency norm equals 0, robustness return NaN for zero-norm saliency currently.') | |||
| return np.array([np.nan]) | |||
| perturbations = [] | |||
| for sample in inputs_np: | |||
| sample = np.expand_dims(sample, axis=0) | |||
| perturbations_per_input = [] | |||
| for _ in range(self._num_perturbations): | |||
| perturbation = self._perturb(sample) | |||
| perturbations_per_input.append(perturbation) | |||
| perturbations_per_input = np.vstack(perturbations_per_input) | |||
| perturbations.append(perturbations_per_input) | |||
| perturbations = np.stack(perturbations, axis=0) | |||
| perturbations = np.reshape(perturbations, (-1,) + inputs_np.shape[1:]) | |||
| perturbations = ms.Tensor(perturbations, ms.float32) | |||
| repeated_targets = np.repeat(targets.asnumpy(), repeats=self._num_perturbations, axis=0) | |||
| repeated_targets = ms.Tensor(repeated_targets, ms.int32) | |||
| saliency_of_perturbations = explainer(perturbations, repeated_targets) | |||
| perturbations_saliency = saliency_of_perturbations.asnumpy() | |||
| repeated_saliency = np.repeat(saliency_np, repeats=self._num_perturbations, axis=0) | |||
| sensitivities = np.sum((repeated_saliency - perturbations_saliency) ** 2, | |||
| axis=tuple(range(1, len(repeated_saliency.shape)))) | |||
| max_sensitivity = np.max(sensitivities.reshape((norm.shape[0], -1)), axis=1) / norm | |||
| norm[norm == 0] = np.nan | |||
| model = nn.SequentialCell([explainer.model, self._activation_fn]) | |||
| original_outputs = model(inputs).asnumpy() | |||
| sensitivities = [] | |||
| for _ in range(self._num_perturbations): | |||
| perturbations = [] | |||
| for j, sample in enumerate(inputs_np): | |||
| perturbation_on_single_sample = self._perturb_with_threshold(model, | |||
| np.expand_dims(sample, axis=0), | |||
| original_outputs[j]) | |||
| perturbations.append(perturbation_on_single_sample) | |||
| perturbations = np.vstack(perturbations) | |||
| perturbations_saliency = explainer(ms.Tensor(perturbations, ms.float32), targets).asnumpy() | |||
| sensitivity = np.sum((perturbations_saliency - saliency_np) ** 2, | |||
| axis=tuple(range(1, len(saliency_np.shape)))) | |||
| sensitivities.append(sensitivity) | |||
| sensitivities = np.stack(sensitivities, axis=-1) | |||
| max_sensitivity = np.max(sensitivities, axis=1) / norm | |||
| robustness_res = 1 / np.exp(max_sensitivity) | |||
| return robustness_res | |||
| def _perturb_with_threshold(self, model: nn.Cell, sample: np.ndarray, original_output: np.ndarray) -> np.ndarray: | |||
| """ | |||
| Generate the perturbation until the L2-distance between original_output and perturbation_output is lower than | |||
| the given self._threshold or until the attempt reaches the max_attempt_time. | |||
| """ | |||
| # the maximum time attempt to get a perturbation with perturb_error low than self._threshold | |||
| max_attempt_time = 3 | |||
| perturbation = None | |||
| for _ in range(max_attempt_time): | |||
| perturbation = self._perturb(sample) | |||
| perturbation_output = self._activation_fn(model(ms.Tensor(sample, ms.float32))).asnumpy() | |||
| perturb_error = np.linalg.norm(original_output - perturbation_output) | |||
| if perturb_error <= self._threshold: | |||
| return perturbation | |||
| return perturbation | |||
| @@ -14,14 +14,11 @@ | |||
| # ============================================================================ | |||
| """Occlusion explainer.""" | |||
| import math | |||
| import numpy as np | |||
| from numpy.lib.stride_tricks import as_strided | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from .ablation import Ablation | |||
| from .perturbation import PerturbationAttribution | |||
| from .replacement import Constant | |||
| @@ -62,8 +59,8 @@ class Occlusion(PerturbationAttribution): | |||
| network (Cell): Specify the black-box model to be explained. | |||
| Inputs: | |||
| inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer. | |||
| - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| - **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. | |||
| If it is a 1D tensor, its length should be the same as `inputs`. | |||
| Outputs: | |||
| @@ -72,13 +69,15 @@ class Occlusion(PerturbationAttribution): | |||
| Example: | |||
| >>> from mindspore.explainer.explanation import Occlusion | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. | |||
| >>> network = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(network, param_dict) | |||
| >>> # initialize Occlusion explainer and pass the pretrained model | |||
| >>> occlusion = Occlusion(network) | |||
| >>> x = Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> label = 1 | |||
| >>> saliency = occlusion(x, label) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> label = ms.Tensor([1], ms.int32) | |||
| >>> saliency = occlusion(input_x, label) | |||
| """ | |||
| def __init__(self, network, activation_fn=nn.Softmax()): | |||
| @@ -88,62 +87,63 @@ class Occlusion(PerturbationAttribution): | |||
| self._aggregation_fn = abs_max | |||
| self._get_replacement = Constant(base_value=0.0) | |||
| self._num_sample_per_dim = 32 # specify the number of perturbations each dimension. | |||
| self._num_per_eval = 32 # number of perturbations each evaluation step. | |||
| self._num_per_eval = 2 # number of perturbations generate for each sample per evaluation step. | |||
| def __call__(self, inputs, targets): | |||
| """Call function for 'Occlusion'.""" | |||
| self._verify_data(inputs, targets) | |||
| inputs = inputs.asnumpy() | |||
| targets = targets.asnumpy() if isinstance(targets, Tensor) else np.array([targets] * inputs.shape[0], np.int) | |||
| inputs_np = inputs.asnumpy() | |||
| targets_np = targets.asnumpy() if isinstance(targets, ms.Tensor) else np.array([targets], np.int) | |||
| # If spatial size of input data is smaller than self._num_sample_per_dim, window_size and strides will set to | |||
| # `(C, 3, 3)` and `(C, 1, 1)` separately. | |||
| window_size = tuple( | |||
| [inputs.shape[1]] | |||
| + [x % self._num_sample_per_dim if x > self._num_sample_per_dim else 3 for x in inputs.shape[2:]]) | |||
| strides = tuple( | |||
| [inputs.shape[1]] | |||
| + [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 1 for x in inputs.shape[2:]]) | |||
| batch_size = inputs_np.shape[0] | |||
| window_size, strides = self._get_window_size_and_strides(inputs_np) | |||
| model = nn.SequentialCell([self._model, self._activation_fn]) | |||
| original_outputs = model(Tensor(inputs, ms.float32)).asnumpy()[np.arange(len(targets)), targets] | |||
| original_outputs = model(ms.Tensor(inputs, ms.float32)).asnumpy()[np.arange(batch_size), targets_np] | |||
| total_attribution = np.zeros_like(inputs) | |||
| weights = np.ones_like(inputs) | |||
| masks = Occlusion._generate_masks(inputs, window_size, strides) | |||
| total_attribution = np.zeros_like(inputs_np) | |||
| weights = np.ones_like(inputs_np) | |||
| masks = Occlusion._generate_masks(inputs_np, window_size, strides) | |||
| num_perturbations = masks.shape[1] | |||
| original_outputs_repeat = np.repeat(original_outputs, repeats=num_perturbations, axis=0) | |||
| reference = self._get_replacement(inputs) | |||
| occluded_inputs = self._ablation(inputs, reference, masks) | |||
| targets_repeat = np.repeat(targets, repeats=num_perturbations, axis=0) | |||
| occluded_inputs = occluded_inputs.reshape((-1, *inputs.shape[1:])) | |||
| if occluded_inputs.shape[0] > self._num_per_eval: | |||
| cal_time = math.ceil(occluded_inputs.shape[0] / self._num_per_eval) | |||
| occluded_outputs = [] | |||
| for i in range(cal_time): | |||
| occluded_input = occluded_inputs[i*self._num_per_eval | |||
| :min((i+1) * self._num_per_eval, occluded_inputs.shape[0])] | |||
| target = targets_repeat[i*self._num_per_eval | |||
| :min((i+1) * self._num_per_eval, occluded_inputs.shape[0])] | |||
| occluded_output = model(Tensor(occluded_input)).asnumpy()[np.arange(target.shape[0]), target] | |||
| occluded_outputs.append(occluded_output) | |||
| occluded_outputs = np.concatenate(occluded_outputs) | |||
| else: | |||
| occluded_outputs = model(Tensor(occluded_inputs)).asnumpy()[np.arange(len(targets_repeat)), targets_repeat] | |||
| outputs_diff = original_outputs_repeat - occluded_outputs | |||
| outputs_diff = outputs_diff.reshape(inputs.shape[0], -1) | |||
| total_attribution += ( | |||
| outputs_diff.reshape(outputs_diff.shape + (1,) * (len(masks.shape) - 2)) * masks).sum(axis=1).clip(1e-6) | |||
| weights += masks.sum(axis=1) | |||
| attribution = self._aggregation_fn(Tensor(total_attribution / weights)) | |||
| reference = self._get_replacement(inputs_np) | |||
| count = 0 | |||
| while count < num_perturbations: | |||
| ith_masks = masks[:, count:min(count+self._num_per_eval, num_perturbations)] | |||
| actual_num_eval = ith_masks.shape[1] | |||
| num_samples = batch_size * actual_num_eval | |||
| occluded_inputs = self._ablation(inputs_np, reference, ith_masks) | |||
| occluded_inputs = occluded_inputs.reshape((-1, *inputs_np.shape[1:])) | |||
| targets_repeat = np.repeat(targets_np, repeats=actual_num_eval, axis=0) | |||
| occluded_outputs = model( | |||
| ms.Tensor(occluded_inputs, ms.float32)).asnumpy()[np.arange(num_samples), targets_repeat] | |||
| original_outputs_repeat = np.repeat(original_outputs, repeats=actual_num_eval, axis=0) | |||
| outputs_diff = original_outputs_repeat - occluded_outputs | |||
| total_attribution += ( | |||
| outputs_diff.reshape(ith_masks.shape[:2] + (1,) * (len(masks.shape) - 2)) * ith_masks).sum(axis=1) | |||
| weights += ith_masks.sum(axis=1) | |||
| count += actual_num_eval | |||
| attribution = self._aggregation_fn(ms.Tensor(total_attribution / weights, ms.float32)) | |||
| return attribution | |||
| def _get_window_size_and_strides(self, inputs): | |||
| """ | |||
| Return window_size and strides. | |||
| # If spatial size of input data is smaller than self._num_sample_per_dim, window_size and strides will set to | |||
| # `(C, 3, 3)` and `(C, 1, 1)` separately. Otherwise, the window_size and strides will generated adaptively to | |||
| match self._num_sample_per_dim. | |||
| """ | |||
| window_size = tuple( | |||
| [inputs.shape[1]] | |||
| + [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 3 for x in inputs.shape[2:]]) | |||
| strides = tuple( | |||
| [inputs.shape[1]] | |||
| + [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 1 for x in inputs.shape[2:]]) | |||
| return window_size, strides | |||
| @staticmethod | |||
| def _generate_masks(inputs, window_size, strides): | |||
| """Generate masks to perturb contiguous regions.""" | |||
| @@ -72,3 +72,6 @@ class Attribution: | |||
| if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != len(inputs)): | |||
| raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, ' | |||
| 'it should have the same length as inputs.') | |||
| elif inputs.shape[0] != 1: | |||
| raise ValueError('If targets have type of int, batch_size of inputs should equals 1. Receive batch_size {}' | |||
| .format(inputs.shape[0])) | |||