From: @lixiaohui33 Reviewed-by: Signed-off-by:pull/13509/MERGE
| @@ -83,21 +83,30 @@ class ImageClassificationRunner: | |||
| >>> from mindspore.explainer.benchmark import Faithfulness | |||
| >>> from mindspore.nn import Softmax | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10 | |||
| >>> dataset = get_dataset('/path/to/Cifar10_dataset') | |||
| >>> | |||
| >>> # The detail of AlexNet is shown in model_zoo.official.cv.alexnet.src.alexnet.py | |||
| >>> net = AlexNet(10) | |||
| >>> # Load the checkpoint | |||
| >>> param_dict = load_checkpoint("/path/to/checkpoint") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> | |||
| >>> # Prepare the dataset for explaining and evaluation. | |||
| >>> # The detail of create_dataset_cifar10 method is shown in model_zoo.official.cv.alexnet.src.dataset.py | |||
| >>> | |||
| >>> dataset = create_dataset_cifar10("/path/to/cifar/dataset", 1) | |||
| >>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |||
| >>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10 | |||
| >>> param_dict = load_checkpoint("checkpoint.ckpt") | |||
| >>> net = resnet50(len(labels)) | |||
| >>> | |||
| >>> activation_fn = Softmax() | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> gbp = GuidedBackprop(net) | |||
| >>> gradient = Gradient(net) | |||
| >>> explainers = [gbp, gradient] | |||
| >>> faithfulness = Faithfulness(len(labels), activation_fn, "NaiveFaithfulness") | |||
| >>> benchmarkers = [faithfulness] | |||
| >>> | |||
| >>> runner = ImageClassificationRunner("./summary_dir", (dataset, labels), net, activation_fn) | |||
| >>> runner.register_saliency(explainers=explainers, benchmarkers=benchmarkers) | |||
| >>> runner.register_uncertainty() | |||
| >>> runner.register_hierarchical_occlusion() | |||
| >>> runner.run() | |||
| """ | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -250,9 +250,9 @@ def summation(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tens | |||
| def stack(inputs: List[Tensor], axis: int) -> Tensor: | |||
| """Packs a list of tensors in specified axis.""" | |||
| pack_op = op.Pack(axis) | |||
| outputs = pack_op(inputs) | |||
| """Stacks a list of tensors in specified axis.""" | |||
| stack_op = op.Stack(axis) | |||
| outputs = stack_op(inputs) | |||
| return outputs | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -104,16 +104,14 @@ def retrieve_layer_by_name(model: _Module, layer_name: str): | |||
| Retrieve the layer in the model by the given layer_name. | |||
| Args: | |||
| model (_Module): model which contains the target layer | |||
| layer_name (str): name of target layer | |||
| model (Cell): Model which contains the target layer. | |||
| layer_name (str): Name of target layer. | |||
| Return: | |||
| - target_layer (_Module) | |||
| Raise: | |||
| ValueError: if module with given layer_name is not found in the model, | |||
| raise ValueError. | |||
| Returns: | |||
| Cell, the target layer. | |||
| Raises: | |||
| ValueError: If module with given layer_name is not found in the model. | |||
| """ | |||
| if not isinstance(layer_name, str): | |||
| raise TypeError('layer_name should be type of str, but receive {}.' | |||
| @@ -146,13 +144,14 @@ def retrieve_layer(model: _Module, target_layer: Union[str, _Module] = ''): | |||
| be raised. | |||
| Args: | |||
| model (_Module): the model to retrieve the target layer | |||
| target_layer (Union[str, _Module]): target layer to retrieve. Can be | |||
| either string (layer name) or the Cell object. If '' is provided, | |||
| the input model will be returned. | |||
| model (Cell): Model which contains the target layer. | |||
| target_layer (str, Cell): Name of target layer or the target layer instance. | |||
| Returns: | |||
| Cell, the target layer. | |||
| Return: | |||
| target layer (_Module) | |||
| Raises: | |||
| ValueError: If module with given layer_name is not found in the model. | |||
| """ | |||
| if isinstance(target_layer, str): | |||
| target_layer = retrieve_layer_by_name(model, target_layer) | |||
| @@ -174,9 +173,7 @@ class ForwardProbe: | |||
| Probe to capture output of specific layer in a given model. | |||
| Args: | |||
| target_layer (_Module): name of target layer or just provide the | |||
| target layer. | |||
| target_layer (str, Cell): Name of target layer or the target layer instance. | |||
| """ | |||
| def __init__(self, target_layer: _Module): | |||
| @@ -204,7 +201,7 @@ class ForwardProbe: | |||
| def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray: | |||
| """Unify `mindspore.Tensor` and `np.ndarray` to `np.ndarray`. """ | |||
| """Unify Tensor and numpy.array to numpy.array.""" | |||
| if isinstance(x, ms.Tensor): | |||
| x = x.asnumpy() | |||
| @@ -231,7 +228,7 @@ def calc_correlation(x: Union[ms.Tensor, np.ndarray], | |||
| def calc_auc(x: _Array) -> _Array: | |||
| """Calculate the Aera under Curve.""" | |||
| """Calculate the Area under Curve.""" | |||
| # take mean for multiple patches if the model is fully convolutional model | |||
| if len(x.shape) == 4: | |||
| x = np.mean(np.mean(x, axis=2), axis=3) | |||
| @@ -242,18 +239,11 @@ def calc_auc(x: _Array) -> _Array: | |||
| def rank_pixels(inputs: _Array, descending: bool = True) -> _Array: | |||
| """ | |||
| Generate rank order fo every pixel in an 2D array. | |||
| Generate rank order for every pixel in an 2D array. | |||
| The rank order start from 0 to (num_pixel-1). If descending is True, the | |||
| rank order will generate in a descending order, otherwise in ascending | |||
| order. | |||
| Example: | |||
| x = np.array([[4., 3., 1.], [5., 9., 1.]]) | |||
| rank_pixels(x, descending=True) | |||
| >> np.array([[2, 3, 4], [1, 0, 5]]) | |||
| rank_pixels(x, descending=False) | |||
| >> np.array([[3, 2, 0], [4, 5, 1]]) | |||
| """ | |||
| if len(inputs.shape) < 2 or len(inputs.shape) > 3: | |||
| raise ValueError('Only support 2D or 3D inputs currently.') | |||
| @@ -275,16 +265,15 @@ def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor: | |||
| Resize the intermediate layer _attribution to the same size as inputs. | |||
| Args: | |||
| inputs (ms.Tensor): the input tensor to be resized | |||
| size (tupleint]): the targeted size resize to | |||
| mode (str): the resize mode. Options: 'nearest_neighbor', 'bilinear' | |||
| inputs (Tensor): The input tensor to be resized. | |||
| size (tuple[int]): The targeted size resize to. | |||
| mode (str): The resize mode. Options: 'nearest_neighbor', 'bilinear'. | |||
| Returns: | |||
| outputs (ms.Tensor): the resized tensor. | |||
| Tensor, the resized tensor. | |||
| Raises: | |||
| ValueError: the resize mode is not in ['nearest_neighbor', | |||
| 'bilinear']. | |||
| ValueError: the resize mode is not in ['nearest_neighbor', 'bilinear']. | |||
| """ | |||
| h, w = size | |||
| if mode == 'nearest_neighbor': | |||
| @@ -305,6 +294,6 @@ def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor: | |||
| resized_np = np.transpose(array_lst, [0, 3, 1, 2]) | |||
| outputs = ms.Tensor(resized_np, inputs.dtype) | |||
| else: | |||
| raise ValueError('Unsupported resize mode {}'.format(mode)) | |||
| raise ValueError('Unsupported resize mode {}.'.format(mode)) | |||
| return outputs | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -45,19 +45,19 @@ class ClassSensitivity(LabelAgnosticMetric): | |||
| numpy.ndarray, 1D array of shape :math:`(N,)`, result of class sensitivity evaluated on `explainer`. | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.benchmark import ClassSensitivity | |||
| >>> from mindspore.explainer.explanation import Gradient | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. | |||
| >>> network = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(network, param_dict) | |||
| >>> | |||
| >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py | |||
| >>> net = LeNet5(10, num_channel=3) | |||
| >>> # prepare your explainer to be evaluated, e.g., Gradient. | |||
| >>> gradient = Gradient(network) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> gradient = Gradient(net) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32) | |||
| >>> class_sensitivity = ClassSensitivity() | |||
| >>> res = class_sensitivity.evaluate(gradient, input_x) | |||
| >>> print(res) | |||
| """ | |||
| self._check_evaluate_param(explainer, inputs) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -419,16 +419,20 @@ class Faithfulness(LabelSensitiveMetric): | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import Gradient | |||
| >>> # init an explainer with a trained network, e.g., resnet50 | |||
| >>> gradient = Gradient(network) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> | |||
| >>> | |||
| >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py | |||
| >>> net = LeNet5(10, num_channel=3) | |||
| >>> gradient = Gradient(net) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32) | |||
| >>> targets = 5 | |||
| >>> # usage 1: input the explainer and the data to be explained, | |||
| >>> # calculate the faithfulness with the specified metric | |||
| >>> # faithfulness is a Faithfulness instance | |||
| >>> res = faithfulness.evaluate(gradient, inputs, targets) | |||
| >>> # usage 2: input the generated saliency map | |||
| >>> saliency = gradient(inputs, targets) | |||
| >>> res = faithfulness.evaluate(gradient, inputs, targets, saliency) | |||
| >>> print(res) | |||
| """ | |||
| self._check_evaluate_param(explainer, inputs, targets, saliency) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -61,7 +61,7 @@ class Localization(LabelSensitiveMetric): | |||
| Examples: | |||
| >>> from mindspore.explainer.benchmark import Localization | |||
| >>> num_labels = 100 | |||
| >>> num_labels = 10 | |||
| >>> localization = Localization(num_labels, "PointingGame") | |||
| """ | |||
| @@ -113,18 +113,22 @@ class Localization(LabelSensitiveMetric): | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import Gradient | |||
| >>> # init an explainer with a trained network, e.g., resnet50 | |||
| >>> gradient = Gradient(network) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> masks = np.zeros([1, 1, 224, 224]) | |||
| >>> masks[:, :, 65: 100, 65: 100] = 1 | |||
| >>> | |||
| >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py | |||
| >>> net = LeNet5(10, num_channel=3) | |||
| >>> gradient = Gradient(net) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32) | |||
| >>> masks = np.zeros([1, 1, 32, 32]) | |||
| >>> masks[:, :, 10: 20, 10: 20] = 1 | |||
| >>> targets = 5 | |||
| >>> # usage 1: input the explainer and the data to be explained, | |||
| >>> # calculate the faithfulness with the specified metric | |||
| >>> # localization is a Localization instance | |||
| >>> res = localization.evaluate(gradient, inputs, targets, mask=masks) | |||
| >>> print(res) | |||
| >>> # usage 2: input the generated saliency map | |||
| >>> saliency = gradient(inputs, targets) | |||
| >>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks) | |||
| >>> print(res) | |||
| """ | |||
| self._check_evaluate_param_with_mask(explainer, inputs, targets, saliency, mask) | |||
| @@ -69,7 +69,7 @@ class AttributionMetric: | |||
| if self._explainer is None: | |||
| self._explainer = explainer | |||
| elif self._explainer is not explainer: | |||
| logger.info('Provided explainer is not the same as previously evaluted one. Please reset the evaluated ' | |||
| logger.info('Provided explainer is not the same as previously evaluated one. Please reset the evaluated ' | |||
| 'results. Previous explainer: %s, current explainer: %s', self._explainer, explainer) | |||
| self._explainer = explainer | |||
| @@ -107,7 +107,7 @@ class LabelAgnosticMetric(AttributionMetric): | |||
| raise TypeError('result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result)) | |||
| def get_results(self): | |||
| """Return the gloabl results.""" | |||
| """Return the global results.""" | |||
| return self._global_results.copy() | |||
| def reset(self): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -80,18 +80,16 @@ class Robustness(LabelSensitiveMetric): | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import Gradient | |||
| >>> from mindspore.explainer.benchmark import Robustness | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. | |||
| >>> network = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(network, param_dict) | |||
| >>> | |||
| >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py | |||
| >>> net = LeNet5(10, num_channel=3) | |||
| >>> # prepare your explainer to be evaluated, e.g., Gradient. | |||
| >>> gradient = Gradient(network) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> gradient = Gradient(net) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32) | |||
| >>> target_label = ms.Tensor([0], ms.int32) | |||
| >>> # robustness is a Robustness instance | |||
| >>> res = robustness.evaluate(gradient, input_x, target_label) | |||
| >>> print(res) | |||
| """ | |||
| self._check_evaluate_param(explainer, inputs, targets, saliency) | |||
| @@ -24,15 +24,15 @@ def get_bp_weights(model, inputs, targets=None, weights=None): | |||
| Compute the gradient of output w.r.t input. | |||
| Args: | |||
| model (`ms.nn.Cell`): Differentiable black-box model. | |||
| inputs (`ms.Tensor`): Input to calculate gradient and explanation. | |||
| model (Cell): Differentiable black-box model. | |||
| inputs (Tensor): Input to calculate gradient and explanation. | |||
| targets (int, optional): Target label id specifying which category to compute gradient. Default: None. | |||
| weights (`ms.Tensor`, optional): Custom weights for computing gradients. The shape of weights should match the | |||
| model outputs. If None is provided, an one-hot weights with one in targets positions will be used instead. | |||
| weights (Tensor, optional): Custom weights for computing gradients. The shape of weights should match the model | |||
| outputs. If None is provided, an one-hot weights with one in targets positions will be used instead. | |||
| Default: None. | |||
| Returns: | |||
| saliency map (ms.Tensor): Gradient back-propagated to the input. | |||
| Tensor, signal to be back-propagated to the input. | |||
| """ | |||
| inputs = unify_inputs(inputs) | |||
| if targets is None and weights is None: | |||
| @@ -61,7 +61,7 @@ class GradCAM(IntermediateLayerAttribution): | |||
| Args: | |||
| network (Cell): The black-box model to be explained. | |||
| layer (str, optional): The layer name to generate the explanation, usually chosen as the last convolutional | |||
| layer for better practice. If it is '', the explantion will be generated at the input layer. | |||
| layer for better practice. If it is '', the explanation will be generated at the input layer. | |||
| Default: ''. | |||
| Inputs: | |||
| @@ -76,18 +76,17 @@ class GradCAM(IntermediateLayerAttribution): | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import GradCAM | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # load a trained network | |||
| >>> net = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> | |||
| >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py | |||
| >>> net = LeNet5(10, num_channel=3) | |||
| >>> # specify a layer name to generate explanation, usually the layer can be set as the last conv layer. | |||
| >>> layer_name = 'layer4' | |||
| >>> layer_name = 'conv2' | |||
| >>> # init GradCAM with a trained network and specify the layer to obtain attribution | |||
| >>> gradcam = GradCAM(net, layer=layer_name) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32) | |||
| >>> label = 5 | |||
| >>> saliency = gradcam(inputs, label) | |||
| >>> print(saliency.shape) | |||
| """ | |||
| def __init__(self, network, layer=""): | |||
| @@ -15,32 +15,14 @@ | |||
| """Gradient explainer.""" | |||
| from copy import deepcopy | |||
| from mindspore import nn | |||
| from mindspore.train._utils import check_value_type | |||
| from mindspore.explainer._operators import reshape, sqrt, Tensor | |||
| from mindspore.explainer._operators import Tensor | |||
| from mindspore.explainer._utils import abs_max, unify_inputs, unify_targets | |||
| from .. import Attribution | |||
| from .backprop_utils import get_bp_weights, GradNet | |||
| def _get_hook(bntype, cache): | |||
| """Provide backward hook function for BatchNorm layer in eval mode.""" | |||
| var, gamma, eps = cache | |||
| if bntype == "2d": | |||
| var = reshape(var, (1, -1, 1, 1)) | |||
| gamma = reshape(gamma, (1, -1, 1, 1)) | |||
| elif bntype == "1d": | |||
| var = reshape(var, (1, -1, 1)) | |||
| gamma = reshape(gamma, (1, -1, 1)) | |||
| def reset_gradient(_, grad_input, grad_output): | |||
| grad_output = grad_input[0] * gamma / sqrt(var + eps) | |||
| return grad_output | |||
| return reset_gradient | |||
| class Gradient(Attribution): | |||
| r""" | |||
| Provides Gradient explanation method. | |||
| @@ -72,15 +54,14 @@ class Gradient(Attribution): | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import Gradient | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # init Gradient with a trained network | |||
| >>> net = resnet50(10) # please refer to model_zoo | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> | |||
| >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py | |||
| >>> net = LeNet5(10, num_channel=3) | |||
| >>> gradient = Gradient(net) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32) | |||
| >>> label = 5 | |||
| >>> saliency = gradient(inputs, label) | |||
| >>> print(saliency.shape) | |||
| """ | |||
| def __init__(self, network): | |||
| @@ -88,7 +69,6 @@ class Gradient(Attribution): | |||
| self._backward_model = deepcopy(network) | |||
| self._backward_model.set_train(False) | |||
| self._backward_model.set_grad(False) | |||
| self._hook_bn() | |||
| self._grad_net = GradNet(self._backward_model) | |||
| self._aggregation_fn = abs_max | |||
| @@ -103,22 +83,19 @@ class Gradient(Attribution): | |||
| saliency = self._aggregation_fn(gradient) | |||
| return saliency | |||
| def _hook_bn(self): | |||
| """Hook BatchNorm layer for `self._backward_model.`""" | |||
| for _, cell in self._backward_model.cells_and_names(): | |||
| if isinstance(cell, nn.BatchNorm2d): | |||
| cache = (cell.moving_variance, cell.gamma, cell.eps) | |||
| cell.register_backward_hook(_get_hook("2d", cache=cache)) | |||
| elif isinstance(cell, nn.BatchNorm1d): | |||
| cache = (cell.moving_variance, cell.gamma, cell.eps) | |||
| cell.register_backward_hook(_get_hook("1d", cache=cache)) | |||
| @staticmethod | |||
| def _verify_data(inputs, targets): | |||
| """Verify the validity of the parsed inputs.""" | |||
| """ | |||
| Verify the validity of the parsed inputs. | |||
| Args: | |||
| inputs (Tensor): The inputs to be explained. | |||
| targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer. | |||
| If it is a 1D tensor, its length should be the same as `inputs`. | |||
| """ | |||
| check_value_type('inputs', inputs, Tensor) | |||
| if len(inputs.shape) != 4: | |||
| raise ValueError('Argument inputs must be 4D Tensor') | |||
| raise ValueError(f'Argument inputs must be 4D Tensor. But got {len(inputs.shape)}D Tensor.') | |||
| check_value_type('targets', targets, (Tensor, int)) | |||
| if isinstance(targets, Tensor): | |||
| if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != len(inputs)): | |||
| @@ -109,16 +109,14 @@ class Deconvolution(ModifiedReLU): | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import Deconvolution | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # init Deconvolution with a trained network. | |||
| >>> net = resnet50(10) # please refer to model_zoo | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py | |||
| >>> net = LeNet5(10, num_channel=3) | |||
| >>> deconvolution = Deconvolution(net) | |||
| >>> # parse data and the target label to be explained and get the saliency map | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32) | |||
| >>> label = 5 | |||
| >>> saliency = deconvolution(inputs, label) | |||
| >>> print(saliency.shape) | |||
| """ | |||
| def __init__(self, network): | |||
| @@ -154,17 +152,15 @@ class GuidedBackprop(ModifiedReLU): | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> from mindspore.explainer.explanation import GuidedBackprop | |||
| >>> # init GuidedBackprop with a trained network. | |||
| >>> net = resnet50(10) # please refer to model_zoo | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py | |||
| >>> net = LeNet5(10, num_channel=3) | |||
| >>> gbp = GuidedBackprop(net) | |||
| >>> # parse data and the target label to be explained and get the saliency map | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> # feed data and the target label to be explained and get the saliency map | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32) | |||
| >>> label = 5 | |||
| >>> saliency = gbp(inputs, label) | |||
| >>> print(saliency.shape) | |||
| """ | |||
| def __init__(self, network): | |||
| @@ -130,7 +130,7 @@ class AblationWithSaliency(Ablation): | |||
| Generate mask for perturbations based on given saliency ranks. | |||
| Args: | |||
| saliency (np.ndarray): Perturbing masks will be generated based on the given saliency map. The shape of | |||
| saliency (numpy.array): Perturbing masks will be generated based on the given saliency map. The shape of | |||
| saliency is expected to be: [batch_size, optional(num_channels), *spatial_size]. If multi-channel | |||
| saliency is provided, an averaged saliency will be taken to calculate pixel order in spatial dimension. | |||
| num_channels (optional[int]): Number of channels of the input data. In order to match the shape of inputs, | |||
| @@ -139,7 +139,7 @@ class AblationWithSaliency(Ablation): | |||
| no channel dimension. Default: None. | |||
| Return: | |||
| mask (np.ndarray): boolen mask for generate perturbations. | |||
| numpy.array, boolean masks for perturbation generation. | |||
| """ | |||
| batch_size = saliency.shape[0] | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -71,17 +71,15 @@ class Occlusion(PerturbationAttribution): | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import Occlusion | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. | |||
| >>> network = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(network, param_dict) | |||
| >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py | |||
| >>> net = LeNet5(10, num_channel=3) | |||
| >>> # initialize Occlusion explainer with the pretrained model and activation function | |||
| >>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities | |||
| >>> occlusion = Occlusion(network, activation_fn=activation_fn) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> occlusion = Occlusion(net, activation_fn=activation_fn) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32) | |||
| >>> label = ms.Tensor([1], ms.int32) | |||
| >>> saliency = occlusion(input_x, label) | |||
| >>> print(saliency.shape) | |||
| """ | |||
| def __init__(self, network, activation_fn, perturbation_per_eval=32): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -62,23 +62,22 @@ class RISE(PerturbationAttribution): | |||
| >>> import numpy as np | |||
| >>> import mindspore as ms | |||
| >>> from mindspore.explainer.explanation import RISE | |||
| >>> from mindspore.nn import Sigmoid | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # prepare your network and load the trained checkpoint file, e.g., resnet50. | |||
| >>> network = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(network, param_dict) | |||
| >>> | |||
| >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py | |||
| >>> net = LeNet5(10, num_channel=3) | |||
| >>> # initialize RISE explainer with the pretrained model and activation function | |||
| >>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities | |||
| >>> rise = RISE(network, activation_fn=activation_fn) | |||
| >>> rise = RISE(net, activation_fn=activation_fn) | |||
| >>> # given an instance of RISE, saliency map can be generate | |||
| >>> inputs = ms.Tensor(np.random.rand(2, 3, 224, 224), ms.float32) | |||
| >>> inputs = ms.Tensor(np.random.rand(2, 3, 32, 32), ms.float32) | |||
| >>> # when `targets` is an integer | |||
| >>> targets = 5 | |||
| >>> saliency = rise(inputs, targets) | |||
| >>> print(saliency.shape) | |||
| >>> # `targets` can also be a 2D tensor | |||
| >>> targets = ms.Tensor([[5], [1]], ms.int32) | |||
| >>> saliency = rise(inputs, targets) | |||
| >>> print(saliency.shape) | |||
| """ | |||
| def __init__(self, | |||
| @@ -88,7 +87,7 @@ class RISE(PerturbationAttribution): | |||
| super(RISE, self).__init__(network, activation_fn, perturbation_per_eval) | |||
| self._num_masks = 6000 # number of masks to be sampled | |||
| self._mask_probability = 0.2 # ratio of inputs to be masked | |||
| self._mask_probability = 0.5 # ratio of inputs to be masked | |||
| self._down_sample_size = 10 # the original size of binary masks | |||
| self._resize_mode = 'bilinear' # mode choice to resize the down-sized binary masks to size of the inputs | |||
| self._perturbation_mode = 'constant' # setting the perturbed pixels to a constant value | |||
| @@ -127,7 +126,9 @@ class RISE(PerturbationAttribution): | |||
| self._num_classes = num_classes | |||
| # Due to the unsupported Op of slice assignment, we use numpy array here | |||
| attr_np = np.zeros(shape=(batch_size, self._num_classes, height, width)) | |||
| targets = self._unify_targets(inputs, targets) | |||
| attr_np = np.zeros(shape=(batch_size, targets.shape[1], height, width)) | |||
| cal_times = math.ceil(self._num_masks / self._perturbation_per_eval) | |||
| @@ -143,24 +144,21 @@ class RISE(PerturbationAttribution): | |||
| weights = self._activation_fn(self.network(masked_input)) | |||
| while len(weights.shape) > 2: | |||
| weights = op.mean(weights, axis=2) | |||
| weights = op.reshape(weights, | |||
| (bs, self._num_classes, 1, 1)) | |||
| attr_np[idx] += op.summation(weights * masks, axis=0).asnumpy() | |||
| weights = np.expand_dims(np.expand_dims(weights.asnumpy()[:, targets[idx]], 2), 3) | |||
| attr_np = attr_np / self._num_masks | |||
| targets = self._unify_targets(inputs, targets) | |||
| attr_np[idx] += np.sum(weights * masks.asnumpy(), axis=0) | |||
| attr_classes = [att_i[target] for att_i, target in zip(attr_np, targets)] | |||
| attr_np = attr_np / self._num_masks | |||
| return op.Tensor(attr_classes, dtype=inputs.dtype) | |||
| return op.Tensor(attr_np, dtype=inputs.dtype) | |||
| @staticmethod | |||
| def _verify_data(inputs, targets): | |||
| """Verify the validity of the parsed inputs.""" | |||
| check_value_type('inputs', inputs, Tensor) | |||
| if len(inputs.shape) != 4: | |||
| raise ValueError('Argument inputs must be 4D Tensor') | |||
| raise ValueError(f'Argument inputs must be 4D Tensor, but got {len(inputs.shape)}D Tensor.') | |||
| check_value_type('targets', targets, (Tensor, int, tuple, list)) | |||
| if isinstance(targets, Tensor): | |||
| if len(targets.shape) > 2: | |||
| @@ -168,7 +166,7 @@ class RISE(PerturbationAttribution): | |||
| 'But got {}D.'.format(len(targets.shape))) | |||
| if targets.shape and len(targets) != len(inputs): | |||
| raise ValueError( | |||
| 'If `targets` is a 2D, 1D Tensor, it should have the same length as inputs {}. But got {}'.format( | |||
| 'If `targets` is a 2D, 1D Tensor, it should have the same length as inputs {}. But got {}.'.format( | |||
| len(inputs), len(targets))) | |||
| @staticmethod | |||