Merge pull request !7656 from lixiaohui33/feature_explain_coretags/v1.1.0
| @@ -118,19 +118,11 @@ message Explain { | |||||
| } | } | ||||
| message Benchmark{ | message Benchmark{ | ||||
| message TotalScore{ | |||||
| optional string benchmark_method = 1; | |||||
| optional float score = 2; | |||||
| } | |||||
| message LabelScore{ | |||||
| repeated float score = 1; | |||||
| optional string benchmark_method = 2; | |||||
| } | |||||
| optional string explain_method = 1; | |||||
| repeated TotalScore total_score = 2; | |||||
| repeated LabelScore label_score = 3; | |||||
| } | |||||
| optional string benchmark_method = 1; | |||||
| optional string explain_method = 2; | |||||
| optional float total_score = 3; | |||||
| repeated float label_score = 4; | |||||
| } | |||||
| message Metadata{ | message Metadata{ | ||||
| repeated string label = 1; | repeated string label = 1; | ||||
| @@ -0,0 +1,19 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Provide ExplainRunner High-level API.""" | |||||
| from ._runner import ExplainRunner | |||||
| __all__ = ['ExplainRunner'] | |||||
| @@ -0,0 +1,261 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Packaged operations based on MindSpore.""" | |||||
| from typing import List, Tuple, Union, Callable | |||||
| import numpy as np | |||||
| import mindspore | |||||
| from mindspore import nn | |||||
| import mindspore.ops.operations as op | |||||
| _Axis = Union[int, Tuple[int, ...], List[int]] | |||||
| _Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]] | |||||
| _Number = Union[int, float, np.int, np.float] | |||||
| _Shape = Union[int, Tuple[int, ...]] | |||||
| Tensor = mindspore.Tensor | |||||
| __all__ = [ | |||||
| 'absolute', | |||||
| 'arange', | |||||
| 'argmax', | |||||
| 'argmin', | |||||
| 'argsort', | |||||
| 'assign', | |||||
| 'intersection', | |||||
| 'matmul', | |||||
| 'maximum', | |||||
| 'minimum', | |||||
| 'mean', | |||||
| 'mul', | |||||
| 'sort', | |||||
| 'squeeze', | |||||
| 'tile', | |||||
| 'reshape', | |||||
| 'zeros', | |||||
| 'zeros_like', | |||||
| 'softmax', | |||||
| 'Tensor', | |||||
| 'summation' | |||||
| ] | |||||
| def absolute(inputs: Tensor) -> Tensor: | |||||
| """Get the absolute value of a tensor value.""" | |||||
| abs_op = op.Abs() | |||||
| outputs = abs_op(inputs) | |||||
| return outputs | |||||
| def arange( | |||||
| start: _Number, | |||||
| end: _Number, | |||||
| step: _Number = 1, | |||||
| dtype: mindspore.dtype = None) -> Tensor: | |||||
| """Get the arange value of tensor.""" | |||||
| nums = np.arange(start=start, stop=end, step=step, dtype=np.int32) | |||||
| nums = mindspore.Tensor(nums, dtype=dtype) | |||||
| return nums | |||||
| def argmax(inputs: Tensor, axis: int = -1, keep_dims: bool = False) -> Tensor: | |||||
| """Returns the indices of the maximum values along an axis.""" | |||||
| inputs_np = inputs.asnumpy() | |||||
| outputs = np.argmax(inputs_np, axis=axis) | |||||
| if keep_dims: | |||||
| outputs = np.expand_dims(outputs, axis=axis) | |||||
| return mindspore.Tensor(outputs, mindspore.int32) | |||||
| def argmin(inputs: Tensor, axis: int = -1, keep_dims: bool = False) -> Tensor: | |||||
| """Returns the indices of the minimum values along an axis.""" | |||||
| inputs_np = inputs.asnumpy() | |||||
| outputs = np.argmin(inputs_np, axis=axis) | |||||
| if keep_dims: | |||||
| outputs = np.expand_dims(outputs, axis=axis) | |||||
| return mindspore.Tensor(outputs, mindspore.int32) | |||||
| def argsort(inputs: Tensor, axis: int = -1, descending: bool = False) -> Tensor: | |||||
| """Returns the indices that would sort an array.""" | |||||
| inputs_np = inputs.asnumpy() | |||||
| factor = -1 if descending else 1 | |||||
| indices_np = np.argsort(factor * inputs_np, axis=axis) | |||||
| indices = mindspore.Tensor(indices_np, dtype=mindspore.int32) | |||||
| return indices | |||||
| def assign(inputs: Tensor, idx: _Idx, value: Tensor) -> Tensor: | |||||
| """Assign a tensor value to the given tensor and index.""" | |||||
| inputs_np = inputs.asnumpy() | |||||
| if isinstance(idx, Tensor): | |||||
| idx = idx.asnumpy() | |||||
| value_np = value.asnumpy() | |||||
| inputs_np[idx] = value_np | |||||
| outputs = mindspore.Tensor(inputs_np) | |||||
| return outputs | |||||
| def intersection(*inputs: Tensor) -> Tensor: | |||||
| """Get the intersection value by the given tensor list.""" | |||||
| outputs_np = np.ones_like(inputs[0]) | |||||
| for inp in inputs: | |||||
| outputs_np &= inp.asnumpy() | |||||
| outputs = mindspore.Tensor(outputs_np) | |||||
| return outputs | |||||
| def matmul(inputs_x: Tensor, inputs_y: Tensor) -> Tensor: | |||||
| """Multiplies matrix `inputs_x` and matrix `inputs_y`.""" | |||||
| matmul_op = op.MatMul() | |||||
| outputs = matmul_op(inputs_x, inputs_y) | |||||
| return outputs | |||||
| def maximum(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor: | |||||
| """Reduce a dimension of a tensor by the maximum value in this dimension.""" | |||||
| max_op = op.ReduceMax(keep_dims) | |||||
| outputs = max_op(inputs, axis) | |||||
| return outputs | |||||
| def minimum(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor: | |||||
| """Reduce a dimension of a tensor by the minimum value in the dimension.""" | |||||
| max_op = op.ReduceMin(keep_dims) | |||||
| outputs = max_op(inputs, axis) | |||||
| return outputs | |||||
| def mean(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor: | |||||
| """Reduce a dimension of a tensor by averaging all elements in the dimension.""" | |||||
| mean_op = op.ReduceMean(keep_dims) | |||||
| outputs = mean_op(inputs, axis) | |||||
| return outputs | |||||
| def mul(inputs_x: Tensor, inputs_y: Tensor) -> Tensor: | |||||
| """ | |||||
| Multiplies two tensors element-wise. | |||||
| Inputs of `input_x` and `input_y` comply with the implicit type conversion rules to make the data types consistent. | |||||
| The inputs must be two tensors or one tensor and one scalar. | |||||
| When the inputs are two tensors, | |||||
| dtypes of them cannot be both bool, and the shapes of them could be broadcast. | |||||
| When the inputs are one tensor and one scalar, | |||||
| the scalar could only be a constant. | |||||
| Inputs: | |||||
| - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or | |||||
| a bool or a tensor whose data type is number or bool. | |||||
| - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or | |||||
| a bool when the first input is a tensor or a tensor whose data type is number or bool. | |||||
| Outputs: | |||||
| Tensor, the shape is the same as the one after broadcasting, | |||||
| and the data type is the one with higher precision or higher digits among the two inputs. | |||||
| """ | |||||
| mul_op = op.Mul() | |||||
| outputs = mul_op(inputs_x, inputs_y) | |||||
| return outputs | |||||
| def sort(inputs: Tensor, axis: _Axis = -1, descending: bool = False) -> Tensor: | |||||
| """Return a sorted copy of an array.""" | |||||
| inputs_np = inputs.asnumpy() | |||||
| outputs_np = np.sort(inputs_np, axis=axis) | |||||
| if descending: | |||||
| outputs_np = np.flip(outputs_np, axis=axis) | |||||
| outputs = mindspore.Tensor(outputs_np) | |||||
| return outputs | |||||
| def squeeze(inputs: Tensor, axis: _Axis = ()): | |||||
| """Returns a tensor with the same type but dimensions of 1 are removed based on `axis`.""" | |||||
| squeeze_op = op.Squeeze(axis) | |||||
| outputs = squeeze_op(inputs) | |||||
| return outputs | |||||
| def tile(inputs: Tensor, shape: Tuple[int, ...]) -> Tensor: | |||||
| """Replicates a tensor with given multiples times.""" | |||||
| tile_op = op.Tile() | |||||
| outputs = tile_op(inputs, shape) | |||||
| return outputs | |||||
| def reshape(inputs: Tensor, shape: _Shape) -> Tensor: | |||||
| """Reshapes input tensor with the same values based on a given shape tuple.""" | |||||
| if isinstance(shape, int): | |||||
| shape = (shape,) | |||||
| return op.Reshape()(inputs, shape) | |||||
| def zeros(shape: _Shape, dtype: mindspore.dtype = None) -> Tensor: | |||||
| """Return a new array of given shape and type, filled with zeros.""" | |||||
| outputs = np.zeros(shape) | |||||
| return mindspore.Tensor(outputs, dtype=dtype) | |||||
| def zeros_like(inputs: Tensor, dtype: mindspore.dtype = None) -> Tensor: | |||||
| """Return an array of zeros with the same shape and type as a given array.""" | |||||
| inputs_np = inputs.asnumpy() | |||||
| outputs_np = np.zeros_like(inputs_np) | |||||
| outputs = mindspore.Tensor(outputs_np, dtype) | |||||
| return outputs | |||||
| def random(shape: _Shape, dtype: mindspore.dtype = None) -> Tensor: | |||||
| """Return random floats in the half-open interval [0.0, 1.0).""" | |||||
| outputs_np = np.random.random(shape) | |||||
| outputs = mindspore.Tensor(outputs_np, dtype) | |||||
| return outputs | |||||
| def randint(low: int, high: int, shape: _Shape, dtype: mindspore.dtype = mindspore.int8) -> Tensor: | |||||
| """Return random integers from `low` (inclusive) to `high` (exclusive).""" | |||||
| outputs_np = np.random.randint(low, high, size=shape) | |||||
| outputs = mindspore.Tensor(outputs_np, dtype=dtype) | |||||
| return outputs | |||||
| def softmax(axis: int) -> Callable: | |||||
| """Softmax activation function.""" | |||||
| func = nn.Softmax(axis=axis) | |||||
| return func | |||||
| def summation(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor: | |||||
| """Reduce a dimension of a tensor by summing all elements in the dimension.""" | |||||
| sum_op = op.ReduceSum(keep_dims) | |||||
| outputs = sum_op(inputs, axis) | |||||
| return outputs | |||||
| def stack(inputs: List[Tensor], axis: int) -> Tensor: | |||||
| """Packs a list of tensors in specified axis.""" | |||||
| pack_op = op.Pack(axis) | |||||
| outputs = pack_op(inputs) | |||||
| return outputs | |||||
| def sqrt(inputs: Tensor) -> Tensor: | |||||
| """Returns square root of a tensor element-wise.""" | |||||
| sqrt_op = op.Sqrt() | |||||
| return sqrt_op(inputs) | |||||
| @@ -0,0 +1,481 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Runner.""" | |||||
| from time import time | |||||
| from typing import Tuple, List, Optional | |||||
| import numpy as np | |||||
| from mindspore.train.summary_pb2 import Explain | |||||
| import mindspore as ms | |||||
| import mindspore.dataset as ds | |||||
| from mindspore import log | |||||
| from mindspore.ops.operations import ExpandDims | |||||
| from mindspore.train.summary._summary_adapter import _convert_image_format, _make_image | |||||
| from mindspore.train.summary.summary_record import SummaryRecord | |||||
| from .benchmark import Localization | |||||
| from .benchmark._attribution.metric import AttributionMetric | |||||
| from .explanation._attribution._attribution import Attribution | |||||
| _EXPAND_DIMS = ExpandDims() | |||||
| _CMAP_0 = np.reshape(np.array([55, 25, 86, 255]), [1, 1, 4]) / 255 | |||||
| _CMAP_1 = np.reshape(np.array([255, 255, 0, 255]), [1, 1, 4]) / 255 | |||||
| def _normalize(img_np): | |||||
| """Normalize the image in the numpy array to be in [0, 255]. """ | |||||
| max_ = img_np.max() | |||||
| min_ = img_np.min() | |||||
| normed = (img_np - min_) / (max_ - min_).clip(min=1e-10) | |||||
| return (normed * 255).astype(np.uint8) | |||||
| def _make_rgba(saliency): | |||||
| """Make rgba image for saliency map.""" | |||||
| saliency = saliency.asnumpy().squeeze() | |||||
| saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min()).clip(1e-10) | |||||
| rgba = np.empty((saliency.shape[0], saliency.shape[1], 4)) | |||||
| rgba[:, :, :] = np.expand_dims(saliency, 2) | |||||
| rgba = rgba * _CMAP_1 + (1 - rgba) * _CMAP_0 | |||||
| rgba[:, :, -1] = saliency * 1 | |||||
| return rgba | |||||
| class ExplainRunner: | |||||
| """ | |||||
| High-level API for users to generate results with the explanation methods and the evaluation methods. | |||||
| After generating results with the explanation methods and the evaluation methods, the results will be written into | |||||
| a specified file with 'mindspore.summary.SummaryRecord'. The stored content can be viewed using MindInsight. | |||||
| Args: | |||||
| summary_dir (str): The directory path to save the summary files which store the generated results. | |||||
| Default: "./" | |||||
| Examples: | |||||
| >>> # init a runner with a specified directory | |||||
| >>> summary_dir = "summary_dir" | |||||
| >>> runner = ExplainRunner(summary_dir) | |||||
| """ | |||||
| def __init__(self, summary_dir: Optional[str] = "./"): | |||||
| self._summary_dir = summary_dir | |||||
| self._count = 0 | |||||
| self._classes = None | |||||
| self._model = None | |||||
| def run(self, | |||||
| dataset: Tuple, | |||||
| explainers: List, | |||||
| benchmarkers: Optional[List] = None): | |||||
| """ | |||||
| Genereate results and write results into the summary files in `self.summary_dir`. | |||||
| Args: | |||||
| dataset (tuple): A tuple that contains `mindspore.dataset` object for iteration and its labels. | |||||
| - dataset[0], a `mindspore.dataset` object to provide data to explain. | |||||
| - dataset[1], a list of string that specifies the label names of the dataset. | |||||
| explainers (list): A list of explanation objects to generate _attribution results. | |||||
| benchmarkers (list): A list of benchmark objects to generate evaluation results. Default: None | |||||
| Examples: | |||||
| >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient | |||||
| >>> # obtain dataset object | |||||
| >>> dataset = get_dataset() | |||||
| >>> classes = ["cat", "dog", ...] | |||||
| >>> # load checkpoint to a network, e.g. resnet50 | |||||
| >>> param_dict = load_checkpoint("checkpoint.ckpt") | |||||
| >>> net = resnet50(len(classes)) | |||||
| >>> load_parama_into_net(net, param_dict) | |||||
| >>> # bind net with its output activation | |||||
| >>> model = nn.SequentialCell([net, nn.Sigmoid()]) | |||||
| >>> gbp = GuidedBackprop(model) | |||||
| >>> gradient = Gradient(model) | |||||
| >>> runner = ExplainRunner("./") | |||||
| >>> explainers = [gbp, gradient] | |||||
| >>> runner.run((dataset, classes), explainers) | |||||
| """ | |||||
| if not isinstance(dataset, tuple): | |||||
| raise TypeError("Argument `dataset` must be a tuple.") | |||||
| if len(dataset) != 2: | |||||
| raise ValueError("Argument `dataset` should be a tuple with length = 2.") | |||||
| dataset, classes = dataset | |||||
| self._verify_data_form(dataset, benchmarkers) | |||||
| self._classes = classes | |||||
| if explainers is None or not explainers: | |||||
| raise ValueError("Argument `explainers` can neither be None nor empty.") | |||||
| for exp in explainers: | |||||
| if not isinstance(exp, Attribution) or not isinstance(explainers, list): | |||||
| raise TypeError("Argument explainers should be a list of objects of classes in " | |||||
| "`mindspore.explainer.explanation._attribution`.") | |||||
| if benchmarkers is not None: | |||||
| for bench in benchmarkers: | |||||
| if not isinstance(bench, AttributionMetric) or not isinstance(explainers, list): | |||||
| raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation" | |||||
| "`mindspore.explainer.benchmark._attribution`.") | |||||
| self._model = explainers[0].model | |||||
| with SummaryRecord(self._summary_dir) as summary: | |||||
| print("Start running and writing......") | |||||
| begin = time() | |||||
| print("Start writing metadata.") | |||||
| explain = Explain() | |||||
| explain.metadata.label.extend(classes) | |||||
| exp_names = [exp.__class__.__name__ for exp in explainers] | |||||
| explain.metadata.explain_method.extend(exp_names) | |||||
| if benchmarkers is not None: | |||||
| bench_names = [bench.__class__.__name__ for bench in benchmarkers] | |||||
| explain.metadata.benchmark_method.extend(bench_names) | |||||
| summary.add_value("explainer", "metadata", explain) | |||||
| summary.record(1) | |||||
| print("Finish writing metadata.") | |||||
| now = time() | |||||
| print("Start running and writing inference data......") | |||||
| imageid_labels = self._run_inference(dataset, summary) | |||||
| print("Finish running and writing inference data. Time elapsed: {}s".format(time() - now)) | |||||
| if benchmarkers is None: | |||||
| for exp in explainers: | |||||
| start = time() | |||||
| print("Start running and writing explanation data for {}......".format(exp.__class__.__name__)) | |||||
| self._count = 0 | |||||
| ds.config.set_seed(58) | |||||
| for idx, next_element in enumerate(dataset): | |||||
| now = time() | |||||
| self._run_exp_step(next_element, exp, imageid_labels, summary) | |||||
| print("Finish writing {}-th explanation data. Time elapsed: {}".format( | |||||
| idx, time() - now)) | |||||
| print("Finish running and writing explanation data for {}. Time elapsed: {}".format( | |||||
| exp.__class__.__name__, time() - start)) | |||||
| else: | |||||
| for exp in explainers: | |||||
| explain = Explain() | |||||
| for bench in benchmarkers: | |||||
| bench.reset() | |||||
| print(f"Start running and writing explanation and benchmark data for {exp.__class__.__name__}.") | |||||
| self._count = 0 | |||||
| start = time() | |||||
| ds.config.set_seed(58) | |||||
| for idx, next_element in enumerate(dataset): | |||||
| now = time() | |||||
| saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary) | |||||
| print("Finish writing {}-th batch explanation data. Time elapsed: {}s".format( | |||||
| idx, time() - now)) | |||||
| for bench in benchmarkers: | |||||
| now = time() | |||||
| self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst) | |||||
| print("Finish running {}-th batch benchmark data for {}. Time elapsed: {}s".format( | |||||
| idx, bench.__class__.__name__, time() - now)) | |||||
| for bench in benchmarkers: | |||||
| benchmark = explain.benchmark.add() | |||||
| benchmark.explain_method = exp.__class__.__name__ | |||||
| benchmark.benchmark_method = bench.__class__.__name__ | |||||
| benchmark.total_score = bench.performance | |||||
| benchmark.label_score.extend(bench.class_performances) | |||||
| print("Finish running and writing explanation and benchmark data for {}. " | |||||
| "Time elapsed: {}s".format(exp.__class__.__name__, time() - start)) | |||||
| summary.add_value('explainer', 'benchmark', explain) | |||||
| summary.record(1) | |||||
| print("Finish running and writing. Total time elapsed: {}s".format(time() - begin)) | |||||
| @staticmethod | |||||
| def _verify_data_form(dataset, benchmarkers): | |||||
| """ | |||||
| Verify the validity of dataset. | |||||
| Args: | |||||
| dataset (`ds`): the user parsed dataset. | |||||
| benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers. | |||||
| """ | |||||
| next_element = dataset.create_tuple_iterator().get_next() | |||||
| if len(next_element) not in [1, 2, 3]: | |||||
| raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]" | |||||
| " as columns.") | |||||
| if len(next_element) == 3: | |||||
| inputs, labels, bboxes = next_element | |||||
| if bboxes.shape[-1] != 4: | |||||
| raise ValueError("The third element of dataset should be bounding boxes with shape of " | |||||
| "[batch_size, num_ground_truth, 4].") | |||||
| else: | |||||
| if True in [isinstance(bench, Localization) for bench in benchmarkers]: | |||||
| raise ValueError("The dataset must provide bboxes if Localization is to be computed.") | |||||
| if len(next_element) == 2: | |||||
| inputs, labels = next_element | |||||
| if len(next_element) == 1: | |||||
| inputs = next_element[0] | |||||
| if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]: | |||||
| raise ValueError( | |||||
| "Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format( | |||||
| inputs.shape)) | |||||
| if len(inputs.shape) == 3: | |||||
| log.warning( | |||||
| "Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th" | |||||
| " dimension as batch data.".format(inputs.shape)) | |||||
| if len(next_element) > 1: | |||||
| if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1: | |||||
| raise ValueError( | |||||
| "Labels shape {} is unrecognizable: labels should not have more than two dimensions" | |||||
| " with length greater than 1.".format(labels.shape)) | |||||
| def _transform_data(self, inputs, labels, bboxes, ifbbox): | |||||
| """ | |||||
| Transform the data from one iteration of dataset to a unifying form for the follow-up operations. | |||||
| Args: | |||||
| inputs (Tensor): the image data | |||||
| labels (Tensor): the labels | |||||
| bboxes (Tensor): the boudnding boxes data | |||||
| ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t label | |||||
| id will be returned. If False, the returned bboxes is the the parsed bboxes. | |||||
| Returns: | |||||
| inputs (Tensor): the image data, unified to a 4D Tensor. | |||||
| labels (List[List[int]]): the ground truth labels. | |||||
| bboxes (Union[List[Dict], None, Tensor]): the bounding boxes | |||||
| """ | |||||
| inputs = ms.Tensor(inputs, ms.float32) | |||||
| if len(inputs.shape) == 3: | |||||
| inputs = _EXPAND_DIMS(inputs, 0) | |||||
| if isinstance(labels, ms.Tensor): | |||||
| labels = ms.Tensor(labels, ms.int32) | |||||
| labels = _EXPAND_DIMS(labels, 0) | |||||
| if isinstance(bboxes, ms.Tensor): | |||||
| bboxes = ms.Tensor(bboxes, ms.int32) | |||||
| bboxes = _EXPAND_DIMS(bboxes, 0) | |||||
| input_len = len(inputs) | |||||
| if bboxes is not None and ifbbox: | |||||
| bboxes = ms.Tensor(bboxes, ms.int32) | |||||
| masks_lst = [] | |||||
| labels = labels.asnumpy().reshape([input_len, -1]) | |||||
| bboxes = bboxes.asnumpy().reshape([input_len, -1, 4]) | |||||
| for idx, label in enumerate(labels): | |||||
| height, width = inputs[idx].shape[-2], inputs[idx].shape[-1] | |||||
| masks = {} | |||||
| for j, label_item in enumerate(label): | |||||
| target = int(label_item) | |||||
| if -1 < target < len(self._classes): | |||||
| if target not in masks: | |||||
| mask = np.zeros((1, 1, height, width)) | |||||
| else: | |||||
| mask = masks[target] | |||||
| x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int) | |||||
| mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1 | |||||
| masks[target] = mask | |||||
| masks_lst.append(masks) | |||||
| bboxes = masks_lst | |||||
| labels = ms.Tensor(labels, ms.int32) | |||||
| if len(labels.shape) == 1: | |||||
| labels_lst = [[int(i)] for i in labels.asnumpy()] | |||||
| else: | |||||
| labels = labels.asnumpy().reshape([input_len, -1]) | |||||
| labels_lst = [] | |||||
| for item in labels: | |||||
| labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._classes)))) | |||||
| labels = labels_lst | |||||
| return inputs, labels, bboxes | |||||
| def _unpack_next_element(self, next_element, ifbbox=False): | |||||
| """ | |||||
| Unpack a single iteration of dataset. | |||||
| Args: | |||||
| next_element (Tuple): a single element iterated from dataset object. | |||||
| ifbbox (bool): whether to preprocess bboxes in self._transform_data. | |||||
| Returns: | |||||
| Tuple, a unified Tuple contains image_data, labels, and bounding boxes. | |||||
| """ | |||||
| if len(next_element) == 3: | |||||
| inputs, labels, bboxes = next_element | |||||
| elif len(next_element) == 2: | |||||
| inputs, labels = next_element | |||||
| bboxes = None | |||||
| else: | |||||
| inputs = next_element[0] | |||||
| labels = [[] for x in inputs] | |||||
| bboxes = None | |||||
| inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox) | |||||
| return inputs, labels, bboxes | |||||
| @staticmethod | |||||
| def _make_label_batch(labels): | |||||
| """ | |||||
| Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max | |||||
| length of all the rows in labels. | |||||
| Args: | |||||
| labels (List[List]): the union labels of a data batch. | |||||
| Returns: | |||||
| 2D Tensor. | |||||
| """ | |||||
| max_len = max([len(l) for l in labels]) | |||||
| batch_labels = np.zeros((len(labels), max_len)) | |||||
| for idx, _ in enumerate(batch_labels): | |||||
| length = len(labels[idx]) | |||||
| batch_labels[idx, :length] = np.array(labels[idx]) | |||||
| return ms.Tensor(batch_labels, ms.int32) | |||||
| def _run_inference(self, dataset, summary, threshod=0.5): | |||||
| """ | |||||
| Run inference for the dataset and write the inference related data into summary. | |||||
| Args: | |||||
| dataset (`ds`): the parsed dataset | |||||
| summary (`SummaryRecord`): the summary object to store the data | |||||
| threshold (float): the threshold for prediction. | |||||
| Returns: | |||||
| imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels. | |||||
| """ | |||||
| imageid_labels = {} | |||||
| ds.config.set_seed(58) | |||||
| self._count = 0 | |||||
| for j, next_element in enumerate(dataset): | |||||
| now = time() | |||||
| inputs, labels, _ = self._unpack_next_element(next_element) | |||||
| prob = self._model(inputs).asnumpy() | |||||
| for idx, inp in enumerate(inputs): | |||||
| gt_labels = labels[idx] | |||||
| gt_probs = [float(prob[idx][i]) for i in gt_labels] | |||||
| data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW') | |||||
| _, _, _, image_string = _make_image(_normalize(data_np)) | |||||
| predicted_labels = [int(i) for i in (prob[idx] > threshod).nonzero()[0]] | |||||
| predicted_probs = [float(prob[idx][i]) for i in predicted_labels] | |||||
| union_labs = list(set(gt_labels + predicted_labels)) | |||||
| imageid_labels[str(self._count)] = union_labs | |||||
| explain = Explain() | |||||
| explain.image_id = str(self._count) | |||||
| explain.image_data = image_string | |||||
| summary.add_value("explainer", "image", explain) | |||||
| explain = Explain() | |||||
| explain.image_id = str(self._count) | |||||
| explain.ground_truth_label.extend(gt_labels) | |||||
| explain.inference.ground_truth_prob.extend(gt_probs) | |||||
| explain.inference.predicted_label.extend(predicted_labels) | |||||
| explain.inference.predicted_prob.extend(predicted_probs) | |||||
| summary.add_value("explainer", "inference", explain) | |||||
| summary.record(1) | |||||
| self._count += 1 | |||||
| print("Finish running and writing {}-th batch inference data. Time elapsed: {}s".format(j, time() - now)) | |||||
| return imageid_labels | |||||
| def _run_exp_step(self, next_element, explainer, imageid_labels, summary): | |||||
| """ | |||||
| Run the explanation for each step and write explanation results into summary. | |||||
| Args: | |||||
| next_element (Tuple): data of one step | |||||
| explainer (_Attribution): an Attribution object to generate saliency maps. | |||||
| imageid_labels (dict): a dict that maps the image_id and its union labels. | |||||
| summary (SummaryRecord): the summary object to store the data | |||||
| Returns: | |||||
| List of dict that maps label to its corresponding saliency map. | |||||
| """ | |||||
| inputs, labels, _ = self._unpack_next_element(next_element) | |||||
| count = self._count | |||||
| unions = [] | |||||
| for _ in range(len(labels)): | |||||
| unions_labels = imageid_labels[str(count)] | |||||
| unions.append(unions_labels) | |||||
| count += 1 | |||||
| batch_unions = self._make_label_batch(unions) | |||||
| saliency_dict_lst = [] | |||||
| batch_saliency_full = [] | |||||
| for i in range(len(batch_unions[0])): | |||||
| batch_saliency = explainer(inputs, batch_unions[:, i]) | |||||
| batch_saliency_full.append(batch_saliency) | |||||
| for idx, union in enumerate(unions): | |||||
| saliency_dict = {} | |||||
| explain = Explain() | |||||
| explain.image_id = str(self._count) | |||||
| for k, lab in enumerate(union): | |||||
| saliency = batch_saliency_full[k][idx:idx + 1] | |||||
| saliency_dict[lab] = saliency | |||||
| saliency_np = _make_rgba(saliency) | |||||
| _, _, _, saliency_string = _make_image(_normalize(saliency_np)) | |||||
| explanation = explain.explanation.add() | |||||
| explanation.explain_method = explainer.__class__.__name__ | |||||
| explanation.label = lab | |||||
| explanation.heatmap = saliency_string | |||||
| summary.add_value("explainer", "explanation", explain) | |||||
| summary.record(1) | |||||
| self._count += 1 | |||||
| saliency_dict_lst.append(saliency_dict) | |||||
| return saliency_dict_lst | |||||
| def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst): | |||||
| """ | |||||
| Run the explanation and evaluation for each step and write explanation results into summary. | |||||
| Args: | |||||
| next_element (Tuple): Data of one step | |||||
| explainer (`_Attribution`): An Attribution object to generate saliency maps. | |||||
| imageid_labels (dict): A dict that maps the image_id and its union labels. | |||||
| """ | |||||
| inputs, labels, _ = self._unpack_next_element(next_element) | |||||
| for idx, inp in enumerate(inputs): | |||||
| inp = _EXPAND_DIMS(inp, 0) | |||||
| saliency_dict = saliency_dict_lst[idx] | |||||
| for label, saliency in saliency_dict.items(): | |||||
| if isinstance(benchmarker, Localization): | |||||
| _, _, bboxes = self._unpack_next_element(next_element, True) | |||||
| if label in labels[idx]: | |||||
| res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], | |||||
| saliency=saliency) | |||||
| benchmarker.aggregate(res, label) | |||||
| else: | |||||
| res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) | |||||
| benchmarker.aggregate(res, label) | |||||
| @@ -0,0 +1,285 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Utils for MindExplain""" | |||||
| __all__ = [ | |||||
| 'ForwardProbe', | |||||
| 'calc_auc', | |||||
| 'calc_correlation', | |||||
| 'format_tensor_to_ndarray', | |||||
| 'generate_one_hot', | |||||
| 'rank_pixels', | |||||
| 'resize', | |||||
| 'retrieve_layer_by_name', | |||||
| 'retrieve_layer', | |||||
| 'unify_inputs', | |||||
| 'unify_targets' | |||||
| ] | |||||
| from typing import Tuple, Union | |||||
| import numpy as np | |||||
| from PIL import Image | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| import mindspore.ops.operations as op | |||||
| _Array = np.ndarray | |||||
| _Module = nn.Cell | |||||
| _Tensor = ms.Tensor | |||||
| def generate_one_hot(indices, depth): | |||||
| r""" | |||||
| Simple wrap of OneHot operation, the on_value an off_value are fixed to 1.0 | |||||
| and 0.0. | |||||
| """ | |||||
| on_value = ms.Tensor(1.0, ms.float32) | |||||
| off_value = ms.Tensor(0.0, ms.float32) | |||||
| weights = op.OneHot()(indices, depth, on_value, off_value) | |||||
| return weights | |||||
| def unify_inputs(inputs) -> tuple: | |||||
| """Unify inputs of explainer.""" | |||||
| if isinstance(inputs, tuple): | |||||
| return inputs | |||||
| if isinstance(inputs, ms.Tensor): | |||||
| inputs = (inputs,) | |||||
| elif isinstance(inputs, np.ndarray): | |||||
| inputs = (ms.Tensor(inputs),) | |||||
| else: | |||||
| raise TypeError( | |||||
| 'inputs must be one of [tuple, ms.Tensor or np.ndarray], ' | |||||
| 'but get {}'.format(type(inputs))) | |||||
| return inputs | |||||
| def unify_targets(targets) -> ms.Tensor: | |||||
| """Unify targets labels of explainer.""" | |||||
| if isinstance(targets, ms.Tensor): | |||||
| return targets | |||||
| if isinstance(targets, list): | |||||
| targets = ms.Tensor(targets, dtype=ms.int32) | |||||
| if isinstance(targets, int): | |||||
| targets = ms.Tensor([targets], dtype=ms.int32) | |||||
| else: | |||||
| raise TypeError( | |||||
| 'targets must be one of [int, list or ms.Tensor], ' | |||||
| 'but get {}'.format(type(targets))) | |||||
| return targets | |||||
| def retrieve_layer_by_name(model: _Module, layer_name: str): | |||||
| """ | |||||
| Retrieve the layer in the model by the given layer_name. | |||||
| Args: | |||||
| model (_Module): model which contains the target layer | |||||
| layer_name (str): name of target layer | |||||
| Return: | |||||
| - target_layer (_Module) | |||||
| Raise: | |||||
| ValueError: is module with given layer_name is not found in the model, | |||||
| raise ValueError. | |||||
| """ | |||||
| if not isinstance(layer_name, str): | |||||
| raise TypeError('layer_name should be type of str, but receive {}.' | |||||
| .format(type(layer_name))) | |||||
| if not layer_name: | |||||
| return model | |||||
| target_layer = None | |||||
| for name, cell in model.cells_and_names(): | |||||
| if name == layer_name: | |||||
| target_layer = cell | |||||
| return target_layer | |||||
| if target_layer is None: | |||||
| raise ValueError( | |||||
| 'Cannot match {}, please provide target layer' | |||||
| 'in the given model.'.format(layer_name)) | |||||
| return None | |||||
| def retrieve_layer(model: _Module, target_layer: Union[str, _Module] = ''): | |||||
| """ | |||||
| Retrieve the layer in the model. | |||||
| 'target' can be either a layer name or a Cell object. Given the layer name, | |||||
| the method will search thourgh the model and return the matched layer. If a | |||||
| Cell object is provided, it will check whether the given layer exists | |||||
| in the model. If target layer is not found in the model, ValueError will | |||||
| be raised. | |||||
| Args: | |||||
| model (_Module): the model to retrieve the target layer | |||||
| target_layer (Union[str, _Module]): target layer to retrieve. Can be | |||||
| either string (layer name) or the Cell object. If '' is provided, | |||||
| the input model will be returned. | |||||
| Return: | |||||
| target layer (_Module) | |||||
| """ | |||||
| if isinstance(target_layer, str): | |||||
| target_layer = retrieve_layer_by_name(model, target_layer) | |||||
| return target_layer | |||||
| if isinstance(target_layer, _Module): | |||||
| for _, cell in model.cells_and_names(): | |||||
| if target_layer is cell: | |||||
| return target_layer | |||||
| raise ValueError( | |||||
| 'Model not contain cell {}, fail to probe.'.format(target_layer) | |||||
| ) | |||||
| raise TypeError('layer_name must have type of str or ms.nn.Cell,' | |||||
| 'but receive {}'.format(type(target_layer))) | |||||
| class ForwardProbe: | |||||
| """ | |||||
| Probe to capture output of specific layer in a given model. | |||||
| Args: | |||||
| target_layer (_Module): name of target layer or just provide the | |||||
| target layer. | |||||
| """ | |||||
| def __init__(self, target_layer: _Module): | |||||
| self._target_layer = target_layer | |||||
| self._original_construct = self._target_layer.construct | |||||
| self._intermediate_tensor = None | |||||
| @property | |||||
| def value(self): | |||||
| return self._intermediate_tensor | |||||
| def __enter__(self): | |||||
| self._target_layer.construct = self._new_construct | |||||
| return self | |||||
| def __exit__(self, *_): | |||||
| self._target_layer.construct = self._original_construct | |||||
| self._intermediate_tensor = None | |||||
| return False | |||||
| def _new_construct(self, *inputs): | |||||
| outputs = self._original_construct(*inputs) | |||||
| self._intermediate_tensor = outputs | |||||
| return outputs | |||||
| def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray: | |||||
| """Unify `mindspore.Tensor` and `np.ndarray` to `np.ndarray`. """ | |||||
| if isinstance(x, ms.Tensor): | |||||
| x = x.asnumpy() | |||||
| if not isinstance(x, np.ndarray): | |||||
| raise TypeError('input should be one of [ms.Tensor or np.ndarray],' | |||||
| ' but receive {}'.format(type(x))) | |||||
| return x | |||||
| def calc_correlation(x: Union[ms.Tensor, np.ndarray], | |||||
| y: Union[ms.Tensor, np.ndarray]) -> float: | |||||
| """Calculate Pearson correlation coefficient between two arrays. """ | |||||
| x = format_tensor_to_ndarray(x) | |||||
| y = format_tensor_to_ndarray(y) | |||||
| faithfulness = -np.corrcoef(x, y)[0, 1] | |||||
| return faithfulness | |||||
| def calc_auc(x: _Array) -> float: | |||||
| """Calculate the Aera under Curve.""" | |||||
| # take mean for multiple patches if the model is fully convolutional model | |||||
| if len(x.shape) == 4: | |||||
| x = np.mean(np.mean(x, axis=2), axis=3) | |||||
| auc = (x.sum() - x[0] - x[-1]) / len(x) | |||||
| return float(auc) | |||||
| def rank_pixels(inputs: _Array, descending: bool = True) -> _Array: | |||||
| """ | |||||
| Generate rank order fo every pixel in an 2D array. | |||||
| The rank order start from 0 to (num_pixel-1). If descending is True, the | |||||
| rank order will generate in a descending order, otherwise in ascending | |||||
| order. | |||||
| Example: | |||||
| x = np.array([[4., 3., 1.], [5., 9., 1.]]) | |||||
| rank_pixels(x, descending=True) | |||||
| >> np.array([[2, 3, 4], [1, 0, 5]]) | |||||
| rank_pixels(x, descending=False) | |||||
| >> np.array([[3, 2, 0], [4, 5, 1]]) | |||||
| """ | |||||
| if len(inputs.shape) != 2: | |||||
| raise ValueError('Only support 2D array currently') | |||||
| flatten_saliency = inputs.reshape(-1) | |||||
| factor = -1 if descending else 1 | |||||
| sorted_arg = np.argsort(factor * flatten_saliency, axis=0) | |||||
| flatten_rank = np.zeros_like(sorted_arg) | |||||
| flatten_rank[sorted_arg] = np.arange(0, sorted_arg.shape[0]) | |||||
| rank_map = flatten_rank.reshape(inputs.shape) | |||||
| return rank_map | |||||
| def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor: | |||||
| """ | |||||
| Resize the intermediate layer _attribution to the same size as inputs. | |||||
| Args: | |||||
| inputs (ms.Tensor): the input tensor to be resized | |||||
| size (tupleint]): the targeted size resize to | |||||
| mode (str): the resize mode. Options: 'nearest_neighbor', 'bilinear' | |||||
| Returns: | |||||
| outputs (ms.Tensor): the resized tensor. | |||||
| Raises: | |||||
| ValueError: the resize mode is not in ['nearest_neighbor', | |||||
| 'bilinear']. | |||||
| """ | |||||
| h, w = size | |||||
| if mode == 'nearest_neighbor': | |||||
| resize_nn = op.ResizeNearestNeighbor((h, w)) | |||||
| outputs = resize_nn(inputs) | |||||
| elif mode == 'bilinear': | |||||
| inputs_np = inputs.asnumpy() | |||||
| inputs_np = np.transpose(inputs_np, [0, 2, 3, 1]) | |||||
| array_lst = [] | |||||
| for inp in inputs_np: | |||||
| array = (np.repeat(inp, 3, axis=2) * 255).astype(np.uint8) | |||||
| image = Image.fromarray(array) | |||||
| image = image.resize(size, resample=Image.BILINEAR) | |||||
| array = np.asarray(image).astype(np.float32) / 255 | |||||
| array_lst.append(array[:, :, 0:1]) | |||||
| resized_np = np.transpose(array_lst, [0, 3, 1, 2]) | |||||
| outputs = ms.Tensor(resized_np, inputs.dtype) | |||||
| else: | |||||
| raise ValueError('Unsupported resize mode {}'.format(mode)) | |||||
| return outputs | |||||
| @@ -0,0 +1,23 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Predefined XAI metrics.""" | |||||
| from ._attribution.faithfulness import Faithfulness | |||||
| from ._attribution.localization import Localization | |||||
| __all__ = [ | |||||
| "Faithfulness", | |||||
| "Localization" | |||||
| ] | |||||
| @@ -0,0 +1,23 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Predefined XAI metrics""" | |||||
| from .faithfulness import Faithfulness | |||||
| from .localization import Localization | |||||
| __all__ = [ | |||||
| "Faithfulness", | |||||
| "Localization" | |||||
| ] | |||||
| @@ -0,0 +1,593 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Faithfulness""" | |||||
| import math | |||||
| from typing import Callable, Optional, Union, Tuple | |||||
| import numpy as np | |||||
| from scipy.ndimage.filters import gaussian_filter | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| import mindspore.ops.operations as op | |||||
| from .metric import AttributionMetric | |||||
| from ..._utils import calc_correlation, calc_auc, format_tensor_to_ndarray, rank_pixels | |||||
| from ...explanation._attribution._attribution import Attribution as _Attribution | |||||
| _Array = np.ndarray | |||||
| _Explainer = Union[_Attribution, Callable] | |||||
| _Label = Union[int, ms.Tensor] | |||||
| _Module = nn.Cell | |||||
| def _calc_feature_importance(saliency: _Array, masks: _Array) -> _Array: | |||||
| """Calculate feature important w.r.t given masks.""" | |||||
| feature_importance = [] | |||||
| num_perturbations = masks.shape[0] | |||||
| for i in range(num_perturbations): | |||||
| patch_feature_importance = saliency[masks[i]].sum() / masks[i].sum() | |||||
| feature_importance.append(patch_feature_importance) | |||||
| feature_importance = np.array(feature_importance, dtype=np.float32) | |||||
| return feature_importance | |||||
| class _BaseReplacement: | |||||
| """ | |||||
| Base class of generator for generating different replacement for perturbations. | |||||
| Args: | |||||
| kwargs: Optional args for generating replacement. Derived class need to | |||||
| add necessary arg names and default value to '_necessary_args'. | |||||
| If the argument has no default value, the value should be set to | |||||
| 'EMPTY' to mark the required args. Initializing an object will | |||||
| check the given kwargs w.r.t '_necessary_args'. | |||||
| Raise: | |||||
| ValueError: Raise when provided kwargs not contain necessary arg names with 'EMPTY' mark. | |||||
| """ | |||||
| _necessary_args = {} | |||||
| def __init__(self, **kwargs): | |||||
| self._replace_args = self._necessary_args.copy() | |||||
| for key, value in self._replace_args.items(): | |||||
| if key in kwargs.keys(): | |||||
| self._replace_args[key] = kwargs[key] | |||||
| elif key not in kwargs.keys() and value == 'EMPTY': | |||||
| raise ValueError(f"Missing keyword arg {key} for {self.__class__.__name__}.") | |||||
| __call__: Callable | |||||
| """ | |||||
| Generate replacement for perturbations. Derived class should overwrite this | |||||
| function to generate different replacement for perturbing. | |||||
| Args: | |||||
| inputs (_Array): Array to be perturb. | |||||
| Returns: | |||||
| - replacement (_Array): Array to provide alternative pixels for every | |||||
| position in the given | |||||
| inputs. The returned array should have same shape as inputs. | |||||
| """ | |||||
| class Constant(_BaseReplacement): | |||||
| """ Generator to provide constant-value replacement for perturbations """ | |||||
| _necessary_args = {'base_value': 'EMPTY'} | |||||
| def __call__(self, inputs: _Array) -> _Array: | |||||
| replacement = np.ones_like(inputs, dtype=np.float32) | |||||
| replacement *= self._replace_args['base_value'] | |||||
| return replacement | |||||
| class GaussianBlur(_BaseReplacement): | |||||
| """ Generator to provided gaussian blurred inputs for perturbation. """ | |||||
| _necessary_args = {'sigma': 0.7} | |||||
| def __call__(self, inputs: _Array) -> _Array: | |||||
| sigma = self._replace_args['sigma'] | |||||
| replacement = gaussian_filter(inputs, sigma=sigma) | |||||
| return replacement | |||||
| class Perturb: | |||||
| """ | |||||
| Perturbation generator to generate perturbations for a given array. | |||||
| Args: | |||||
| perturb_percent (float): percentage of pixels to perturb | |||||
| perturb_mode (str): specify perturbing mode, through deleting or | |||||
| inserting pixels. Current support: ['Deletion', 'Insertion']. | |||||
| is_accumulate (bool): whether to accumulate the former perturbations to | |||||
| the later perturbations. | |||||
| perturb_pixel_per_step (int, optional): number of pixel to perturb | |||||
| for each perturbation. If perturb_pixel_per_step is None, actual | |||||
| perturb_pixel_per_step will be calculate by: | |||||
| num_image_pixel * perturb_percent / num_perturb_steps. | |||||
| Default: None | |||||
| num_perturbations (int, optional): number of perturbations. If | |||||
| num_perturbations if None, it will be calculated by: | |||||
| num_image_pixel * perturb_percent / perturb_pixel_per_step. | |||||
| Default: None | |||||
| """ | |||||
| def __init__(self, | |||||
| perturb_percent: float, | |||||
| perturb_mode: str, | |||||
| is_accumulate: bool, | |||||
| perturb_pixel_per_step: Optional[int] = None, | |||||
| num_perturbations: Optional[int] = None): | |||||
| self._perturb_percent = perturb_percent | |||||
| self._perturb_mode = perturb_mode | |||||
| self._pixel_per_step = perturb_pixel_per_step | |||||
| self._num_perturbations = num_perturbations | |||||
| self._is_accumulate = is_accumulate | |||||
| @staticmethod | |||||
| def _assign(x: _Array, y: _Array, masks: _Array): | |||||
| """Assign values to perturb pixels on perturbations.""" | |||||
| if masks.dtype != bool: | |||||
| raise TypeError('The param "masks" should be an array of bool, but receive {}' | |||||
| .format(masks.dtype)) | |||||
| for i in range(x.shape[0]): | |||||
| x[i][:, masks[i]] = y[:, masks[i]] | |||||
| def _generate_mask(self, saliency_rank: _Array) -> _Array: | |||||
| """Generate mask for perturbations based on given saliency ranks.""" | |||||
| if len(saliency_rank.shape) != 2: | |||||
| raise ValueError(f'The param "saliency_rank" should be 2-dim, but receive {len(saliency_rank.shape)}.') | |||||
| num_pixels = saliency_rank.shape[0] * saliency_rank.shape[1] | |||||
| if self._pixel_per_step: | |||||
| pixel_per_step = self._pixel_per_step | |||||
| num_perturbations = math.floor( | |||||
| num_pixels * self._perturb_percent / self._pixel_per_step) | |||||
| elif self._num_perturbations: | |||||
| pixel_per_step = math.floor( | |||||
| num_pixels * self._perturb_percent / self._num_perturbations) | |||||
| num_perturbations = self._num_perturbations | |||||
| else: | |||||
| raise ValueError("Must provide either pixel_per_step or num_perturbations.") | |||||
| masks = np.zeros( | |||||
| (num_perturbations, saliency_rank.shape[0], saliency_rank.shape[1]), | |||||
| dtype=np.bool) | |||||
| low_bound = 0 | |||||
| up_bound = low_bound + pixel_per_step | |||||
| factor = 0 if self._is_accumulate else 1 | |||||
| for i in range(num_perturbations): | |||||
| masks[i, ((saliency_rank >= low_bound) | |||||
| & (saliency_rank < up_bound))] = True | |||||
| low_bound = up_bound * factor | |||||
| up_bound += pixel_per_step | |||||
| if len(masks.shape) == 3: | |||||
| return masks | |||||
| raise ValueError(f'Invalid masks shape {len(masks.shape)}, expect 3-dim.') | |||||
| def __call__(self, | |||||
| inputs: _Array, | |||||
| saliency: _Array, | |||||
| reference: _Array, | |||||
| return_mask: bool = False, | |||||
| ) -> Union[_Array, Tuple[_Array, ...]]: | |||||
| """ | |||||
| Generate perturbations of given array. | |||||
| Args: | |||||
| inputs (_Array): input array to perturb | |||||
| saliency (_Array): saliency map | |||||
| return_mask (bool): whether return the mask for generating | |||||
| the perturbation. The mask can be used to calculate | |||||
| average feature importance of pixels perturbed at each step. | |||||
| Return: | |||||
| perturbations (_Array) | |||||
| masks (_Array): return when return_mask is set to True. | |||||
| """ | |||||
| if not np.array_equal(inputs.shape, reference.shape): | |||||
| raise ValueError('reference must have the same shape as inputs.') | |||||
| saliency_rank = rank_pixels(saliency, descending=True) | |||||
| masks = self._generate_mask(saliency_rank) | |||||
| num_perturbations = masks.shape[0] | |||||
| if self._perturb_mode == 'Insertion': | |||||
| inputs, reference = reference, inputs | |||||
| perturbations = np.tile( | |||||
| inputs, (num_perturbations, *[1] * len(inputs.shape))) | |||||
| Perturb._assign(perturbations, reference, masks) | |||||
| if return_mask: | |||||
| return perturbations, masks | |||||
| return perturbations | |||||
| class _FaithfulnessHelper: | |||||
| """Base class for faithfulness calculator.""" | |||||
| _support = [Constant, GaussianBlur] | |||||
| def __init__(self, | |||||
| perturb_percent: float, | |||||
| perturb_mode: str, | |||||
| perturb_method: str, | |||||
| is_accumulate: bool, | |||||
| perturb_pixel_per_step: Optional[int] = None, | |||||
| num_perturbations: Optional[int] = None, | |||||
| **kwargs): | |||||
| self._get_reference = None | |||||
| for method in self._support: | |||||
| if perturb_method == method.__name__: | |||||
| self._get_reference = method(**kwargs) | |||||
| if self._get_reference is None: | |||||
| raise ValueError( | |||||
| 'The param "perturb_method" should be one of {}.'.format([x.__name__ for x in self._support])) | |||||
| self._perturb = Perturb(perturb_percent=perturb_percent, | |||||
| perturb_mode=perturb_mode, | |||||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||||
| num_perturbations=num_perturbations, | |||||
| is_accumulate=is_accumulate) | |||||
| calc_faithfulness: Callable | |||||
| """ | |||||
| Method used to calculate faithfulness for given inputs, target label, | |||||
| saliency. Derive class should implement this method. | |||||
| Args: | |||||
| inputs (_Array): sample to calculate faithfulness score | |||||
| model (_Module): model to explanation | |||||
| targets (_Label): label to explanation on. | |||||
| saliency (_Array): Saliency map of given inputs and targets from the | |||||
| explainer. | |||||
| Return: | |||||
| - faithfulness (float): faithfulness score | |||||
| """ | |||||
| class NaiveFaithfulness(_FaithfulnessHelper): | |||||
| """ | |||||
| Calculator for naive faithfulness. | |||||
| Naive faithfulness, the metric replace several pixels on original image by | |||||
| specific method for each perturbations. The metric predicts on the perturbed | |||||
| images and record a series of probabilities. Then calculates the | |||||
| correlation between prob distribution and averaged feature importance. | |||||
| Higher correlation indicates better faithfulness. | |||||
| Args: | |||||
| perturb_percent (float): percentage of pixels to perturb | |||||
| perturb_method (str): specify the method to replace the pixel. | |||||
| Current support: ['Constant', 'GaussianBlur'] | |||||
| is_accumulate (bool): whether to accumulate the former perturbations to | |||||
| the later perturbations. | |||||
| Default: False. | |||||
| perturb_pixel_per_step (Optional[int]): number of pixel to perturb | |||||
| for each perturbation. If perturb_pixel_per_step is None, actual | |||||
| perturb_pixel_per_step will be calculate by: | |||||
| num_image_pixel * perturb_percent / num_perturb_steps. | |||||
| Default: None | |||||
| num_perturbations (Optional[int]): number of perturbations. If | |||||
| num_perturbations if None, it will be calculated by: | |||||
| num_image_pixel * perturb_percent / perturb_pixel_per_step. | |||||
| Default: None | |||||
| kwargs: specific perturb_method will require | |||||
| different arguments. Below lists required args for each method. | |||||
| 'Constant': base_value (int) | |||||
| 'GaussianBlur': sigma (float): 0.7 | |||||
| """ | |||||
| def __init__(self, | |||||
| perturb_percent: float, | |||||
| perturb_method: str, | |||||
| is_accumulate: bool = False, | |||||
| perturb_pixel_per_step: Optional[int] = None, | |||||
| num_perturbations: Optional[int] = None, | |||||
| **kwargs): | |||||
| super(NaiveFaithfulness, self).__init__( | |||||
| perturb_percent=perturb_percent, | |||||
| perturb_mode='Deletion', | |||||
| perturb_method=perturb_method, | |||||
| is_accumulate=is_accumulate, | |||||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||||
| num_perturbations=num_perturbations, | |||||
| **kwargs) | |||||
| def calc_faithfulness(self, | |||||
| inputs: _Array, | |||||
| model: _Module, | |||||
| targets: _Label, | |||||
| saliency: _Array) -> np.ndarray: | |||||
| """ | |||||
| Calculate naive faithfulness. | |||||
| Args: | |||||
| inputs (_Array): sample to calculate faithfulness score | |||||
| model (_Module): model to explanation | |||||
| targets (_Label): label to explanation on. | |||||
| saliency (_Array): Saliency map of given inputs and targets from the | |||||
| explainer. | |||||
| Return: | |||||
| - faithfulness (np.ndarray): faithfulness score | |||||
| """ | |||||
| reference = self._get_reference(inputs) | |||||
| perturbations, masks = self._perturb( | |||||
| inputs, saliency, reference, return_mask=True) | |||||
| feature_importance = _calc_feature_importance(saliency, masks) | |||||
| perturbations = ms.Tensor(perturbations, dtype=ms.float32) | |||||
| predictions = model(perturbations).asnumpy()[:, targets] | |||||
| faithfulness = calc_correlation(feature_importance, predictions) | |||||
| normalized_faithfulness = (faithfulness + 1) / 2 | |||||
| return np.array([normalized_faithfulness], np.float) | |||||
| class DeletionAUC(_FaithfulnessHelper): | |||||
| """ Calculator for deletion AUC. | |||||
| For Deletion AUC, the metric accumulative replace pixels on origin | |||||
| images through specific 'perturb_method', predict on the perturbed images | |||||
| and record series of probabilities. The metric then calculates the AUC of | |||||
| the probability variation curve during perturbations. Faithfulness is define | |||||
| as (1 - deletion_AUC). Higher score indicates better faithfulness of | |||||
| explanation. | |||||
| Args: | |||||
| perturb_percent (float): percentage of pixels to perturb | |||||
| perturb_method (str): specify the method to replace the pixel. | |||||
| Current support: ['Constant', 'GaussianBlur'] | |||||
| perturb_pixel_per_step (Optional[int]): number of pixel to perturb | |||||
| for each perturbation. If perturb_pixel_per_step is None, actual | |||||
| perturb_pixel_per_step will be calculate by: | |||||
| num_image_pixel * perturb_percent / num_perturb_steps. | |||||
| Default: None | |||||
| num_perturbations (Optional[int]): number of perturbations. If | |||||
| num_perturbations if None, it will be calculated by: | |||||
| num_image_pixel * perterb_percent / perturb_pixel_per_step. | |||||
| Default: None | |||||
| kwargs: specific perturb_method will require | |||||
| different arguments. Below lists required args for each method. | |||||
| 'Constant': base_value (int) | |||||
| 'GaussianBlur': sigma (float): 0.7 | |||||
| """ | |||||
| def __init__(self, | |||||
| perturb_percent: float, | |||||
| perturb_method: str, | |||||
| perturb_pixel_per_step: Optional[int] = None, | |||||
| num_perturbations: Optional[int] = None, | |||||
| **kwargs): | |||||
| super(DeletionAUC, self).__init__( | |||||
| perturb_percent=perturb_percent, | |||||
| perturb_mode='Deletion', | |||||
| perturb_method=perturb_method, | |||||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||||
| num_perturbations=num_perturbations, | |||||
| is_accumulate=True, | |||||
| **kwargs) | |||||
| def calc_faithfulness(self, | |||||
| inputs: _Array, | |||||
| model: _Module, | |||||
| targets: _Label, | |||||
| saliency: _Array) -> np.ndarray: | |||||
| """ | |||||
| Calculate faithfulness through deletion AUC. | |||||
| Args: | |||||
| inputs (_Array): sample to calculate faithfulness score | |||||
| model (_Module): model to explanation | |||||
| targets (_Label): label to explanation on. | |||||
| saliency (_Array): Saliency map of given inputs and targets from the | |||||
| explainer. | |||||
| Return: | |||||
| - faithfulness (float): faithfulness score | |||||
| """ | |||||
| reference = self._get_reference(inputs) | |||||
| perturbations = self._perturb(inputs, saliency, reference) | |||||
| perturbations = ms.Tensor(perturbations, dtype=ms.float32) | |||||
| predictions = model(perturbations).asnumpy()[:, targets] | |||||
| input_tensor = op.ExpandDims()(ms.Tensor(inputs, ms.float32), 0) | |||||
| original_output = model(input_tensor).asnumpy()[:, targets] | |||||
| auc = calc_auc(original_output - predictions) | |||||
| return np.array([1 - auc]) | |||||
| class InsertionAUC(_FaithfulnessHelper): | |||||
| """ Calculator for insertion AUC. | |||||
| For Insertion AUC, the metric accumulative replace pixels of reference | |||||
| image by pixels from origin image, like inserting pixel from origin image to | |||||
| reference. The reference if generated through specific 'perturb_method'. | |||||
| The metric predicts on the perturbed images and records series of | |||||
| probabilities. The metric then calculates the AUC of the probability | |||||
| variation curve during perturbations. Faithfulness is define as (1 - | |||||
| deletion_AUC). Higher score indicates better faithfulness of explanation. | |||||
| Args: | |||||
| perturb_percent (float): percentage of pixels to perturb | |||||
| perturb_method (str): specify the method to replace the pixel. | |||||
| Current support: ['Constant', 'GaussianBlur'] | |||||
| perturb_pixel_per_step (Optional[int]): number of pixel to perturb | |||||
| for each perturbation. If perturb_pixel_per_step is None, actual | |||||
| perturb_pixel_per_step will be calculate by: | |||||
| num_image_pixel * perturb_percent / num_perturb_steps. | |||||
| Default: None | |||||
| num_perturbations (Optional[int]): number of perturbations. If | |||||
| num_perturbations if None, it will be calculated by: | |||||
| num_image_pixel * perterb_percent / perturb_pixel_per_step. | |||||
| Default: None | |||||
| kwargs: specific perturb_method will require | |||||
| different arguments. Below lists required args for each method. | |||||
| 'Constant': base_value (int) | |||||
| 'GaussianBlur': sigma (float): 0.7 | |||||
| """ | |||||
| def __init__(self, | |||||
| perturb_percent: float, | |||||
| perturb_method: str, | |||||
| perturb_pixel_per_step: Optional[int] = None, | |||||
| num_perturbations: Optional[int] = None, | |||||
| **kwargs): | |||||
| super(InsertionAUC, self).__init__( | |||||
| perturb_percent=perturb_percent, | |||||
| perturb_mode='Insertion', | |||||
| perturb_method=perturb_method, | |||||
| perturb_pixel_per_step=perturb_pixel_per_step, | |||||
| num_perturbations=num_perturbations, | |||||
| is_accumulate=True, | |||||
| **kwargs) | |||||
| def calc_faithfulness(self, | |||||
| inputs: _Array, | |||||
| model: _Module, | |||||
| targets: _Label, | |||||
| saliency: _Array) -> np.ndarray: | |||||
| """ | |||||
| Calculate faithfulness through insertion AUC. | |||||
| Args: | |||||
| inputs (_Array): sample to calculate faithfulness score | |||||
| model (_Module): model to explanation | |||||
| targets (_Label): label to explanation on. | |||||
| saliency (_Array): Saliency map of given inputs and targets from the | |||||
| explainer. | |||||
| Return: | |||||
| - faithfulness (float): faithfulness score | |||||
| """ | |||||
| reference = self._get_reference(inputs) | |||||
| perturbations = self._perturb(inputs, saliency, reference) | |||||
| perturbations = ms.Tensor(perturbations, dtype=ms.float32) | |||||
| predictions = model(perturbations).asnumpy()[:, targets] | |||||
| base_tensor = op.ExpandDims()(ms.Tensor(reference, ms.float32), 0) | |||||
| base_outputs = model(base_tensor).asnumpy()[:, targets] | |||||
| auc = calc_auc(predictions - base_outputs) | |||||
| return np.array([auc]) | |||||
| class Faithfulness(AttributionMetric): | |||||
| """ | |||||
| Provides evaluation on faithfulness on XAI explanations. | |||||
| Faithfulness first generate saliency map with given explainers and calculate faithfulness based on different | |||||
| faithfulness metric. | |||||
| Args: | |||||
| num_labels (int): number of labels | |||||
| metric (str): the specifi metric to quantify faithfulness. | |||||
| Options: 'DeletionAUC', 'InsertionAUC', 'NaiveFaithfulness'. | |||||
| Default: 'NaiveFaithfulness'. | |||||
| Examples: | |||||
| >>> # init a `Faithfulness` object | |||||
| >>> num_labels = 10 | |||||
| >>> metric = "InsertionAUC" | |||||
| >>> faithfulness = Faithfulness(num_labels, metric) | |||||
| """ | |||||
| _methods = [NaiveFaithfulness, DeletionAUC, InsertionAUC] | |||||
| def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness"): | |||||
| super(Faithfulness, self).__init__(num_labels) | |||||
| perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument | |||||
| perturb_method = "Constant" # perturbation method, all the perturbed pixels will be set to constant | |||||
| num_perturb_pixel_per_step = None # number of pixels for each perturbation step | |||||
| num_perturb_steps = 100 # separate the perturbation progress in to 100 steps. | |||||
| base_value = 0.0 # the pixel value set for the perturbed pixels | |||||
| self._verify_metrics(metric) | |||||
| for method in self._methods: | |||||
| if metric == method.__name__: | |||||
| self._faithfulness_helper = method( | |||||
| perturb_percent=perturb_percent, | |||||
| perturb_method=perturb_method, | |||||
| perturb_pixel_per_step=num_perturb_pixel_per_step, | |||||
| num_perturbations=num_perturb_steps, | |||||
| base_value=base_value | |||||
| ) | |||||
| def evaluate(self, explainer, inputs, targets, saliency=None): | |||||
| """ | |||||
| Evaluate faithfulness on a single data sample. | |||||
| Args: | |||||
| explainer (Explainer): A explainer instance object. | |||||
| The 'Explainer' object see mindspore/explainer/explanation. | |||||
| inputs (Tensor): data sample. Currently only support single sample at each call. | |||||
| targets (Union[int, Tensor]): A target label to evaluate on. | |||||
| saliency (Tensor): A saliency tensor. | |||||
| Return: | |||||
| np.ndarray: result of faithfulness evaluated on explainer. | |||||
| Notes: | |||||
| To apply `Faithfulness` to evaluate an explainer, this explainer must be initialize with a network that | |||||
| contains the output activation function. Otherwise, the results will not be correct. | |||||
| Examples: | |||||
| >>> # init an explainer, the network should contain the output activation function. | |||||
| >>> network = nn.SequentialCell([resnet50, nn.Sigmoid()]) | |||||
| >>> gradient = Gradient(network) | |||||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||||
| >>> targets = 5 | |||||
| >>> # usage 1: input the explainer and the data to be explained, | |||||
| >>> # calculate the faithfulness with the specified metric | |||||
| >>> res = faithfulness.evaluate(gradient, inputs, targets) | |||||
| >>> # usage 2: input the generated saliency map | |||||
| >>> saliency = gradient(inputs, targets) | |||||
| >>> res = faithfulenss.evaluate(gradient, inputs, targets, saliency) | |||||
| """ | |||||
| self._check_evaluate_param(explainer, inputs, targets, saliency) | |||||
| if saliency is None: | |||||
| saliency = explainer(inputs, targets) | |||||
| inputs = format_tensor_to_ndarray(inputs) | |||||
| saliency = format_tensor_to_ndarray(saliency) | |||||
| inputs = inputs.squeeze(axis=0) | |||||
| saliency = saliency.squeeze() | |||||
| if len(saliency.shape) != 2: | |||||
| raise ValueError('Squeezed saliency map is expected to 2D, but receive {}.'.format(len(saliency.shape))) | |||||
| faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=explainer.model, | |||||
| targets=targets, saliency=saliency) | |||||
| return faithfulness | |||||
| def _verify_metrics(self, metric: str): | |||||
| supports = [x.__name__ for x in self._methods] | |||||
| if metric not in supports: | |||||
| raise ValueError("Metric should be one of {}.".format(supports)) | |||||
| @@ -0,0 +1,146 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Localization metrics.""" | |||||
| import numpy as np | |||||
| from mindspore.train._utils import check_value_type | |||||
| from .metric import AttributionMetric | |||||
| from ..._operators import maximum, reshape, Tensor | |||||
| from ..._utils import format_tensor_to_ndarray | |||||
| def _get_max_position(saliency): | |||||
| """Get the position of the max pixel of the saliency map.""" | |||||
| saliency = saliency.asnumpy() | |||||
| w = saliency.shape[3] | |||||
| saliency = np.reshape(saliency, (len(saliency), -1)) | |||||
| max_arg = np.argmax(saliency, axis=1) | |||||
| return max_arg // w, max_arg - (max_arg // w) * w | |||||
| def _mask_out_saliency(saliency, threshold): | |||||
| """Keep the saliency map with value greater than threshold.""" | |||||
| max_value = maximum(saliency) | |||||
| mask_out = saliency > (reshape(max_value, (len(saliency), -1, 1, 1)) * threshold) | |||||
| return mask_out | |||||
| class Localization(AttributionMetric): | |||||
| """ | |||||
| Provides evaluation on the localization capability of XAI methods. | |||||
| We support two metrics for the evaluation os localization capability: "PointingGame" and "IoSR". | |||||
| For metric "PointingGame", the localization capability is calculated as the ratio of data in which the max position | |||||
| of their saliency maps lies within the bounding boxes. Specifically, for a single datum, given the saliency map and | |||||
| its bounding box, if the max point of its saliency map lies within the bounding box, the evaluation result is 1 | |||||
| otherwise 0. | |||||
| For metric "IoSR" (Intersection over Salient Region), the localization capability is calculated as the intersection | |||||
| of the bounding box and the salient region over the area of the salient region. | |||||
| Args: | |||||
| num_labels (int): number of classes in the dataset. | |||||
| metric (str): specific metric to calculate localization capability. | |||||
| Options: "PointingGame", "IoSR". | |||||
| Default: "PointingGame". | |||||
| Examples: | |||||
| >>> from mindspore.explainer.benchmark import Localization | |||||
| >>> num_labels = 100 | |||||
| >>> localization = Localization(num_labels, "PointingGame") | |||||
| """ | |||||
| def __init__(self, | |||||
| num_labels, | |||||
| metric="PointingGame" | |||||
| ): | |||||
| super(Localization, self).__init__(num_labels) | |||||
| self._verify_metrics(metric) | |||||
| self._metric = metric | |||||
| # Arg for specific metric, for "PointingGame" it should be an integer indicating the tolerance | |||||
| # of "PointingGame", while for "IoSR" it should be a float number | |||||
| # indicating the threshold to choose salient region. Default: 25. | |||||
| if self._metric == "PointingGame": | |||||
| self._metric_arg = 15 | |||||
| else: | |||||
| self._metric_arg = 0.5 | |||||
| @staticmethod | |||||
| def _verify_metrics(metric): | |||||
| """Verify the user defined metric.""" | |||||
| supports = ["PointingGame", "IoSR"] | |||||
| if metric not in supports: | |||||
| raise ValueError("Metric should be one of {}".format(supports)) | |||||
| def evaluate(self, explainer, inputs, targets, saliency=None, mask=None): | |||||
| """ | |||||
| Evaluate localization on a single data sample. | |||||
| Args: | |||||
| explainer (Explanation): The explainer to be evaluated, see `mindspore/explainer/explanation`. | |||||
| inputs (Tensor): data sample. Currently only support single sample at each call. | |||||
| targets (int): target label to evaluate on. | |||||
| saliency (Tensor): A saliency tensor. | |||||
| mask (Union[Tensor, np.ndarray]): ground truth bounding box/masks for the inputs w.r.t targets. | |||||
| Returns: | |||||
| np.ndarray, result of localization evaluated on explainer | |||||
| Examples: | |||||
| >>> # init an explainer, the network should contain the output activation function. | |||||
| >>> gradient = Gradient(network) | |||||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | |||||
| >>> masks = np.zeros(1, 1, 224, 224) | |||||
| >>> masks[:, :, 65: 100, 65: 100] = 1 | |||||
| >>> targets = 5 | |||||
| >>> # usage 1: input the explainer and the data to be explained, | |||||
| >>> # calculate the faithfulness with the specified metric | |||||
| >>> res = localization.evaluate(gradient, inputs, targets, mask=masks) | |||||
| >>> # usage 2: input the generated saliency map | |||||
| >>> saliency = gradient(inputs, targets) | |||||
| >>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks) | |||||
| """ | |||||
| self._check_evaluate_param(explainer, inputs, targets, saliency) | |||||
| mask_np = format_tensor_to_ndarray(mask)[0] | |||||
| if saliency is None: | |||||
| saliency = explainer(inputs, targets) | |||||
| if self._metric == "PointingGame": | |||||
| point = _get_max_position(saliency) | |||||
| x, y = np.meshgrid( | |||||
| (np.arange(mask_np.shape[1]) - point[0]) ** 2, | |||||
| (np.arange(mask_np.shape[2]) - point[1]) ** 2) | |||||
| max_region = (x + y) < self._metric_arg ** 2 | |||||
| # if max_region has overlap with mask_np return 1 otherwise 0. | |||||
| result = 1 if (mask_np.astype(bool) & max_region).any() else 0 | |||||
| elif self._metric == "IoSR": | |||||
| mask_out = _mask_out_saliency(saliency, self._metric_arg) | |||||
| mask_out_np = format_tensor_to_ndarray(mask_out) | |||||
| overlap = np.sum(mask_np.astype(bool) & mask_out_np.astype(bool)) | |||||
| saliency_area = np.sum(mask_out_np) | |||||
| result = overlap / saliency_area.clip(min=1e-10) | |||||
| return np.array([result], np.float) | |||||
| def _check_evaluate_param_with_mask(self, explainer, inputs, targets, saliency, mask): | |||||
| self._check_evaluate_param(explainer, inputs, targets, saliency) | |||||
| check_value_type('mask', mask, (Tensor, np.ndarray)) | |||||
| if len(inputs.shape) != 4: | |||||
| raise ValueError('Argument mask must be 4D Tensor') | |||||
| @@ -0,0 +1,123 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Base class for XAI metrics.""" | |||||
| import numpy as np | |||||
| from mindspore.train._utils import check_value_type | |||||
| from ..._operators import Tensor | |||||
| from ..._utils import format_tensor_to_ndarray | |||||
| from ...explanation._attribution._attribution import Attribution | |||||
| def verify_argument(inputs, arg_name): | |||||
| """Verify the validity of the parsed arguments.""" | |||||
| check_value_type(arg_name, inputs, Tensor) | |||||
| if len(inputs.shape) != 4: | |||||
| raise ValueError('Argument {} must be a 4D Tensor.'.format(arg_name)) | |||||
| if len(inputs) > 1: | |||||
| raise ValueError('Support single data evaluation only, but got {}.'.format(len(inputs))) | |||||
| def verify_targets(targets, num_labels): | |||||
| """Verify the validity of the parsed targets.""" | |||||
| check_value_type('targets', targets, (int, Tensor)) | |||||
| if isinstance(targets, Tensor): | |||||
| if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != 1): | |||||
| raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, ' | |||||
| 'it should have the length = 1 as we only support single evaluation now.') | |||||
| targets = int(targets.asnumpy()[0]) if len(targets.shape) == 1 else int(targets.asnumpy()) | |||||
| if targets > num_labels - 1 or targets < 0: | |||||
| raise ValueError('Parsed targets exceed the label range.') | |||||
| class AttributionMetric: | |||||
| """Super class of XAI metric class used in classification scenarios.""" | |||||
| def __init__(self, num_labels=None): | |||||
| self._num_labels = num_labels | |||||
| self._global_results = {i: [] for i in range(num_labels)} | |||||
| def evaluate(self, explainer, inputs, targets, saliency=None): | |||||
| """This function evaluates on a single sample and return the result.""" | |||||
| raise NotImplementedError | |||||
| def aggregate(self, result, targets): | |||||
| """Aggregates single result to global_results.""" | |||||
| if isinstance(result, float): | |||||
| if isinstance(targets, int): | |||||
| self._global_results[targets].append(result) | |||||
| else: | |||||
| target_np = format_tensor_to_ndarray(targets) | |||||
| if len(target_np) > 1: | |||||
| raise ValueError("One result can not be aggreated to multiple targets.") | |||||
| else: | |||||
| result_np = format_tensor_to_ndarray(result) | |||||
| if isinstance(targets, int): | |||||
| for res in result_np: | |||||
| self._global_results[targets].append(float(res)) | |||||
| else: | |||||
| target_np = format_tensor_to_ndarray(targets) | |||||
| if len(target_np) != len(result_np): | |||||
| raise ValueError("Length of result does not match with length of targets.") | |||||
| for tar, res in zip(target_np, result_np): | |||||
| self._global_results[int(tar)].append(float(res)) | |||||
| def reset(self): | |||||
| """Resets global_result.""" | |||||
| self._global_results = {i: [] for i in range(self._num_labels)} | |||||
| @property | |||||
| def class_performances(self): | |||||
| """ | |||||
| Get the class performances by global result. | |||||
| Returns: | |||||
| (:class:`np.ndarray`): :attr:`num_labels`-dimensional vector | |||||
| containing per-class performance. | |||||
| """ | |||||
| count = np.array( | |||||
| [len(self._global_results[i]) for i in range(self._num_labels)]) | |||||
| result_sum = np.array( | |||||
| [sum(self._global_results[i]) for i in range(self._num_labels)]) | |||||
| return result_sum / count.clip(min=1) | |||||
| @property | |||||
| def performance(self): | |||||
| """ | |||||
| Get the performance by global result. | |||||
| Returns: | |||||
| (:class:`float`): mean performance. | |||||
| """ | |||||
| count = sum( | |||||
| [len(self._global_results[i]) for i in range(self._num_labels)]) | |||||
| result_sum = sum( | |||||
| [sum(self._global_results[i]) for i in range(self._num_labels)]) | |||||
| if count == 0: | |||||
| return 0 | |||||
| return result_sum / count | |||||
| def get_results(self): | |||||
| """Global result of the metric can be return""" | |||||
| return self._global_results | |||||
| def _check_evaluate_param(self, explainer, inputs, targets, saliency): | |||||
| """Check the evaluate parameters.""" | |||||
| check_value_type('explainer', explainer, Attribution) | |||||
| verify_argument(inputs, 'inputs') | |||||
| verify_targets(targets, self._num_labels) | |||||
| check_value_type('saliency', saliency, (Tensor, type(None))) | |||||
| @@ -0,0 +1,26 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Predefined Attribution explainers.""" | |||||
| from ._attribution._backprop.gradcam import GradCAM | |||||
| from ._attribution._backprop.gradient import Gradient | |||||
| from ._attribution._backprop.modified_relu import Deconvolution, GuidedBackprop | |||||
| __all__ = [ | |||||
| 'Gradient', | |||||
| 'Deconvolution', | |||||
| 'GuidedBackprop', | |||||
| 'GradCAM', | |||||
| ] | |||||
| @@ -0,0 +1,25 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Predefined Attribution explainers.""" | |||||
| from ._backprop.gradcam import GradCAM | |||||
| from ._backprop.gradient import Gradient | |||||
| from ._backprop.modified_relu import Deconvolution, GuidedBackprop | |||||
| __all__ = [ | |||||
| 'Gradient', | |||||
| 'Deconvolution', | |||||
| 'GuidedBackprop', | |||||
| 'GradCAM', | |||||
| ] | |||||
| @@ -0,0 +1,60 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Attribution.""" | |||||
| from typing import Callable | |||||
| import mindspore as ms | |||||
| class Attribution: | |||||
| r""" | |||||
| Basic class of attributing the salient score | |||||
| The explainers which explanation through attributing the relevance scores | |||||
| should inherit this class. | |||||
| Args: | |||||
| network (ms.nn.Cell): The black-box model to explanation. | |||||
| """ | |||||
| def __init__(self, network): | |||||
| self._verify_model(network) | |||||
| self._model = network | |||||
| @staticmethod | |||||
| def _verify_model(model): | |||||
| """ | |||||
| Verify the input `network` for __init__ function. | |||||
| """ | |||||
| if not isinstance(model, ms.nn.Cell): | |||||
| raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.") | |||||
| __call__: Callable | |||||
| """ | |||||
| The explainers return the explanations by calling directly on the explanation. | |||||
| Derived class should overwrite this implementations for different | |||||
| algorithms. | |||||
| Args: | |||||
| input (ms.Tensor): Input tensor to be explained. | |||||
| Returns: | |||||
| - saliency map (ms.Tensor): saliency map of the input. | |||||
| """ | |||||
| @property | |||||
| def model(self): | |||||
| return self._model | |||||
| @@ -0,0 +1,24 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Backprop-base _attribution explainer.""" | |||||
| from .gradient import Gradient | |||||
| from .gradcam import GradCAM | |||||
| from .modified_relu import Deconvolution, GuidedBackprop | |||||
| __all__ = ['Gradient', | |||||
| 'GradCAM', | |||||
| 'Deconvolution', | |||||
| 'GuidedBackprop'] | |||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Providing utility functions.""" | |||||
| from mindspore.ops.composite import GradOperation | |||||
| from ...._utils import unify_inputs, unify_targets, generate_one_hot | |||||
| def compute_gradients(model, inputs, targets=None, weights=None): | |||||
| r""" | |||||
| Compute the gradient of output w.r.t input. | |||||
| Args: | |||||
| model (`ms.nn.Cell`): Differentiable black-box model. | |||||
| inputs (`ms.Tensor`): Input to calculate gradient and explanation. | |||||
| targets (int, optional): Target label id specifying which category to compute gradient. Default: None. | |||||
| weights (`ms.Tensor`, optional): Custom weights for computing gradients. The shape of weights should match the | |||||
| model outputs. If None is provided, an one-hot weights with one in targets positions will be used instead. | |||||
| Default: None. | |||||
| Returns: | |||||
| saliency map (ms.Tensor): Gradient back-propagated to the input. | |||||
| """ | |||||
| inputs = unify_inputs(inputs) | |||||
| if targets is None and weights is None: | |||||
| raise ValueError('Must provide one of targets or weights') | |||||
| if weights is None: | |||||
| targets = unify_targets(targets) | |||||
| output = model(*inputs).asnumpy() | |||||
| num_categories = output.shape[-1] | |||||
| weights = generate_one_hot(targets, num_categories) | |||||
| grad_op = GradOperation( | |||||
| get_all=True, get_by_list=False, sens_param=True)(model) | |||||
| gradients = grad_op(*inputs, weights) | |||||
| return gradients[0] | |||||
| @@ -0,0 +1,141 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ GradCAM and GuidedGradCAM. """ | |||||
| from mindspore.ops import operations as op | |||||
| from .backprop_utils import compute_gradients | |||||
| from .intermediate_layer import IntermediateLayerAttribution | |||||
| from ...._utils import ForwardProbe, retrieve_layer, unify_inputs, unify_targets | |||||
| def _gradcam_aggregation(attributions): | |||||
| """ | |||||
| Aggregate the gradient and activation to get the final _attribution. | |||||
| Args: | |||||
| attributions (Tensor): the _attribution with channel dimension. | |||||
| Returns: | |||||
| Tensor: the _attribution with channel dimension aggregated. | |||||
| """ | |||||
| sum_ = op.ReduceSum(keep_dims=True) | |||||
| relu_ = op.ReLU() | |||||
| attributions = relu_(sum_(attributions, 1)) | |||||
| return attributions | |||||
| class GradCAM(IntermediateLayerAttribution): | |||||
| r""" | |||||
| Provides GradCAM explanation method. | |||||
| GradCAM generates saliency map at intermediate layer. | |||||
| ..math: | |||||
| \alpha_k^c = 1/Z \sum_i \sum_j \div{\partial{y^c}}{\partial{A_{i,j}^k}} | |||||
| L_{GradCAM} = ReLu(\sum_k \alpha_k^c A^k) | |||||
| For more details, please refer to the original paper: GradCAM | |||||
| [https://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf] | |||||
| Args: | |||||
| network (Cell): The black-box model to be explained. | |||||
| layer (str): The layer name to generate the explanation at. Default: ''. | |||||
| If default, the explantion will be generated at the input layer. | |||||
| Examples: | |||||
| >>> net = resnet50(10) | |||||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||||
| >>> load_param_into_net(net, param_dict) | |||||
| >>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(), | |||||
| >>> # you may also use the net itself. | |||||
| >>> net = nn.SequentialCell([net, nn.Sigmoid()]) | |||||
| >>> # specify a layer name to generate explanation, usually the layer can be set as the last conv layer. | |||||
| >>> layer_name = '0.layer4' | |||||
| >>> # init GradCAM with a trained network and specify the layer to obtain | |||||
| >>> gradcam = GradCAM(net, layer=layer_name) | |||||
| >>> # parse data and the target label to be explained and get the saliency map | |||||
| >>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32) | |||||
| >>> label = 5 | |||||
| >>> saliency = gradcam(inputs, label) | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| network, | |||||
| layer=""): | |||||
| super(GradCAM, self).__init__(network, layer) | |||||
| self._saliency_cell = retrieve_layer(self._backward_model, target_layer=layer) | |||||
| self._avgpool = op.ReduceMean(keep_dims=True) | |||||
| self._intermediate_grad = None | |||||
| self._aggregation_fn = _gradcam_aggregation | |||||
| self._resize_mode = 'bilinear' | |||||
| def _hook_cell(self): | |||||
| if self._saliency_cell: | |||||
| self._saliency_cell.register_backward_hook(self._cell_hook_fn) | |||||
| self._saliency_cell.enable_hook = True | |||||
| self._intermediate_grad = None | |||||
| def _cell_hook_fn(self, _, grad_input, grad_output): | |||||
| """ | |||||
| Hook function to deal with the backward gradient. | |||||
| The arguments are set as required by Cell.register_back_hook | |||||
| """ | |||||
| self._intermediate_grad = grad_input | |||||
| def __call__(self, inputs, targets): | |||||
| """ | |||||
| Call function for `GradCAM`. | |||||
| Args: | |||||
| inputs (Tensor): The input data to be explained, 4D Tensor. | |||||
| targets (Union[Tensor, int]): The label of interest. It should be a 1D or 0D Tensor, or an integer. | |||||
| If `targets` is a 1D Tensor, its length should be the same as `inputs`. | |||||
| """ | |||||
| self._verify_data(inputs, targets) | |||||
| self._hook_cell() | |||||
| with ForwardProbe(self._saliency_cell) as probe: | |||||
| inputs = unify_inputs(inputs) | |||||
| targets = unify_targets(targets) | |||||
| gradients = compute_gradients(self._backward_model, *inputs, targets) | |||||
| # get intermediate activation | |||||
| activation = (probe.value,) | |||||
| if self._layer == "": | |||||
| activation = inputs | |||||
| self._intermediate_grad = unify_inputs(gradients) | |||||
| if self._intermediate_grad is not None: | |||||
| # average pooling on gradients | |||||
| intermediate_grad = unify_inputs( | |||||
| self._avgpool(self._intermediate_grad[0], (2, 3))) | |||||
| else: | |||||
| raise ValueError("Gradient for intermediate layer is not " | |||||
| "obtained") | |||||
| mul = op.Mul() | |||||
| attribution = self._aggregation_fn( | |||||
| mul(*intermediate_grad, *activation)) | |||||
| if self._resize: | |||||
| attribution = self._resize_fn(attribution, *inputs, | |||||
| mode=self._resize_mode) | |||||
| self._intermediate_grad = None | |||||
| return attribution | |||||
| @@ -0,0 +1,129 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Gradient explainer.""" | |||||
| from copy import deepcopy | |||||
| from mindspore import nn | |||||
| from mindspore.ops import operations as op | |||||
| from mindspore.train._utils import check_value_type | |||||
| from ...._operators import reshape, sqrt, Tensor | |||||
| from .._attribution import Attribution | |||||
| from .backprop_utils import compute_gradients | |||||
| from ...._utils import unify_inputs, unify_targets | |||||
| def _get_hook(bntype, cache): | |||||
| """Provide backward hook function for BatchNorm layer in eval mode.""" | |||||
| var, gamma, eps = cache | |||||
| if bntype == "2d": | |||||
| var = reshape(var, (1, -1, 1, 1)) | |||||
| gamma = reshape(gamma, (1, -1, 1, 1)) | |||||
| elif bntype == "1d": | |||||
| var = reshape(var, (1, -1, 1)) | |||||
| gamma = reshape(gamma, (1, -1, 1)) | |||||
| def reset_gradient(_, grad_input, grad_output): | |||||
| grad_output = grad_input[0] * gamma / sqrt(var + eps) | |||||
| return grad_output | |||||
| return reset_gradient | |||||
| def _abs_max(gradients): | |||||
| """ | |||||
| Transform gradients to saliency through abs then take max along | |||||
| channels. | |||||
| """ | |||||
| gradients = op.Abs()(gradients) | |||||
| saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1) | |||||
| return saliency | |||||
| class Gradient(Attribution): | |||||
| r""" | |||||
| Provides Gradient explanation method. | |||||
| Gradient is the simplest attribution method which uses the naive gradients of outputs w.r.t inputs as the | |||||
| explanation. | |||||
| .. math:: | |||||
| _attribution = \div{\delta{y}, \delta{x}} | |||||
| Args: | |||||
| network (Cell): The black-box model to be explained. | |||||
| Examples: | |||||
| >>> net = resnet50(10) | |||||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||||
| >>> load_param_into_net(net, param_dict) | |||||
| >>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(), | |||||
| >>> # you may also use the net itself. The saliency map might be slightly different for softmax activation. | |||||
| >>> net = nn.SequentialCell([net, nn.Sigmoid()]) | |||||
| >>> # init Gradient with a trained network. | |||||
| >>> gradient = Gradient(net) | |||||
| >>> # parse data and the target label to be explained and get the saliency map | |||||
| >>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32) | |||||
| >>> label = 5 | |||||
| >>> saliency = gradient(inputs, label) | |||||
| """ | |||||
| def __init__(self, network): | |||||
| super(Gradient, self).__init__(network) | |||||
| self._backward_model = deepcopy(network) | |||||
| self._backward_model.set_train(False) | |||||
| self._backward_model.set_grad(False) | |||||
| self._hook_bn() | |||||
| self._grad_op = compute_gradients | |||||
| self._aggregation_fn = _abs_max | |||||
| def __call__(self, inputs, targets): | |||||
| """ | |||||
| Call function for `Gradient`. | |||||
| Args: | |||||
| inputs (Tensor): The input data to be explained, 4D Tensor. | |||||
| targets (Union[Tensor, int]): The label of interest. It should be a 1D or 0D Tensor, or an integer. | |||||
| If `targets` is a 1D `Tensor`, its length should be the same as `inputs`. | |||||
| """ | |||||
| self._verify_data(inputs, targets) | |||||
| inputs = unify_inputs(inputs) | |||||
| targets = unify_targets(targets) | |||||
| gradient = self._grad_op(self._backward_model, *inputs, targets) | |||||
| saliency = self._aggregation_fn(gradient) | |||||
| return saliency | |||||
| def _hook_bn(self): | |||||
| """Hook BatchNorm layer for `self._backward_model.`""" | |||||
| for _, cell in self._backward_model.cells_and_names(): | |||||
| if isinstance(cell, nn.BatchNorm2d): | |||||
| cache = (cell.moving_variance, cell.gamma, cell.eps) | |||||
| cell.register_backward_hook(_get_hook("2d", cache=cache)) | |||||
| elif isinstance(cell, nn.BatchNorm1d): | |||||
| cache = (cell.moving_variance, cell.gamma, cell.eps) | |||||
| cell.register_backward_hook(_get_hook("1d", cache=cache)) | |||||
| @staticmethod | |||||
| def _verify_data(inputs, targets): | |||||
| """Verify the validity of the parsed inputs.""" | |||||
| check_value_type('inputs', inputs, Tensor) | |||||
| if len(inputs.shape) != 4: | |||||
| raise ValueError('Argument inputs must be 4D Tensor') | |||||
| check_value_type('targets', targets, (Tensor, int)) | |||||
| if isinstance(targets, Tensor): | |||||
| if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != len(inputs)): | |||||
| raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, ' | |||||
| 'it should have the same length as inputs.') | |||||
| @@ -0,0 +1,47 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Base class IntermediateLayerAttribution""" | |||||
| from .gradient import Gradient | |||||
| from ...._utils import resize as resize_fn | |||||
| class IntermediateLayerAttribution(Gradient): | |||||
| """ | |||||
| Base class for generating _attribution map at intermediate layer. | |||||
| Args: | |||||
| network (nn.Cell): DNN model to be explained. | |||||
| layer (str, optional): string that specifies the layer to generate | |||||
| intermediate _attribution. When using default value, the input layer | |||||
| will be specified. Default: ''. | |||||
| """ | |||||
| def __init__(self, network, layer=''): | |||||
| super(IntermediateLayerAttribution, self).__init__(network) | |||||
| # Whether resize the _attribution layer to the input size. | |||||
| self._resize = True | |||||
| # string that specifies the resize mode. Default: 'nearest_neighbor'. | |||||
| self._resize_mode = 'nearest_neighbor' | |||||
| self._layer = layer | |||||
| @staticmethod | |||||
| def _resize_fn(attributions, inputs, mode): | |||||
| """Resize the intermediate layer _attribution to the same size as inputs.""" | |||||
| height, width = inputs.shape[2], inputs.shape[3] | |||||
| return resize_fn(attributions, (height, width), mode) | |||||
| @@ -0,0 +1,117 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Explainer with modified ReLU.""" | |||||
| import mindspore.nn as nn | |||||
| import mindspore.ops.operations as op | |||||
| from .gradient import Gradient | |||||
| from ...._utils import ( | |||||
| unify_inputs, | |||||
| unify_targets, | |||||
| ) | |||||
| class ModifiedReLU(Gradient): | |||||
| """Basic class for modified ReLU explanation.""" | |||||
| def __init__(self, network, use_relu_backprop=False): | |||||
| super(ModifiedReLU, self).__init__(network) | |||||
| self.use_relu_backprop = use_relu_backprop | |||||
| self.hooked_list = [] | |||||
| def __call__(self, inputs, targets): | |||||
| self._verify_data(inputs, targets) | |||||
| inputs = unify_inputs(inputs) | |||||
| targets = unify_targets(targets) | |||||
| self._hook_relu_backward() | |||||
| gradients = self._grad_op(self._backward_model, inputs, targets) | |||||
| saliency = self._aggregation_fn(gradients) | |||||
| return saliency | |||||
| def _hook_relu_backward(self): | |||||
| """Set backward hook for ReLU layers.""" | |||||
| for _, cell in self._backward_model.cells_and_names(): | |||||
| if isinstance(cell, nn.ReLU): | |||||
| cell.register_backward_hook(self._backward_hook) | |||||
| self.hooked_list.append(cell) | |||||
| def _backward_hook(self, _, grad_inputs, grad_outputs): | |||||
| """Hook function for ReLU layers.""" | |||||
| inputs = grad_inputs if self.use_relu_backprop else grad_outputs | |||||
| relu = op.ReLU() | |||||
| if isinstance(inputs, tuple): | |||||
| return relu(*inputs) | |||||
| return relu(inputs) | |||||
| class Deconvolution(ModifiedReLU): | |||||
| """ | |||||
| Deconvolution explanation. | |||||
| To use `Deconvolution`, the `ReLU` operations in the network must be implemented with `mindspore.nn.Cell` object | |||||
| rather than `mindspore.ops.Operations.ReLU`. Otherwise, the results will not be correct. | |||||
| Args: | |||||
| network (Cell): The black-box model to be explained. | |||||
| Examples: | |||||
| >>> net = resnet50(10) | |||||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||||
| >>> load_param_into_net(net, param_dict) | |||||
| >>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(), | |||||
| >>> # you may also use the net itself. The saliency map might be slightly different for softmax activation. | |||||
| >>> net = nn.SequentialCell([net, nn.Sigmoid()]) | |||||
| >>> # init Gradient with a trained network. | |||||
| >>> deconvolution = Deconvolution(net) | |||||
| >>> # parse data and the target label to be explained and get the saliency map | |||||
| >>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32) | |||||
| >>> label = 5 | |||||
| >>> saliency = deconvolution(inputs, label) | |||||
| """ | |||||
| def __init__(self, network): | |||||
| super(Deconvolution, self).__init__(network, use_relu_backprop=True) | |||||
| class GuidedBackprop(ModifiedReLU): | |||||
| """ | |||||
| Guided-Backpropation explanation. | |||||
| To use `GuidedBackprop`, the `ReLU` operations in the network must be implemented with `mindspore.nn.Cell` object | |||||
| rather than `mindspore.ops.Operations.ReLU`. Otherwise, the results will not be correct. | |||||
| Args: | |||||
| network (Cell): The black-box model to be explained. | |||||
| Examples: | |||||
| >>> net = resnet50(10) | |||||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||||
| >>> load_param_into_net(net, param_dict) | |||||
| >>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(), | |||||
| >>> # you may also use the net itself. The saliency map might be slightly different for softmax activation. | |||||
| >>> net = nn.SequentialCell([net, nn.Sigmoid()]) | |||||
| >>> # init Gradient with a trained network. | |||||
| >>> gbp = GuidedBackprop(net) | |||||
| >>> # parse data and the target label to be explained and get the saliency map | |||||
| >>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32) | |||||
| >>> label = 5 | |||||
| >>> saliency = gbp(inputs, label) | |||||
| """ | |||||
| def __init__(self, network): | |||||
| super(GuidedBackprop, self).__init__(network, use_relu_backprop=False) | |||||