From 2e4b686408442ba92eeec406d87b12e80204cb74 Mon Sep 17 00:00:00 2001 From: lixiaohui Date: Thu, 3 Dec 2020 15:17:54 +0800 Subject: [PATCH] refactor explain core code --- .../explainer/_image_classification_runner.py | 17 +++- mindspore/explainer/_utils.py | 2 +- .../_attribution/class_sensitivity.py | 4 +- .../benchmark/_attribution/faithfulness.py | 22 +++-- .../benchmark/_attribution/metric.py | 4 +- .../benchmark/_attribution/robustness.py | 32 ++++--- .../_attribution/_backprop/backprop_utils.py | 2 +- .../_attribution/_backprop/gradcam.py | 36 ++++---- .../_attribution/_backprop/gradient.py | 33 +++---- .../_attribution/_backprop/modified_relu.py | 20 +++- .../_attribution/_perturbation/occlusion.py | 31 ++++--- .../_perturbation/perturbation.py | 8 +- .../_attribution/_perturbation/rise.py | 92 +++++++++---------- .../explanation/_attribution/attribution.py | 16 ++-- 14 files changed, 181 insertions(+), 138 deletions(-) diff --git a/mindspore/explainer/_image_classification_runner.py b/mindspore/explainer/_image_classification_runner.py index 9c42a546fd..8895bc1f3e 100644 --- a/mindspore/explainer/_image_classification_runner.py +++ b/mindspore/explainer/_image_classification_runner.py @@ -64,7 +64,10 @@ class ImageClassificationRunner: should provides [images], [images, labels] or [images, labels, bboxes] as columns. The label list must share the exact same length and order of the network outputs. network (Cell): The network(with logit outputs) to be explained. - activation_fn (Cell): The activation function for converting network's output to probabilities. + activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For + single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, + `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as + when combining this function with network, the final output is the probability of the input. Examples: >>> from mindspore.explainer import ImageClassificationRunner @@ -302,6 +305,8 @@ class ImageClassificationRunner: ds.config.set_seed(self._DATASET_SEED) for idx, next_element in enumerate(self._dataset): now = time() + self._spaced_print("Start running {}-th explanation data for {}......".format( + idx, exp.__class__.__name__), end='') self._run_exp_step(next_element, exp, sample_id_labels, summary) self._spaced_print("Finish writing {}-th explanation data for {}. Time elapsed: " "{:.3f} s".format(idx, exp.__class__.__name__, time() - now), end='') @@ -320,12 +325,17 @@ class ImageClassificationRunner: ds.config.set_seed(self._DATASET_SEED) for idx, next_element in enumerate(self._dataset): now = time() + self._spaced_print("Start running {}-th explanation data for {}......".format( + idx, exp.__class__.__name__), end='') saliency_dict_lst = self._run_exp_step(next_element, exp, sample_id_labels, summary) self._spaced_print( "Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format( idx, exp.__class__.__name__, time() - now), end='') for bench in self._benchmarkers: now = time() + self._spaced_print( + "Start running {}-th batch {} data for {}......".format( + idx, bench.__class__.__name__, exp.__class__.__name__), end='') self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst) self._spaced_print( "Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format( @@ -496,7 +506,7 @@ class ImageClassificationRunner: if explainer.__class__ in explainer_classes: raise ValueError(f"Repeated {explainer.__class__.__name__} explainer! " "Please make sure all explainers' class is distinct.") - if explainer.model != self._network: + if explainer.network is not self._network: raise ValueError(f"The network of {explainer.__class__.__name__} explainer is different " "instance from network of runner. Please make sure they are the same " "instance.") @@ -717,4 +727,5 @@ class ImageClassificationRunner: @classmethod def _spaced_print(cls, message, *args, **kwargs): """Spaced message printing.""" - print(cls._SPACER.format(message), *args, **kwargs) + # workaround to print logs starting new line in case line width mismatch. + print(cls._SPACER.format(message)) diff --git a/mindspore/explainer/_utils.py b/mindspore/explainer/_utils.py index 91bd80080e..212a5bc48a 100644 --- a/mindspore/explainer/_utils.py +++ b/mindspore/explainer/_utils.py @@ -226,7 +226,7 @@ def calc_correlation(x: Union[ms.Tensor, np.ndarray], if np.all(x == 0) or np.all(y == 0): return np.float(0) - faithfulness = -np.corrcoef(x, y)[0, 1] + faithfulness = np.corrcoef(x, y)[0, 1] return faithfulness diff --git a/mindspore/explainer/benchmark/_attribution/class_sensitivity.py b/mindspore/explainer/benchmark/_attribution/class_sensitivity.py index 80f013ebe7..bf4bf959e7 100644 --- a/mindspore/explainer/benchmark/_attribution/class_sensitivity.py +++ b/mindspore/explainer/benchmark/_attribution/class_sensitivity.py @@ -55,12 +55,12 @@ class ClassSensitivity(LabelAgnosticMetric): >>> # prepare your explainer to be evaluated, e.g., Gradient. >>> gradient = Gradient(network) >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) - >>> class_sensitivity = ClassSensitivity() + >>> # class_sensitivity is a ClassSensitivity instance >>> res = class_sensitivity.evaluate(gradient, input_x) """ self._check_evaluate_param(explainer, inputs) - outputs = explainer.model(inputs) + outputs = explainer.network(inputs) max_confidence_label = ops.argmax(outputs) min_confidence_label = ops.argmin(outputs) diff --git a/mindspore/explainer/benchmark/_attribution/faithfulness.py b/mindspore/explainer/benchmark/_attribution/faithfulness.py index bcaf02a7d3..cfbfeff077 100644 --- a/mindspore/explainer/benchmark/_attribution/faithfulness.py +++ b/mindspore/explainer/benchmark/_attribution/faithfulness.py @@ -18,9 +18,9 @@ from typing import Callable, Optional, Union import numpy as np -from mindspore import log import mindspore as ms -import mindspore.nn as nn +from mindspore import log, nn +from mindspore.train._utils import check_value_type from .metric import LabelSensitiveMetric from ..._utils import calc_auc, format_tensor_to_ndarray from ...explanation._attribution import Attribution as _Attribution @@ -358,22 +358,26 @@ class Faithfulness(LabelSensitiveMetric): Args: num_labels (int): Number of labels. + activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For + single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, + `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as + when combining this function with network, the final output is the probability of the input. metric (str, optional): The specifi metric to quantify faithfulness. Options: "DeletionAUC", "InsertionAUC", "NaiveFaithfulness". Default: 'NaiveFaithfulness'. - activation_fn (Cell, optional): The activation function that transforms the network output to a probability. - Default: nn.Softmax(). Examples: + >>> from mindspore import nn >>> from mindspore.explainer.benchmark import Faithfulness >>> # init a `Faithfulness` object >>> num_labels = 10 >>> metric = "InsertionAUC" - >>> faithfulness = Faithfulness(num_labels, metric) + >>> activation_fn = nn.Softmax() + >>> faithfulness = Faithfulness(num_labels, activation_fn, metric) """ _methods = [NaiveFaithfulness, DeletionAUC, InsertionAUC] - def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness", activation_fn=nn.Softmax()): + def __init__(self, num_labels, activation_fn, metric="NaiveFaithfulness"): super(Faithfulness, self).__init__(num_labels) perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument @@ -382,7 +386,9 @@ class Faithfulness(LabelSensitiveMetric): num_perturb_steps = 100 # separate the perturbation progress in to 100 steps. base_value = 0.0 # the pixel value set for the perturbed pixels + check_value_type("activation_fn", activation_fn, nn.Cell) self._activation_fn = activation_fn + self._verify_metrics(metric) for method in self._methods: if metric == method.__name__: @@ -437,8 +443,8 @@ class Faithfulness(LabelSensitiveMetric): inputs = format_tensor_to_ndarray(inputs) saliency = format_tensor_to_ndarray(saliency) - model = nn.SequentialCell([explainer.model, self._activation_fn]) - faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=model, + full_network = nn.SequentialCell([explainer.network, self._activation_fn]) + faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=full_network, targets=targets, saliency=saliency) return (1 + faithfulness) / 2 diff --git a/mindspore/explainer/benchmark/_attribution/metric.py b/mindspore/explainer/benchmark/_attribution/metric.py index 7fc7cc6b12..efa2b42335 100644 --- a/mindspore/explainer/benchmark/_attribution/metric.py +++ b/mindspore/explainer/benchmark/_attribution/metric.py @@ -204,9 +204,9 @@ class LabelSensitiveMetric(AttributionMetric): check_value_type('explainer', explainer, Attribution) self._record_explainer(explainer) verify_argument(inputs, 'inputs') - output = explainer.model(inputs) + output = explainer.network(inputs) check_value_type("output of explainer model", output, Tensor) - output_dim = explainer.model(inputs).shape[1] + output_dim = explainer.network(inputs).shape[1] if output_dim != self._num_labels: raise ValueError("The output dimension of of black-box model in explainer does not match the dimension " "of num_labels set in the __init__, please check explainer and num_labels again.") diff --git a/mindspore/explainer/benchmark/_attribution/robustness.py b/mindspore/explainer/benchmark/_attribution/robustness.py index 927cc6523b..f2f2fe39c6 100644 --- a/mindspore/explainer/benchmark/_attribution/robustness.py +++ b/mindspore/explainer/benchmark/_attribution/robustness.py @@ -18,6 +18,7 @@ import numpy as np import mindspore as ms import mindspore.nn as nn +from mindspore.train._utils import check_value_type from mindspore import log from .metric import LabelSensitiveMetric from ...explanation._attribution._perturbation.replacement import RandomPerturb @@ -30,17 +31,24 @@ class Robustness(LabelSensitiveMetric): Args: num_labels (int): Number of classes in the dataset. + activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For + single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, + `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as + when combining this function with network, the final output is the probability of the input. + Examples: - >>> # Initialize a Robustness benchmarker passing num_labels of the dataset. + >>> from mindspore import nn >>> from mindspore.explainer.benchmark import Robustness - >>> num_labels = 100 - >>> robustness = Robustness(num_labels) + >>> # Initialize a Robustness benchmarker passing num_labels of the dataset. + >>> num_labels = 10 + >>> activation_fn = nn.Softmax() + >>> robustness = Robustness(num_labels, activation_fn) """ - def __init__(self, num_labels, activation_fn=nn.Softmax()): + def __init__(self, num_labels, activation_fn): super().__init__(num_labels) - + check_value_type("activation_fn", activation_fn, nn.Cell) self._perturb = RandomPerturb() self._num_perturbations = 10 # number of perturbations used in evaluation self._threshold = 0.1 # threshold to generate perturbation @@ -69,6 +77,8 @@ class Robustness(LabelSensitiveMetric): ValueError: If batch_size is larger than 1. Examples: + >>> import numpy as np + >>> import mindspore as ms >>> from mindspore.explainer.explanation import Gradient >>> from mindspore.explainer.benchmark import Robustness >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net @@ -80,7 +90,7 @@ class Robustness(LabelSensitiveMetric): >>> gradient = Gradient(network) >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) >>> target_label = ms.Tensor([0], ms.int32) - >>> robustness = Robustness(num_labels=10) + >>> # robustness is a Robustness instance >>> res = robustness.evaluate(gradient, input_x, target_label) """ @@ -100,13 +110,13 @@ class Robustness(LabelSensitiveMetric): log.warning('Get saliency norm equals 0, robustness return NaN for zero-norm saliency currently.') norm[norm == 0] = np.nan - model = nn.SequentialCell([explainer.model, self._activation_fn]) - original_outputs = model(inputs).asnumpy() + full_network = nn.SequentialCell([explainer.network, self._activation_fn]) + original_outputs = full_network(inputs).asnumpy() sensitivities = [] for _ in range(self._num_perturbations): perturbations = [] for j, sample in enumerate(inputs_np): - perturbation_on_single_sample = self._perturb_with_threshold(model, + perturbation_on_single_sample = self._perturb_with_threshold(full_network, np.expand_dims(sample, axis=0), original_outputs[j]) perturbations.append(perturbation_on_single_sample) @@ -120,7 +130,7 @@ class Robustness(LabelSensitiveMetric): robustness_res = 1 / np.exp(max_sensitivity) return robustness_res - def _perturb_with_threshold(self, model: nn.Cell, sample: np.ndarray, original_output: np.ndarray) -> np.ndarray: + def _perturb_with_threshold(self, network: nn.Cell, sample: np.ndarray, original_output: np.ndarray) -> np.ndarray: """ Generate the perturbation until the L2-distance between original_output and perturbation_output is lower than the given self._threshold or until the attempt reaches the max_attempt_time. @@ -130,7 +140,7 @@ class Robustness(LabelSensitiveMetric): perturbation = None for _ in range(max_attempt_time): perturbation = self._perturb(sample) - perturbation_output = self._activation_fn(model(ms.Tensor(sample, ms.float32))).asnumpy() + perturbation_output = self._activation_fn(network(ms.Tensor(sample, ms.float32))).asnumpy() perturb_error = np.linalg.norm(original_output - perturbation_output) if perturb_error <= self._threshold: return perturbation diff --git a/mindspore/explainer/explanation/_attribution/_backprop/backprop_utils.py b/mindspore/explainer/explanation/_attribution/_backprop/backprop_utils.py index df822e9a32..78f1dc5744 100644 --- a/mindspore/explainer/explanation/_attribution/_backprop/backprop_utils.py +++ b/mindspore/explainer/explanation/_attribution/_backprop/backprop_utils.py @@ -39,7 +39,7 @@ def compute_gradients(model, inputs, targets=None, weights=None): raise ValueError('Must provide one of targets or weights') if weights is None: targets = unify_targets(targets) - output = model(*inputs).asnumpy() + output = model(*inputs) num_categories = output.shape[-1] weights = generate_one_hot(targets, num_categories) diff --git a/mindspore/explainer/explanation/_attribution/_backprop/gradcam.py b/mindspore/explainer/explanation/_attribution/_backprop/gradcam.py index fb18189c6c..f84f17b395 100644 --- a/mindspore/explainer/explanation/_attribution/_backprop/gradcam.py +++ b/mindspore/explainer/explanation/_attribution/_backprop/gradcam.py @@ -64,16 +64,30 @@ class GradCAM(IntermediateLayerAttribution): layer for better practice. If it is '', the explantion will be generated at the input layer. Default: ''. + Inputs: + - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. + - **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. + If it is a 1D tensor, its length should be the same as `inputs`. + + Outputs: + Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. + Examples: + >>> import numpy as np + >>> import mindspore as ms >>> from mindspore.explainer.explanation import GradCAM >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net - >>> network = resnet50(10) # please refer to model_zoo + >>> # load a trained network + >>> net = resnet50(10) >>> param_dict = load_checkpoint("resnet50.ckpt") >>> load_param_into_net(net, param_dict) >>> # specify a layer name to generate explanation, usually the layer can be set as the last conv layer. >>> layer_name = 'layer4' >>> # init GradCAM with a trained network and specify the layer to obtain attribution >>> gradcam = GradCAM(net, layer=layer_name) + >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) + >>> label = 5 + >>> saliency = gradcam(inputs, label) """ def __init__(self, network, layer=""): @@ -100,25 +114,7 @@ class GradCAM(IntermediateLayerAttribution): self._intermediate_grad = grad_input def __call__(self, inputs, targets): - """ - Call function for `GradCAM`. - - Args: - inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. - targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer. - If it is a 1D tensor, its length should be the same as `inputs`. - - Returns: - Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. - - Examples: - >>> import mindspore as ms - >>> import numpy as np - >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) - >>> label = 5 - >>> # gradcam is a GradCAM object, parse data and the target label to be explained and get the attribution - >>> saliency = gradcam(inputs, label) - """ + """Call function for `GradCAM`.""" self._verify_data(inputs, targets) self._hook_cell() diff --git a/mindspore/explainer/explanation/_attribution/_backprop/gradient.py b/mindspore/explainer/explanation/_attribution/_backprop/gradient.py index 302c22348b..9ee70ba4cc 100644 --- a/mindspore/explainer/explanation/_attribution/_backprop/gradient.py +++ b/mindspore/explainer/explanation/_attribution/_backprop/gradient.py @@ -59,7 +59,17 @@ class Gradient(Attribution): Args: network (Cell): The black-box model to be explained. + Inputs: + - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. + - **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. + If it is a 1D tensor, its length should be the same as `inputs`. + + Outputs: + Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. + Examples: + >>> import numpy as np + >>> import mindspore as ms >>> from mindspore.explainer.explanation import Gradient >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net >>> # init Gradient with a trained network @@ -67,6 +77,9 @@ class Gradient(Attribution): >>> param_dict = load_checkpoint("resnet50.ckpt") >>> load_param_into_net(net, param_dict) >>> gradient = Gradient(net) + >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) + >>> label = 5 + >>> saliency = gradient(inputs, label) """ def __init__(self, network): @@ -79,25 +92,7 @@ class Gradient(Attribution): self._aggregation_fn = abs_max def __call__(self, inputs, targets): - """ - Call function for `Gradient`. - - Args: - inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. - targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer. - If it is a 1D tensor, its length should be the same as `inputs`. - - Returns: - Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. - - Examples: - >>> import mindspore as ms - >>> import numpy as np - >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) - >>> label = 5 - >>> # gradient is a Gradient object, parse data and the target label to be explained and get the attribution - >>> saliency = gradient(inputs, label) - """ + """Call function for `Gradient`.""" self._verify_data(inputs, targets) inputs = unify_inputs(inputs) targets = unify_targets(targets) diff --git a/mindspore/explainer/explanation/_attribution/_backprop/modified_relu.py b/mindspore/explainer/explanation/_attribution/_backprop/modified_relu.py index 04df261e85..f84291b53a 100644 --- a/mindspore/explainer/explanation/_attribution/_backprop/modified_relu.py +++ b/mindspore/explainer/explanation/_attribution/_backprop/modified_relu.py @@ -96,15 +96,23 @@ class Deconvolution(ModifiedReLU): Args: network (Cell): The black-box model to be explained. + Inputs: + - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. + - **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. + If it is a 1D tensor, its length should be the same as `inputs`. + + Outputs: + Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. + Examples: >>> import numpy as np >>> import mindspore as ms >>> from mindspore.explainer.explanation import Deconvolution >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net + >>> # init Deconvolution with a trained network. >>> net = resnet50(10) # please refer to model_zoo >>> param_dict = load_checkpoint("resnet50.ckpt") >>> load_param_into_net(net, param_dict) - >>> # init Deconvolution with a trained network. >>> deconvolution = Deconvolution(net) >>> # parse data and the target label to be explained and get the saliency map >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) @@ -134,15 +142,23 @@ class GuidedBackprop(ModifiedReLU): Args: network (Cell): The black-box model to be explained. + Inputs: + - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. + - **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. + If it is a 1D tensor, its length should be the same as `inputs`. + + Outputs: + Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. + Examples: >>> import numpy as np >>> import mindspore as ms >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net >>> from mindspore.explainer.explanation import GuidedBackprop + >>> # init GuidedBackprop with a trained network. >>> net = resnet50(10) # please refer to model_zoo >>> param_dict = load_checkpoint("resnet50.ckpt") >>> load_param_into_net(net, param_dict) - >>> # init GuidedBackprop with a trained network. >>> gbp = GuidedBackprop(net) >>> # parse data and the target label to be explained and get the saliency map >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) diff --git a/mindspore/explainer/explanation/_attribution/_perturbation/occlusion.py b/mindspore/explainer/explanation/_attribution/_perturbation/occlusion.py index 59563dd688..acd35f3559 100644 --- a/mindspore/explainer/explanation/_attribution/_perturbation/occlusion.py +++ b/mindspore/explainer/explanation/_attribution/_perturbation/occlusion.py @@ -47,7 +47,7 @@ def _generate_patches(array, window_size, stride): class Occlusion(PerturbationAttribution): - r""" + """ Occlusion uses a sliding window to replace the pixels with a reference value (e.g. constant value), and computes the output difference w.r.t the original output. The output difference caused by perturbed pixels are assigned as feature importance to those pixels. For pixels involved in multiple sliding windows, the feature importance is the @@ -56,7 +56,14 @@ class Occlusion(PerturbationAttribution): For more details, please refer to the original paper via: ``_. Args: - network (Cell): Specify the black-box model to be explained. + network (Cell): The black-box model to be explained. + activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For + single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, + `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as + when combining this function with network, the final output is the probability of the input. + perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the + perturbed samples. Within the memory capacity, usually the larger this number is, the faster the + explanation is obtained. Default: 32. Inputs: - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. @@ -67,27 +74,29 @@ class Occlusion(PerturbationAttribution): Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. Example: + >>> import numpy as np + >>> import mindspore as ms >>> from mindspore.explainer.explanation import Occlusion >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. >>> network = resnet50(10) >>> param_dict = load_checkpoint("resnet50.ckpt") >>> load_param_into_net(network, param_dict) - >>> # initialize Occlusion explainer and pass the pretrained model - >>> occlusion = Occlusion(network) + >>> # initialize Occlusion explainer with the pretrained model and activation function + >>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities + >>> occlusion = Occlusion(network, activation_fn=activation_fn) >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) >>> label = ms.Tensor([1], ms.int32) >>> saliency = occlusion(input_x, label) """ - def __init__(self, network, activation_fn=nn.Softmax()): - super().__init__(network, activation_fn) + def __init__(self, network, activation_fn, perturbation_per_eval=32): + super().__init__(network, activation_fn, perturbation_per_eval) self._ablation = Ablation(perturb_mode='Deletion') self._aggregation_fn = abs_max self._get_replacement = Constant(base_value=0.0) self._num_sample_per_dim = 32 # specify the number of perturbations each dimension. - self._num_per_eval = 2 # number of perturbations generate for each sample per evaluation step. def __call__(self, inputs, targets): """Call function for 'Occlusion'.""" @@ -99,9 +108,9 @@ class Occlusion(PerturbationAttribution): batch_size = inputs_np.shape[0] window_size, strides = self._get_window_size_and_strides(inputs_np) - model = nn.SequentialCell([self._model, self._activation_fn]) + full_network = nn.SequentialCell([self._network, self._activation_fn]) - original_outputs = model(ms.Tensor(inputs, ms.float32)).asnumpy()[np.arange(batch_size), targets_np] + original_outputs = full_network(ms.Tensor(inputs, ms.float32)).asnumpy()[np.arange(batch_size), targets_np] total_attribution = np.zeros_like(inputs_np) weights = np.ones_like(inputs_np) @@ -111,13 +120,13 @@ class Occlusion(PerturbationAttribution): count = 0 while count < num_perturbations: - ith_masks = masks[:, count:min(count+self._num_per_eval, num_perturbations)] + ith_masks = masks[:, count:min(count+self._perturbation_per_eval, num_perturbations)] actual_num_eval = ith_masks.shape[1] num_samples = batch_size * actual_num_eval occluded_inputs = self._ablation(inputs_np, reference, ith_masks) occluded_inputs = occluded_inputs.reshape((-1, *inputs_np.shape[1:])) targets_repeat = np.repeat(targets_np, repeats=actual_num_eval, axis=0) - occluded_outputs = model( + occluded_outputs = full_network( ms.Tensor(occluded_inputs, ms.float32)).asnumpy()[np.arange(num_samples), targets_repeat] original_outputs_repeat = np.repeat(original_outputs, repeats=actual_num_eval, axis=0) outputs_diff = original_outputs_repeat - occluded_outputs diff --git a/mindspore/explainer/explanation/_attribution/_perturbation/perturbation.py b/mindspore/explainer/explanation/_attribution/_perturbation/perturbation.py index 991f526eaf..c28df5e312 100644 --- a/mindspore/explainer/explanation/_attribution/_perturbation/perturbation.py +++ b/mindspore/explainer/explanation/_attribution/_perturbation/perturbation.py @@ -19,7 +19,6 @@ from mindspore.train._utils import check_value_type from mindspore.nn import Cell from ..attribution import Attribution -from ...._operators import softmax class PerturbationAttribution(Attribution): @@ -31,8 +30,13 @@ class PerturbationAttribution(Attribution): def __init__(self, network, - activation_fn=softmax(), + activation_fn, + perturbation_per_eval, ): super(PerturbationAttribution, self).__init__(network) check_value_type("activation_fn", activation_fn, Cell) self._activation_fn = activation_fn + check_value_type('perturbation_per_eval', perturbation_per_eval, int) + if perturbation_per_eval <= 0: + raise ValueError('Argument perturbation_per_eval should be a positive integer.') + self._perturbation_per_eval = perturbation_per_eval diff --git a/mindspore/explainer/explanation/_attribution/_perturbation/rise.py b/mindspore/explainer/explanation/_attribution/_perturbation/rise.py index 3e710e29a6..89583ac17a 100644 --- a/mindspore/explainer/explanation/_attribution/_perturbation/rise.py +++ b/mindspore/explainer/explanation/_attribution/_perturbation/rise.py @@ -14,11 +14,12 @@ # ============================================================================ """RISE.""" import math +import random import numpy as np +from mindspore.ops.operations import Concat from mindspore import Tensor -from mindspore import nn from mindspore.train._utils import check_value_type from .perturbation import PerturbationAttribution @@ -36,41 +37,57 @@ class RISE(PerturbationAttribution): node of interest: .. math:: - E_{RISE}(I, f)_c = \sum_{i}f_c(I\odot M_i) M_i + attribution = \sum_{i}f_c(I\odot M_i) M_i For more details, please refer to the original paper via: `RISE `_. Args: network (Cell): The black-box model to be explained. - activation_fn (Cell, optional): The activation layer that transforms logits to prediction probabilities. For + activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as when combining this function with network, the final output is the probability of the input. - Default: `nn.Softmax`. perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the - perturbed samples. Default: 32. + perturbed samples. Within the memory capacity, usually the larger this number is, the faster the + explanation is obtained. Default: 32. + + Inputs: + - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. + - **targets** (Tensor, int) - The labels of interest to be explained. When `targets` is an integer, + all of the inputs will generates attribution map w.r.t this integer. When `targets` is a tensor, it + should be of shape :math:`(N, l)` (l being the number of labels for each sample) or :math:`(N,)` :math:`()`. + + Outputs: + Tensor, a 4D tensor of shape :math:`(N, ?, H, W)`. Examples: + >>> import numpy as np + >>> import mindspore as ms >>> from mindspore.explainer.explanation import RISE >>> from mindspore.nn import Sigmoid >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net - >>> # init RISE with a trained network - >>> net = resnet50(10) # please refer to model_zoo + >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. + >>> network = resnet50(10) >>> param_dict = load_checkpoint("resnet50.ckpt") - >>> load_param_into_net(net, param_dict) - >>> # init RISE with specified activation function - >>> rise = RISE(net, activation_fn=Sigmoid()) - """ + >>> load_param_into_net(network, param_dict) + >>> # initialize RISE explainer with the pretrained model and activation function + >>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities + >>> rise = RISE(network, activation_fn=activation_fn) + >>> # given an instance of RISE, saliency map can be generate + >>> inputs = ms.Tensor(np.random.rand(2, 3, 224, 224), ms.float32) + >>> # when `targets` is an integer + >>> targets = 5 + >>> saliency = rise(inputs, targets) + >>> # `targets` can also be a 2D tensor + >>> targets = ms.Tensor([[5], [1]]) + >>> saliency = rise(inputs, targets) +""" def __init__(self, network, - activation_fn=nn.Softmax(), + activation_fn, perturbation_per_eval=32): - super(RISE, self).__init__(network, activation_fn) - check_value_type('perturbation_per-eval', perturbation_per_eval, int) - if perturbation_per_eval <= 0: - raise ValueError('perturbation_per_eval should be postive integer.') - self._perturbation_per_eval = perturbation_per_eval + super(RISE, self).__init__(network, activation_fn, perturbation_per_eval) self._num_masks = 6000 # number of masks to be sampled self._mask_probability = 0.2 # ratio of inputs to be masked @@ -93,47 +110,26 @@ class RISE(PerturbationAttribution): self._resize_mode) # Pack operator not available for GPU, thus transfer to numpy first - upsample_np = upsample.asnumpy() masks_lst = [] - for sample in upsample_np: - shift_x = np.random.randint(0, mask_size[0] + 1) - shift_y = np.random.randint(0, mask_size[1] + 1) + for sample in upsample: + shift_x = random.randint(0, mask_size[0]) + shift_y = random.randint(0, mask_size[1]) masks_lst.append(sample[:, shift_x: shift_x + height, shift_y:shift_y + width]) - masks = op.Tensor(np.array(masks_lst), data.dtype) + + concat = Concat() + masks = concat(tuple(masks_lst)) + masks = op.reshape(masks, (batch_size, -1, height, width)) return masks def __call__(self, inputs, targets): - """ - Generates attribution maps for inputs. - - Args: - inputs (Tensor): Input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. - targets (int, Tensor): The labels of interest to be explained. When `targets` is an integer, - all of the inputs will generates attribution map w.r.t this integer. When `targets` is a tensor, it - should be of shape :math:`(N, ?)` or :math:`(N,)` :math:`()`. - - Returns: - Tensor, a 4D tensor of shape :math:`(N, ?, H, W)` or :math:`(N, 1, H, W)`. - - Examples: - >>> import mindspore as ms - >>> import numpy as np - >>> # given an instance of RISE, saliency map can be generate - >>> inputs = ms.Tensor(np.random.rand(2, 3, 224, 224), ms.float32) - >>> # when `targets` is an integer - >>> targets = 5 - >>> saliency = rise(inputs, targets) - >>> # `targets` can also be a tensor - >>> targets = ms.Tensor([[5], [1]]) - >>> saliency = rise(inputs, targets) - """ + """Generates attribution maps for inputs.""" self._verify_data(inputs, targets) height, width = inputs.shape[2], inputs.shape[3] batch_size = inputs.shape[0] if self._num_classes is None: - logits = self.model(inputs) + logits = self.network(inputs) num_classes = logits.shape[1] self._num_classes = num_classes @@ -151,7 +147,7 @@ class RISE(PerturbationAttribution): masks = self._generate_masks(data, bs) masked_input = masks * data + (1 - masks) * bg_data - weights = self._activation_fn(self.model(masked_input)) + weights = self._activation_fn(self.network(masked_input)) while len(weights.shape) > 2: weights = op.mean(weights, axis=2) weights = op.reshape(weights, diff --git a/mindspore/explainer/explanation/_attribution/attribution.py b/mindspore/explainer/explanation/_attribution/attribution.py index b72b840675..9b9725cc63 100644 --- a/mindspore/explainer/explanation/_attribution/attribution.py +++ b/mindspore/explainer/explanation/_attribution/attribution.py @@ -28,19 +28,19 @@ class Attribution: The explainers which explanation through attributing the relevance scores should inherit this class. Args: - network (nn.Cell): The black-box model to explanation. + network (nn.Cell): The black-box model to be explained. """ def __init__(self, network): check_value_type("network", network, nn.Cell) - self._model = network - self._model.set_train(False) - self._model.set_grad(False) + self._network = network + self._network.set_train(False) + self._network.set_grad(False) @staticmethod - def _verify_model(model): + def _verify_network(network): """Verify the input `network` for __init__ function.""" - if not isinstance(model, nn.Cell): + if not isinstance(network, nn.Cell): raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.") __call__: Callable @@ -57,9 +57,9 @@ class Attribution: """ @property - def model(self): + def network(self): """Return the model.""" - return self._model + return self._network @staticmethod def _verify_data(inputs, targets):