Merge pull request !7656 from lixiaohui33/feature_explain_coretags/v1.1.0
| @@ -118,19 +118,11 @@ message Explain { | |||
| } | |||
| message Benchmark{ | |||
| message TotalScore{ | |||
| optional string benchmark_method = 1; | |||
| optional float score = 2; | |||
| } | |||
| message LabelScore{ | |||
| repeated float score = 1; | |||
| optional string benchmark_method = 2; | |||
| } | |||
| optional string explain_method = 1; | |||
| repeated TotalScore total_score = 2; | |||
| repeated LabelScore label_score = 3; | |||
| } | |||
| optional string benchmark_method = 1; | |||
| optional string explain_method = 2; | |||
| optional float total_score = 3; | |||
| repeated float label_score = 4; | |||
| } | |||
| message Metadata{ | |||
| repeated string label = 1; | |||
| @@ -0,0 +1,19 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Provide ExplainRunner High-level API.""" | |||
| from ._runner import ExplainRunner | |||
| __all__ = ['ExplainRunner'] | |||
| @@ -0,0 +1,261 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Packaged operations based on MindSpore.""" | |||
| from typing import List, Tuple, Union, Callable | |||
| import numpy as np | |||
| import mindspore | |||
| from mindspore import nn | |||
| import mindspore.ops.operations as op | |||
| _Axis = Union[int, Tuple[int, ...], List[int]] | |||
| _Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]] | |||
| _Number = Union[int, float, np.int, np.float] | |||
| _Shape = Union[int, Tuple[int, ...]] | |||
| Tensor = mindspore.Tensor | |||
| __all__ = [ | |||
| 'absolute', | |||
| 'arange', | |||
| 'argmax', | |||
| 'argmin', | |||
| 'argsort', | |||
| 'assign', | |||
| 'intersection', | |||
| 'matmul', | |||
| 'maximum', | |||
| 'minimum', | |||
| 'mean', | |||
| 'mul', | |||
| 'sort', | |||
| 'squeeze', | |||
| 'tile', | |||
| 'reshape', | |||
| 'zeros', | |||
| 'zeros_like', | |||
| 'softmax', | |||
| 'Tensor', | |||
| 'summation' | |||
| ] | |||
| def absolute(inputs: Tensor) -> Tensor: | |||
| """Get the absolute value of a tensor value.""" | |||
| abs_op = op.Abs() | |||
| outputs = abs_op(inputs) | |||
| return outputs | |||
| def arange( | |||
| start: _Number, | |||
| end: _Number, | |||
| step: _Number = 1, | |||
| dtype: mindspore.dtype = None) -> Tensor: | |||
| """Get the arange value of tensor.""" | |||
| nums = np.arange(start=start, stop=end, step=step, dtype=np.int32) | |||
| nums = mindspore.Tensor(nums, dtype=dtype) | |||
| return nums | |||
| def argmax(inputs: Tensor, axis: int = -1, keep_dims: bool = False) -> Tensor: | |||
| """Returns the indices of the maximum values along an axis.""" | |||
| inputs_np = inputs.asnumpy() | |||
| outputs = np.argmax(inputs_np, axis=axis) | |||
| if keep_dims: | |||
| outputs = np.expand_dims(outputs, axis=axis) | |||
| return mindspore.Tensor(outputs, mindspore.int32) | |||
| def argmin(inputs: Tensor, axis: int = -1, keep_dims: bool = False) -> Tensor: | |||
| """Returns the indices of the minimum values along an axis.""" | |||
| inputs_np = inputs.asnumpy() | |||
| outputs = np.argmin(inputs_np, axis=axis) | |||
| if keep_dims: | |||
| outputs = np.expand_dims(outputs, axis=axis) | |||
| return mindspore.Tensor(outputs, mindspore.int32) | |||
| def argsort(inputs: Tensor, axis: int = -1, descending: bool = False) -> Tensor: | |||
| """Returns the indices that would sort an array.""" | |||
| inputs_np = inputs.asnumpy() | |||
| factor = -1 if descending else 1 | |||
| indices_np = np.argsort(factor * inputs_np, axis=axis) | |||
| indices = mindspore.Tensor(indices_np, dtype=mindspore.int32) | |||
| return indices | |||
| def assign(inputs: Tensor, idx: _Idx, value: Tensor) -> Tensor: | |||
| """Assign a tensor value to the given tensor and index.""" | |||
| inputs_np = inputs.asnumpy() | |||
| if isinstance(idx, Tensor): | |||
| idx = idx.asnumpy() | |||
| value_np = value.asnumpy() | |||
| inputs_np[idx] = value_np | |||
| outputs = mindspore.Tensor(inputs_np) | |||
| return outputs | |||
| def intersection(*inputs: Tensor) -> Tensor: | |||
| """Get the intersection value by the given tensor list.""" | |||
| outputs_np = np.ones_like(inputs[0]) | |||
| for inp in inputs: | |||
| outputs_np &= inp.asnumpy() | |||
| outputs = mindspore.Tensor(outputs_np) | |||
| return outputs | |||
| def matmul(inputs_x: Tensor, inputs_y: Tensor) -> Tensor: | |||
| """Multiplies matrix `inputs_x` and matrix `inputs_y`.""" | |||
| matmul_op = op.MatMul() | |||
| outputs = matmul_op(inputs_x, inputs_y) | |||
| return outputs | |||
| def maximum(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor: | |||
| """Reduce a dimension of a tensor by the maximum value in this dimension.""" | |||
| max_op = op.ReduceMax(keep_dims) | |||
| outputs = max_op(inputs, axis) | |||
| return outputs | |||
| def minimum(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor: | |||
| """Reduce a dimension of a tensor by the minimum value in the dimension.""" | |||
| max_op = op.ReduceMin(keep_dims) | |||
| outputs = max_op(inputs, axis) | |||
| return outputs | |||
| def mean(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor: | |||
| """Reduce a dimension of a tensor by averaging all elements in the dimension.""" | |||
| mean_op = op.ReduceMean(keep_dims) | |||
| outputs = mean_op(inputs, axis) | |||
| return outputs | |||
| def mul(inputs_x: Tensor, inputs_y: Tensor) -> Tensor: | |||
| """ | |||
| Multiplies two tensors element-wise. | |||
| Inputs of `input_x` and `input_y` comply with the implicit type conversion rules to make the data types consistent. | |||
| The inputs must be two tensors or one tensor and one scalar. | |||
| When the inputs are two tensors, | |||
| dtypes of them cannot be both bool, and the shapes of them could be broadcast. | |||
| When the inputs are one tensor and one scalar, | |||
| the scalar could only be a constant. | |||
| Inputs: | |||
| - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or | |||
| a bool or a tensor whose data type is number or bool. | |||
| - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or | |||
| a bool when the first input is a tensor or a tensor whose data type is number or bool. | |||
| Outputs: | |||
| Tensor, the shape is the same as the one after broadcasting, | |||
| and the data type is the one with higher precision or higher digits among the two inputs. | |||
| """ | |||
| mul_op = op.Mul() | |||
| outputs = mul_op(inputs_x, inputs_y) | |||
| return outputs | |||
| def sort(inputs: Tensor, axis: _Axis = -1, descending: bool = False) -> Tensor: | |||
| """Return a sorted copy of an array.""" | |||
| inputs_np = inputs.asnumpy() | |||
| outputs_np = np.sort(inputs_np, axis=axis) | |||
| if descending: | |||
| outputs_np = np.flip(outputs_np, axis=axis) | |||
| outputs = mindspore.Tensor(outputs_np) | |||
| return outputs | |||
| def squeeze(inputs: Tensor, axis: _Axis = ()): | |||
| """Returns a tensor with the same type but dimensions of 1 are removed based on `axis`.""" | |||
| squeeze_op = op.Squeeze(axis) | |||
| outputs = squeeze_op(inputs) | |||
| return outputs | |||
| def tile(inputs: Tensor, shape: Tuple[int, ...]) -> Tensor: | |||
| """Replicates a tensor with given multiples times.""" | |||
| tile_op = op.Tile() | |||
| outputs = tile_op(inputs, shape) | |||
| return outputs | |||
| def reshape(inputs: Tensor, shape: _Shape) -> Tensor: | |||
| """Reshapes input tensor with the same values based on a given shape tuple.""" | |||
| if isinstance(shape, int): | |||
| shape = (shape,) | |||
| return op.Reshape()(inputs, shape) | |||
| def zeros(shape: _Shape, dtype: mindspore.dtype = None) -> Tensor: | |||
| """Return a new array of given shape and type, filled with zeros.""" | |||
| outputs = np.zeros(shape) | |||
| return mindspore.Tensor(outputs, dtype=dtype) | |||
| def zeros_like(inputs: Tensor, dtype: mindspore.dtype = None) -> Tensor: | |||
| """Return an array of zeros with the same shape and type as a given array.""" | |||
| inputs_np = inputs.asnumpy() | |||
| outputs_np = np.zeros_like(inputs_np) | |||
| outputs = mindspore.Tensor(outputs_np, dtype) | |||
| return outputs | |||
| def random(shape: _Shape, dtype: mindspore.dtype = None) -> Tensor: | |||
| """Return random floats in the half-open interval [0.0, 1.0).""" | |||
| outputs_np = np.random.random(shape) | |||
| outputs = mindspore.Tensor(outputs_np, dtype) | |||
| return outputs | |||
| def randint(low: int, high: int, shape: _Shape, dtype: mindspore.dtype = mindspore.int8) -> Tensor: | |||
| """Return random integers from `low` (inclusive) to `high` (exclusive).""" | |||
| outputs_np = np.random.randint(low, high, size=shape) | |||
| outputs = mindspore.Tensor(outputs_np, dtype=dtype) | |||
| return outputs | |||
| def softmax(axis: int) -> Callable: | |||
| """Softmax activation function.""" | |||
| func = nn.Softmax(axis=axis) | |||
| return func | |||
| def summation(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor: | |||
| """Reduce a dimension of a tensor by summing all elements in the dimension.""" | |||
| sum_op = op.ReduceSum(keep_dims) | |||
| outputs = sum_op(inputs, axis) | |||
| return outputs | |||
| def stack(inputs: List[Tensor], axis: int) -> Tensor: | |||
| """Packs a list of tensors in specified axis.""" | |||
| pack_op = op.Pack(axis) | |||
| outputs = pack_op(inputs) | |||
| return outputs | |||
| def sqrt(inputs: Tensor) -> Tensor: | |||
| """Returns square root of a tensor element-wise.""" | |||
| sqrt_op = op.Sqrt() | |||
| return sqrt_op(inputs) | |||
| @@ -0,0 +1,481 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Runner.""" | |||
| from time import time | |||
| from typing import Tuple, List, Optional | |||
| import numpy as np | |||
| from mindspore.train.summary_pb2 import Explain | |||
| import mindspore as ms | |||
| import mindspore.dataset as ds | |||
| from mindspore import log | |||
| from mindspore.ops.operations import ExpandDims | |||
| from mindspore.train.summary._summary_adapter import _convert_image_format, _make_image | |||
| from mindspore.train.summary.summary_record import SummaryRecord | |||
| from .benchmark import Localization | |||
| from .benchmark._attribution.metric import AttributionMetric | |||
| from .explanation._attribution._attribution import Attribution | |||
| _EXPAND_DIMS = ExpandDims() | |||
| _CMAP_0 = np.reshape(np.array([55, 25, 86, 255]), [1, 1, 4]) / 255 | |||
| _CMAP_1 = np.reshape(np.array([255, 255, 0, 255]), [1, 1, 4]) / 255 | |||
| def _normalize(img_np): | |||
| """Normalize the image in the numpy array to be in [0, 255]. """ | |||
| max_ = img_np.max() | |||
| min_ = img_np.min() | |||
| normed = (img_np - min_) / (max_ - min_).clip(min=1e-10) | |||
| return (normed * 255).astype(np.uint8) | |||
| def _make_rgba(saliency): | |||
| """Make rgba image for saliency map.""" | |||
| saliency = saliency.asnumpy().squeeze() | |||
| saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min()).clip(1e-10) | |||
| rgba = np.empty((saliency.shape[0], saliency.shape[1], 4)) | |||
| rgba[:, :, :] = np.expand_dims(saliency, 2) | |||
| rgba = rgba * _CMAP_1 + (1 - rgba) * _CMAP_0 | |||
| rgba[:, :, -1] = saliency * 1 | |||
| return rgba | |||
| class ExplainRunner: | |||
| """ | |||
| High-level API for users to generate results with the explanation methods and the evaluation methods. | |||
| After generating results with the explanation methods and the evaluation methods, the results will be written into | |||
| a specified file with 'mindspore.summary.SummaryRecord'. The stored content can be viewed using MindInsight. | |||
| Args: | |||
| summary_dir (str): The directory path to save the summary files which store the generated results. | |||
| Default: "./" | |||
| Examples: | |||
| >>> # init a runner with a specified directory | |||
| >>> summary_dir = "summary_dir" | |||
| >>> runner = ExplainRunner(summary_dir) | |||
| """ | |||
| def __init__(self, summary_dir: Optional[str] = "./"): | |||
| self._summary_dir = summary_dir | |||
| self._count = 0 | |||
| self._classes = None | |||
| self._model = None | |||
| def run(self, | |||
| dataset: Tuple, | |||
| explainers: List, | |||
| benchmarkers: Optional[List] = None): | |||
| """ | |||
| Genereate results and write results into the summary files in `self.summary_dir`. | |||
| Args: | |||
| dataset (tuple): A tuple that contains `mindspore.dataset` object for iteration and its labels. | |||
| - dataset[0], a `mindspore.dataset` object to provide data to explain. | |||
| - dataset[1], a list of string that specifies the label names of the dataset. | |||
| explainers (list): A list of explanation objects to generate _attribution results. | |||
| benchmarkers (list): A list of benchmark objects to generate evaluation results. Default: None | |||
| Examples: | |||
| >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient | |||
| >>> # obtain dataset object | |||
| >>> dataset = get_dataset() | |||
| >>> classes = ["cat", "dog", ...] | |||
| >>> # load checkpoint to a network, e.g. resnet50 | |||
| >>> param_dict = load_checkpoint("checkpoint.ckpt") | |||
| >>> net = resnet50(len(classes)) | |||
| >>> load_parama_into_net(net, param_dict) | |||
| >>> # bind net with its output activation | |||
| >>> model = nn.SequentialCell([net, nn.Sigmoid()]) | |||
| >>> gbp = GuidedBackprop(model) | |||
| >>> gradient = Gradient(model) | |||
| >>> runner = ExplainRunner("./") | |||
| >>> explainers = [gbp, gradient] | |||
| >>> runner.run((dataset, classes), explainers) | |||
| """ | |||
| if not isinstance(dataset, tuple): | |||
| raise TypeError("Argument `dataset` must be a tuple.") | |||
| if len(dataset) != 2: | |||
| raise ValueError("Argument `dataset` should be a tuple with length = 2.") | |||
| dataset, classes = dataset | |||
| self._verify_data_form(dataset, benchmarkers) | |||
| self._classes = classes | |||
| if explainers is None or not explainers: | |||
| raise ValueError("Argument `explainers` can neither be None nor empty.") | |||
| for exp in explainers: | |||
| if not isinstance(exp, Attribution) or not isinstance(explainers, list): | |||
| raise TypeError("Argument explainers should be a list of objects of classes in " | |||
| "`mindspore.explainer.explanation._attribution`.") | |||
| if benchmarkers is not None: | |||
| for bench in benchmarkers: | |||
| if not isinstance(bench, AttributionMetric) or not isinstance(explainers, list): | |||
| raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation" | |||
| "`mindspore.explainer.benchmark._attribution`.") | |||
| self._model = explainers[0].model | |||
| with SummaryRecord(self._summary_dir) as summary: | |||
| print("Start running and writing......") | |||
| begin = time() | |||
| print("Start writing metadata.") | |||
| explain = Explain() | |||
| explain.metadata.label.extend(classes) | |||
| exp_names = [exp.__class__.__name__ for exp in explainers] | |||
| explain.metadata.explain_method.extend(exp_names) | |||
| if benchmarkers is not None: | |||
| bench_names = [bench.__class__.__name__ for bench in benchmarkers] | |||
| explain.metadata.benchmark_method.extend(bench_names) | |||
| summary.add_value("explainer", "metadata", explain) | |||
| summary.record(1) | |||
| print("Finish writing metadata.") | |||
| now = time() | |||
| print("Start running and writing inference data......") | |||
| imageid_labels = self._run_inference(dataset, summary) | |||
| print("Finish running and writing inference data. Time elapsed: {}s".format(time() - now)) | |||
| if benchmarkers is None: | |||
| for exp in explainers: | |||
| start = time() | |||
| print("Start running and writing explanation data for {}......".format(exp.__class__.__name__)) | |||
| self._count = 0 | |||
| ds.config.set_seed(58) | |||
| for idx, next_element in enumerate(dataset): | |||
| now = time() | |||
| self._run_exp_step(next_element, exp, imageid_labels, summary) | |||
| print("Finish writing {}-th explanation data. Time elapsed: {}".format( | |||
| idx, time() - now)) | |||
| print("Finish running and writing explanation data for {}. Time elapsed: {}".format( | |||
| exp.__class__.__name__, time() - start)) | |||
| else: | |||
| for exp in explainers: | |||
| explain = Explain() | |||
| for bench in benchmarkers: | |||
| bench.reset() | |||
| print(f"Start running and writing explanation and benchmark data for {exp.__class__.__name__}.") | |||
| self._count = 0 | |||
| start = time() | |||
| ds.config.set_seed(58) | |||
| for idx, next_element in enumerate(dataset): | |||
| now = time() | |||
| saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary) | |||
| print("Finish writing {}-th batch explanation data. Time elapsed: {}s".format( | |||
| idx, time() - now)) | |||
| for bench in benchmarkers: | |||
| now = time() | |||
| self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst) | |||
| print("Finish running {}-th batch benchmark data for {}. Time elapsed: {}s".format( | |||
| idx, bench.__class__.__name__, time() - now)) | |||
| for bench in benchmarkers: | |||
| benchmark = explain.benchmark.add() | |||
| benchmark.explain_method = exp.__class__.__name__ | |||
| benchmark.benchmark_method = bench.__class__.__name__ | |||
| benchmark.total_score = bench.performance | |||
| benchmark.label_score.extend(bench.class_performances) | |||
| print("Finish running and writing explanation and benchmark data for {}. " | |||
| "Time elapsed: {}s".format(exp.__class__.__name__, time() - start)) | |||
| summary.add_value('explainer', 'benchmark', explain) | |||
| summary.record(1) | |||
| print("Finish running and writing. Total time elapsed: {}s".format(time() - begin)) | |||
| @staticmethod | |||
| def _verify_data_form(dataset, benchmarkers): | |||
| """ | |||
| Verify the validity of dataset. | |||
| Args: | |||
| dataset (`ds`): the user parsed dataset. | |||
| benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers. | |||
| """ | |||
| next_element = dataset.create_tuple_iterator().get_next() | |||
| if len(next_element) not in [1, 2, 3]: | |||
| raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]" | |||
| " as columns.") | |||
| if len(next_element) == 3: | |||
| inputs, labels, bboxes = next_element | |||
| if bboxes.shape[-1] != 4: | |||
| raise ValueError("The third element of dataset should be bounding boxes with shape of " | |||
| "[batch_size, num_ground_truth, 4].") | |||
| else: | |||
| if True in [isinstance(bench, Localization) for bench in benchmarkers]: | |||
| raise ValueError("The dataset must provide bboxes if Localization is to be computed.") | |||
| if len(next_element) == 2: | |||
| inputs, labels = next_element | |||
| if len(next_element) == 1: | |||
| inputs = next_element[0] | |||
| if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]: | |||
| raise ValueError( | |||
| "Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format( | |||
| inputs.shape)) | |||
| if len(inputs.shape) == 3: | |||
| log.warning( | |||
| "Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th" | |||
| " dimension as batch data.".format(inputs.shape)) | |||
| if len(next_element) > 1: | |||
| if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1: | |||
| raise ValueError( | |||
| "Labels shape {} is unrecognizable: labels should not have more than two dimensions" | |||
| " with length greater than 1.".format(labels.shape)) | |||
| def _transform_data(self, inputs, labels, bboxes, ifbbox): | |||
| """ | |||
| Transform the data from one iteration of dataset to a unifying form for the follow-up operations. | |||
| Args: | |||
| inputs (Tensor): the image data | |||
| labels (Tensor): the labels | |||
| bboxes (Tensor): the boudnding boxes data | |||
| ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t label | |||
| id will be returned. If False, the returned bboxes is the the parsed bboxes. | |||
| Returns: | |||
| inputs (Tensor): the image data, unified to a 4D Tensor. | |||
| labels (List[List[int]]): the ground truth labels. | |||
| bboxes (Union[List[Dict], None, Tensor]): the bounding boxes | |||
| """ | |||
| inputs = ms.Tensor(inputs, ms.float32) | |||
| if len(inputs.shape) == 3: | |||
| inputs = _EXPAND_DIMS(inputs, 0) | |||
| if isinstance(labels, ms.Tensor): | |||
| labels = ms.Tensor(labels, ms.int32) | |||
| labels = _EXPAND_DIMS(labels, 0) | |||
| if isinstance(bboxes, ms.Tensor): | |||
| bboxes = ms.Tensor(bboxes, ms.int32) | |||
| bboxes = _EXPAND_DIMS(bboxes, 0) | |||
| input_len = len(inputs) | |||
| if bboxes is not None and ifbbox: | |||
| bboxes = ms.Tensor(bboxes, ms.int32) | |||
| masks_lst = [] | |||
| labels = labels.asnumpy().reshape([input_len, -1]) | |||
| bboxes = bboxes.asnumpy().reshape([input_len, -1, 4]) | |||
| for idx, label in enumerate(labels): | |||
| height, width = inputs[idx].shape[-2], inputs[idx].shape[-1] | |||
| masks = {} | |||
| for j, label_item in enumerate(label): | |||
| target = int(label_item) | |||
| if -1 < target < len(self._classes): | |||
| if target not in masks: | |||
| mask = np.zeros((1, 1, height, width)) | |||
| else: | |||
| mask = masks[target] | |||
| x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int) | |||
| mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1 | |||
| masks[target] = mask | |||
| masks_lst.append(masks) | |||
| bboxes = masks_lst | |||
| labels = ms.Tensor(labels, ms.int32) | |||
| if len(labels.shape) == 1: | |||
| labels_lst = [[int(i)] for i in labels.asnumpy()] | |||
| else: | |||
| labels = labels.asnumpy().reshape([input_len, -1]) | |||
| labels_lst = [] | |||
| for item in labels: | |||
| labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._classes)))) | |||
| labels = labels_lst | |||
| return inputs, labels, bboxes | |||
| def _unpack_next_element(self, next_element, ifbbox=False): | |||
| """ | |||
| Unpack a single iteration of dataset. | |||
| Args: | |||
| next_element (Tuple): a single element iterated from dataset object. | |||
| ifbbox (bool): whether to preprocess bboxes in self._transform_data. | |||
| Returns: | |||
| Tuple, a unified Tuple contains image_data, labels, and bounding boxes. | |||
| """ | |||
| if len(next_element) == 3: | |||
| inputs, labels, bboxes = next_element | |||
| elif len(next_element) == 2: | |||
| inputs, labels = next_element | |||
| bboxes = None | |||
| else: | |||
| inputs = next_element[0] | |||
| labels = [[] for x in inputs] | |||
| bboxes = None | |||
| inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox) | |||
| return inputs, labels, bboxes | |||
| @staticmethod | |||
| def _make_label_batch(labels): | |||
| """ | |||
| Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max | |||
| length of all the rows in labels. | |||
| Args: | |||
| labels (List[List]): the union labels of a data batch. | |||
| Returns: | |||
| 2D Tensor. | |||
| """ | |||
| max_len = max([len(l) for l in labels]) | |||
| batch_labels = np.zeros((len(labels), max_len)) | |||
| for idx, _ in enumerate(batch_labels): | |||
| length = len(labels[idx]) | |||
| batch_labels[idx, :length] = np.array(labels[idx]) | |||
| return ms.Tensor(batch_labels, ms.int32) | |||
| def _run_inference(self, dataset, summary, threshod=0.5): | |||
| """ | |||
| Run inference for the dataset and write the inference related data into summary. | |||
| Args: | |||
| dataset (`ds`): the parsed dataset | |||
| summary (`SummaryRecord`): the summary object to store the data | |||
| threshold (float): the threshold for prediction. | |||
| Returns: | |||
| imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels. | |||
| """ | |||
| imageid_labels = {} | |||
| ds.config.set_seed(58) | |||
| self._count = 0 | |||
| for j, next_element in enumerate(dataset): | |||
| now = time() | |||
| inputs, labels, _ = self._unpack_next_element(next_element) | |||
| prob = self._model(inputs).asnumpy() | |||
| for idx, inp in enumerate(inputs): | |||
| gt_labels = labels[idx] | |||
| gt_probs = [float(prob[idx][i]) for i in gt_labels] | |||
| data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW') | |||
| _, _, _, image_string = _make_image(_normalize(data_np)) | |||
| predicted_labels = [int(i) for i in (prob[idx] > threshod).nonzero()[0]] | |||
| predicted_probs = [float(prob[idx][i]) for i in predicted_labels] | |||
| union_labs = list(set(gt_labels + predicted_labels)) | |||
| imageid_labels[str(self._count)] = union_labs | |||
| explain = Explain() | |||
| explain.image_id = str(self._count) | |||
| explain.image_data = image_string | |||
| summary.add_value("explainer", "image", explain) | |||
| explain = Explain() | |||
| explain.image_id = str(self._count) | |||
| explain.ground_truth_label.extend(gt_labels) | |||
| explain.inference.ground_truth_prob.extend(gt_probs) | |||
| explain.inference.predicted_label.extend(predicted_labels) | |||
| explain.inference.predicted_prob.extend(predicted_probs) | |||
| summary.add_value("explainer", "inference", explain) | |||
| summary.record(1) | |||
| self._count += 1 | |||
| print("Finish running and writing {}-th batch inference data. Time elapsed: {}s".format(j, time() - now)) | |||
| return imageid_labels | |||
| def _run_exp_step(self, next_element, explainer, imageid_labels, summary): | |||
| """ | |||
| Run the explanation for each step and write explanation results into summary. | |||
| Args: | |||
| next_element (Tuple): data of one step | |||
| explainer (_Attribution): an Attribution object to generate saliency maps. | |||
| imageid_labels (dict): a dict that maps the image_id and its union labels. | |||
| summary (SummaryRecord): the summary object to store the data | |||
| Returns: | |||
| List of dict that maps label to its corresponding saliency map. | |||
| """ | |||
| inputs, labels, _ = self._unpack_next_element(next_element) | |||
| count = self._count | |||
| unions = [] | |||
| for _ in range(len(labels)): | |||
| unions_labels = imageid_labels[str(count)] | |||
| unions.append(unions_labels) | |||
| count += 1 | |||
| batch_unions = self._make_label_batch(unions) | |||
| saliency_dict_lst = [] | |||
| batch_saliency_full = [] | |||
| for i in range(len(batch_unions[0])): | |||
| batch_saliency = explainer(inputs, batch_unions[:, i]) | |||
| batch_saliency_full.append(batch_saliency) | |||
| for idx, union in enumerate(unions): | |||
| saliency_dict = {} | |||
| explain = Explain() | |||
| explain.image_id = str(self._count) | |||
| for k, lab in enumerate(union): | |||
| saliency = batch_saliency_full[k][idx:idx + 1] | |||
| saliency_dict[lab] = saliency | |||
| saliency_np = _make_rgba(saliency) | |||
| _, _, _, saliency_string = _make_image(_normalize(saliency_np)) | |||
| explanation = explain.explanation.add() | |||
| explanation.explain_method = explainer.__class__.__name__ | |||
| explanation.label = lab | |||
| explanation.heatmap = saliency_string | |||
| summary.add_value("explainer", "explanation", explain) | |||
| summary.record(1) | |||
| self._count += 1 | |||
| saliency_dict_lst.append(saliency_dict) | |||
| return saliency_dict_lst | |||
| def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst): | |||
| """ | |||
| Run the explanation and evaluation for each step and write explanation results into summary. | |||
| Args: | |||
| next_element (Tuple): Data of one step | |||
| explainer (`_Attribution`): An Attribution object to generate saliency maps. | |||
| imageid_labels (dict): A dict that maps the image_id and its union labels. | |||
| """ | |||
| inputs, labels, _ = self._unpack_next_element(next_element) | |||
| for idx, inp in enumerate(inputs): | |||
| inp = _EXPAND_DIMS(inp, 0) | |||
| saliency_dict = saliency_dict_lst[idx] | |||
| for label, saliency in saliency_dict.items(): | |||
| if isinstance(benchmarker, Localization): | |||
| _, _, bboxes = self._unpack_next_element(next_element, True) | |||
| if label in labels[idx]: | |||
| res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], | |||
| saliency=saliency) | |||
| benchmarker.aggregate(res, label) | |||
| else: | |||
| res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) | |||
| benchmarker.aggregate(res, label) | |||
| @@ -0,0 +1,285 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Utils for MindExplain""" | |||
| __all__ = [ | |||
| 'ForwardProbe', | |||
| 'calc_auc', | |||
| 'calc_correlation', | |||
| 'format_tensor_to_ndarray', | |||
| 'generate_one_hot', | |||
| 'rank_pixels', | |||
| 'resize', | |||
| 'retrieve_layer_by_name', | |||
| 'retrieve_layer', | |||
| 'unify_inputs', | |||
| 'unify_targets' | |||
| ] | |||
| from typing import Tuple, Union | |||
| import numpy as np | |||
| from PIL import Image | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as op | |||
| _Array = np.ndarray | |||
| _Module = nn.Cell | |||
| _Tensor = ms.Tensor | |||
| def generate_one_hot(indices, depth): | |||
| r""" | |||
| Simple wrap of OneHot operation, the on_value an off_value are fixed to 1.0 | |||
| and 0.0. | |||
| """ | |||
| on_value = ms.Tensor(1.0, ms.float32) | |||
| off_value = ms.Tensor(0.0, ms.float32) | |||
| weights = op.OneHot()(indices, depth, on_value, off_value) | |||
| return weights | |||
| def unify_inputs(inputs) -> tuple: | |||
| """Unify inputs of explainer.""" | |||
| if isinstance(inputs, tuple): | |||
| return inputs | |||
| if isinstance(inputs, ms.Tensor): | |||
| inputs = (inputs,) | |||
| elif isinstance(inputs, np.ndarray): | |||
| inputs = (ms.Tensor(inputs),) | |||
| else: | |||
| raise TypeError( | |||
| 'inputs must be one of [tuple, ms.Tensor or np.ndarray], ' | |||
| 'but get {}'.format(type(inputs))) | |||
| return inputs | |||
| def unify_targets(targets) -> ms.Tensor: | |||
| """Unify targets labels of explainer.""" | |||
| if isinstance(targets, ms.Tensor): | |||
| return targets | |||
| if isinstance(targets, list): | |||
| targets = ms.Tensor(targets, dtype=ms.int32) | |||
| if isinstance(targets, int): | |||
| targets = ms.Tensor([targets], dtype=ms.int32) | |||
| else: | |||
| raise TypeError( | |||
| 'targets must be one of [int, list or ms.Tensor], ' | |||
| 'but get {}'.format(type(targets))) | |||
| return targets | |||
| def retrieve_layer_by_name(model: _Module, layer_name: str): | |||
| """ | |||
| Retrieve the layer in the model by the given layer_name. | |||
| Args: | |||
| model (_Module): model which contains the target layer | |||
| layer_name (str): name of target layer | |||
| Return: | |||
| - target_layer (_Module) | |||
| Raise: | |||
| ValueError: is module with given layer_name is not found in the model, | |||
| raise ValueError. | |||
| """ | |||
| if not isinstance(layer_name, str): | |||
| raise TypeError('layer_name should be type of str, but receive {}.' | |||
| .format(type(layer_name))) | |||
| if not layer_name: | |||
| return model | |||
| target_layer = None | |||
| for name, cell in model.cells_and_names(): | |||
| if name == layer_name: | |||
| target_layer = cell | |||
| return target_layer | |||
| if target_layer is None: | |||
| raise ValueError( | |||
| 'Cannot match {}, please provide target layer' | |||
| 'in the given model.'.format(layer_name)) | |||
| return None | |||
| def retrieve_layer(model: _Module, target_layer: Union[str, _Module] = ''): | |||
| """ | |||
| Retrieve the layer in the model. | |||
| 'target' can be either a layer name or a Cell object. Given the layer name, | |||
| the method will search thourgh the model and return the matched layer. If a | |||
| Cell object is provided, it will check whether the given layer exists | |||
| in the model. If target layer is not found in the model, ValueError will | |||
| be raised. | |||
| Args: | |||
| model (_Module): the model to retrieve the target layer | |||
| target_layer (Union[str, _Module]): target layer to retrieve. Can be | |||
| either string (layer name) or the Cell object. If '' is provided, | |||
| the input model will be returned. | |||
| Return: | |||
| target layer (_Module) | |||
| """ | |||
| if isinstance(target_layer, str): | |||
| target_layer = retrieve_layer_by_name(model, target_layer) | |||
| return target_layer | |||
| if isinstance(target_layer, _Module): | |||
| for _, cell in model.cells_and_names(): | |||
| if target_layer is cell: | |||
| return target_layer | |||
| raise ValueError( | |||
| 'Model not contain cell {}, fail to probe.'.format(target_layer) | |||
| ) | |||
| raise TypeError('layer_name must have type of str or ms.nn.Cell,' | |||
| 'but receive {}'.format(type(target_layer))) | |||
| class ForwardProbe: | |||
| """ | |||
| Probe to capture output of specific layer in a given model. | |||
| Args: | |||
| target_layer (_Module): name of target layer or just provide the | |||
| target layer. | |||
| """ | |||
| def __init__(self, target_layer: _Module): | |||
| self._target_layer = target_layer | |||
| self._original_construct = self._target_layer.construct | |||
| self._intermediate_tensor = None | |||
| @property | |||
| def value(self): | |||
| return self._intermediate_tensor | |||
| def __enter__(self): | |||
| self._target_layer.construct = self._new_construct | |||
| return self | |||
| def __exit__(self, *_): | |||
| self._target_layer.construct = self._original_construct | |||
| self._intermediate_tensor = None | |||
| return False | |||
| def _new_construct(self, *inputs): | |||
| outputs = self._original_construct(*inputs) | |||
| self._intermediate_tensor = outputs | |||
| return outputs | |||
| def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray: | |||
| """Unify `mindspore.Tensor` and `np.ndarray` to `np.ndarray`. """ | |||
| if isinstance(x, ms.Tensor): | |||
| x = x.asnumpy() | |||
| if not isinstance(x, np.ndarray): | |||
| raise TypeError('input should be one of [ms.Tensor or np.ndarray],' | |||
| ' but receive {}'.format(type(x))) | |||
| return x | |||
| def calc_correlation(x: Union[ms.Tensor, np.ndarray], | |||
| y: Union[ms.Tensor, np.ndarray]) -> float: | |||
| """Calculate Pearson correlation coefficient between two arrays. """ | |||
| x = format_tensor_to_ndarray(x) | |||
| y = format_tensor_to_ndarray(y) | |||
| faithfulness = -np.corrcoef(x, y)[0, 1] | |||
| return faithfulness | |||
| def calc_auc(x: _Array) -> float: | |||
| """Calculate the Aera under Curve.""" | |||
| # take mean for multiple patches if the model is fully convolutional model | |||
| if len(x.shape) == 4: | |||
| x = np.mean(np.mean(x, axis=2), axis=3) | |||
| auc = (x.sum() - x[0] - x[-1]) / len(x) | |||
| return float(auc) | |||
| def rank_pixels(inputs: _Array, descending: bool = True) -> _Array: | |||
| """ | |||
| Generate rank order fo every pixel in an 2D array. | |||
| The rank order start from 0 to (num_pixel-1). If descending is True, the | |||
| rank order will generate in a descending order, otherwise in ascending | |||
| order. | |||
| Example: | |||
| x = np.array([[4., 3., 1.], [5., 9., 1.]]) | |||
| rank_pixels(x, descending=True) | |||
| >> np.array([[2, 3, 4], [1, 0, 5]]) | |||
| rank_pixels(x, descending=False) | |||
| >> np.array([[3, 2, 0], [4, 5, 1]]) | |||
| """ | |||
| if len(inputs.shape) != 2: | |||
| raise ValueError('Only support 2D array currently') | |||
| flatten_saliency = inputs.reshape(-1) | |||
| factor = -1 if descending else 1 | |||
| sorted_arg = np.argsort(factor * flatten_saliency, axis=0) | |||
| flatten_rank = np.zeros_like(sorted_arg) | |||
| flatten_rank[sorted_arg] = np.arange(0, sorted_arg.shape[0]) | |||
| rank_map = flatten_rank.reshape(inputs.shape) | |||
| return rank_map | |||
| def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor: | |||
| """ | |||
| Resize the intermediate layer _attribution to the same size as inputs. | |||
| Args: | |||
| inputs (ms.Tensor): the input tensor to be resized | |||
| size (tupleint]): the targeted size resize to | |||
| mode (str): the resize mode. Options: 'nearest_neighbor', 'bilinear' | |||
| Returns: | |||
| outputs (ms.Tensor): the resized tensor. | |||
| Raises: | |||
| ValueError: the resize mode is not in ['nearest_neighbor', | |||
| 'bilinear']. | |||
| """ | |||
| h, w = size | |||
| if mode == 'nearest_neighbor': | |||
| resize_nn = op.ResizeNearestNeighbor((h, w)) | |||
| outputs = resize_nn(inputs) | |||
| elif mode == 'bilinear': | |||
| inputs_np = inputs.asnumpy() | |||
| inputs_np = np.transpose(inputs_np, [0, 2, 3, 1]) | |||
| array_lst = [] | |||
| for inp in inputs_np: | |||
| array = (np.repeat(inp, 3, axis=2) * 255).astype(np.uint8) | |||
| image = Image.fromarray(array) | |||
| image = image.resize(size, resample=Image.BILINEAR) | |||
| array = np.asarray(image).astype(np.float32) / 255 | |||
| array_lst.append(array[:, :, 0:1]) | |||
| resized_np = np.transpose(array_lst, [0, 3, 1, 2]) | |||
| outputs = ms.Tensor(resized_np, inputs.dtype) | |||
| else: | |||
| raise ValueError('Unsupported resize mode {}'.format(mode)) | |||
| return outputs | |||
| @@ -0,0 +1,23 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Predefined XAI metrics.""" | |||
| from ._attribution.faithfulness import Faithfulness | |||
| from ._attribution.localization import Localization | |||
| __all__ = [ | |||
| "Faithfulness", | |||
| "Localization" | |||
| ] | |||
| @@ -0,0 +1,23 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Predefined XAI metrics""" | |||
| from .faithfulness import Faithfulness | |||
| from .localization import Localization | |||
| __all__ = [ | |||
| "Faithfulness", | |||
| "Localization" | |||
| ] | |||
| @@ -0,0 +1,593 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Faithfulness""" | |||
| import math | |||
| from typing import Callable, Optional, Union, Tuple | |||
| import numpy as np | |||
| from scipy.ndimage.filters import gaussian_filter | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as op | |||
| from .metric import AttributionMetric | |||
| from ..._utils import calc_correlation, calc_auc, format_tensor_to_ndarray, rank_pixels | |||
| from ...explanation._attribution._attribution import Attribution as _Attribution | |||
| _Array = np.ndarray | |||
| _Explainer = Union[_Attribution, Callable] | |||
| _Label = Union[int, ms.Tensor] | |||
| _Module = nn.Cell | |||
| def _calc_feature_importance(saliency: _Array, masks: _Array) -> _Array: | |||
| """Calculate feature important w.r.t given masks.""" | |||
| feature_importance = [] | |||
| num_perturbations = masks.shape[0] | |||
| for i in range(num_perturbations): | |||
| patch_feature_importance = saliency[masks[i]].sum() / masks[i].sum() | |||
| feature_importance.append(patch_feature_importance) | |||
| feature_importance = np.array(feature_importance, dtype=np.float32) | |||
| return feature_importance | |||
| class _BaseReplacement: | |||
| """ | |||
| Base class of generator for generating different replacement for perturbations. | |||
| Args: | |||
| kwargs: Optional args for generating replacement. Derived class need to | |||
| add necessary arg names and default value to '_necessary_args'. | |||
| If the argument has no default value, the value should be set to | |||
| 'EMPTY' to mark the required args. Initializing an object will | |||
| check the given kwargs w.r.t '_necessary_args'. | |||
| Raise: | |||
| ValueError: Raise when provided kwargs not contain necessary arg names with 'EMPTY' mark. | |||
| """ | |||
| _necessary_args = {} | |||
| def __init__(self, **kwargs): | |||
| self._replace_args = self._necessary_args.copy() | |||
| for key, value in self._replace_args.items(): | |||
| if key in kwargs.keys(): | |||
| self._replace_args[key] = kwargs[key] | |||
| elif key not in kwargs.keys() and value == 'EMPTY': | |||
| raise ValueError(f"Missing keyword arg {key} for {self.__class__.__name__}.") | |||
| __call__: Callable | |||
| """ | |||
| Generate replacement for perturbations. Derived class should overwrite this | |||
| function to generate different replacement for perturbing. | |||
| Args: | |||
| inputs (_Array): Array to be perturb. | |||
| Returns: | |||
| - replacement (_Array): Array to provide alternative pixels for every | |||
| position in the given | |||
| inputs. The returned array should have same shape as inputs. | |||
| """ | |||
| class Constant(_BaseReplacement): | |||
| """ Generator to provide constant-value replacement for perturbations """ | |||
| _necessary_args = {'base_value': 'EMPTY'} | |||
| def __call__(self, inputs: _Array) -> _Array: | |||
| replacement = np.ones_like(inputs, dtype=np.float32) | |||
| replacement *= self._replace_args['base_value'] | |||
| return replacement | |||
| class GaussianBlur(_BaseReplacement): | |||
| """ Generator to provided gaussian blurred inputs for perturbation. """ | |||
| _necessary_args = {'sigma': 0.7} | |||
| def __call__(self, inputs: _Array) -> _Array: | |||
| sigma = self._replace_args['sigma'] | |||
| replacement = gaussian_filter(inputs, sigma=sigma) | |||
| return replacement | |||
| class Perturb: | |||
| """ | |||
| Perturbation generator to generate perturbations for a given array. | |||
| Args: | |||
| perturb_percent (float): percentage of pixels to perturb | |||
| perturb_mode (str): specify perturbing mode, through deleting or | |||
| inserting pixels. Current support: ['Deletion', 'Insertion']. | |||
| is_accumulate (bool): whether to accumulate the former perturbations to | |||
| the later perturbations. | |||
| perturb_pixel_per_step (int, optional): number of pixel to perturb | |||
| for each perturbation. If perturb_pixel_per_step is None, actual | |||
| perturb_pixel_per_step will be calculate by: | |||
| num_image_pixel * perturb_percent / num_perturb_steps. | |||
| Default: None | |||
| num_perturbations (int, optional): number of perturbations. If | |||
| num_perturbations if None, it will be calculated by: | |||
| num_image_pixel * perturb_percent / perturb_pixel_per_step. | |||
| Default: None | |||
| """ | |||
| def __init__(self, | |||
| perturb_percent: float, | |||
| perturb_mode: str, | |||
| is_accumulate: bool, | |||
| perturb_pixel_per_step: Optional[int] = None, | |||
| num_perturbations: Optional[int] = None): | |||
| self._perturb_percent = perturb_percent | |||
| self._perturb_mode = perturb_mode | |||
| self._pixel_per_step = perturb_pixel_per_step | |||
| self._num_perturbations = num_perturbations | |||
| self._is_accumulate = is_accumulate | |||
| @staticmethod | |||
| def _assign(x: _Array, y: _Array, masks: _Array): | |||
| """Assign values to perturb pixels on perturbations.""" | |||
| if masks.dtype != bool: | |||
| raise TypeError('The param "masks" should be an array of bool, but receive {}' | |||
| .format(masks.dtype)) | |||
| for i in range(x.shape[0]): | |||
| x[i][:, masks[i]] = y[:, masks[i]] | |||
| def _generate_mask(self, saliency_rank: _Array) -> _Array: | |||
| """Generate mask for perturbations based on given saliency ranks.""" | |||
| if len(saliency_rank.shape) != 2: | |||
| raise ValueError(f'The param "saliency_rank" should be 2-dim, but receive {len(saliency_rank.shape)}.') | |||
| num_pixels = saliency_rank.shape[0] * saliency_rank.shape[1] | |||
| if self._pixel_per_step: | |||
| pixel_per_step = self._pixel_per_step | |||
| num_perturbations = math.floor( | |||
| num_pixels * self._perturb_percent / self._pixel_per_step) | |||
| elif self._num_perturbations: | |||
| pixel_per_step = math.floor( | |||
| num_pixels * self._perturb_percent / self._num_perturbations) | |||
| num_perturbations = self._num_perturbations | |||
| else: | |||
| raise ValueError("Must provide either pixel_per_step or num_perturbations.") | |||
| masks = np.zeros( | |||
| (num_perturbations, saliency_rank.shape[0], saliency_rank.shape[1]), | |||
| dtype=np.bool) | |||
| low_bound = 0 | |||
| up_bound = low_bound + pixel_per_step | |||
| factor = 0 if self._is_accumulate else 1 | |||
| for i in range(num_perturbations): | |||
| masks[i, ((saliency_rank >= low_bound) | |||
| & (saliency_rank < up_bound))] = True | |||
| low_bound = up_bound * factor | |||
| up_bound += pixel_per_step | |||
| if len(masks.shape) == 3: | |||
| return masks | |||
| raise ValueError(f'Invalid masks shape {len(masks.shape)}, expect 3-dim.') | |||
| def __call__(self, | |||
| inputs: _Array, | |||
| saliency: _Array, | |||
| reference: _Array, | |||
| return_mask: bool = False, | |||
| ) -> Union[_Array, Tuple[_Array, ...]]: | |||
| """ | |||
| Generate perturbations of given array. | |||
| Args: | |||
| inputs (_Array): input array to perturb | |||
| saliency (_Array): saliency map | |||
| return_mask (bool): whether return the mask for generating | |||
| the perturbation. The mask can be used to calculate | |||
| average feature importance of pixels perturbed at each step. | |||
| Return: | |||
| perturbations (_Array) | |||
| masks (_Array): return when return_mask is set to True. | |||
| """ | |||
| if not np.array_equal(inputs.shape, reference.shape): | |||
| raise ValueError('reference must have the same shape as inputs.') | |||
| saliency_rank = rank_pixels(saliency, descending=True) | |||
| masks = self._generate_mask(saliency_rank) | |||
| num_perturbations = masks.shape[0] | |||
| if self._perturb_mode == 'Insertion': | |||
| inputs, reference = reference, inputs | |||
| perturbations = np.tile( | |||
| inputs, (num_perturbations, *[1] * len(inputs.shape))) | |||
| Perturb._assign(perturbations, reference, masks) | |||
| if return_mask: | |||
| return perturbations, masks | |||
| return perturbations | |||
| class _FaithfulnessHelper: | |||
| """Base class for faithfulness calculator.""" | |||
| _support = [Constant, GaussianBlur] | |||
| def __init__(self, | |||
| perturb_percent: float, | |||
| perturb_mode: str, | |||
| perturb_method: str, | |||
| is_accumulate: bool, | |||
| perturb_pixel_per_step: Optional[int] = None, | |||
| num_perturbations: Optional[int] = None, | |||
| **kwargs): | |||
| self._get_reference = None | |||
| for method in self._support: | |||
| if perturb_method == method.__name__: | |||
| self._get_reference = method(**kwargs) | |||
| if self._get_reference is None: | |||
| raise ValueError( | |||
| 'The param "perturb_method" should be one of {}.'.format([x.__name__ for x in self._support])) | |||
| self._perturb = Perturb(perturb_percent=perturb_percent, | |||
| perturb_mode=perturb_mode, | |||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||
| num_perturbations=num_perturbations, | |||
| is_accumulate=is_accumulate) | |||
| calc_faithfulness: Callable | |||
| """ | |||
| Method used to calculate faithfulness for given inputs, target label, | |||
| saliency. Derive class should implement this method. | |||
| Args: | |||
| inputs (_Array): sample to calculate faithfulness score | |||
| model (_Module): model to explanation | |||
| targets (_Label): label to explanation on. | |||
| saliency (_Array): Saliency map of given inputs and targets from the | |||
| explainer. | |||
| Return: | |||
| - faithfulness (float): faithfulness score | |||
| """ | |||
| class NaiveFaithfulness(_FaithfulnessHelper): | |||
| """ | |||
| Calculator for naive faithfulness. | |||
| Naive faithfulness, the metric replace several pixels on original image by | |||
| specific method for each perturbations. The metric predicts on the perturbed | |||
| images and record a series of probabilities. Then calculates the | |||
| correlation between prob distribution and averaged feature importance. | |||
| Higher correlation indicates better faithfulness. | |||
| Args: | |||
| perturb_percent (float): percentage of pixels to perturb | |||
| perturb_method (str): specify the method to replace the pixel. | |||
| Current support: ['Constant', 'GaussianBlur'] | |||
| is_accumulate (bool): whether to accumulate the former perturbations to | |||
| the later perturbations. | |||
| Default: False. | |||
| perturb_pixel_per_step (Optional[int]): number of pixel to perturb | |||
| for each perturbation. If perturb_pixel_per_step is None, actual | |||
| perturb_pixel_per_step will be calculate by: | |||
| num_image_pixel * perturb_percent / num_perturb_steps. | |||
| Default: None | |||
| num_perturbations (Optional[int]): number of perturbations. If | |||
| num_perturbations if None, it will be calculated by: | |||
| num_image_pixel * perturb_percent / perturb_pixel_per_step. | |||
| Default: None | |||
| kwargs: specific perturb_method will require | |||
| different arguments. Below lists required args for each method. | |||
| 'Constant': base_value (int) | |||
| 'GaussianBlur': sigma (float): 0.7 | |||
| """ | |||
| def __init__(self, | |||
| perturb_percent: float, | |||
| perturb_method: str, | |||
| is_accumulate: bool = False, | |||
| perturb_pixel_per_step: Optional[int] = None, | |||
| num_perturbations: Optional[int] = None, | |||
| **kwargs): | |||
| super(NaiveFaithfulness, self).__init__( | |||
| perturb_percent=perturb_percent, | |||
| perturb_mode='Deletion', | |||
| perturb_method=perturb_method, | |||
| is_accumulate=is_accumulate, | |||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||
| num_perturbations=num_perturbations, | |||
| **kwargs) | |||
| def calc_faithfulness(self, | |||
| inputs: _Array, | |||
| model: _Module, | |||
| targets: _Label, | |||
| saliency: _Array) -> np.ndarray: | |||
| """ | |||
| Calculate naive faithfulness. | |||
| Args: | |||
| inputs (_Array): sample to calculate faithfulness score | |||
| model (_Module): model to explanation | |||
| targets (_Label): label to explanation on. | |||
| saliency (_Array): Saliency map of given inputs and targets from the | |||
| explainer. | |||
| Return: | |||
| - faithfulness (np.ndarray): faithfulness score | |||
| """ | |||
| reference = self._get_reference(inputs) | |||
| perturbations, masks = self._perturb( | |||
| inputs, saliency, reference, return_mask=True) | |||
| feature_importance = _calc_feature_importance(saliency, masks) | |||
| perturbations = ms.Tensor(perturbations, dtype=ms.float32) | |||
| predictions = model(perturbations).asnumpy()[:, targets] | |||
| faithfulness = calc_correlation(feature_importance, predictions) | |||
| normalized_faithfulness = (faithfulness + 1) / 2 | |||
| return np.array([normalized_faithfulness], np.float) | |||
| class DeletionAUC(_FaithfulnessHelper): | |||
| """ Calculator for deletion AUC. | |||
| For Deletion AUC, the metric accumulative replace pixels on origin | |||
| images through specific 'perturb_method', predict on the perturbed images | |||
| and record series of probabilities. The metric then calculates the AUC of | |||
| the probability variation curve during perturbations. Faithfulness is define | |||
| as (1 - deletion_AUC). Higher score indicates better faithfulness of | |||
| explanation. | |||
| Args: | |||
| perturb_percent (float): percentage of pixels to perturb | |||
| perturb_method (str): specify the method to replace the pixel. | |||
| Current support: ['Constant', 'GaussianBlur'] | |||
| perturb_pixel_per_step (Optional[int]): number of pixel to perturb | |||
| for each perturbation. If perturb_pixel_per_step is None, actual | |||
| perturb_pixel_per_step will be calculate by: | |||
| num_image_pixel * perturb_percent / num_perturb_steps. | |||
| Default: None | |||
| num_perturbations (Optional[int]): number of perturbations. If | |||
| num_perturbations if None, it will be calculated by: | |||
| num_image_pixel * perterb_percent / perturb_pixel_per_step. | |||
| Default: None | |||
| kwargs: specific perturb_method will require | |||
| different arguments. Below lists required args for each method. | |||
| 'Constant': base_value (int) | |||
| 'GaussianBlur': sigma (float): 0.7 | |||
| """ | |||
| def __init__(self, | |||
| perturb_percent: float, | |||
| perturb_method: str, | |||
| perturb_pixel_per_step: Optional[int] = None, | |||
| num_perturbations: Optional[int] = None, | |||
| **kwargs): | |||
| super(DeletionAUC, self).__init__( | |||
| perturb_percent=perturb_percent, | |||
| perturb_mode='Deletion', | |||
| perturb_method=perturb_method, | |||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||
| num_perturbations=num_perturbations, | |||
| is_accumulate=True, | |||
| **kwargs) | |||
| def calc_faithfulness(self, | |||
| inputs: _Array, | |||
| model: _Module, | |||
| targets: _Label, | |||
| saliency: _Array) -> np.ndarray: | |||
| """ | |||
| Calculate faithfulness through deletion AUC. | |||
| Args: | |||
| inputs (_Array): sample to calculate faithfulness score | |||
| model (_Module): model to explanation | |||
| targets (_Label): label to explanation on. | |||
| saliency (_Array): Saliency map of given inputs and targets from the | |||
| explainer. | |||
| Return: | |||
| - faithfulness (float): faithfulness score | |||
| """ | |||
| reference = self._get_reference(inputs) | |||
| perturbations = self._perturb(inputs, saliency, reference) | |||
| perturbations = ms.Tensor(perturbations, dtype=ms.float32) | |||
| predictions = model(perturbations).asnumpy()[:, targets] | |||
| input_tensor = op.ExpandDims()(ms.Tensor(inputs, ms.float32), 0) | |||
| original_output = model(input_tensor).asnumpy()[:, targets] | |||
| auc = calc_auc(original_output - predictions) | |||
| return np.array([1 - auc]) | |||
| class InsertionAUC(_FaithfulnessHelper): | |||
| """ Calculator for insertion AUC. | |||
| For Insertion AUC, the metric accumulative replace pixels of reference | |||
| image by pixels from origin image, like inserting pixel from origin image to | |||
| reference. The reference if generated through specific 'perturb_method'. | |||
| The metric predicts on the perturbed images and records series of | |||
| probabilities. The metric then calculates the AUC of the probability | |||
| variation curve during perturbations. Faithfulness is define as (1 - | |||
| deletion_AUC). Higher score indicates better faithfulness of explanation. | |||
| Args: | |||
| perturb_percent (float): percentage of pixels to perturb | |||
| perturb_method (str): specify the method to replace the pixel. | |||
| Current support: ['Constant', 'GaussianBlur'] | |||
| perturb_pixel_per_step (Optional[int]): number of pixel to perturb | |||
| for each perturbation. If perturb_pixel_per_step is None, actual | |||
| perturb_pixel_per_step will be calculate by: | |||
| num_image_pixel * perturb_percent / num_perturb_steps. | |||
| Default: None | |||
| num_perturbations (Optional[int]): number of perturbations. If | |||
| num_perturbations if None, it will be calculated by: | |||
| num_image_pixel * perterb_percent / perturb_pixel_per_step. | |||
| Default: None | |||
| kwargs: specific perturb_method will require | |||
| different arguments. Below lists required args for each method. | |||
| 'Constant': base_value (int) | |||
| 'GaussianBlur': sigma (float): 0.7 | |||
| """ | |||
| def __init__(self, | |||
| perturb_percent: float, | |||
| perturb_method: str, | |||
| perturb_pixel_per_step: Optional[int] = None, | |||
| num_perturbations: Optional[int] = None, | |||
| **kwargs): | |||
| super(InsertionAUC, self).__init__( | |||
| perturb_percent=perturb_percent, | |||
| perturb_mode='Insertion', | |||
| perturb_method=perturb_method, | |||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||
| num_perturbations=num_perturbations, | |||
| is_accumulate=True, | |||
| **kwargs) | |||
| def calc_faithfulness(self, | |||
| inputs: _Array, | |||
| model: _Module, | |||
| targets: _Label, | |||
| saliency: _Array) -> np.ndarray: | |||
| """ | |||
| Calculate faithfulness through insertion AUC. | |||
| Args: | |||
| inputs (_Array): sample to calculate faithfulness score | |||
| model (_Module): model to explanation | |||
| targets (_Label): label to explanation on. | |||
| saliency (_Array): Saliency map of given inputs and targets from the | |||
| explainer. | |||
| Return: | |||
| - faithfulness (float): faithfulness score | |||
| """ | |||
| reference = self._get_reference(inputs) | |||
| perturbations = self._perturb(inputs, saliency, reference) | |||
| perturbations = ms.Tensor(perturbations, dtype=ms.float32) | |||
| predictions = model(perturbations).asnumpy()[:, targets] | |||
| base_tensor = op.ExpandDims()(ms.Tensor(reference, ms.float32), 0) | |||
| base_outputs = model(base_tensor).asnumpy()[:, targets] | |||
| auc = calc_auc(predictions - base_outputs) | |||
| return np.array([auc]) | |||
| class Faithfulness(AttributionMetric): | |||
| """ | |||
| Provides evaluation on faithfulness on XAI explanations. | |||
| Faithfulness first generate saliency map with given explainers and calculate faithfulness based on different | |||
| faithfulness metric. | |||
| Args: | |||
| num_labels (int): number of labels | |||
| metric (str): the specifi metric to quantify faithfulness. | |||
| Options: 'DeletionAUC', 'InsertionAUC', 'NaiveFaithfulness'. | |||
| Default: 'NaiveFaithfulness'. | |||
| Examples: | |||
| >>> # init a `Faithfulness` object | |||
| >>> num_labels = 10 | |||
| >>> metric = "InsertionAUC" | |||
| >>> faithfulness = Faithfulness(num_labels, metric) | |||
| """ | |||
| _methods = [NaiveFaithfulness, DeletionAUC, InsertionAUC] | |||
| def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness"): | |||
| super(Faithfulness, self).__init__(num_labels) | |||
| perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument | |||
| perturb_method = "Constant" # perturbation method, all the perturbed pixels will be set to constant | |||
| num_perturb_pixel_per_step = None # number of pixels for each perturbation step | |||
| num_perturb_steps = 100 # separate the perturbation progress in to 100 steps. | |||
| base_value = 0.0 # the pixel value set for the perturbed pixels | |||
| self._verify_metrics(metric) | |||
| for method in self._methods: | |||
| if metric == method.__name__: | |||
| self._faithfulness_helper = method( | |||
| perturb_percent=perturb_percent, | |||
| perturb_method=perturb_method, | |||
| perturb_pixel_per_step=num_perturb_pixel_per_step, | |||
| num_perturbations=num_perturb_steps, | |||
| base_value=base_value | |||
| ) | |||
| def evaluate(self, explainer, inputs, targets, saliency=None): | |||
| """ | |||
| Evaluate faithfulness on a single data sample. | |||
| Args: | |||
| explainer (Explainer): A explainer instance object. | |||
| The 'Explainer' object see mindspore/explainer/explanation. | |||
| inputs (Tensor): data sample. Currently only support single sample at each call. | |||
| targets (Union[int, Tensor]): A target label to evaluate on. | |||
| saliency (Tensor): A saliency tensor. | |||
| Return: | |||
| np.ndarray: result of faithfulness evaluated on explainer. | |||
| Notes: | |||
| To apply `Faithfulness` to evaluate an explainer, this explainer must be initialize with a network that | |||
| contains the output activation function. Otherwise, the results will not be correct. | |||
| Examples: | |||
| >>> # init an explainer, the network should contain the output activation function. | |||
| >>> network = nn.SequentialCell([resnet50, nn.Sigmoid()]) | |||
| >>> gradient = Gradient(network) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> targets = 5 | |||
| >>> # usage 1: input the explainer and the data to be explained, | |||
| >>> # calculate the faithfulness with the specified metric | |||
| >>> res = faithfulness.evaluate(gradient, inputs, targets) | |||
| >>> # usage 2: input the generated saliency map | |||
| >>> saliency = gradient(inputs, targets) | |||
| >>> res = faithfulenss.evaluate(gradient, inputs, targets, saliency) | |||
| """ | |||
| self._check_evaluate_param(explainer, inputs, targets, saliency) | |||
| if saliency is None: | |||
| saliency = explainer(inputs, targets) | |||
| inputs = format_tensor_to_ndarray(inputs) | |||
| saliency = format_tensor_to_ndarray(saliency) | |||
| inputs = inputs.squeeze(axis=0) | |||
| saliency = saliency.squeeze() | |||
| if len(saliency.shape) != 2: | |||
| raise ValueError('Squeezed saliency map is expected to 2D, but receive {}.'.format(len(saliency.shape))) | |||
| faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=explainer.model, | |||
| targets=targets, saliency=saliency) | |||
| return faithfulness | |||
| def _verify_metrics(self, metric: str): | |||
| supports = [x.__name__ for x in self._methods] | |||
| if metric not in supports: | |||
| raise ValueError("Metric should be one of {}.".format(supports)) | |||
| @@ -0,0 +1,146 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Localization metrics.""" | |||
| import numpy as np | |||
| from mindspore.train._utils import check_value_type | |||
| from .metric import AttributionMetric | |||
| from ..._operators import maximum, reshape, Tensor | |||
| from ..._utils import format_tensor_to_ndarray | |||
| def _get_max_position(saliency): | |||
| """Get the position of the max pixel of the saliency map.""" | |||
| saliency = saliency.asnumpy() | |||
| w = saliency.shape[3] | |||
| saliency = np.reshape(saliency, (len(saliency), -1)) | |||
| max_arg = np.argmax(saliency, axis=1) | |||
| return max_arg // w, max_arg - (max_arg // w) * w | |||
| def _mask_out_saliency(saliency, threshold): | |||
| """Keep the saliency map with value greater than threshold.""" | |||
| max_value = maximum(saliency) | |||
| mask_out = saliency > (reshape(max_value, (len(saliency), -1, 1, 1)) * threshold) | |||
| return mask_out | |||
| class Localization(AttributionMetric): | |||
| """ | |||
| Provides evaluation on the localization capability of XAI methods. | |||
| We support two metrics for the evaluation os localization capability: "PointingGame" and "IoSR". | |||
| For metric "PointingGame", the localization capability is calculated as the ratio of data in which the max position | |||
| of their saliency maps lies within the bounding boxes. Specifically, for a single datum, given the saliency map and | |||
| its bounding box, if the max point of its saliency map lies within the bounding box, the evaluation result is 1 | |||
| otherwise 0. | |||
| For metric "IoSR" (Intersection over Salient Region), the localization capability is calculated as the intersection | |||
| of the bounding box and the salient region over the area of the salient region. | |||
| Args: | |||
| num_labels (int): number of classes in the dataset. | |||
| metric (str): specific metric to calculate localization capability. | |||
| Options: "PointingGame", "IoSR". | |||
| Default: "PointingGame". | |||
| Examples: | |||
| >>> from mindspore.explainer.benchmark import Localization | |||
| >>> num_labels = 100 | |||
| >>> localization = Localization(num_labels, "PointingGame") | |||
| """ | |||
| def __init__(self, | |||
| num_labels, | |||
| metric="PointingGame" | |||
| ): | |||
| super(Localization, self).__init__(num_labels) | |||
| self._verify_metrics(metric) | |||
| self._metric = metric | |||
| # Arg for specific metric, for "PointingGame" it should be an integer indicating the tolerance | |||
| # of "PointingGame", while for "IoSR" it should be a float number | |||
| # indicating the threshold to choose salient region. Default: 25. | |||
| if self._metric == "PointingGame": | |||
| self._metric_arg = 15 | |||
| else: | |||
| self._metric_arg = 0.5 | |||
| @staticmethod | |||
| def _verify_metrics(metric): | |||
| """Verify the user defined metric.""" | |||
| supports = ["PointingGame", "IoSR"] | |||
| if metric not in supports: | |||
| raise ValueError("Metric should be one of {}".format(supports)) | |||
| def evaluate(self, explainer, inputs, targets, saliency=None, mask=None): | |||
| """ | |||
| Evaluate localization on a single data sample. | |||
| Args: | |||
| explainer (Explanation): The explainer to be evaluated, see `mindspore/explainer/explanation`. | |||
| inputs (Tensor): data sample. Currently only support single sample at each call. | |||
| targets (int): target label to evaluate on. | |||
| saliency (Tensor): A saliency tensor. | |||
| mask (Union[Tensor, np.ndarray]): ground truth bounding box/masks for the inputs w.r.t targets. | |||
| Returns: | |||
| np.ndarray, result of localization evaluated on explainer | |||
| Examples: | |||
| >>> # init an explainer, the network should contain the output activation function. | |||
| >>> gradient = Gradient(network) | |||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||
| >>> masks = np.zeros(1, 1, 224, 224) | |||
| >>> masks[:, :, 65: 100, 65: 100] = 1 | |||
| >>> targets = 5 | |||
| >>> # usage 1: input the explainer and the data to be explained, | |||
| >>> # calculate the faithfulness with the specified metric | |||
| >>> res = localization.evaluate(gradient, inputs, targets, mask=masks) | |||
| >>> # usage 2: input the generated saliency map | |||
| >>> saliency = gradient(inputs, targets) | |||
| >>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks) | |||
| """ | |||
| self._check_evaluate_param(explainer, inputs, targets, saliency) | |||
| mask_np = format_tensor_to_ndarray(mask)[0] | |||
| if saliency is None: | |||
| saliency = explainer(inputs, targets) | |||
| if self._metric == "PointingGame": | |||
| point = _get_max_position(saliency) | |||
| x, y = np.meshgrid( | |||
| (np.arange(mask_np.shape[1]) - point[0]) ** 2, | |||
| (np.arange(mask_np.shape[2]) - point[1]) ** 2) | |||
| max_region = (x + y) < self._metric_arg ** 2 | |||
| # if max_region has overlap with mask_np return 1 otherwise 0. | |||
| result = 1 if (mask_np.astype(bool) & max_region).any() else 0 | |||
| elif self._metric == "IoSR": | |||
| mask_out = _mask_out_saliency(saliency, self._metric_arg) | |||
| mask_out_np = format_tensor_to_ndarray(mask_out) | |||
| overlap = np.sum(mask_np.astype(bool) & mask_out_np.astype(bool)) | |||
| saliency_area = np.sum(mask_out_np) | |||
| result = overlap / saliency_area.clip(min=1e-10) | |||
| return np.array([result], np.float) | |||
| def _check_evaluate_param_with_mask(self, explainer, inputs, targets, saliency, mask): | |||
| self._check_evaluate_param(explainer, inputs, targets, saliency) | |||
| check_value_type('mask', mask, (Tensor, np.ndarray)) | |||
| if len(inputs.shape) != 4: | |||
| raise ValueError('Argument mask must be 4D Tensor') | |||
| @@ -0,0 +1,123 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Base class for XAI metrics.""" | |||
| import numpy as np | |||
| from mindspore.train._utils import check_value_type | |||
| from ..._operators import Tensor | |||
| from ..._utils import format_tensor_to_ndarray | |||
| from ...explanation._attribution._attribution import Attribution | |||
| def verify_argument(inputs, arg_name): | |||
| """Verify the validity of the parsed arguments.""" | |||
| check_value_type(arg_name, inputs, Tensor) | |||
| if len(inputs.shape) != 4: | |||
| raise ValueError('Argument {} must be a 4D Tensor.'.format(arg_name)) | |||
| if len(inputs) > 1: | |||
| raise ValueError('Support single data evaluation only, but got {}.'.format(len(inputs))) | |||
| def verify_targets(targets, num_labels): | |||
| """Verify the validity of the parsed targets.""" | |||
| check_value_type('targets', targets, (int, Tensor)) | |||
| if isinstance(targets, Tensor): | |||
| if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != 1): | |||
| raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, ' | |||
| 'it should have the length = 1 as we only support single evaluation now.') | |||
| targets = int(targets.asnumpy()[0]) if len(targets.shape) == 1 else int(targets.asnumpy()) | |||
| if targets > num_labels - 1 or targets < 0: | |||
| raise ValueError('Parsed targets exceed the label range.') | |||
| class AttributionMetric: | |||
| """Super class of XAI metric class used in classification scenarios.""" | |||
| def __init__(self, num_labels=None): | |||
| self._num_labels = num_labels | |||
| self._global_results = {i: [] for i in range(num_labels)} | |||
| def evaluate(self, explainer, inputs, targets, saliency=None): | |||
| """This function evaluates on a single sample and return the result.""" | |||
| raise NotImplementedError | |||
| def aggregate(self, result, targets): | |||
| """Aggregates single result to global_results.""" | |||
| if isinstance(result, float): | |||
| if isinstance(targets, int): | |||
| self._global_results[targets].append(result) | |||
| else: | |||
| target_np = format_tensor_to_ndarray(targets) | |||
| if len(target_np) > 1: | |||
| raise ValueError("One result can not be aggreated to multiple targets.") | |||
| else: | |||
| result_np = format_tensor_to_ndarray(result) | |||
| if isinstance(targets, int): | |||
| for res in result_np: | |||
| self._global_results[targets].append(float(res)) | |||
| else: | |||
| target_np = format_tensor_to_ndarray(targets) | |||
| if len(target_np) != len(result_np): | |||
| raise ValueError("Length of result does not match with length of targets.") | |||
| for tar, res in zip(target_np, result_np): | |||
| self._global_results[int(tar)].append(float(res)) | |||
| def reset(self): | |||
| """Resets global_result.""" | |||
| self._global_results = {i: [] for i in range(self._num_labels)} | |||
| @property | |||
| def class_performances(self): | |||
| """ | |||
| Get the class performances by global result. | |||
| Returns: | |||
| (:class:`np.ndarray`): :attr:`num_labels`-dimensional vector | |||
| containing per-class performance. | |||
| """ | |||
| count = np.array( | |||
| [len(self._global_results[i]) for i in range(self._num_labels)]) | |||
| result_sum = np.array( | |||
| [sum(self._global_results[i]) for i in range(self._num_labels)]) | |||
| return result_sum / count.clip(min=1) | |||
| @property | |||
| def performance(self): | |||
| """ | |||
| Get the performance by global result. | |||
| Returns: | |||
| (:class:`float`): mean performance. | |||
| """ | |||
| count = sum( | |||
| [len(self._global_results[i]) for i in range(self._num_labels)]) | |||
| result_sum = sum( | |||
| [sum(self._global_results[i]) for i in range(self._num_labels)]) | |||
| if count == 0: | |||
| return 0 | |||
| return result_sum / count | |||
| def get_results(self): | |||
| """Global result of the metric can be return""" | |||
| return self._global_results | |||
| def _check_evaluate_param(self, explainer, inputs, targets, saliency): | |||
| """Check the evaluate parameters.""" | |||
| check_value_type('explainer', explainer, Attribution) | |||
| verify_argument(inputs, 'inputs') | |||
| verify_targets(targets, self._num_labels) | |||
| check_value_type('saliency', saliency, (Tensor, type(None))) | |||
| @@ -0,0 +1,26 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Predefined Attribution explainers.""" | |||
| from ._attribution._backprop.gradcam import GradCAM | |||
| from ._attribution._backprop.gradient import Gradient | |||
| from ._attribution._backprop.modified_relu import Deconvolution, GuidedBackprop | |||
| __all__ = [ | |||
| 'Gradient', | |||
| 'Deconvolution', | |||
| 'GuidedBackprop', | |||
| 'GradCAM', | |||
| ] | |||
| @@ -0,0 +1,25 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Predefined Attribution explainers.""" | |||
| from ._backprop.gradcam import GradCAM | |||
| from ._backprop.gradient import Gradient | |||
| from ._backprop.modified_relu import Deconvolution, GuidedBackprop | |||
| __all__ = [ | |||
| 'Gradient', | |||
| 'Deconvolution', | |||
| 'GuidedBackprop', | |||
| 'GradCAM', | |||
| ] | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Attribution.""" | |||
| from typing import Callable | |||
| import mindspore as ms | |||
| class Attribution: | |||
| r""" | |||
| Basic class of attributing the salient score | |||
| The explainers which explanation through attributing the relevance scores | |||
| should inherit this class. | |||
| Args: | |||
| network (ms.nn.Cell): The black-box model to explanation. | |||
| """ | |||
| def __init__(self, network): | |||
| self._verify_model(network) | |||
| self._model = network | |||
| @staticmethod | |||
| def _verify_model(model): | |||
| """ | |||
| Verify the input `network` for __init__ function. | |||
| """ | |||
| if not isinstance(model, ms.nn.Cell): | |||
| raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.") | |||
| __call__: Callable | |||
| """ | |||
| The explainers return the explanations by calling directly on the explanation. | |||
| Derived class should overwrite this implementations for different | |||
| algorithms. | |||
| Args: | |||
| input (ms.Tensor): Input tensor to be explained. | |||
| Returns: | |||
| - saliency map (ms.Tensor): saliency map of the input. | |||
| """ | |||
| @property | |||
| def model(self): | |||
| return self._model | |||
| @@ -0,0 +1,24 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Backprop-base _attribution explainer.""" | |||
| from .gradient import Gradient | |||
| from .gradcam import GradCAM | |||
| from .modified_relu import Deconvolution, GuidedBackprop | |||
| __all__ = ['Gradient', | |||
| 'GradCAM', | |||
| 'Deconvolution', | |||
| 'GuidedBackprop'] | |||
| @@ -0,0 +1,49 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Providing utility functions.""" | |||
| from mindspore.ops.composite import GradOperation | |||
| from ...._utils import unify_inputs, unify_targets, generate_one_hot | |||
| def compute_gradients(model, inputs, targets=None, weights=None): | |||
| r""" | |||
| Compute the gradient of output w.r.t input. | |||
| Args: | |||
| model (`ms.nn.Cell`): Differentiable black-box model. | |||
| inputs (`ms.Tensor`): Input to calculate gradient and explanation. | |||
| targets (int, optional): Target label id specifying which category to compute gradient. Default: None. | |||
| weights (`ms.Tensor`, optional): Custom weights for computing gradients. The shape of weights should match the | |||
| model outputs. If None is provided, an one-hot weights with one in targets positions will be used instead. | |||
| Default: None. | |||
| Returns: | |||
| saliency map (ms.Tensor): Gradient back-propagated to the input. | |||
| """ | |||
| inputs = unify_inputs(inputs) | |||
| if targets is None and weights is None: | |||
| raise ValueError('Must provide one of targets or weights') | |||
| if weights is None: | |||
| targets = unify_targets(targets) | |||
| output = model(*inputs).asnumpy() | |||
| num_categories = output.shape[-1] | |||
| weights = generate_one_hot(targets, num_categories) | |||
| grad_op = GradOperation( | |||
| get_all=True, get_by_list=False, sens_param=True)(model) | |||
| gradients = grad_op(*inputs, weights) | |||
| return gradients[0] | |||
| @@ -0,0 +1,141 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ GradCAM and GuidedGradCAM. """ | |||
| from mindspore.ops import operations as op | |||
| from .backprop_utils import compute_gradients | |||
| from .intermediate_layer import IntermediateLayerAttribution | |||
| from ...._utils import ForwardProbe, retrieve_layer, unify_inputs, unify_targets | |||
| def _gradcam_aggregation(attributions): | |||
| """ | |||
| Aggregate the gradient and activation to get the final _attribution. | |||
| Args: | |||
| attributions (Tensor): the _attribution with channel dimension. | |||
| Returns: | |||
| Tensor: the _attribution with channel dimension aggregated. | |||
| """ | |||
| sum_ = op.ReduceSum(keep_dims=True) | |||
| relu_ = op.ReLU() | |||
| attributions = relu_(sum_(attributions, 1)) | |||
| return attributions | |||
| class GradCAM(IntermediateLayerAttribution): | |||
| r""" | |||
| Provides GradCAM explanation method. | |||
| GradCAM generates saliency map at intermediate layer. | |||
| ..math: | |||
| \alpha_k^c = 1/Z \sum_i \sum_j \div{\partial{y^c}}{\partial{A_{i,j}^k}} | |||
| L_{GradCAM} = ReLu(\sum_k \alpha_k^c A^k) | |||
| For more details, please refer to the original paper: GradCAM | |||
| [https://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf] | |||
| Args: | |||
| network (Cell): The black-box model to be explained. | |||
| layer (str): The layer name to generate the explanation at. Default: ''. | |||
| If default, the explantion will be generated at the input layer. | |||
| Examples: | |||
| >>> net = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(), | |||
| >>> # you may also use the net itself. | |||
| >>> net = nn.SequentialCell([net, nn.Sigmoid()]) | |||
| >>> # specify a layer name to generate explanation, usually the layer can be set as the last conv layer. | |||
| >>> layer_name = '0.layer4' | |||
| >>> # init GradCAM with a trained network and specify the layer to obtain | |||
| >>> gradcam = GradCAM(net, layer=layer_name) | |||
| >>> # parse data and the target label to be explained and get the saliency map | |||
| >>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32) | |||
| >>> label = 5 | |||
| >>> saliency = gradcam(inputs, label) | |||
| """ | |||
| def __init__( | |||
| self, | |||
| network, | |||
| layer=""): | |||
| super(GradCAM, self).__init__(network, layer) | |||
| self._saliency_cell = retrieve_layer(self._backward_model, target_layer=layer) | |||
| self._avgpool = op.ReduceMean(keep_dims=True) | |||
| self._intermediate_grad = None | |||
| self._aggregation_fn = _gradcam_aggregation | |||
| self._resize_mode = 'bilinear' | |||
| def _hook_cell(self): | |||
| if self._saliency_cell: | |||
| self._saliency_cell.register_backward_hook(self._cell_hook_fn) | |||
| self._saliency_cell.enable_hook = True | |||
| self._intermediate_grad = None | |||
| def _cell_hook_fn(self, _, grad_input, grad_output): | |||
| """ | |||
| Hook function to deal with the backward gradient. | |||
| The arguments are set as required by Cell.register_back_hook | |||
| """ | |||
| self._intermediate_grad = grad_input | |||
| def __call__(self, inputs, targets): | |||
| """ | |||
| Call function for `GradCAM`. | |||
| Args: | |||
| inputs (Tensor): The input data to be explained, 4D Tensor. | |||
| targets (Union[Tensor, int]): The label of interest. It should be a 1D or 0D Tensor, or an integer. | |||
| If `targets` is a 1D Tensor, its length should be the same as `inputs`. | |||
| """ | |||
| self._verify_data(inputs, targets) | |||
| self._hook_cell() | |||
| with ForwardProbe(self._saliency_cell) as probe: | |||
| inputs = unify_inputs(inputs) | |||
| targets = unify_targets(targets) | |||
| gradients = compute_gradients(self._backward_model, *inputs, targets) | |||
| # get intermediate activation | |||
| activation = (probe.value,) | |||
| if self._layer == "": | |||
| activation = inputs | |||
| self._intermediate_grad = unify_inputs(gradients) | |||
| if self._intermediate_grad is not None: | |||
| # average pooling on gradients | |||
| intermediate_grad = unify_inputs( | |||
| self._avgpool(self._intermediate_grad[0], (2, 3))) | |||
| else: | |||
| raise ValueError("Gradient for intermediate layer is not " | |||
| "obtained") | |||
| mul = op.Mul() | |||
| attribution = self._aggregation_fn( | |||
| mul(*intermediate_grad, *activation)) | |||
| if self._resize: | |||
| attribution = self._resize_fn(attribution, *inputs, | |||
| mode=self._resize_mode) | |||
| self._intermediate_grad = None | |||
| return attribution | |||
| @@ -0,0 +1,129 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Gradient explainer.""" | |||
| from copy import deepcopy | |||
| from mindspore import nn | |||
| from mindspore.ops import operations as op | |||
| from mindspore.train._utils import check_value_type | |||
| from ...._operators import reshape, sqrt, Tensor | |||
| from .._attribution import Attribution | |||
| from .backprop_utils import compute_gradients | |||
| from ...._utils import unify_inputs, unify_targets | |||
| def _get_hook(bntype, cache): | |||
| """Provide backward hook function for BatchNorm layer in eval mode.""" | |||
| var, gamma, eps = cache | |||
| if bntype == "2d": | |||
| var = reshape(var, (1, -1, 1, 1)) | |||
| gamma = reshape(gamma, (1, -1, 1, 1)) | |||
| elif bntype == "1d": | |||
| var = reshape(var, (1, -1, 1)) | |||
| gamma = reshape(gamma, (1, -1, 1)) | |||
| def reset_gradient(_, grad_input, grad_output): | |||
| grad_output = grad_input[0] * gamma / sqrt(var + eps) | |||
| return grad_output | |||
| return reset_gradient | |||
| def _abs_max(gradients): | |||
| """ | |||
| Transform gradients to saliency through abs then take max along | |||
| channels. | |||
| """ | |||
| gradients = op.Abs()(gradients) | |||
| saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1) | |||
| return saliency | |||
| class Gradient(Attribution): | |||
| r""" | |||
| Provides Gradient explanation method. | |||
| Gradient is the simplest attribution method which uses the naive gradients of outputs w.r.t inputs as the | |||
| explanation. | |||
| .. math:: | |||
| _attribution = \div{\delta{y}, \delta{x}} | |||
| Args: | |||
| network (Cell): The black-box model to be explained. | |||
| Examples: | |||
| >>> net = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(), | |||
| >>> # you may also use the net itself. The saliency map might be slightly different for softmax activation. | |||
| >>> net = nn.SequentialCell([net, nn.Sigmoid()]) | |||
| >>> # init Gradient with a trained network. | |||
| >>> gradient = Gradient(net) | |||
| >>> # parse data and the target label to be explained and get the saliency map | |||
| >>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32) | |||
| >>> label = 5 | |||
| >>> saliency = gradient(inputs, label) | |||
| """ | |||
| def __init__(self, network): | |||
| super(Gradient, self).__init__(network) | |||
| self._backward_model = deepcopy(network) | |||
| self._backward_model.set_train(False) | |||
| self._backward_model.set_grad(False) | |||
| self._hook_bn() | |||
| self._grad_op = compute_gradients | |||
| self._aggregation_fn = _abs_max | |||
| def __call__(self, inputs, targets): | |||
| """ | |||
| Call function for `Gradient`. | |||
| Args: | |||
| inputs (Tensor): The input data to be explained, 4D Tensor. | |||
| targets (Union[Tensor, int]): The label of interest. It should be a 1D or 0D Tensor, or an integer. | |||
| If `targets` is a 1D `Tensor`, its length should be the same as `inputs`. | |||
| """ | |||
| self._verify_data(inputs, targets) | |||
| inputs = unify_inputs(inputs) | |||
| targets = unify_targets(targets) | |||
| gradient = self._grad_op(self._backward_model, *inputs, targets) | |||
| saliency = self._aggregation_fn(gradient) | |||
| return saliency | |||
| def _hook_bn(self): | |||
| """Hook BatchNorm layer for `self._backward_model.`""" | |||
| for _, cell in self._backward_model.cells_and_names(): | |||
| if isinstance(cell, nn.BatchNorm2d): | |||
| cache = (cell.moving_variance, cell.gamma, cell.eps) | |||
| cell.register_backward_hook(_get_hook("2d", cache=cache)) | |||
| elif isinstance(cell, nn.BatchNorm1d): | |||
| cache = (cell.moving_variance, cell.gamma, cell.eps) | |||
| cell.register_backward_hook(_get_hook("1d", cache=cache)) | |||
| @staticmethod | |||
| def _verify_data(inputs, targets): | |||
| """Verify the validity of the parsed inputs.""" | |||
| check_value_type('inputs', inputs, Tensor) | |||
| if len(inputs.shape) != 4: | |||
| raise ValueError('Argument inputs must be 4D Tensor') | |||
| check_value_type('targets', targets, (Tensor, int)) | |||
| if isinstance(targets, Tensor): | |||
| if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != len(inputs)): | |||
| raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, ' | |||
| 'it should have the same length as inputs.') | |||
| @@ -0,0 +1,47 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Base class IntermediateLayerAttribution""" | |||
| from .gradient import Gradient | |||
| from ...._utils import resize as resize_fn | |||
| class IntermediateLayerAttribution(Gradient): | |||
| """ | |||
| Base class for generating _attribution map at intermediate layer. | |||
| Args: | |||
| network (nn.Cell): DNN model to be explained. | |||
| layer (str, optional): string that specifies the layer to generate | |||
| intermediate _attribution. When using default value, the input layer | |||
| will be specified. Default: ''. | |||
| """ | |||
| def __init__(self, network, layer=''): | |||
| super(IntermediateLayerAttribution, self).__init__(network) | |||
| # Whether resize the _attribution layer to the input size. | |||
| self._resize = True | |||
| # string that specifies the resize mode. Default: 'nearest_neighbor'. | |||
| self._resize_mode = 'nearest_neighbor' | |||
| self._layer = layer | |||
| @staticmethod | |||
| def _resize_fn(attributions, inputs, mode): | |||
| """Resize the intermediate layer _attribution to the same size as inputs.""" | |||
| height, width = inputs.shape[2], inputs.shape[3] | |||
| return resize_fn(attributions, (height, width), mode) | |||
| @@ -0,0 +1,117 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Explainer with modified ReLU.""" | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as op | |||
| from .gradient import Gradient | |||
| from ...._utils import ( | |||
| unify_inputs, | |||
| unify_targets, | |||
| ) | |||
| class ModifiedReLU(Gradient): | |||
| """Basic class for modified ReLU explanation.""" | |||
| def __init__(self, network, use_relu_backprop=False): | |||
| super(ModifiedReLU, self).__init__(network) | |||
| self.use_relu_backprop = use_relu_backprop | |||
| self.hooked_list = [] | |||
| def __call__(self, inputs, targets): | |||
| self._verify_data(inputs, targets) | |||
| inputs = unify_inputs(inputs) | |||
| targets = unify_targets(targets) | |||
| self._hook_relu_backward() | |||
| gradients = self._grad_op(self._backward_model, inputs, targets) | |||
| saliency = self._aggregation_fn(gradients) | |||
| return saliency | |||
| def _hook_relu_backward(self): | |||
| """Set backward hook for ReLU layers.""" | |||
| for _, cell in self._backward_model.cells_and_names(): | |||
| if isinstance(cell, nn.ReLU): | |||
| cell.register_backward_hook(self._backward_hook) | |||
| self.hooked_list.append(cell) | |||
| def _backward_hook(self, _, grad_inputs, grad_outputs): | |||
| """Hook function for ReLU layers.""" | |||
| inputs = grad_inputs if self.use_relu_backprop else grad_outputs | |||
| relu = op.ReLU() | |||
| if isinstance(inputs, tuple): | |||
| return relu(*inputs) | |||
| return relu(inputs) | |||
| class Deconvolution(ModifiedReLU): | |||
| """ | |||
| Deconvolution explanation. | |||
| To use `Deconvolution`, the `ReLU` operations in the network must be implemented with `mindspore.nn.Cell` object | |||
| rather than `mindspore.ops.Operations.ReLU`. Otherwise, the results will not be correct. | |||
| Args: | |||
| network (Cell): The black-box model to be explained. | |||
| Examples: | |||
| >>> net = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(), | |||
| >>> # you may also use the net itself. The saliency map might be slightly different for softmax activation. | |||
| >>> net = nn.SequentialCell([net, nn.Sigmoid()]) | |||
| >>> # init Gradient with a trained network. | |||
| >>> deconvolution = Deconvolution(net) | |||
| >>> # parse data and the target label to be explained and get the saliency map | |||
| >>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32) | |||
| >>> label = 5 | |||
| >>> saliency = deconvolution(inputs, label) | |||
| """ | |||
| def __init__(self, network): | |||
| super(Deconvolution, self).__init__(network, use_relu_backprop=True) | |||
| class GuidedBackprop(ModifiedReLU): | |||
| """ | |||
| Guided-Backpropation explanation. | |||
| To use `GuidedBackprop`, the `ReLU` operations in the network must be implemented with `mindspore.nn.Cell` object | |||
| rather than `mindspore.ops.Operations.ReLU`. Otherwise, the results will not be correct. | |||
| Args: | |||
| network (Cell): The black-box model to be explained. | |||
| Examples: | |||
| >>> net = resnet50(10) | |||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(), | |||
| >>> # you may also use the net itself. The saliency map might be slightly different for softmax activation. | |||
| >>> net = nn.SequentialCell([net, nn.Sigmoid()]) | |||
| >>> # init Gradient with a trained network. | |||
| >>> gbp = GuidedBackprop(net) | |||
| >>> # parse data and the target label to be explained and get the saliency map | |||
| >>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32) | |||
| >>> label = 5 | |||
| >>> saliency = gbp(inputs, label) | |||
| """ | |||
| def __init__(self, network): | |||
| super(GuidedBackprop, self).__init__(network, use_relu_backprop=False) | |||