From 2161c85d5297ab31051652700ae370d042408ff2 Mon Sep 17 00:00:00 2001 From: lixiaohui Date: Tue, 17 Nov 2020 19:26:08 +0800 Subject: [PATCH] add rise explanation method, modify runner and faithfulness --- mindspore/explainer/_operators.py | 5 +- mindspore/explainer/_runner.py | 109 ++++++---- .../benchmark/_attribution/faithfulness.py | 28 ++- mindspore/explainer/explanation/__init__.py | 2 + .../explanation/_attribution/__init__.py | 2 + .../explanation/_attribution/_attribution.py | 20 +- .../_attribution/_backprop/gradcam.py | 4 +- .../_attribution/_perturbation/__init__.py | 19 ++ .../_perturbation/perturbation.py | 38 ++++ .../_attribution/_perturbation/rise.py | 194 ++++++++++++++++++ 10 files changed, 344 insertions(+), 77 deletions(-) create mode 100644 mindspore/explainer/explanation/_attribution/_perturbation/__init__.py create mode 100644 mindspore/explainer/explanation/_attribution/_perturbation/perturbation.py create mode 100644 mindspore/explainer/explanation/_attribution/_perturbation/rise.py diff --git a/mindspore/explainer/_operators.py b/mindspore/explainer/_operators.py index 761fc9c4ce..011daa9024 100644 --- a/mindspore/explainer/_operators.py +++ b/mindspore/explainer/_operators.py @@ -18,9 +18,8 @@ from typing import List, Tuple, Union, Callable import numpy as np import mindspore -from mindspore import nn import mindspore.ops.operations as op - +from mindspore import nn _Axis = Union[int, Tuple[int, ...], List[int]] _Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]] @@ -235,7 +234,7 @@ def randint(low: int, high: int, shape: _Shape, dtype: mindspore.dtype = mindspo return outputs -def softmax(axis: int) -> Callable: +def softmax(axis: int = -1) -> Callable: """Softmax activation function.""" func = nn.Softmax(axis=axis) return func diff --git a/mindspore/explainer/_runner.py b/mindspore/explainer/_runner.py index 51080a6ac1..252a7f25f3 100644 --- a/mindspore/explainer/_runner.py +++ b/mindspore/explainer/_runner.py @@ -20,20 +20,23 @@ from time import time from typing import Tuple, List, Optional import numpy as np -from scipy.stats import beta from PIL import Image +from scipy.stats import beta -from mindspore.train._utils import check_value_type -from mindspore.train.summary_pb2 import Explain import mindspore as ms import mindspore.dataset as ds from mindspore import log +from mindspore.nn import Softmax, Cell +from mindspore.nn.probability.toolbox import UncertaintyEvaluation from mindspore.ops.operations import ExpandDims +from mindspore.train._utils import check_value_type from mindspore.train.summary._summary_adapter import _convert_image_format from mindspore.train.summary.summary_record import SummaryRecord -from mindspore.nn.probability.toolbox import UncertaintyEvaluation +from mindspore.train.summary_pb2 import Explain + from .benchmark import Localization from .benchmark._attribution.metric import AttributionMetric +from .explanation import RISE from .explanation._attribution._attribution import Attribution # datafile directory names @@ -43,8 +46,8 @@ _HEATMAP_DIRNAME = "heatmap" # max. no. of sample per directory _SAMPLE_PER_DIR = 1000 - _EXPAND_DIMS = ExpandDims() +_SEED = 58 # set a seed to fix the iterating order of the dataset def _normalize(img_np): @@ -57,7 +60,7 @@ def _normalize(img_np): def _np_to_image(img_np, mode): """Convert numpy array to PIL image.""" - return Image.fromarray(np.uint8(img_np*255), mode=mode) + return Image.fromarray(np.uint8(img_np * 255), mode=mode) def _calc_prob_interval(volume, probs, prob_vars): @@ -89,7 +92,7 @@ def _calc_prob_interval(volume, probs, prob_vars): def _get_id_dirname(sample_id: int): """Get the name of parent directory of the image id.""" - return str(int(sample_id/_SAMPLE_PER_DIR)*_SAMPLE_PER_DIR) + return str(int(sample_id / _SAMPLE_PER_DIR) * _SAMPLE_PER_DIR) def _extract_timestamp(filename: str): @@ -107,6 +110,9 @@ class ExplainRunner: After generating results with the explanation methods and the evaluation methods, the results will be written into a specified file with `mindspore.summary.SummaryRecord`. The stored content can be viewed using MindInsight. + Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version + will be deprecated and will not be supported in MindInsight of current version. + Args: summary_dir (str, optional): The directory path to save the summary files which store the generated results. Default: "./" @@ -131,7 +137,8 @@ class ExplainRunner: dataset: Tuple, explainers: List, benchmarkers: Optional[List] = None, - uncertainty: Optional[UncertaintyEvaluation] = None): + uncertainty: Optional[UncertaintyEvaluation] = None, + activation_fn: Optional[Cell] = Softmax()): """ Genereates results and writes results into the summary files in `summary_dir` specified during the object initialization. @@ -149,8 +156,12 @@ class ExplainRunner: Default: None uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference uncertainty of samples. + activation_fn (Cell, optional): The activation layer that transforms the output of the network to + label probability distribution :math:`P(y|x)`. Default: Softmax(). + Examples: >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient + >>> from mindspore.nn import Sigmoid >>> # obtain dataset object >>> dataset = get_dataset() >>> classes = ["cat", "dog", ...] @@ -158,13 +169,11 @@ class ExplainRunner: >>> param_dict = load_checkpoint("checkpoint.ckpt") >>> net = resnet50(len(classes)) >>> load_parama_into_net(net, param_dict) - >>> # bind net with its output activation - >>> model = nn.SequentialCell([net, nn.Sigmoid()]) - >>> gbp = GuidedBackprop(model) - >>> gradient = Gradient(model) + >>> gbp = GuidedBackprop(net) + >>> gradient = Gradient(net) >>> runner = ExplainRunner("./") >>> explainers = [gbp, gradient] - >>> runner.run((dataset, classes), explainers) + >>> runner.run((dataset, classes), explainers, activation_fn=Sigmoid()) """ check_value_type("dataset", dataset, tuple) @@ -181,16 +190,17 @@ class ExplainRunner: for exp in explainers: if not isinstance(exp, Attribution): - raise TypeError("Argument explainers should be a list of objects of classes in " + raise TypeError("Argument `explainers` should be a list of objects of classes in " "`mindspore.explainer.explanation`.") if benchmarkers is not None: check_value_type("benchmarkers", benchmarkers, list) for bench in benchmarkers: if not isinstance(bench, AttributionMetric): - raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation" + raise TypeError("Argument `benchmarkers` should be a list of objects of classes in explanation" "`mindspore.explainer.benchmark`.") + check_value_type("activation_fn", activation_fn, Cell) - self._model = explainers[0].model + self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn]) next_element = dataset.create_tuple_iterator().get_next() inputs, _, _ = self._unpack_next_element(next_element) prop_test = self._model(inputs) @@ -211,9 +221,10 @@ class ExplainRunner: self._uncertainty = None with SummaryRecord(self._summary_dir) as summary: + spacer = '{:120}\r' print("Start running and writing......") begin = time() - print("Start writing metadata.") + print("Start writing metadata......") self._summary_timestamp = _extract_timestamp(summary.event_file_name) if self._summary_timestamp is None: @@ -234,42 +245,47 @@ class ExplainRunner: print("Finish writing metadata.") now = time() - print("Start running and writing inference data......") + print("Start running and writing inference data.....") imageid_labels = self._run_inference(dataset, summary) - print("Finish running and writing inference data. Time elapsed: {}s".format(time() - now)) + print(spacer.format("Finish running and writing inference data. " + "Time elapsed: {:.3f} s".format(time() - now))) if benchmarkers is None or not benchmarkers: for exp in explainers: start = time() print("Start running and writing explanation data for {}......".format(exp.__class__.__name__)) self._count = 0 - ds.config.set_seed(58) + ds.config.set_seed(_SEED) for idx, next_element in enumerate(dataset): now = time() self._run_exp_step(next_element, exp, imageid_labels, summary) - print("Finish writing {}-th explanation data. Time elapsed: {}".format( - idx, time() - now)) - print("Finish running and writing explanation data for {}. Time elapsed: {}".format( - exp.__class__.__name__, time() - start)) + print(spacer.format("Finish writing {}-th explanation data for {}. Time elapsed: " + "{:.3f} s".format(idx, time() - now, exp.__class__.__name__)), end='') + print(spacer.format( + "Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format( + exp.__class__.__name__, time() - start))) else: for exp in explainers: explain = Explain() for bench in benchmarkers: bench.reset() - print(f"Start running and writing explanation and benchmark data for {exp.__class__.__name__}.") + print(f"Start running and writing explanation and " + f"benchmark data for {exp.__class__.__name__}......") self._count = 0 start = time() - ds.config.set_seed(58) + ds.config.set_seed(_SEED) for idx, next_element in enumerate(dataset): now = time() saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary) - print("Finish writing {}-th batch explanation data. Time elapsed: {}s".format( - idx, time() - now)) + print(spacer.format( + "Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format( + idx, exp.__class__.__name__, time() - now)), end='') for bench in benchmarkers: now = time() self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst) - print("Finish running {}-th batch benchmark data for {}. Time elapsed: {}s".format( - idx, bench.__class__.__name__, time() - now)) + print(spacer.format( + "Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format( + idx, bench.__class__.__name__, exp.__class__.__name__, time() - now)), end='') for bench in benchmarkers: benchmark = explain.benchmark.add() @@ -279,11 +295,11 @@ class ExplainRunner: benchmark.total_score = bench.performance benchmark.label_score.extend(bench.class_performances) - print("Finish running and writing explanation and benchmark data for {}. " - "Time elapsed: {}s".format(exp.__class__.__name__, time() - start)) + print(spacer.format("Finish running and writing explanation and benchmark data for {}. " + "Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start))) summary.add_value('explainer', 'benchmark', explain) summary.record(1) - print("Finish running and writing. Total time elapsed: {}s".format(time() - begin)) + print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin)) @staticmethod def _verify_data_form(dataset, benchmarkers): @@ -446,8 +462,9 @@ class ExplainRunner: Returns: imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels. """ + spacer = '{:120}\r' imageid_labels = {} - ds.config.set_seed(58) + ds.config.set_seed(_SEED) self._count = 0 for j, next_element in enumerate(dataset): now = time() @@ -516,7 +533,9 @@ class ExplainRunner: summary.record(1) self._count += 1 - print("Finish running and writing {}-th batch inference data. Time elapsed: {}s".format(j, time() - now)) + print(spacer.format("Finish running and writing {}-th batch inference data." + " Time elapsed: {:.3f} s".format(j, time() - now)), + end='') return imageid_labels def _run_exp_step(self, next_element, explainer, imageid_labels, summary): @@ -543,18 +562,22 @@ class ExplainRunner: batch_unions = self._make_label_batch(unions) saliency_dict_lst = [] - batch_saliency_full = [] - for i in range(len(batch_unions[0])): - batch_saliency = explainer(inputs, batch_unions[:, i]) - batch_saliency_full.append(batch_saliency) + if isinstance(explainer, RISE): + batch_saliency_full = explainer(inputs, batch_unions) + else: + batch_saliency_full = [] + for i in range(len(batch_unions[0])): + batch_saliency = explainer(inputs, batch_unions[:, i]) + batch_saliency_full.append(batch_saliency) + concat = ms.ops.operations.Concat(1) + batch_saliency_full = concat(tuple(batch_saliency_full)) for idx, union in enumerate(unions): saliency_dict = {} explain = Explain() explain.sample_id = self._count for k, lab in enumerate(union): - saliency = batch_saliency_full[k][idx:idx + 1] - + saliency = batch_saliency_full[idx:idx + 1, k:k + 1] saliency_dict[lab] = saliency saliency_np = _normalize(saliency.asnumpy().squeeze()) @@ -600,7 +623,7 @@ class ExplainRunner: def _save_original_image(self, sample_id: int, image): """Save an image to summary directory.""" id_dirname = _get_id_dirname(sample_id) - relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp), + relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp), _ORIGINAL_IMAGE_DIRNAME, id_dirname) os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True) @@ -613,7 +636,7 @@ class ExplainRunner: def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image): """Save heatmap image to summary directory.""" id_dirname = _get_id_dirname(sample_id) - relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp), + relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp), _HEATMAP_DIRNAME, explain_method, id_dirname) diff --git a/mindspore/explainer/benchmark/_attribution/faithfulness.py b/mindspore/explainer/benchmark/_attribution/faithfulness.py index 9fa17c65ae..6a7dcfbf93 100644 --- a/mindspore/explainer/benchmark/_attribution/faithfulness.py +++ b/mindspore/explainer/benchmark/_attribution/faithfulness.py @@ -21,6 +21,7 @@ from scipy.ndimage.filters import gaussian_filter from mindspore import log import mindspore as ms +from mindspore.train._utils import check_value_type import mindspore.nn as nn import mindspore.ops.operations as op from .metric import AttributionMetric @@ -140,9 +141,7 @@ class Perturb: @staticmethod def _assign(x: _Array, y: _Array, masks: _Array): """Assign values to perturb pixels on perturbations.""" - if masks.dtype != bool: - raise TypeError('The param "masks" should be an array of bool, but receive {}' - .format(masks.dtype)) + check_value_type("masks dtype", masks.dtype, type(np.dtype(bool))) for i in range(x.shape[0]): x[i][:, masks[i]] = y[:, masks[i]] @@ -336,8 +335,7 @@ class NaiveFaithfulness(_FaithfulnessHelper): if not np.count_nonzero(saliency): log.warning("The saliency map is zero everywhere. The correlation will be set to zero.") correlation = 0 - normalized_faithfulness = (correlation + 1) / 2 - return np.array([normalized_faithfulness], np.float) + return np.array([correlation], np.float) reference = self._get_reference(inputs) perturbations, masks = self._perturb( inputs, saliency, reference, return_mask=True) @@ -347,8 +345,7 @@ class NaiveFaithfulness(_FaithfulnessHelper): predictions = model(perturbations).asnumpy()[:, targets] faithfulness = calc_correlation(feature_importance, predictions) - normalized_faithfulness = (faithfulness + 1) / 2 - return np.array([normalized_faithfulness], np.float) + return np.array([faithfulness], np.float) class DeletionAUC(_FaithfulnessHelper): @@ -533,6 +530,8 @@ class Faithfulness(AttributionMetric): metric (str, optional): The specifi metric to quantify faithfulness. Options: "DeletionAUC", "InsertionAUC", "NaiveFaithfulness". Default: 'NaiveFaithfulness'. + activation_fn (Cell, optional): The activation function that transforms the network output to a probability. + Default: nn.Softmax(). Examples: >>> from mindspore.explainer.benchmark import Faithfulness @@ -543,7 +542,7 @@ class Faithfulness(AttributionMetric): """ _methods = [NaiveFaithfulness, DeletionAUC, InsertionAUC] - def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness"): + def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness", activation_fn=nn.Softmax()): super(Faithfulness, self).__init__(num_labels) perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument @@ -552,6 +551,7 @@ class Faithfulness(AttributionMetric): num_perturb_steps = 100 # separate the perturbation progress in to 100 steps. base_value = 0.0 # the pixel value set for the perturbed pixels + self._activation_fn = activation_fn self._verify_metrics(metric) for method in self._methods: if metric == method.__name__: @@ -568,9 +568,7 @@ class Faithfulness(AttributionMetric): Evaluate faithfulness on a single data sample. Note: - To apply `Faithfulness` to evaluate an explainer, this explainer must be initialized with a network that - contains the output activation function. Otherwise, the results will not be correct. Currently only single - sample (:math:`N=1`) at each call is supported. + Currently only single sample (:math:`N=1`) at each call is supported. Args: explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`. @@ -586,7 +584,7 @@ class Faithfulness(AttributionMetric): Examples: >>> # init an explainer, the network should contain the output activation function. - >>> network = nn.SequentialCell([resnet50, nn.Sigmoid()]) + >>> network = resnet50(20) >>> gradient = Gradient(network) >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) >>> targets = 5 @@ -610,10 +608,10 @@ class Faithfulness(AttributionMetric): saliency = saliency.squeeze() if len(saliency.shape) != 2: raise ValueError('Squeezed saliency map is expected to 2D, but receive {}.'.format(len(saliency.shape))) - - faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=explainer.model, + model = nn.SequentialCell([explainer.model, self._activation_fn]) + faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=model, targets=targets, saliency=saliency) - return faithfulness + return (1 + faithfulness) / 2 def _verify_metrics(self, metric: str): supports = [x.__name__ for x in self._methods] diff --git a/mindspore/explainer/explanation/__init__.py b/mindspore/explainer/explanation/__init__.py index 239fb00e45..f2dc4b464b 100644 --- a/mindspore/explainer/explanation/__init__.py +++ b/mindspore/explainer/explanation/__init__.py @@ -17,10 +17,12 @@ from ._attribution._backprop.gradcam import GradCAM from ._attribution._backprop.gradient import Gradient from ._attribution._backprop.modified_relu import Deconvolution, GuidedBackprop +from ._attribution._perturbation.rise import RISE __all__ = [ 'Gradient', 'Deconvolution', 'GuidedBackprop', 'GradCAM', + 'RISE' ] diff --git a/mindspore/explainer/explanation/_attribution/__init__.py b/mindspore/explainer/explanation/_attribution/__init__.py index 1f4ddaf1b1..0c1348bb36 100644 --- a/mindspore/explainer/explanation/_attribution/__init__.py +++ b/mindspore/explainer/explanation/_attribution/__init__.py @@ -16,10 +16,12 @@ from ._backprop.gradcam import GradCAM from ._backprop.gradient import Gradient from ._backprop.modified_relu import Deconvolution, GuidedBackprop +from ._perturbation.rise import RISE __all__ = [ 'Gradient', 'Deconvolution', 'GuidedBackprop', 'GradCAM', + 'RISE' ] diff --git a/mindspore/explainer/explanation/_attribution/_attribution.py b/mindspore/explainer/explanation/_attribution/_attribution.py index 78e4131103..00c5072a89 100644 --- a/mindspore/explainer/explanation/_attribution/_attribution.py +++ b/mindspore/explainer/explanation/_attribution/_attribution.py @@ -16,33 +16,25 @@ from typing import Callable -import mindspore as ms +from mindspore.train._utils import check_value_type +from mindspore.nn import Cell class Attribution: - r""" + """ Basic class of attributing the salient score - The explainers which explanation through attributing the relevance scores - should inherit this class. + The explainers which explanation through attributing the relevance scores should inherit this class. Args: - network (ms.nn.Cell): The black-box model to explanation. + network (Cell): The black-box model to explain. """ def __init__(self, network): - self._verify_model(network) + check_value_type("network", network, Cell) self._model = network self._model.set_train(False) self._model.set_grad(False) - @staticmethod - def _verify_model(model): - """ - Verify the input `network` for __init__ function. - """ - if not isinstance(model, ms.nn.Cell): - raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.") - __call__: Callable """ diff --git a/mindspore/explainer/explanation/_attribution/_backprop/gradcam.py b/mindspore/explainer/explanation/_attribution/_backprop/gradcam.py index ecf07d78e0..155abe5869 100644 --- a/mindspore/explainer/explanation/_attribution/_backprop/gradcam.py +++ b/mindspore/explainer/explanation/_attribution/_backprop/gradcam.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -""" GradCAM and GuidedGradCAM. """ +"""GradCAM.""" from mindspore.ops import operations as op @@ -98,7 +98,7 @@ class GradCAM(IntermediateLayerAttribution): """ Hook function to deal with the backward gradient. - The arguments are set as required by Cell.register_back_hook + The arguments are set as required by `Cell.register_backward_hook`. """ self._intermediate_grad = grad_input diff --git a/mindspore/explainer/explanation/_attribution/_perturbation/__init__.py b/mindspore/explainer/explanation/_attribution/_perturbation/__init__.py new file mode 100644 index 0000000000..62720b4e86 --- /dev/null +++ b/mindspore/explainer/explanation/_attribution/_perturbation/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" Perturbation-based _attribution explainer. """ + +from .rise import RISE + +__all__ = ['RISE'] diff --git a/mindspore/explainer/explanation/_attribution/_perturbation/perturbation.py b/mindspore/explainer/explanation/_attribution/_perturbation/perturbation.py new file mode 100644 index 0000000000..204326de48 --- /dev/null +++ b/mindspore/explainer/explanation/_attribution/_perturbation/perturbation.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Base class `PerturbationAttribtuion`""" + +from mindspore.train._utils import check_value_type +from mindspore.nn import Cell + +from .._attribution import Attribution +from ...._operators import softmax + + +class PerturbationAttribution(Attribution): + """ + Base class for perturbation-based attribution methods. + + All perturbation-based _attribution methods extend from this class. + """ + + def __init__(self, + network, + activation_fn=softmax(), + ): + super(PerturbationAttribution, self).__init__(network) + check_value_type("activation_fn", activation_fn, Cell) + self._activation_fn = activation_fn diff --git a/mindspore/explainer/explanation/_attribution/_perturbation/rise.py b/mindspore/explainer/explanation/_attribution/_perturbation/rise.py new file mode 100644 index 0000000000..1ab4640557 --- /dev/null +++ b/mindspore/explainer/explanation/_attribution/_perturbation/rise.py @@ -0,0 +1,194 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""RISE.""" +import math + +import numpy as np + +from mindspore import Tensor +from mindspore import nn +from mindspore.train._utils import check_value_type + +from .perturbation import PerturbationAttribution +from .... import _operators as op +from ...._utils import resize + + +class RISE(PerturbationAttribution): + r""" + RISE: Randomized Input Sampling for Explanation of Black-box Model. + + RISE is a perturbation-based method that generates attribution maps by sampling on multiple random binary masks. + The original image is randomly masked, and then fed into the black-box model to get predictions. The final + attribution map is the weighted sum of these random masks, with the weights being the corresponding output on the + node of interest: + + .. math:: + E_{RISE}(I, f)_c = \sum_{i}f_c(I\odot M_i) M_i + + For more details, please refer to the original paper via: `RISE `_. + + Args: + network (Cell): The black-box model to be explained. + activation_fn (Cell, optional): The activation layer that transforms logits to prediction probabilities. For + single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, + `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as + when combining this function with network, the final output is the probability of the input. + Default: `nn.Softmax`. + perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the + perturbed samples. Default: 32. + + Examples: + >>> from mindspore.explainer.explanation import RISE + >>> net = resnet50(10) + >>> param_dict = load_checkpoint("resnet50.ckpt") + >>> load_param_into_net(net, param_dict) + >>> # init RISE with specified activation function + >>> rise = RISE(net, activation_fn=nn.layer.Sigmoid()) + """ + + def __init__(self, + network, + activation_fn=nn.Softmax(), + perturbation_per_eval=32): + super(RISE, self).__init__(network, activation_fn) + self._perturbation_per_eval = perturbation_per_eval + + self._num_masks = 6000 # number of masks to be sampled + self._mask_probability = 0.2 # ratio of inputs to be masked + self._down_sample_size = 10 # the original size of binary masks + self._resize_mode = 'bilinear' # mode choice to resize the down-sized binary masks to size of the inputs + self._perturbation_mode = 'constant' # setting the perturbed pixels to a constant value + self._base_value = 0 # setting the perturbed pixels to this constant value + self._num_classes = None # placeholder of self._num_classes just for future assignment in other methods + + def _generate_masks(self, data, batch_size): + """Generate a batch of binary masks for data.""" + + height, width = data.shape[2], data.shape[3] + + mask_size = (self._down_sample_size, self._down_sample_size) + + up_size = (height + mask_size[0], width + mask_size[1]) + mask = np.random.random((batch_size, 1) + mask_size) < self._mask_probability + upsample = resize(op.Tensor(mask, data.dtype), up_size, + self._resize_mode) + + # Pack operator not available for GPU, thus transfer to numpy first + upsample_np = upsample.asnumpy() + masks_lst = [] + for sample in upsample_np: + shift_x = np.random.randint(0, mask_size[0] + 1) + shift_y = np.random.randint(0, mask_size[1] + 1) + masks_lst.append(sample[:, shift_x: shift_x + height, shift_y:shift_y + width]) + masks = op.Tensor(np.array(masks_lst), data.dtype) + return masks + + def __call__(self, inputs, targets): + """ + Generates attribution maps for inputs. + + Args: + inputs (Tensor): Input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. + targets (int, Tensor): The labels of interest to be explained. When `targets` is an integer, + all of the inputs will generates attribution map w.r.t this integer. When `targets` is a tensor, it + should be of shape :math:`(N, ?)` or :math:`(N,)` :math:`()`. + + Returns: + Tensor, a 4D tensor of shape :math:`(N, ?, H, W)` or :math:`(N, 1, H, W)`. + + Examples: + >>> # given an instance of RISE, saliency map can be generate + >>> inputs = ms.Tensor(np.random.rand([2, 3, 224, 224]), ms.float32) + >>> # when `targets` is an integer + >>> targets = 5 + >>> saliency = rise(inputs, targets) + >>> # `targets` can also be a tensor + >>> targets = ms.Tensor([[5], [1]]) + >>> saliency = rise(inputs, targets) + >>> + """ + self._verify_data(inputs, targets) + height, width = inputs.shape[2], inputs.shape[3] + + batch_size = inputs.shape[0] + + if self._num_classes is None: + logits = self.model(inputs) + num_classes = logits.shape[1] + self._num_classes = num_classes + + # Due to the unsupported Op of slice assignment, we use numpy array here + attr_np = np.zeros(shape=(batch_size, self._num_classes, height, width)) + + cal_times = math.ceil(self._num_masks / self._perturbation_per_eval) + + for idx, data in enumerate(inputs): + bg_data = data * 0 + self._base_value + for j in range(cal_times): + bs = min(self._num_masks - j * self._perturbation_per_eval, + self._perturbation_per_eval) + data = op.reshape(data, (1, -1, height, width)) + masks = self._generate_masks(data, bs) + + masked_input = masks * data + (1 - masks) * bg_data + weights = self._activation_fn(self.model(masked_input)) + while len(weights.shape) > 2: + weights = op.mean(weights, axis=2) + weights = op.reshape(weights, + (bs, self._num_classes, 1, 1)) + + attr_np[idx] += op.summation(weights * masks, axis=0).asnumpy() + + attr_np = attr_np / self._num_masks + targets = self._unify_targets(inputs, targets) + attr_classes = [] + for idx, target in enumerate(targets): + dtype = inputs.dtype + attr_np_idx = attr_np[idx] + attr_idx = attr_np_idx[target] + attr_classes.append(attr_idx) + + return op.Tensor(attr_classes, dtype=dtype) + + @staticmethod + def _verify_data(inputs, targets): + """Verify the validity of the parsed inputs.""" + check_value_type('inputs', inputs, Tensor) + if len(inputs.shape) != 4: + raise ValueError('Argument inputs must be 4D Tensor') + check_value_type('targets', targets, (Tensor, int, tuple, list)) + if isinstance(targets, Tensor): + if len(targets.shape) > 2: + raise ValueError('Dimension invalid. If `targets` is a Tensor, it should be 0D, 1D or 2D. ' + 'But got {}D.'.format(len(targets.shape))) + if targets.shape and len(targets) != len(inputs): + raise ValueError( + 'If `targets` is a 2D, 1D Tensor, it should have the same length as inputs {}. But got {}'.format( + len(inputs), len(targets))) + + @staticmethod + def _unify_targets(inputs, targets): + """To unify targets to be 2D numpy.ndarray.""" + if isinstance(targets, int): + return np.array([[targets] for i in inputs]).astype(np.int) + if isinstance(targets, Tensor): + if not targets.shape: + return np.array([[targets.asnumpy()] for _ in inputs]).astype(np.int) + if len(targets.shape) == 1: + return np.array([[t.asnumpy()] for t in targets]).astype(np.int) + if len(targets.shape) == 2: + return np.array([t.asnumpy() for t in targets]).astype(np.int) + return targets