| @@ -77,16 +77,16 @@ class ImageClassificationRunner: | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10 | |||
| >>> dataset = get_dataset('/path/to/Cifar10_dataset') | |||
| >>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'turck'] | |||
| >>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |||
| >>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10 | |||
| >>> param_dict = load_checkpoint("checkpoint.ckpt") | |||
| >>> net = resnet50(len(classes)) | |||
| >>> net = resnet50(len(labels)) | |||
| >>> activation_fn = Softmax() | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> gbp = GuidedBackprop(net) | |||
| >>> gradient = Gradient(net) | |||
| >>> explainers = [gbp, gradient] | |||
| >>> faithfulness = Faithfulness(len(labels), "NaiveFaithfulness", activation_fn) | |||
| >>> faithfulness = Faithfulness(len(labels), activation_fn, "NaiveFaithfulness") | |||
| >>> benchmarkers = [faithfulness] | |||
| >>> runner = ImageClassificationRunner("./summary_dir", (dataset, labels), net, activation_fn) | |||
| >>> runner.register_saliency(explainers=explainers, benchmarkers=benchmarkers) | |||
| @@ -16,6 +16,7 @@ | |||
| import numpy as np | |||
| from mindspore.explainer.explanation import RISE | |||
| from .metric import LabelAgnosticMetric | |||
| from ... import _operators as ops | |||
| from ..._utils import calc_correlation | |||
| @@ -55,7 +56,7 @@ class ClassSensitivity(LabelAgnosticMetric): | |||
| >>> # prepare your explainer to be evaluated, e.g., Gradient. | |||
| >>> gradient = Gradient(network) | |||
| >>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> # class_sensitivity is a ClassSensitivity instance | |||
| >>> class_sensitivity = ClassSensitivity() | |||
| >>> res = class_sensitivity.evaluate(gradient, input_x) | |||
| """ | |||
| self._check_evaluate_param(explainer, inputs) | |||
| @@ -64,9 +65,14 @@ class ClassSensitivity(LabelAgnosticMetric): | |||
| max_confidence_label = ops.argmax(outputs) | |||
| min_confidence_label = ops.argmin(outputs) | |||
| max_confidence_saliency = explainer(inputs, max_confidence_label).asnumpy() | |||
| min_confidence_saliency = explainer(inputs, min_confidence_label).asnumpy() | |||
| if isinstance(explainer, RISE): | |||
| labels = ops.stack([max_confidence_label, min_confidence_label], axis=1) | |||
| full_saliency = explainer(inputs, labels) | |||
| max_confidence_saliency = full_saliency[:, max_confidence_label].asnumpy() | |||
| min_confidence_saliency = full_saliency[:, min_confidence_label].asnumpy() | |||
| else: | |||
| max_confidence_saliency = explainer(inputs, max_confidence_label).asnumpy() | |||
| min_confidence_saliency = explainer(inputs, min_confidence_label).asnumpy() | |||
| correlations = [] | |||
| for i in range(inputs.shape[0]): | |||
| @@ -14,11 +14,9 @@ | |||
| # ============================================================================ | |||
| """RISE.""" | |||
| import math | |||
| import random | |||
| import numpy as np | |||
| from mindspore.ops.operations import Concat | |||
| from mindspore import Tensor | |||
| from mindspore.train._utils import check_value_type | |||
| @@ -107,18 +105,13 @@ class RISE(PerturbationAttribution): | |||
| up_size = (height + mask_size[0], width + mask_size[1]) | |||
| mask = np.random.random((batch_size, 1) + mask_size) < self._mask_probability | |||
| upsample = resize(op.Tensor(mask, data.dtype), up_size, | |||
| self._resize_mode) | |||
| # Pack operator not available for GPU, thus transfer to numpy first | |||
| masks_lst = [] | |||
| for sample in upsample: | |||
| shift_x = random.randint(0, mask_size[0]) | |||
| shift_y = random.randint(0, mask_size[1]) | |||
| masks_lst.append(sample[:, shift_x: shift_x + height, shift_y:shift_y + width]) | |||
| concat = Concat() | |||
| masks = concat(tuple(masks_lst)) | |||
| masks = op.reshape(masks, (batch_size, -1, height, width)) | |||
| self._resize_mode).asnumpy() | |||
| shift_x = np.random.randint(0, mask_size[0] + 1, size=batch_size) | |||
| shift_y = np.random.randint(0, mask_size[1] + 1, size=batch_size) | |||
| masks = [sample[:, x_i: x_i + height, y_i: y_i + width] for sample, x_i, y_i | |||
| in zip(upsample, shift_x, shift_y)] | |||
| masks = Tensor(np.array(masks), data.dtype) | |||
| return masks | |||
| def __call__(self, inputs, targets): | |||
| @@ -157,11 +150,8 @@ class RISE(PerturbationAttribution): | |||
| attr_np = attr_np / self._num_masks | |||
| targets = self._unify_targets(inputs, targets) | |||
| attr_classes = [] | |||
| for idx, target in enumerate(targets): | |||
| attr_np_idx = attr_np[idx] | |||
| attr_idx = attr_np_idx[target] | |||
| attr_classes.append(attr_idx) | |||
| attr_classes = [att_i[target] for att_i, target in zip(attr_np, targets)] | |||
| return op.Tensor(attr_classes, dtype=inputs.dtype) | |||