From: @lixiaohui33 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -18,9 +18,8 @@ from typing import List, Tuple, Union, Callable | |||
| import numpy as np | |||
| import mindspore | |||
| from mindspore import nn | |||
| import mindspore.ops.operations as op | |||
| from mindspore import nn | |||
| _Axis = Union[int, Tuple[int, ...], List[int]] | |||
| _Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]] | |||
| @@ -235,7 +234,7 @@ def randint(low: int, high: int, shape: _Shape, dtype: mindspore.dtype = mindspo | |||
| return outputs | |||
| def softmax(axis: int) -> Callable: | |||
| def softmax(axis: int = -1) -> Callable: | |||
| """Softmax activation function.""" | |||
| func = nn.Softmax(axis=axis) | |||
| return func | |||
| @@ -20,20 +20,23 @@ from time import time | |||
| from typing import Tuple, List, Optional | |||
| import numpy as np | |||
| from scipy.stats import beta | |||
| from PIL import Image | |||
| from scipy.stats import beta | |||
| from mindspore.train._utils import check_value_type | |||
| from mindspore.train.summary_pb2 import Explain | |||
| import mindspore as ms | |||
| import mindspore.dataset as ds | |||
| from mindspore import log | |||
| from mindspore.nn import Softmax, Cell | |||
| from mindspore.nn.probability.toolbox import UncertaintyEvaluation | |||
| from mindspore.ops.operations import ExpandDims | |||
| from mindspore.train._utils import check_value_type | |||
| from mindspore.train.summary._summary_adapter import _convert_image_format | |||
| from mindspore.train.summary.summary_record import SummaryRecord | |||
| from mindspore.nn.probability.toolbox import UncertaintyEvaluation | |||
| from mindspore.train.summary_pb2 import Explain | |||
| from .benchmark import Localization | |||
| from .benchmark._attribution.metric import AttributionMetric | |||
| from .explanation import RISE | |||
| from .explanation._attribution._attribution import Attribution | |||
| # datafile directory names | |||
| @@ -43,8 +46,8 @@ _HEATMAP_DIRNAME = "heatmap" | |||
| # max. no. of sample per directory | |||
| _SAMPLE_PER_DIR = 1000 | |||
| _EXPAND_DIMS = ExpandDims() | |||
| _SEED = 58 # set a seed to fix the iterating order of the dataset | |||
| def _normalize(img_np): | |||
| @@ -57,7 +60,7 @@ def _normalize(img_np): | |||
| def _np_to_image(img_np, mode): | |||
| """Convert numpy array to PIL image.""" | |||
| return Image.fromarray(np.uint8(img_np*255), mode=mode) | |||
| return Image.fromarray(np.uint8(img_np * 255), mode=mode) | |||
| def _calc_prob_interval(volume, probs, prob_vars): | |||
| @@ -89,7 +92,7 @@ def _calc_prob_interval(volume, probs, prob_vars): | |||
| def _get_id_dirname(sample_id: int): | |||
| """Get the name of parent directory of the image id.""" | |||
| return str(int(sample_id/_SAMPLE_PER_DIR)*_SAMPLE_PER_DIR) | |||
| return str(int(sample_id / _SAMPLE_PER_DIR) * _SAMPLE_PER_DIR) | |||
| def _extract_timestamp(filename: str): | |||
| @@ -107,6 +110,9 @@ class ExplainRunner: | |||
| After generating results with the explanation methods and the evaluation methods, the results will be written into | |||
| a specified file with `mindspore.summary.SummaryRecord`. The stored content can be viewed using MindInsight. | |||
| Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version | |||
| will be deprecated and will not be supported in MindInsight of current version. | |||
| Args: | |||
| summary_dir (str, optional): The directory path to save the summary files which store the generated results. | |||
| Default: "./" | |||
| @@ -131,7 +137,8 @@ class ExplainRunner: | |||
| dataset: Tuple, | |||
| explainers: List, | |||
| benchmarkers: Optional[List] = None, | |||
| uncertainty: Optional[UncertaintyEvaluation] = None): | |||
| uncertainty: Optional[UncertaintyEvaluation] = None, | |||
| activation_fn: Optional[Cell] = Softmax()): | |||
| """ | |||
| Genereates results and writes results into the summary files in `summary_dir` specified during the object | |||
| initialization. | |||
| @@ -149,8 +156,12 @@ class ExplainRunner: | |||
| Default: None | |||
| uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference | |||
| uncertainty of samples. | |||
| activation_fn (Cell, optional): The activation layer that transforms the output of the network to | |||
| label probability distribution :math:`P(y|x)`. Default: Softmax(). | |||
| Examples: | |||
| >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient | |||
| >>> from mindspore.nn import Sigmoid | |||
| >>> # obtain dataset object | |||
| >>> dataset = get_dataset() | |||
| >>> classes = ["cat", "dog", ...] | |||
| @@ -158,13 +169,11 @@ class ExplainRunner: | |||
| >>> param_dict = load_checkpoint("checkpoint.ckpt") | |||
| >>> net = resnet50(len(classes)) | |||
| >>> load_parama_into_net(net, param_dict) | |||
| >>> # bind net with its output activation | |||
| >>> model = nn.SequentialCell([net, nn.Sigmoid()]) | |||
| >>> gbp = GuidedBackprop(model) | |||
| >>> gradient = Gradient(model) | |||
| >>> gbp = GuidedBackprop(net) | |||
| >>> gradient = Gradient(net) | |||
| >>> runner = ExplainRunner("./") | |||
| >>> explainers = [gbp, gradient] | |||
| >>> runner.run((dataset, classes), explainers) | |||
| >>> runner.run((dataset, classes), explainers, activation_fn=Sigmoid()) | |||
| """ | |||
| check_value_type("dataset", dataset, tuple) | |||
| @@ -181,16 +190,17 @@ class ExplainRunner: | |||
| for exp in explainers: | |||
| if not isinstance(exp, Attribution): | |||
| raise TypeError("Argument explainers should be a list of objects of classes in " | |||
| raise TypeError("Argument `explainers` should be a list of objects of classes in " | |||
| "`mindspore.explainer.explanation`.") | |||
| if benchmarkers is not None: | |||
| check_value_type("benchmarkers", benchmarkers, list) | |||
| for bench in benchmarkers: | |||
| if not isinstance(bench, AttributionMetric): | |||
| raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation" | |||
| raise TypeError("Argument `benchmarkers` should be a list of objects of classes in explanation" | |||
| "`mindspore.explainer.benchmark`.") | |||
| check_value_type("activation_fn", activation_fn, Cell) | |||
| self._model = explainers[0].model | |||
| self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn]) | |||
| next_element = dataset.create_tuple_iterator().get_next() | |||
| inputs, _, _ = self._unpack_next_element(next_element) | |||
| prop_test = self._model(inputs) | |||
| @@ -211,9 +221,10 @@ class ExplainRunner: | |||
| self._uncertainty = None | |||
| with SummaryRecord(self._summary_dir) as summary: | |||
| spacer = '{:120}\r' | |||
| print("Start running and writing......") | |||
| begin = time() | |||
| print("Start writing metadata.") | |||
| print("Start writing metadata......") | |||
| self._summary_timestamp = _extract_timestamp(summary.event_file_name) | |||
| if self._summary_timestamp is None: | |||
| @@ -234,42 +245,47 @@ class ExplainRunner: | |||
| print("Finish writing metadata.") | |||
| now = time() | |||
| print("Start running and writing inference data......") | |||
| print("Start running and writing inference data.....") | |||
| imageid_labels = self._run_inference(dataset, summary) | |||
| print("Finish running and writing inference data. Time elapsed: {}s".format(time() - now)) | |||
| print(spacer.format("Finish running and writing inference data. " | |||
| "Time elapsed: {:.3f} s".format(time() - now))) | |||
| if benchmarkers is None or not benchmarkers: | |||
| for exp in explainers: | |||
| start = time() | |||
| print("Start running and writing explanation data for {}......".format(exp.__class__.__name__)) | |||
| self._count = 0 | |||
| ds.config.set_seed(58) | |||
| ds.config.set_seed(_SEED) | |||
| for idx, next_element in enumerate(dataset): | |||
| now = time() | |||
| self._run_exp_step(next_element, exp, imageid_labels, summary) | |||
| print("Finish writing {}-th explanation data. Time elapsed: {}".format( | |||
| idx, time() - now)) | |||
| print("Finish running and writing explanation data for {}. Time elapsed: {}".format( | |||
| exp.__class__.__name__, time() - start)) | |||
| print(spacer.format("Finish writing {}-th explanation data for {}. Time elapsed: " | |||
| "{:.3f} s".format(idx, time() - now, exp.__class__.__name__)), end='') | |||
| print(spacer.format( | |||
| "Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format( | |||
| exp.__class__.__name__, time() - start))) | |||
| else: | |||
| for exp in explainers: | |||
| explain = Explain() | |||
| for bench in benchmarkers: | |||
| bench.reset() | |||
| print(f"Start running and writing explanation and benchmark data for {exp.__class__.__name__}.") | |||
| print(f"Start running and writing explanation and " | |||
| f"benchmark data for {exp.__class__.__name__}......") | |||
| self._count = 0 | |||
| start = time() | |||
| ds.config.set_seed(58) | |||
| ds.config.set_seed(_SEED) | |||
| for idx, next_element in enumerate(dataset): | |||
| now = time() | |||
| saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary) | |||
| print("Finish writing {}-th batch explanation data. Time elapsed: {}s".format( | |||
| idx, time() - now)) | |||
| print(spacer.format( | |||
| "Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format( | |||
| idx, exp.__class__.__name__, time() - now)), end='') | |||
| for bench in benchmarkers: | |||
| now = time() | |||
| self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst) | |||
| print("Finish running {}-th batch benchmark data for {}. Time elapsed: {}s".format( | |||
| idx, bench.__class__.__name__, time() - now)) | |||
| print(spacer.format( | |||
| "Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format( | |||
| idx, bench.__class__.__name__, exp.__class__.__name__, time() - now)), end='') | |||
| for bench in benchmarkers: | |||
| benchmark = explain.benchmark.add() | |||
| @@ -279,11 +295,11 @@ class ExplainRunner: | |||
| benchmark.total_score = bench.performance | |||
| benchmark.label_score.extend(bench.class_performances) | |||
| print("Finish running and writing explanation and benchmark data for {}. " | |||
| "Time elapsed: {}s".format(exp.__class__.__name__, time() - start)) | |||
| print(spacer.format("Finish running and writing explanation and benchmark data for {}. " | |||
| "Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start))) | |||
| summary.add_value('explainer', 'benchmark', explain) | |||
| summary.record(1) | |||
| print("Finish running and writing. Total time elapsed: {}s".format(time() - begin)) | |||
| print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin)) | |||
| @staticmethod | |||
| def _verify_data_form(dataset, benchmarkers): | |||
| @@ -446,8 +462,9 @@ class ExplainRunner: | |||
| Returns: | |||
| imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels. | |||
| """ | |||
| spacer = '{:120}\r' | |||
| imageid_labels = {} | |||
| ds.config.set_seed(58) | |||
| ds.config.set_seed(_SEED) | |||
| self._count = 0 | |||
| for j, next_element in enumerate(dataset): | |||
| now = time() | |||
| @@ -516,7 +533,9 @@ class ExplainRunner: | |||
| summary.record(1) | |||
| self._count += 1 | |||
| print("Finish running and writing {}-th batch inference data. Time elapsed: {}s".format(j, time() - now)) | |||
| print(spacer.format("Finish running and writing {}-th batch inference data." | |||
| " Time elapsed: {:.3f} s".format(j, time() - now)), | |||
| end='') | |||
| return imageid_labels | |||
| def _run_exp_step(self, next_element, explainer, imageid_labels, summary): | |||
| @@ -543,18 +562,22 @@ class ExplainRunner: | |||
| batch_unions = self._make_label_batch(unions) | |||
| saliency_dict_lst = [] | |||
| batch_saliency_full = [] | |||
| for i in range(len(batch_unions[0])): | |||
| batch_saliency = explainer(inputs, batch_unions[:, i]) | |||
| batch_saliency_full.append(batch_saliency) | |||
| if isinstance(explainer, RISE): | |||
| batch_saliency_full = explainer(inputs, batch_unions) | |||
| else: | |||
| batch_saliency_full = [] | |||
| for i in range(len(batch_unions[0])): | |||
| batch_saliency = explainer(inputs, batch_unions[:, i]) | |||
| batch_saliency_full.append(batch_saliency) | |||
| concat = ms.ops.operations.Concat(1) | |||
| batch_saliency_full = concat(tuple(batch_saliency_full)) | |||
| for idx, union in enumerate(unions): | |||
| saliency_dict = {} | |||
| explain = Explain() | |||
| explain.sample_id = self._count | |||
| for k, lab in enumerate(union): | |||
| saliency = batch_saliency_full[k][idx:idx + 1] | |||
| saliency = batch_saliency_full[idx:idx + 1, k:k + 1] | |||
| saliency_dict[lab] = saliency | |||
| saliency_np = _normalize(saliency.asnumpy().squeeze()) | |||
| @@ -600,7 +623,7 @@ class ExplainRunner: | |||
| def _save_original_image(self, sample_id: int, image): | |||
| """Save an image to summary directory.""" | |||
| id_dirname = _get_id_dirname(sample_id) | |||
| relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp), | |||
| relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp), | |||
| _ORIGINAL_IMAGE_DIRNAME, | |||
| id_dirname) | |||
| os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True) | |||
| @@ -613,7 +636,7 @@ class ExplainRunner: | |||
| def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image): | |||
| """Save heatmap image to summary directory.""" | |||
| id_dirname = _get_id_dirname(sample_id) | |||
| relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp), | |||
| relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp), | |||
| _HEATMAP_DIRNAME, | |||
| explain_method, | |||
| id_dirname) | |||
| @@ -21,6 +21,7 @@ from scipy.ndimage.filters import gaussian_filter | |||
| from mindspore import log | |||
| import mindspore as ms | |||
| from mindspore.train._utils import check_value_type | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as op | |||
| from .metric import AttributionMetric | |||
| @@ -140,9 +141,7 @@ class Perturb: | |||
| @staticmethod | |||
| def _assign(x: _Array, y: _Array, masks: _Array): | |||
| """Assign values to perturb pixels on perturbations.""" | |||
| if masks.dtype != bool: | |||
| raise TypeError('The param "masks" should be an array of bool, but receive {}' | |||
| .format(masks.dtype)) | |||
| check_value_type("masks dtype", masks.dtype, type(np.dtype(bool))) | |||
| for i in range(x.shape[0]): | |||
| x[i][:, masks[i]] = y[:, masks[i]] | |||
| @@ -336,8 +335,7 @@ class NaiveFaithfulness(_FaithfulnessHelper): | |||
| if not np.count_nonzero(saliency): | |||
| log.warning("The saliency map is zero everywhere. The correlation will be set to zero.") | |||
| correlation = 0 | |||
| normalized_faithfulness = (correlation + 1) / 2 | |||
| return np.array([normalized_faithfulness], np.float) | |||
| return np.array([correlation], np.float) | |||
| reference = self._get_reference(inputs) | |||
| perturbations, masks = self._perturb( | |||
| inputs, saliency, reference, return_mask=True) | |||
| @@ -347,8 +345,7 @@ class NaiveFaithfulness(_FaithfulnessHelper): | |||
| predictions = model(perturbations).asnumpy()[:, targets] | |||
| faithfulness = calc_correlation(feature_importance, predictions) | |||
| normalized_faithfulness = (faithfulness + 1) / 2 | |||
| return np.array([normalized_faithfulness], np.float) | |||
| return np.array([faithfulness], np.float) | |||
| class DeletionAUC(_FaithfulnessHelper): | |||
| @@ -533,6 +530,8 @@ class Faithfulness(AttributionMetric): | |||
| metric (str, optional): The specifi metric to quantify faithfulness. | |||
| Options: "DeletionAUC", "InsertionAUC", "NaiveFaithfulness". | |||
| Default: 'NaiveFaithfulness'. | |||
| activation_fn (Cell, optional): The activation function that transforms the network output to a probability. | |||
| Default: nn.Softmax(). | |||
| Examples: | |||
| >>> from mindspore.explainer.benchmark import Faithfulness | |||
| @@ -543,7 +542,7 @@ class Faithfulness(AttributionMetric): | |||
| """ | |||
| _methods = [NaiveFaithfulness, DeletionAUC, InsertionAUC] | |||
| def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness"): | |||
| def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness", activation_fn=nn.Softmax()): | |||
| super(Faithfulness, self).__init__(num_labels) | |||
| perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument | |||
| @@ -552,6 +551,7 @@ class Faithfulness(AttributionMetric): | |||
| num_perturb_steps = 100 # separate the perturbation progress in to 100 steps. | |||
| base_value = 0.0 # the pixel value set for the perturbed pixels | |||
| self._activation_fn = activation_fn | |||
| self._verify_metrics(metric) | |||
| for method in self._methods: | |||
| if metric == method.__name__: | |||
| @@ -568,9 +568,7 @@ class Faithfulness(AttributionMetric): | |||
| Evaluate faithfulness on a single data sample. | |||
| Note: | |||
| To apply `Faithfulness` to evaluate an explainer, this explainer must be initialized with a network that | |||
| contains the output activation function. Otherwise, the results will not be correct. Currently only single | |||
| sample (:math:`N=1`) at each call is supported. | |||
| Currently only single sample (:math:`N=1`) at each call is supported. | |||
| Args: | |||
| explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`. | |||
| @@ -586,7 +584,7 @@ class Faithfulness(AttributionMetric): | |||
| Examples: | |||
| >>> # init an explainer, the network should contain the output activation function. | |||
| >>> network = nn.SequentialCell([resnet50, nn.Sigmoid()]) | |||
| >>> network = resnet50(20) | |||
| >>> gradient = Gradient(network) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> targets = 5 | |||
| @@ -610,10 +608,10 @@ class Faithfulness(AttributionMetric): | |||
| saliency = saliency.squeeze() | |||
| if len(saliency.shape) != 2: | |||
| raise ValueError('Squeezed saliency map is expected to 2D, but receive {}.'.format(len(saliency.shape))) | |||
| faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=explainer.model, | |||
| model = nn.SequentialCell([explainer.model, self._activation_fn]) | |||
| faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=model, | |||
| targets=targets, saliency=saliency) | |||
| return faithfulness | |||
| return (1 + faithfulness) / 2 | |||
| def _verify_metrics(self, metric: str): | |||
| supports = [x.__name__ for x in self._methods] | |||
| @@ -17,10 +17,12 @@ | |||
| from ._attribution._backprop.gradcam import GradCAM | |||
| from ._attribution._backprop.gradient import Gradient | |||
| from ._attribution._backprop.modified_relu import Deconvolution, GuidedBackprop | |||
| from ._attribution._perturbation.rise import RISE | |||
| __all__ = [ | |||
| 'Gradient', | |||
| 'Deconvolution', | |||
| 'GuidedBackprop', | |||
| 'GradCAM', | |||
| 'RISE' | |||
| ] | |||
| @@ -16,10 +16,12 @@ | |||
| from ._backprop.gradcam import GradCAM | |||
| from ._backprop.gradient import Gradient | |||
| from ._backprop.modified_relu import Deconvolution, GuidedBackprop | |||
| from ._perturbation.rise import RISE | |||
| __all__ = [ | |||
| 'Gradient', | |||
| 'Deconvolution', | |||
| 'GuidedBackprop', | |||
| 'GradCAM', | |||
| 'RISE' | |||
| ] | |||
| @@ -16,33 +16,25 @@ | |||
| from typing import Callable | |||
| import mindspore as ms | |||
| from mindspore.train._utils import check_value_type | |||
| from mindspore.nn import Cell | |||
| class Attribution: | |||
| r""" | |||
| """ | |||
| Basic class of attributing the salient score | |||
| The explainers which explanation through attributing the relevance scores | |||
| should inherit this class. | |||
| The explainers which explanation through attributing the relevance scores should inherit this class. | |||
| Args: | |||
| network (ms.nn.Cell): The black-box model to explanation. | |||
| network (Cell): The black-box model to explain. | |||
| """ | |||
| def __init__(self, network): | |||
| self._verify_model(network) | |||
| check_value_type("network", network, Cell) | |||
| self._model = network | |||
| self._model.set_train(False) | |||
| self._model.set_grad(False) | |||
| @staticmethod | |||
| def _verify_model(model): | |||
| """ | |||
| Verify the input `network` for __init__ function. | |||
| """ | |||
| if not isinstance(model, ms.nn.Cell): | |||
| raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.") | |||
| __call__: Callable | |||
| """ | |||
| @@ -13,7 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ GradCAM and GuidedGradCAM. """ | |||
| """GradCAM.""" | |||
| from mindspore.ops import operations as op | |||
| @@ -98,7 +98,7 @@ class GradCAM(IntermediateLayerAttribution): | |||
| """ | |||
| Hook function to deal with the backward gradient. | |||
| The arguments are set as required by Cell.register_back_hook | |||
| The arguments are set as required by `Cell.register_backward_hook`. | |||
| """ | |||
| self._intermediate_grad = grad_input | |||
| @@ -0,0 +1,19 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ Perturbation-based _attribution explainer. """ | |||
| from .rise import RISE | |||
| __all__ = ['RISE'] | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Base class `PerturbationAttribtuion`""" | |||
| from mindspore.train._utils import check_value_type | |||
| from mindspore.nn import Cell | |||
| from .._attribution import Attribution | |||
| from ...._operators import softmax | |||
| class PerturbationAttribution(Attribution): | |||
| """ | |||
| Base class for perturbation-based attribution methods. | |||
| All perturbation-based _attribution methods extend from this class. | |||
| """ | |||
| def __init__(self, | |||
| network, | |||
| activation_fn=softmax(), | |||
| ): | |||
| super(PerturbationAttribution, self).__init__(network) | |||
| check_value_type("activation_fn", activation_fn, Cell) | |||
| self._activation_fn = activation_fn | |||
| @@ -0,0 +1,194 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """RISE.""" | |||
| import math | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore import nn | |||
| from mindspore.train._utils import check_value_type | |||
| from .perturbation import PerturbationAttribution | |||
| from .... import _operators as op | |||
| from ...._utils import resize | |||
| class RISE(PerturbationAttribution): | |||
| r""" | |||
| RISE: Randomized Input Sampling for Explanation of Black-box Model. | |||
| RISE is a perturbation-based method that generates attribution maps by sampling on multiple random binary masks. | |||
| The original image is randomly masked, and then fed into the black-box model to get predictions. The final | |||
| attribution map is the weighted sum of these random masks, with the weights being the corresponding output on the | |||
| node of interest: | |||
| .. math:: | |||
| E_{RISE}(I, f)_c = \sum_{i}f_c(I\odot M_i) M_i | |||
| For more details, please refer to the original paper via: `RISE <https://arxiv.org/abs/1806.07421>`_. | |||
| Args: | |||
| network (Cell): The black-box model to be explained. | |||
| activation_fn (Cell, optional): The activation layer that transforms logits to prediction probabilities. For | |||
| single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, | |||
| `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as | |||
| when combining this function with network, the final output is the probability of the input. | |||
| Default: `nn.Softmax`. | |||
| perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the | |||
| perturbed samples. Default: 32. | |||
| Examples: | |||
| >>> from mindspore.explainer.explanation import RISE | |||
| >>> net = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> # init RISE with specified activation function | |||
| >>> rise = RISE(net, activation_fn=nn.layer.Sigmoid()) | |||
| """ | |||
| def __init__(self, | |||
| network, | |||
| activation_fn=nn.Softmax(), | |||
| perturbation_per_eval=32): | |||
| super(RISE, self).__init__(network, activation_fn) | |||
| self._perturbation_per_eval = perturbation_per_eval | |||
| self._num_masks = 6000 # number of masks to be sampled | |||
| self._mask_probability = 0.2 # ratio of inputs to be masked | |||
| self._down_sample_size = 10 # the original size of binary masks | |||
| self._resize_mode = 'bilinear' # mode choice to resize the down-sized binary masks to size of the inputs | |||
| self._perturbation_mode = 'constant' # setting the perturbed pixels to a constant value | |||
| self._base_value = 0 # setting the perturbed pixels to this constant value | |||
| self._num_classes = None # placeholder of self._num_classes just for future assignment in other methods | |||
| def _generate_masks(self, data, batch_size): | |||
| """Generate a batch of binary masks for data.""" | |||
| height, width = data.shape[2], data.shape[3] | |||
| mask_size = (self._down_sample_size, self._down_sample_size) | |||
| up_size = (height + mask_size[0], width + mask_size[1]) | |||
| mask = np.random.random((batch_size, 1) + mask_size) < self._mask_probability | |||
| upsample = resize(op.Tensor(mask, data.dtype), up_size, | |||
| self._resize_mode) | |||
| # Pack operator not available for GPU, thus transfer to numpy first | |||
| upsample_np = upsample.asnumpy() | |||
| masks_lst = [] | |||
| for sample in upsample_np: | |||
| shift_x = np.random.randint(0, mask_size[0] + 1) | |||
| shift_y = np.random.randint(0, mask_size[1] + 1) | |||
| masks_lst.append(sample[:, shift_x: shift_x + height, shift_y:shift_y + width]) | |||
| masks = op.Tensor(np.array(masks_lst), data.dtype) | |||
| return masks | |||
| def __call__(self, inputs, targets): | |||
| """ | |||
| Generates attribution maps for inputs. | |||
| Args: | |||
| inputs (Tensor): Input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||
| targets (int, Tensor): The labels of interest to be explained. When `targets` is an integer, | |||
| all of the inputs will generates attribution map w.r.t this integer. When `targets` is a tensor, it | |||
| should be of shape :math:`(N, ?)` or :math:`(N,)` :math:`()`. | |||
| Returns: | |||
| Tensor, a 4D tensor of shape :math:`(N, ?, H, W)` or :math:`(N, 1, H, W)`. | |||
| Examples: | |||
| >>> # given an instance of RISE, saliency map can be generate | |||
| >>> inputs = ms.Tensor(np.random.rand([2, 3, 224, 224]), ms.float32) | |||
| >>> # when `targets` is an integer | |||
| >>> targets = 5 | |||
| >>> saliency = rise(inputs, targets) | |||
| >>> # `targets` can also be a tensor | |||
| >>> targets = ms.Tensor([[5], [1]]) | |||
| >>> saliency = rise(inputs, targets) | |||
| >>> | |||
| """ | |||
| self._verify_data(inputs, targets) | |||
| height, width = inputs.shape[2], inputs.shape[3] | |||
| batch_size = inputs.shape[0] | |||
| if self._num_classes is None: | |||
| logits = self.model(inputs) | |||
| num_classes = logits.shape[1] | |||
| self._num_classes = num_classes | |||
| # Due to the unsupported Op of slice assignment, we use numpy array here | |||
| attr_np = np.zeros(shape=(batch_size, self._num_classes, height, width)) | |||
| cal_times = math.ceil(self._num_masks / self._perturbation_per_eval) | |||
| for idx, data in enumerate(inputs): | |||
| bg_data = data * 0 + self._base_value | |||
| for j in range(cal_times): | |||
| bs = min(self._num_masks - j * self._perturbation_per_eval, | |||
| self._perturbation_per_eval) | |||
| data = op.reshape(data, (1, -1, height, width)) | |||
| masks = self._generate_masks(data, bs) | |||
| masked_input = masks * data + (1 - masks) * bg_data | |||
| weights = self._activation_fn(self.model(masked_input)) | |||
| while len(weights.shape) > 2: | |||
| weights = op.mean(weights, axis=2) | |||
| weights = op.reshape(weights, | |||
| (bs, self._num_classes, 1, 1)) | |||
| attr_np[idx] += op.summation(weights * masks, axis=0).asnumpy() | |||
| attr_np = attr_np / self._num_masks | |||
| targets = self._unify_targets(inputs, targets) | |||
| attr_classes = [] | |||
| for idx, target in enumerate(targets): | |||
| dtype = inputs.dtype | |||
| attr_np_idx = attr_np[idx] | |||
| attr_idx = attr_np_idx[target] | |||
| attr_classes.append(attr_idx) | |||
| return op.Tensor(attr_classes, dtype=dtype) | |||
| @staticmethod | |||
| def _verify_data(inputs, targets): | |||
| """Verify the validity of the parsed inputs.""" | |||
| check_value_type('inputs', inputs, Tensor) | |||
| if len(inputs.shape) != 4: | |||
| raise ValueError('Argument inputs must be 4D Tensor') | |||
| check_value_type('targets', targets, (Tensor, int, tuple, list)) | |||
| if isinstance(targets, Tensor): | |||
| if len(targets.shape) > 2: | |||
| raise ValueError('Dimension invalid. If `targets` is a Tensor, it should be 0D, 1D or 2D. ' | |||
| 'But got {}D.'.format(len(targets.shape))) | |||
| if targets.shape and len(targets) != len(inputs): | |||
| raise ValueError( | |||
| 'If `targets` is a 2D, 1D Tensor, it should have the same length as inputs {}. But got {}'.format( | |||
| len(inputs), len(targets))) | |||
| @staticmethod | |||
| def _unify_targets(inputs, targets): | |||
| """To unify targets to be 2D numpy.ndarray.""" | |||
| if isinstance(targets, int): | |||
| return np.array([[targets] for i in inputs]).astype(np.int) | |||
| if isinstance(targets, Tensor): | |||
| if not targets.shape: | |||
| return np.array([[targets.asnumpy()] for _ in inputs]).astype(np.int) | |||
| if len(targets.shape) == 1: | |||
| return np.array([[t.asnumpy()] for t in targets]).astype(np.int) | |||
| if len(targets.shape) == 2: | |||
| return np.array([t.asnumpy() for t in targets]).astype(np.int) | |||
| return targets | |||