From: @lixiaohui33 Reviewed-by: @wuxuejian Signed-off-by:tags/v1.1.0
| @@ -64,7 +64,10 @@ class ImageClassificationRunner: | |||
| should provides [images], [images, labels] or [images, labels, bboxes] as columns. The label list must | |||
| share the exact same length and order of the network outputs. | |||
| network (Cell): The network(with logit outputs) to be explained. | |||
| activation_fn (Cell): The activation function for converting network's output to probabilities. | |||
| activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For | |||
| single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, | |||
| `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as | |||
| when combining this function with network, the final output is the probability of the input. | |||
| Examples: | |||
| >>> from mindspore.explainer import ImageClassificationRunner | |||
| @@ -302,6 +305,8 @@ class ImageClassificationRunner: | |||
| ds.config.set_seed(self._DATASET_SEED) | |||
| for idx, next_element in enumerate(self._dataset): | |||
| now = time() | |||
| self._spaced_print("Start running {}-th explanation data for {}......".format( | |||
| idx, exp.__class__.__name__), end='') | |||
| self._run_exp_step(next_element, exp, sample_id_labels, summary) | |||
| self._spaced_print("Finish writing {}-th explanation data for {}. Time elapsed: " | |||
| "{:.3f} s".format(idx, exp.__class__.__name__, time() - now), end='') | |||
| @@ -320,12 +325,17 @@ class ImageClassificationRunner: | |||
| ds.config.set_seed(self._DATASET_SEED) | |||
| for idx, next_element in enumerate(self._dataset): | |||
| now = time() | |||
| self._spaced_print("Start running {}-th explanation data for {}......".format( | |||
| idx, exp.__class__.__name__), end='') | |||
| saliency_dict_lst = self._run_exp_step(next_element, exp, sample_id_labels, summary) | |||
| self._spaced_print( | |||
| "Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format( | |||
| idx, exp.__class__.__name__, time() - now), end='') | |||
| for bench in self._benchmarkers: | |||
| now = time() | |||
| self._spaced_print( | |||
| "Start running {}-th batch {} data for {}......".format( | |||
| idx, bench.__class__.__name__, exp.__class__.__name__), end='') | |||
| self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst) | |||
| self._spaced_print( | |||
| "Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format( | |||
| @@ -496,7 +506,7 @@ class ImageClassificationRunner: | |||
| if explainer.__class__ in explainer_classes: | |||
| raise ValueError(f"Repeated {explainer.__class__.__name__} explainer! " | |||
| "Please make sure all explainers' class is distinct.") | |||
| if explainer.model != self._network: | |||
| if explainer.network is not self._network: | |||
| raise ValueError(f"The network of {explainer.__class__.__name__} explainer is different " | |||
| "instance from network of runner. Please make sure they are the same " | |||
| "instance.") | |||
| @@ -717,4 +727,5 @@ class ImageClassificationRunner: | |||
| @classmethod | |||
| def _spaced_print(cls, message, *args, **kwargs): | |||
| """Spaced message printing.""" | |||
| print(cls._SPACER.format(message), *args, **kwargs) | |||
| # workaround to print logs starting new line in case line width mismatch. | |||
| print(cls._SPACER.format(message)) | |||
| @@ -226,7 +226,7 @@ def calc_correlation(x: Union[ms.Tensor, np.ndarray], | |||
| if np.all(x == 0) or np.all(y == 0): | |||
| return np.float(0) | |||
| faithfulness = -np.corrcoef(x, y)[0, 1] | |||
| faithfulness = np.corrcoef(x, y)[0, 1] | |||
| return faithfulness | |||
| @@ -55,12 +55,12 @@ class ClassSensitivity(LabelAgnosticMetric): | |||
| >>> # prepare your explainer to be evaluated, e.g., Gradient. | |||
| >>> gradient = Gradient(network) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> class_sensitivity = ClassSensitivity() | |||
| >>> # class_sensitivity is a ClassSensitivity instance | |||
| >>> res = class_sensitivity.evaluate(gradient, input_x) | |||
| """ | |||
| self._check_evaluate_param(explainer, inputs) | |||
| outputs = explainer.model(inputs) | |||
| outputs = explainer.network(inputs) | |||
| max_confidence_label = ops.argmax(outputs) | |||
| min_confidence_label = ops.argmin(outputs) | |||
| @@ -18,9 +18,9 @@ from typing import Callable, Optional, Union | |||
| import numpy as np | |||
| from mindspore import log | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import log, nn | |||
| from mindspore.train._utils import check_value_type | |||
| from .metric import LabelSensitiveMetric | |||
| from ..._utils import calc_auc, format_tensor_to_ndarray | |||
| from ...explanation._attribution import Attribution as _Attribution | |||
| @@ -358,22 +358,26 @@ class Faithfulness(LabelSensitiveMetric): | |||
| Args: | |||
| num_labels (int): Number of labels. | |||
| activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For | |||
| single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, | |||
| `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as | |||
| when combining this function with network, the final output is the probability of the input. | |||
| metric (str, optional): The specifi metric to quantify faithfulness. | |||
| Options: "DeletionAUC", "InsertionAUC", "NaiveFaithfulness". | |||
| Default: 'NaiveFaithfulness'. | |||
| activation_fn (Cell, optional): The activation function that transforms the network output to a probability. | |||
| Default: nn.Softmax(). | |||
| Examples: | |||
| >>> from mindspore import nn | |||
| >>> from mindspore.explainer.benchmark import Faithfulness | |||
| >>> # init a `Faithfulness` object | |||
| >>> num_labels = 10 | |||
| >>> metric = "InsertionAUC" | |||
| >>> faithfulness = Faithfulness(num_labels, metric) | |||
| >>> activation_fn = nn.Softmax() | |||
| >>> faithfulness = Faithfulness(num_labels, activation_fn, metric) | |||
| """ | |||
| _methods = [NaiveFaithfulness, DeletionAUC, InsertionAUC] | |||
| def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness", activation_fn=nn.Softmax()): | |||
| def __init__(self, num_labels, activation_fn, metric="NaiveFaithfulness"): | |||
| super(Faithfulness, self).__init__(num_labels) | |||
| perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument | |||
| @@ -382,7 +386,9 @@ class Faithfulness(LabelSensitiveMetric): | |||
| num_perturb_steps = 100 # separate the perturbation progress in to 100 steps. | |||
| base_value = 0.0 # the pixel value set for the perturbed pixels | |||
| check_value_type("activation_fn", activation_fn, nn.Cell) | |||
| self._activation_fn = activation_fn | |||
| self._verify_metrics(metric) | |||
| for method in self._methods: | |||
| if metric == method.__name__: | |||
| @@ -437,8 +443,8 @@ class Faithfulness(LabelSensitiveMetric): | |||
| inputs = format_tensor_to_ndarray(inputs) | |||
| saliency = format_tensor_to_ndarray(saliency) | |||
| model = nn.SequentialCell([explainer.model, self._activation_fn]) | |||
| faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=model, | |||
| full_network = nn.SequentialCell([explainer.network, self._activation_fn]) | |||
| faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=full_network, | |||
| targets=targets, saliency=saliency) | |||
| return (1 + faithfulness) / 2 | |||
| @@ -204,9 +204,9 @@ class LabelSensitiveMetric(AttributionMetric): | |||
| check_value_type('explainer', explainer, Attribution) | |||
| self._record_explainer(explainer) | |||
| verify_argument(inputs, 'inputs') | |||
| output = explainer.model(inputs) | |||
| output = explainer.network(inputs) | |||
| check_value_type("output of explainer model", output, Tensor) | |||
| output_dim = explainer.model(inputs).shape[1] | |||
| output_dim = explainer.network(inputs).shape[1] | |||
| if output_dim != self._num_labels: | |||
| raise ValueError("The output dimension of of black-box model in explainer does not match the dimension " | |||
| "of num_labels set in the __init__, please check explainer and num_labels again.") | |||
| @@ -18,6 +18,7 @@ import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore.train._utils import check_value_type | |||
| from mindspore import log | |||
| from .metric import LabelSensitiveMetric | |||
| from ...explanation._attribution._perturbation.replacement import RandomPerturb | |||
| @@ -30,17 +31,24 @@ class Robustness(LabelSensitiveMetric): | |||
| Args: | |||
| num_labels (int): Number of classes in the dataset. | |||
| activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For | |||
| single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, | |||
| `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as | |||
| when combining this function with network, the final output is the probability of the input. | |||
| Examples: | |||
| >>> # Initialize a Robustness benchmarker passing num_labels of the dataset. | |||
| >>> from mindspore import nn | |||
| >>> from mindspore.explainer.benchmark import Robustness | |||
| >>> num_labels = 100 | |||
| >>> robustness = Robustness(num_labels) | |||
| >>> # Initialize a Robustness benchmarker passing num_labels of the dataset. | |||
| >>> num_labels = 10 | |||
| >>> activation_fn = nn.Softmax() | |||
| >>> robustness = Robustness(num_labels, activation_fn) | |||
| """ | |||
| def __init__(self, num_labels, activation_fn=nn.Softmax()): | |||
| def __init__(self, num_labels, activation_fn): | |||
| super().__init__(num_labels) | |||
| check_value_type("activation_fn", activation_fn, nn.Cell) | |||
| self._perturb = RandomPerturb() | |||
| self._num_perturbations = 10 # number of perturbations used in evaluation | |||
| self._threshold = 0.1 # threshold to generate perturbation | |||
| @@ -69,6 +77,8 @@ class Robustness(LabelSensitiveMetric): | |||
| ValueError: If batch_size is larger than 1. | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import Gradient | |||
| >>> from mindspore.explainer.benchmark import Robustness | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| @@ -80,7 +90,7 @@ class Robustness(LabelSensitiveMetric): | |||
| >>> gradient = Gradient(network) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> target_label = ms.Tensor([0], ms.int32) | |||
| >>> robustness = Robustness(num_labels=10) | |||
| >>> # robustness is a Robustness instance | |||
| >>> res = robustness.evaluate(gradient, input_x, target_label) | |||
| """ | |||
| @@ -100,13 +110,13 @@ class Robustness(LabelSensitiveMetric): | |||
| log.warning('Get saliency norm equals 0, robustness return NaN for zero-norm saliency currently.') | |||
| norm[norm == 0] = np.nan | |||
| model = nn.SequentialCell([explainer.model, self._activation_fn]) | |||
| original_outputs = model(inputs).asnumpy() | |||
| full_network = nn.SequentialCell([explainer.network, self._activation_fn]) | |||
| original_outputs = full_network(inputs).asnumpy() | |||
| sensitivities = [] | |||
| for _ in range(self._num_perturbations): | |||
| perturbations = [] | |||
| for j, sample in enumerate(inputs_np): | |||
| perturbation_on_single_sample = self._perturb_with_threshold(model, | |||
| perturbation_on_single_sample = self._perturb_with_threshold(full_network, | |||
| np.expand_dims(sample, axis=0), | |||
| original_outputs[j]) | |||
| perturbations.append(perturbation_on_single_sample) | |||
| @@ -120,7 +130,7 @@ class Robustness(LabelSensitiveMetric): | |||
| robustness_res = 1 / np.exp(max_sensitivity) | |||
| return robustness_res | |||
| def _perturb_with_threshold(self, model: nn.Cell, sample: np.ndarray, original_output: np.ndarray) -> np.ndarray: | |||
| def _perturb_with_threshold(self, network: nn.Cell, sample: np.ndarray, original_output: np.ndarray) -> np.ndarray: | |||
| """ | |||
| Generate the perturbation until the L2-distance between original_output and perturbation_output is lower than | |||
| the given self._threshold or until the attempt reaches the max_attempt_time. | |||
| @@ -130,7 +140,7 @@ class Robustness(LabelSensitiveMetric): | |||
| perturbation = None | |||
| for _ in range(max_attempt_time): | |||
| perturbation = self._perturb(sample) | |||
| perturbation_output = self._activation_fn(model(ms.Tensor(sample, ms.float32))).asnumpy() | |||
| perturbation_output = self._activation_fn(network(ms.Tensor(sample, ms.float32))).asnumpy() | |||
| perturb_error = np.linalg.norm(original_output - perturbation_output) | |||
| if perturb_error <= self._threshold: | |||
| return perturbation | |||
| @@ -39,7 +39,7 @@ def compute_gradients(model, inputs, targets=None, weights=None): | |||
| raise ValueError('Must provide one of targets or weights') | |||
| if weights is None: | |||
| targets = unify_targets(targets) | |||
| output = model(*inputs).asnumpy() | |||
| output = model(*inputs) | |||
| num_categories = output.shape[-1] | |||
| weights = generate_one_hot(targets, num_categories) | |||
| @@ -64,16 +64,30 @@ class GradCAM(IntermediateLayerAttribution): | |||
| layer for better practice. If it is '', the explantion will be generated at the input layer. | |||
| Default: ''. | |||
| Inputs: | |||
| - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| - **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. | |||
| If it is a 1D tensor, its length should be the same as `inputs`. | |||
| Outputs: | |||
| Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import GradCAM | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> network = resnet50(10) # please refer to model_zoo | |||
| >>> # load a trained network | |||
| >>> net = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> # specify a layer name to generate explanation, usually the layer can be set as the last conv layer. | |||
| >>> layer_name = 'layer4' | |||
| >>> # init GradCAM with a trained network and specify the layer to obtain attribution | |||
| >>> gradcam = GradCAM(net, layer=layer_name) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> label = 5 | |||
| >>> saliency = gradcam(inputs, label) | |||
| """ | |||
| def __init__(self, network, layer=""): | |||
| @@ -100,25 +114,7 @@ class GradCAM(IntermediateLayerAttribution): | |||
| self._intermediate_grad = grad_input | |||
| def __call__(self, inputs, targets): | |||
| """ | |||
| Call function for `GradCAM`. | |||
| Args: | |||
| inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer. | |||
| If it is a 1D tensor, its length should be the same as `inputs`. | |||
| Returns: | |||
| Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. | |||
| Examples: | |||
| >>> import mindspore as ms | |||
| >>> import numpy as np | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> label = 5 | |||
| >>> # gradcam is a GradCAM object, parse data and the target label to be explained and get the attribution | |||
| >>> saliency = gradcam(inputs, label) | |||
| """ | |||
| """Call function for `GradCAM`.""" | |||
| self._verify_data(inputs, targets) | |||
| self._hook_cell() | |||
| @@ -59,7 +59,17 @@ class Gradient(Attribution): | |||
| Args: | |||
| network (Cell): The black-box model to be explained. | |||
| Inputs: | |||
| - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| - **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. | |||
| If it is a 1D tensor, its length should be the same as `inputs`. | |||
| Outputs: | |||
| Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import Gradient | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # init Gradient with a trained network | |||
| @@ -67,6 +77,9 @@ class Gradient(Attribution): | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> gradient = Gradient(net) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> label = 5 | |||
| >>> saliency = gradient(inputs, label) | |||
| """ | |||
| def __init__(self, network): | |||
| @@ -79,25 +92,7 @@ class Gradient(Attribution): | |||
| self._aggregation_fn = abs_max | |||
| def __call__(self, inputs, targets): | |||
| """ | |||
| Call function for `Gradient`. | |||
| Args: | |||
| inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer. | |||
| If it is a 1D tensor, its length should be the same as `inputs`. | |||
| Returns: | |||
| Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. | |||
| Examples: | |||
| >>> import mindspore as ms | |||
| >>> import numpy as np | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> label = 5 | |||
| >>> # gradient is a Gradient object, parse data and the target label to be explained and get the attribution | |||
| >>> saliency = gradient(inputs, label) | |||
| """ | |||
| """Call function for `Gradient`.""" | |||
| self._verify_data(inputs, targets) | |||
| inputs = unify_inputs(inputs) | |||
| targets = unify_targets(targets) | |||
| @@ -96,15 +96,23 @@ class Deconvolution(ModifiedReLU): | |||
| Args: | |||
| network (Cell): The black-box model to be explained. | |||
| Inputs: | |||
| - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| - **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. | |||
| If it is a 1D tensor, its length should be the same as `inputs`. | |||
| Outputs: | |||
| Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import Deconvolution | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # init Deconvolution with a trained network. | |||
| >>> net = resnet50(10) # please refer to model_zoo | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> # init Deconvolution with a trained network. | |||
| >>> deconvolution = Deconvolution(net) | |||
| >>> # parse data and the target label to be explained and get the saliency map | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| @@ -134,15 +142,23 @@ class GuidedBackprop(ModifiedReLU): | |||
| Args: | |||
| network (Cell): The black-box model to be explained. | |||
| Inputs: | |||
| - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| - **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. | |||
| If it is a 1D tensor, its length should be the same as `inputs`. | |||
| Outputs: | |||
| Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> from mindspore.explainer.explanation import GuidedBackprop | |||
| >>> # init GuidedBackprop with a trained network. | |||
| >>> net = resnet50(10) # please refer to model_zoo | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> # init GuidedBackprop with a trained network. | |||
| >>> gbp = GuidedBackprop(net) | |||
| >>> # parse data and the target label to be explained and get the saliency map | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| @@ -47,7 +47,7 @@ def _generate_patches(array, window_size, stride): | |||
| class Occlusion(PerturbationAttribution): | |||
| r""" | |||
| """ | |||
| Occlusion uses a sliding window to replace the pixels with a reference value (e.g. constant value), and computes | |||
| the output difference w.r.t the original output. The output difference caused by perturbed pixels are assigned as | |||
| feature importance to those pixels. For pixels involved in multiple sliding windows, the feature importance is the | |||
| @@ -56,7 +56,14 @@ class Occlusion(PerturbationAttribution): | |||
| For more details, please refer to the original paper via: `<https://arxiv.org/abs/1311.2901>`_. | |||
| Args: | |||
| network (Cell): Specify the black-box model to be explained. | |||
| network (Cell): The black-box model to be explained. | |||
| activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For | |||
| single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, | |||
| `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as | |||
| when combining this function with network, the final output is the probability of the input. | |||
| perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the | |||
| perturbed samples. Within the memory capacity, usually the larger this number is, the faster the | |||
| explanation is obtained. Default: 32. | |||
| Inputs: | |||
| - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| @@ -67,27 +74,29 @@ class Occlusion(PerturbationAttribution): | |||
| Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`. | |||
| Example: | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import Occlusion | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. | |||
| >>> network = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(network, param_dict) | |||
| >>> # initialize Occlusion explainer and pass the pretrained model | |||
| >>> occlusion = Occlusion(network) | |||
| >>> # initialize Occlusion explainer with the pretrained model and activation function | |||
| >>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities | |||
| >>> occlusion = Occlusion(network, activation_fn=activation_fn) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> label = ms.Tensor([1], ms.int32) | |||
| >>> saliency = occlusion(input_x, label) | |||
| """ | |||
| def __init__(self, network, activation_fn=nn.Softmax()): | |||
| super().__init__(network, activation_fn) | |||
| def __init__(self, network, activation_fn, perturbation_per_eval=32): | |||
| super().__init__(network, activation_fn, perturbation_per_eval) | |||
| self._ablation = Ablation(perturb_mode='Deletion') | |||
| self._aggregation_fn = abs_max | |||
| self._get_replacement = Constant(base_value=0.0) | |||
| self._num_sample_per_dim = 32 # specify the number of perturbations each dimension. | |||
| self._num_per_eval = 2 # number of perturbations generate for each sample per evaluation step. | |||
| def __call__(self, inputs, targets): | |||
| """Call function for 'Occlusion'.""" | |||
| @@ -99,9 +108,9 @@ class Occlusion(PerturbationAttribution): | |||
| batch_size = inputs_np.shape[0] | |||
| window_size, strides = self._get_window_size_and_strides(inputs_np) | |||
| model = nn.SequentialCell([self._model, self._activation_fn]) | |||
| full_network = nn.SequentialCell([self._network, self._activation_fn]) | |||
| original_outputs = model(ms.Tensor(inputs, ms.float32)).asnumpy()[np.arange(batch_size), targets_np] | |||
| original_outputs = full_network(ms.Tensor(inputs, ms.float32)).asnumpy()[np.arange(batch_size), targets_np] | |||
| total_attribution = np.zeros_like(inputs_np) | |||
| weights = np.ones_like(inputs_np) | |||
| @@ -111,13 +120,13 @@ class Occlusion(PerturbationAttribution): | |||
| count = 0 | |||
| while count < num_perturbations: | |||
| ith_masks = masks[:, count:min(count+self._num_per_eval, num_perturbations)] | |||
| ith_masks = masks[:, count:min(count+self._perturbation_per_eval, num_perturbations)] | |||
| actual_num_eval = ith_masks.shape[1] | |||
| num_samples = batch_size * actual_num_eval | |||
| occluded_inputs = self._ablation(inputs_np, reference, ith_masks) | |||
| occluded_inputs = occluded_inputs.reshape((-1, *inputs_np.shape[1:])) | |||
| targets_repeat = np.repeat(targets_np, repeats=actual_num_eval, axis=0) | |||
| occluded_outputs = model( | |||
| occluded_outputs = full_network( | |||
| ms.Tensor(occluded_inputs, ms.float32)).asnumpy()[np.arange(num_samples), targets_repeat] | |||
| original_outputs_repeat = np.repeat(original_outputs, repeats=actual_num_eval, axis=0) | |||
| outputs_diff = original_outputs_repeat - occluded_outputs | |||
| @@ -19,7 +19,6 @@ from mindspore.train._utils import check_value_type | |||
| from mindspore.nn import Cell | |||
| from ..attribution import Attribution | |||
| from ...._operators import softmax | |||
| class PerturbationAttribution(Attribution): | |||
| @@ -31,8 +30,13 @@ class PerturbationAttribution(Attribution): | |||
| def __init__(self, | |||
| network, | |||
| activation_fn=softmax(), | |||
| activation_fn, | |||
| perturbation_per_eval, | |||
| ): | |||
| super(PerturbationAttribution, self).__init__(network) | |||
| check_value_type("activation_fn", activation_fn, Cell) | |||
| self._activation_fn = activation_fn | |||
| check_value_type('perturbation_per_eval', perturbation_per_eval, int) | |||
| if perturbation_per_eval <= 0: | |||
| raise ValueError('Argument perturbation_per_eval should be a positive integer.') | |||
| self._perturbation_per_eval = perturbation_per_eval | |||
| @@ -14,11 +14,12 @@ | |||
| # ============================================================================ | |||
| """RISE.""" | |||
| import math | |||
| import random | |||
| import numpy as np | |||
| from mindspore.ops.operations import Concat | |||
| from mindspore import Tensor | |||
| from mindspore import nn | |||
| from mindspore.train._utils import check_value_type | |||
| from .perturbation import PerturbationAttribution | |||
| @@ -36,41 +37,57 @@ class RISE(PerturbationAttribution): | |||
| node of interest: | |||
| .. math:: | |||
| E_{RISE}(I, f)_c = \sum_{i}f_c(I\odot M_i) M_i | |||
| attribution = \sum_{i}f_c(I\odot M_i) M_i | |||
| For more details, please refer to the original paper via: `RISE <https://arxiv.org/abs/1806.07421>`_. | |||
| Args: | |||
| network (Cell): The black-box model to be explained. | |||
| activation_fn (Cell, optional): The activation layer that transforms logits to prediction probabilities. For | |||
| activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For | |||
| single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, | |||
| `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as | |||
| when combining this function with network, the final output is the probability of the input. | |||
| Default: `nn.Softmax`. | |||
| perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the | |||
| perturbed samples. Default: 32. | |||
| perturbed samples. Within the memory capacity, usually the larger this number is, the faster the | |||
| explanation is obtained. Default: 32. | |||
| Inputs: | |||
| - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| - **targets** (Tensor, int) - The labels of interest to be explained. When `targets` is an integer, | |||
| all of the inputs will generates attribution map w.r.t this integer. When `targets` is a tensor, it | |||
| should be of shape :math:`(N, l)` (l being the number of labels for each sample) or :math:`(N,)` :math:`()`. | |||
| Outputs: | |||
| Tensor, a 4D tensor of shape :math:`(N, ?, H, W)`. | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import RISE | |||
| >>> from mindspore.nn import Sigmoid | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # init RISE with a trained network | |||
| >>> net = resnet50(10) # please refer to model_zoo | |||
| >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. | |||
| >>> network = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> # init RISE with specified activation function | |||
| >>> rise = RISE(net, activation_fn=Sigmoid()) | |||
| """ | |||
| >>> load_param_into_net(network, param_dict) | |||
| >>> # initialize RISE explainer with the pretrained model and activation function | |||
| >>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities | |||
| >>> rise = RISE(network, activation_fn=activation_fn) | |||
| >>> # given an instance of RISE, saliency map can be generate | |||
| >>> inputs = ms.Tensor(np.random.rand(2, 3, 224, 224), ms.float32) | |||
| >>> # when `targets` is an integer | |||
| >>> targets = 5 | |||
| >>> saliency = rise(inputs, targets) | |||
| >>> # `targets` can also be a 2D tensor | |||
| >>> targets = ms.Tensor([[5], [1]]) | |||
| >>> saliency = rise(inputs, targets) | |||
| """ | |||
| def __init__(self, | |||
| network, | |||
| activation_fn=nn.Softmax(), | |||
| activation_fn, | |||
| perturbation_per_eval=32): | |||
| super(RISE, self).__init__(network, activation_fn) | |||
| check_value_type('perturbation_per-eval', perturbation_per_eval, int) | |||
| if perturbation_per_eval <= 0: | |||
| raise ValueError('perturbation_per_eval should be postive integer.') | |||
| self._perturbation_per_eval = perturbation_per_eval | |||
| super(RISE, self).__init__(network, activation_fn, perturbation_per_eval) | |||
| self._num_masks = 6000 # number of masks to be sampled | |||
| self._mask_probability = 0.2 # ratio of inputs to be masked | |||
| @@ -93,47 +110,26 @@ class RISE(PerturbationAttribution): | |||
| self._resize_mode) | |||
| # Pack operator not available for GPU, thus transfer to numpy first | |||
| upsample_np = upsample.asnumpy() | |||
| masks_lst = [] | |||
| for sample in upsample_np: | |||
| shift_x = np.random.randint(0, mask_size[0] + 1) | |||
| shift_y = np.random.randint(0, mask_size[1] + 1) | |||
| for sample in upsample: | |||
| shift_x = random.randint(0, mask_size[0]) | |||
| shift_y = random.randint(0, mask_size[1]) | |||
| masks_lst.append(sample[:, shift_x: shift_x + height, shift_y:shift_y + width]) | |||
| masks = op.Tensor(np.array(masks_lst), data.dtype) | |||
| concat = Concat() | |||
| masks = concat(tuple(masks_lst)) | |||
| masks = op.reshape(masks, (batch_size, -1, height, width)) | |||
| return masks | |||
| def __call__(self, inputs, targets): | |||
| """ | |||
| Generates attribution maps for inputs. | |||
| Args: | |||
| inputs (Tensor): Input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| targets (int, Tensor): The labels of interest to be explained. When `targets` is an integer, | |||
| all of the inputs will generates attribution map w.r.t this integer. When `targets` is a tensor, it | |||
| should be of shape :math:`(N, ?)` or :math:`(N,)` :math:`()`. | |||
| Returns: | |||
| Tensor, a 4D tensor of shape :math:`(N, ?, H, W)` or :math:`(N, 1, H, W)`. | |||
| Examples: | |||
| >>> import mindspore as ms | |||
| >>> import numpy as np | |||
| >>> # given an instance of RISE, saliency map can be generate | |||
| >>> inputs = ms.Tensor(np.random.rand(2, 3, 224, 224), ms.float32) | |||
| >>> # when `targets` is an integer | |||
| >>> targets = 5 | |||
| >>> saliency = rise(inputs, targets) | |||
| >>> # `targets` can also be a tensor | |||
| >>> targets = ms.Tensor([[5], [1]]) | |||
| >>> saliency = rise(inputs, targets) | |||
| """ | |||
| """Generates attribution maps for inputs.""" | |||
| self._verify_data(inputs, targets) | |||
| height, width = inputs.shape[2], inputs.shape[3] | |||
| batch_size = inputs.shape[0] | |||
| if self._num_classes is None: | |||
| logits = self.model(inputs) | |||
| logits = self.network(inputs) | |||
| num_classes = logits.shape[1] | |||
| self._num_classes = num_classes | |||
| @@ -151,7 +147,7 @@ class RISE(PerturbationAttribution): | |||
| masks = self._generate_masks(data, bs) | |||
| masked_input = masks * data + (1 - masks) * bg_data | |||
| weights = self._activation_fn(self.model(masked_input)) | |||
| weights = self._activation_fn(self.network(masked_input)) | |||
| while len(weights.shape) > 2: | |||
| weights = op.mean(weights, axis=2) | |||
| weights = op.reshape(weights, | |||
| @@ -28,19 +28,19 @@ class Attribution: | |||
| The explainers which explanation through attributing the relevance scores should inherit this class. | |||
| Args: | |||
| network (nn.Cell): The black-box model to explanation. | |||
| network (nn.Cell): The black-box model to be explained. | |||
| """ | |||
| def __init__(self, network): | |||
| check_value_type("network", network, nn.Cell) | |||
| self._model = network | |||
| self._model.set_train(False) | |||
| self._model.set_grad(False) | |||
| self._network = network | |||
| self._network.set_train(False) | |||
| self._network.set_grad(False) | |||
| @staticmethod | |||
| def _verify_model(model): | |||
| def _verify_network(network): | |||
| """Verify the input `network` for __init__ function.""" | |||
| if not isinstance(model, nn.Cell): | |||
| if not isinstance(network, nn.Cell): | |||
| raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.") | |||
| __call__: Callable | |||
| @@ -57,9 +57,9 @@ class Attribution: | |||
| """ | |||
| @property | |||
| def model(self): | |||
| def network(self): | |||
| """Return the model.""" | |||
| return self._model | |||
| return self._network | |||
| @staticmethod | |||
| def _verify_data(inputs, targets): | |||