From: @lixiaohui33 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -18,9 +18,8 @@ from typing import List, Tuple, Union, Callable | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore | import mindspore | ||||
| from mindspore import nn | |||||
| import mindspore.ops.operations as op | import mindspore.ops.operations as op | ||||
| from mindspore import nn | |||||
| _Axis = Union[int, Tuple[int, ...], List[int]] | _Axis = Union[int, Tuple[int, ...], List[int]] | ||||
| _Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]] | _Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]] | ||||
| @@ -235,7 +234,7 @@ def randint(low: int, high: int, shape: _Shape, dtype: mindspore.dtype = mindspo | |||||
| return outputs | return outputs | ||||
| def softmax(axis: int) -> Callable: | |||||
| def softmax(axis: int = -1) -> Callable: | |||||
| """Softmax activation function.""" | """Softmax activation function.""" | ||||
| func = nn.Softmax(axis=axis) | func = nn.Softmax(axis=axis) | ||||
| return func | return func | ||||
| @@ -20,20 +20,23 @@ from time import time | |||||
| from typing import Tuple, List, Optional | from typing import Tuple, List, Optional | ||||
| import numpy as np | import numpy as np | ||||
| from scipy.stats import beta | |||||
| from PIL import Image | from PIL import Image | ||||
| from scipy.stats import beta | |||||
| from mindspore.train._utils import check_value_type | |||||
| from mindspore.train.summary_pb2 import Explain | |||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import log | from mindspore import log | ||||
| from mindspore.nn import Softmax, Cell | |||||
| from mindspore.nn.probability.toolbox import UncertaintyEvaluation | |||||
| from mindspore.ops.operations import ExpandDims | from mindspore.ops.operations import ExpandDims | ||||
| from mindspore.train._utils import check_value_type | |||||
| from mindspore.train.summary._summary_adapter import _convert_image_format | from mindspore.train.summary._summary_adapter import _convert_image_format | ||||
| from mindspore.train.summary.summary_record import SummaryRecord | from mindspore.train.summary.summary_record import SummaryRecord | ||||
| from mindspore.nn.probability.toolbox import UncertaintyEvaluation | |||||
| from mindspore.train.summary_pb2 import Explain | |||||
| from .benchmark import Localization | from .benchmark import Localization | ||||
| from .benchmark._attribution.metric import AttributionMetric | from .benchmark._attribution.metric import AttributionMetric | ||||
| from .explanation import RISE | |||||
| from .explanation._attribution._attribution import Attribution | from .explanation._attribution._attribution import Attribution | ||||
| # datafile directory names | # datafile directory names | ||||
| @@ -43,8 +46,8 @@ _HEATMAP_DIRNAME = "heatmap" | |||||
| # max. no. of sample per directory | # max. no. of sample per directory | ||||
| _SAMPLE_PER_DIR = 1000 | _SAMPLE_PER_DIR = 1000 | ||||
| _EXPAND_DIMS = ExpandDims() | _EXPAND_DIMS = ExpandDims() | ||||
| _SEED = 58 # set a seed to fix the iterating order of the dataset | |||||
| def _normalize(img_np): | def _normalize(img_np): | ||||
| @@ -57,7 +60,7 @@ def _normalize(img_np): | |||||
| def _np_to_image(img_np, mode): | def _np_to_image(img_np, mode): | ||||
| """Convert numpy array to PIL image.""" | """Convert numpy array to PIL image.""" | ||||
| return Image.fromarray(np.uint8(img_np*255), mode=mode) | |||||
| return Image.fromarray(np.uint8(img_np * 255), mode=mode) | |||||
| def _calc_prob_interval(volume, probs, prob_vars): | def _calc_prob_interval(volume, probs, prob_vars): | ||||
| @@ -89,7 +92,7 @@ def _calc_prob_interval(volume, probs, prob_vars): | |||||
| def _get_id_dirname(sample_id: int): | def _get_id_dirname(sample_id: int): | ||||
| """Get the name of parent directory of the image id.""" | """Get the name of parent directory of the image id.""" | ||||
| return str(int(sample_id/_SAMPLE_PER_DIR)*_SAMPLE_PER_DIR) | |||||
| return str(int(sample_id / _SAMPLE_PER_DIR) * _SAMPLE_PER_DIR) | |||||
| def _extract_timestamp(filename: str): | def _extract_timestamp(filename: str): | ||||
| @@ -107,6 +110,9 @@ class ExplainRunner: | |||||
| After generating results with the explanation methods and the evaluation methods, the results will be written into | After generating results with the explanation methods and the evaluation methods, the results will be written into | ||||
| a specified file with `mindspore.summary.SummaryRecord`. The stored content can be viewed using MindInsight. | a specified file with `mindspore.summary.SummaryRecord`. The stored content can be viewed using MindInsight. | ||||
| Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version | |||||
| will be deprecated and will not be supported in MindInsight of current version. | |||||
| Args: | Args: | ||||
| summary_dir (str, optional): The directory path to save the summary files which store the generated results. | summary_dir (str, optional): The directory path to save the summary files which store the generated results. | ||||
| Default: "./" | Default: "./" | ||||
| @@ -131,7 +137,8 @@ class ExplainRunner: | |||||
| dataset: Tuple, | dataset: Tuple, | ||||
| explainers: List, | explainers: List, | ||||
| benchmarkers: Optional[List] = None, | benchmarkers: Optional[List] = None, | ||||
| uncertainty: Optional[UncertaintyEvaluation] = None): | |||||
| uncertainty: Optional[UncertaintyEvaluation] = None, | |||||
| activation_fn: Optional[Cell] = Softmax()): | |||||
| """ | """ | ||||
| Genereates results and writes results into the summary files in `summary_dir` specified during the object | Genereates results and writes results into the summary files in `summary_dir` specified during the object | ||||
| initialization. | initialization. | ||||
| @@ -149,8 +156,12 @@ class ExplainRunner: | |||||
| Default: None | Default: None | ||||
| uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference | uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference | ||||
| uncertainty of samples. | uncertainty of samples. | ||||
| activation_fn (Cell, optional): The activation layer that transforms the output of the network to | |||||
| label probability distribution :math:`P(y|x)`. Default: Softmax(). | |||||
| Examples: | Examples: | ||||
| >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient | >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient | ||||
| >>> from mindspore.nn import Sigmoid | |||||
| >>> # obtain dataset object | >>> # obtain dataset object | ||||
| >>> dataset = get_dataset() | >>> dataset = get_dataset() | ||||
| >>> classes = ["cat", "dog", ...] | >>> classes = ["cat", "dog", ...] | ||||
| @@ -158,13 +169,11 @@ class ExplainRunner: | |||||
| >>> param_dict = load_checkpoint("checkpoint.ckpt") | >>> param_dict = load_checkpoint("checkpoint.ckpt") | ||||
| >>> net = resnet50(len(classes)) | >>> net = resnet50(len(classes)) | ||||
| >>> load_parama_into_net(net, param_dict) | >>> load_parama_into_net(net, param_dict) | ||||
| >>> # bind net with its output activation | |||||
| >>> model = nn.SequentialCell([net, nn.Sigmoid()]) | |||||
| >>> gbp = GuidedBackprop(model) | |||||
| >>> gradient = Gradient(model) | |||||
| >>> gbp = GuidedBackprop(net) | |||||
| >>> gradient = Gradient(net) | |||||
| >>> runner = ExplainRunner("./") | >>> runner = ExplainRunner("./") | ||||
| >>> explainers = [gbp, gradient] | >>> explainers = [gbp, gradient] | ||||
| >>> runner.run((dataset, classes), explainers) | |||||
| >>> runner.run((dataset, classes), explainers, activation_fn=Sigmoid()) | |||||
| """ | """ | ||||
| check_value_type("dataset", dataset, tuple) | check_value_type("dataset", dataset, tuple) | ||||
| @@ -181,16 +190,17 @@ class ExplainRunner: | |||||
| for exp in explainers: | for exp in explainers: | ||||
| if not isinstance(exp, Attribution): | if not isinstance(exp, Attribution): | ||||
| raise TypeError("Argument explainers should be a list of objects of classes in " | |||||
| raise TypeError("Argument `explainers` should be a list of objects of classes in " | |||||
| "`mindspore.explainer.explanation`.") | "`mindspore.explainer.explanation`.") | ||||
| if benchmarkers is not None: | if benchmarkers is not None: | ||||
| check_value_type("benchmarkers", benchmarkers, list) | check_value_type("benchmarkers", benchmarkers, list) | ||||
| for bench in benchmarkers: | for bench in benchmarkers: | ||||
| if not isinstance(bench, AttributionMetric): | if not isinstance(bench, AttributionMetric): | ||||
| raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation" | |||||
| raise TypeError("Argument `benchmarkers` should be a list of objects of classes in explanation" | |||||
| "`mindspore.explainer.benchmark`.") | "`mindspore.explainer.benchmark`.") | ||||
| check_value_type("activation_fn", activation_fn, Cell) | |||||
| self._model = explainers[0].model | |||||
| self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn]) | |||||
| next_element = dataset.create_tuple_iterator().get_next() | next_element = dataset.create_tuple_iterator().get_next() | ||||
| inputs, _, _ = self._unpack_next_element(next_element) | inputs, _, _ = self._unpack_next_element(next_element) | ||||
| prop_test = self._model(inputs) | prop_test = self._model(inputs) | ||||
| @@ -211,9 +221,10 @@ class ExplainRunner: | |||||
| self._uncertainty = None | self._uncertainty = None | ||||
| with SummaryRecord(self._summary_dir) as summary: | with SummaryRecord(self._summary_dir) as summary: | ||||
| spacer = '{:120}\r' | |||||
| print("Start running and writing......") | print("Start running and writing......") | ||||
| begin = time() | begin = time() | ||||
| print("Start writing metadata.") | |||||
| print("Start writing metadata......") | |||||
| self._summary_timestamp = _extract_timestamp(summary.event_file_name) | self._summary_timestamp = _extract_timestamp(summary.event_file_name) | ||||
| if self._summary_timestamp is None: | if self._summary_timestamp is None: | ||||
| @@ -234,42 +245,47 @@ class ExplainRunner: | |||||
| print("Finish writing metadata.") | print("Finish writing metadata.") | ||||
| now = time() | now = time() | ||||
| print("Start running and writing inference data......") | |||||
| print("Start running and writing inference data.....") | |||||
| imageid_labels = self._run_inference(dataset, summary) | imageid_labels = self._run_inference(dataset, summary) | ||||
| print("Finish running and writing inference data. Time elapsed: {}s".format(time() - now)) | |||||
| print(spacer.format("Finish running and writing inference data. " | |||||
| "Time elapsed: {:.3f} s".format(time() - now))) | |||||
| if benchmarkers is None or not benchmarkers: | if benchmarkers is None or not benchmarkers: | ||||
| for exp in explainers: | for exp in explainers: | ||||
| start = time() | start = time() | ||||
| print("Start running and writing explanation data for {}......".format(exp.__class__.__name__)) | print("Start running and writing explanation data for {}......".format(exp.__class__.__name__)) | ||||
| self._count = 0 | self._count = 0 | ||||
| ds.config.set_seed(58) | |||||
| ds.config.set_seed(_SEED) | |||||
| for idx, next_element in enumerate(dataset): | for idx, next_element in enumerate(dataset): | ||||
| now = time() | now = time() | ||||
| self._run_exp_step(next_element, exp, imageid_labels, summary) | self._run_exp_step(next_element, exp, imageid_labels, summary) | ||||
| print("Finish writing {}-th explanation data. Time elapsed: {}".format( | |||||
| idx, time() - now)) | |||||
| print("Finish running and writing explanation data for {}. Time elapsed: {}".format( | |||||
| exp.__class__.__name__, time() - start)) | |||||
| print(spacer.format("Finish writing {}-th explanation data for {}. Time elapsed: " | |||||
| "{:.3f} s".format(idx, time() - now, exp.__class__.__name__)), end='') | |||||
| print(spacer.format( | |||||
| "Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format( | |||||
| exp.__class__.__name__, time() - start))) | |||||
| else: | else: | ||||
| for exp in explainers: | for exp in explainers: | ||||
| explain = Explain() | explain = Explain() | ||||
| for bench in benchmarkers: | for bench in benchmarkers: | ||||
| bench.reset() | bench.reset() | ||||
| print(f"Start running and writing explanation and benchmark data for {exp.__class__.__name__}.") | |||||
| print(f"Start running and writing explanation and " | |||||
| f"benchmark data for {exp.__class__.__name__}......") | |||||
| self._count = 0 | self._count = 0 | ||||
| start = time() | start = time() | ||||
| ds.config.set_seed(58) | |||||
| ds.config.set_seed(_SEED) | |||||
| for idx, next_element in enumerate(dataset): | for idx, next_element in enumerate(dataset): | ||||
| now = time() | now = time() | ||||
| saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary) | saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary) | ||||
| print("Finish writing {}-th batch explanation data. Time elapsed: {}s".format( | |||||
| idx, time() - now)) | |||||
| print(spacer.format( | |||||
| "Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format( | |||||
| idx, exp.__class__.__name__, time() - now)), end='') | |||||
| for bench in benchmarkers: | for bench in benchmarkers: | ||||
| now = time() | now = time() | ||||
| self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst) | self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst) | ||||
| print("Finish running {}-th batch benchmark data for {}. Time elapsed: {}s".format( | |||||
| idx, bench.__class__.__name__, time() - now)) | |||||
| print(spacer.format( | |||||
| "Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format( | |||||
| idx, bench.__class__.__name__, exp.__class__.__name__, time() - now)), end='') | |||||
| for bench in benchmarkers: | for bench in benchmarkers: | ||||
| benchmark = explain.benchmark.add() | benchmark = explain.benchmark.add() | ||||
| @@ -279,11 +295,11 @@ class ExplainRunner: | |||||
| benchmark.total_score = bench.performance | benchmark.total_score = bench.performance | ||||
| benchmark.label_score.extend(bench.class_performances) | benchmark.label_score.extend(bench.class_performances) | ||||
| print("Finish running and writing explanation and benchmark data for {}. " | |||||
| "Time elapsed: {}s".format(exp.__class__.__name__, time() - start)) | |||||
| print(spacer.format("Finish running and writing explanation and benchmark data for {}. " | |||||
| "Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start))) | |||||
| summary.add_value('explainer', 'benchmark', explain) | summary.add_value('explainer', 'benchmark', explain) | ||||
| summary.record(1) | summary.record(1) | ||||
| print("Finish running and writing. Total time elapsed: {}s".format(time() - begin)) | |||||
| print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin)) | |||||
| @staticmethod | @staticmethod | ||||
| def _verify_data_form(dataset, benchmarkers): | def _verify_data_form(dataset, benchmarkers): | ||||
| @@ -446,8 +462,9 @@ class ExplainRunner: | |||||
| Returns: | Returns: | ||||
| imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels. | imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels. | ||||
| """ | """ | ||||
| spacer = '{:120}\r' | |||||
| imageid_labels = {} | imageid_labels = {} | ||||
| ds.config.set_seed(58) | |||||
| ds.config.set_seed(_SEED) | |||||
| self._count = 0 | self._count = 0 | ||||
| for j, next_element in enumerate(dataset): | for j, next_element in enumerate(dataset): | ||||
| now = time() | now = time() | ||||
| @@ -516,7 +533,9 @@ class ExplainRunner: | |||||
| summary.record(1) | summary.record(1) | ||||
| self._count += 1 | self._count += 1 | ||||
| print("Finish running and writing {}-th batch inference data. Time elapsed: {}s".format(j, time() - now)) | |||||
| print(spacer.format("Finish running and writing {}-th batch inference data." | |||||
| " Time elapsed: {:.3f} s".format(j, time() - now)), | |||||
| end='') | |||||
| return imageid_labels | return imageid_labels | ||||
| def _run_exp_step(self, next_element, explainer, imageid_labels, summary): | def _run_exp_step(self, next_element, explainer, imageid_labels, summary): | ||||
| @@ -543,18 +562,22 @@ class ExplainRunner: | |||||
| batch_unions = self._make_label_batch(unions) | batch_unions = self._make_label_batch(unions) | ||||
| saliency_dict_lst = [] | saliency_dict_lst = [] | ||||
| batch_saliency_full = [] | |||||
| for i in range(len(batch_unions[0])): | |||||
| batch_saliency = explainer(inputs, batch_unions[:, i]) | |||||
| batch_saliency_full.append(batch_saliency) | |||||
| if isinstance(explainer, RISE): | |||||
| batch_saliency_full = explainer(inputs, batch_unions) | |||||
| else: | |||||
| batch_saliency_full = [] | |||||
| for i in range(len(batch_unions[0])): | |||||
| batch_saliency = explainer(inputs, batch_unions[:, i]) | |||||
| batch_saliency_full.append(batch_saliency) | |||||
| concat = ms.ops.operations.Concat(1) | |||||
| batch_saliency_full = concat(tuple(batch_saliency_full)) | |||||
| for idx, union in enumerate(unions): | for idx, union in enumerate(unions): | ||||
| saliency_dict = {} | saliency_dict = {} | ||||
| explain = Explain() | explain = Explain() | ||||
| explain.sample_id = self._count | explain.sample_id = self._count | ||||
| for k, lab in enumerate(union): | for k, lab in enumerate(union): | ||||
| saliency = batch_saliency_full[k][idx:idx + 1] | |||||
| saliency = batch_saliency_full[idx:idx + 1, k:k + 1] | |||||
| saliency_dict[lab] = saliency | saliency_dict[lab] = saliency | ||||
| saliency_np = _normalize(saliency.asnumpy().squeeze()) | saliency_np = _normalize(saliency.asnumpy().squeeze()) | ||||
| @@ -600,7 +623,7 @@ class ExplainRunner: | |||||
| def _save_original_image(self, sample_id: int, image): | def _save_original_image(self, sample_id: int, image): | ||||
| """Save an image to summary directory.""" | """Save an image to summary directory.""" | ||||
| id_dirname = _get_id_dirname(sample_id) | id_dirname = _get_id_dirname(sample_id) | ||||
| relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp), | |||||
| relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp), | |||||
| _ORIGINAL_IMAGE_DIRNAME, | _ORIGINAL_IMAGE_DIRNAME, | ||||
| id_dirname) | id_dirname) | ||||
| os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True) | os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True) | ||||
| @@ -613,7 +636,7 @@ class ExplainRunner: | |||||
| def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image): | def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image): | ||||
| """Save heatmap image to summary directory.""" | """Save heatmap image to summary directory.""" | ||||
| id_dirname = _get_id_dirname(sample_id) | id_dirname = _get_id_dirname(sample_id) | ||||
| relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp), | |||||
| relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp), | |||||
| _HEATMAP_DIRNAME, | _HEATMAP_DIRNAME, | ||||
| explain_method, | explain_method, | ||||
| id_dirname) | id_dirname) | ||||
| @@ -21,6 +21,7 @@ from scipy.ndimage.filters import gaussian_filter | |||||
| from mindspore import log | from mindspore import log | ||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore.train._utils import check_value_type | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.ops.operations as op | import mindspore.ops.operations as op | ||||
| from .metric import AttributionMetric | from .metric import AttributionMetric | ||||
| @@ -140,9 +141,7 @@ class Perturb: | |||||
| @staticmethod | @staticmethod | ||||
| def _assign(x: _Array, y: _Array, masks: _Array): | def _assign(x: _Array, y: _Array, masks: _Array): | ||||
| """Assign values to perturb pixels on perturbations.""" | """Assign values to perturb pixels on perturbations.""" | ||||
| if masks.dtype != bool: | |||||
| raise TypeError('The param "masks" should be an array of bool, but receive {}' | |||||
| .format(masks.dtype)) | |||||
| check_value_type("masks dtype", masks.dtype, type(np.dtype(bool))) | |||||
| for i in range(x.shape[0]): | for i in range(x.shape[0]): | ||||
| x[i][:, masks[i]] = y[:, masks[i]] | x[i][:, masks[i]] = y[:, masks[i]] | ||||
| @@ -336,8 +335,7 @@ class NaiveFaithfulness(_FaithfulnessHelper): | |||||
| if not np.count_nonzero(saliency): | if not np.count_nonzero(saliency): | ||||
| log.warning("The saliency map is zero everywhere. The correlation will be set to zero.") | log.warning("The saliency map is zero everywhere. The correlation will be set to zero.") | ||||
| correlation = 0 | correlation = 0 | ||||
| normalized_faithfulness = (correlation + 1) / 2 | |||||
| return np.array([normalized_faithfulness], np.float) | |||||
| return np.array([correlation], np.float) | |||||
| reference = self._get_reference(inputs) | reference = self._get_reference(inputs) | ||||
| perturbations, masks = self._perturb( | perturbations, masks = self._perturb( | ||||
| inputs, saliency, reference, return_mask=True) | inputs, saliency, reference, return_mask=True) | ||||
| @@ -347,8 +345,7 @@ class NaiveFaithfulness(_FaithfulnessHelper): | |||||
| predictions = model(perturbations).asnumpy()[:, targets] | predictions = model(perturbations).asnumpy()[:, targets] | ||||
| faithfulness = calc_correlation(feature_importance, predictions) | faithfulness = calc_correlation(feature_importance, predictions) | ||||
| normalized_faithfulness = (faithfulness + 1) / 2 | |||||
| return np.array([normalized_faithfulness], np.float) | |||||
| return np.array([faithfulness], np.float) | |||||
| class DeletionAUC(_FaithfulnessHelper): | class DeletionAUC(_FaithfulnessHelper): | ||||
| @@ -533,6 +530,8 @@ class Faithfulness(AttributionMetric): | |||||
| metric (str, optional): The specifi metric to quantify faithfulness. | metric (str, optional): The specifi metric to quantify faithfulness. | ||||
| Options: "DeletionAUC", "InsertionAUC", "NaiveFaithfulness". | Options: "DeletionAUC", "InsertionAUC", "NaiveFaithfulness". | ||||
| Default: 'NaiveFaithfulness'. | Default: 'NaiveFaithfulness'. | ||||
| activation_fn (Cell, optional): The activation function that transforms the network output to a probability. | |||||
| Default: nn.Softmax(). | |||||
| Examples: | Examples: | ||||
| >>> from mindspore.explainer.benchmark import Faithfulness | >>> from mindspore.explainer.benchmark import Faithfulness | ||||
| @@ -543,7 +542,7 @@ class Faithfulness(AttributionMetric): | |||||
| """ | """ | ||||
| _methods = [NaiveFaithfulness, DeletionAUC, InsertionAUC] | _methods = [NaiveFaithfulness, DeletionAUC, InsertionAUC] | ||||
| def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness"): | |||||
| def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness", activation_fn=nn.Softmax()): | |||||
| super(Faithfulness, self).__init__(num_labels) | super(Faithfulness, self).__init__(num_labels) | ||||
| perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument | perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument | ||||
| @@ -552,6 +551,7 @@ class Faithfulness(AttributionMetric): | |||||
| num_perturb_steps = 100 # separate the perturbation progress in to 100 steps. | num_perturb_steps = 100 # separate the perturbation progress in to 100 steps. | ||||
| base_value = 0.0 # the pixel value set for the perturbed pixels | base_value = 0.0 # the pixel value set for the perturbed pixels | ||||
| self._activation_fn = activation_fn | |||||
| self._verify_metrics(metric) | self._verify_metrics(metric) | ||||
| for method in self._methods: | for method in self._methods: | ||||
| if metric == method.__name__: | if metric == method.__name__: | ||||
| @@ -568,9 +568,7 @@ class Faithfulness(AttributionMetric): | |||||
| Evaluate faithfulness on a single data sample. | Evaluate faithfulness on a single data sample. | ||||
| Note: | Note: | ||||
| To apply `Faithfulness` to evaluate an explainer, this explainer must be initialized with a network that | |||||
| contains the output activation function. Otherwise, the results will not be correct. Currently only single | |||||
| sample (:math:`N=1`) at each call is supported. | |||||
| Currently only single sample (:math:`N=1`) at each call is supported. | |||||
| Args: | Args: | ||||
| explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`. | explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`. | ||||
| @@ -586,7 +584,7 @@ class Faithfulness(AttributionMetric): | |||||
| Examples: | Examples: | ||||
| >>> # init an explainer, the network should contain the output activation function. | >>> # init an explainer, the network should contain the output activation function. | ||||
| >>> network = nn.SequentialCell([resnet50, nn.Sigmoid()]) | |||||
| >>> network = resnet50(20) | |||||
| >>> gradient = Gradient(network) | >>> gradient = Gradient(network) | ||||
| >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | >>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32) | ||||
| >>> targets = 5 | >>> targets = 5 | ||||
| @@ -610,10 +608,10 @@ class Faithfulness(AttributionMetric): | |||||
| saliency = saliency.squeeze() | saliency = saliency.squeeze() | ||||
| if len(saliency.shape) != 2: | if len(saliency.shape) != 2: | ||||
| raise ValueError('Squeezed saliency map is expected to 2D, but receive {}.'.format(len(saliency.shape))) | raise ValueError('Squeezed saliency map is expected to 2D, but receive {}.'.format(len(saliency.shape))) | ||||
| faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=explainer.model, | |||||
| model = nn.SequentialCell([explainer.model, self._activation_fn]) | |||||
| faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=model, | |||||
| targets=targets, saliency=saliency) | targets=targets, saliency=saliency) | ||||
| return faithfulness | |||||
| return (1 + faithfulness) / 2 | |||||
| def _verify_metrics(self, metric: str): | def _verify_metrics(self, metric: str): | ||||
| supports = [x.__name__ for x in self._methods] | supports = [x.__name__ for x in self._methods] | ||||
| @@ -17,10 +17,12 @@ | |||||
| from ._attribution._backprop.gradcam import GradCAM | from ._attribution._backprop.gradcam import GradCAM | ||||
| from ._attribution._backprop.gradient import Gradient | from ._attribution._backprop.gradient import Gradient | ||||
| from ._attribution._backprop.modified_relu import Deconvolution, GuidedBackprop | from ._attribution._backprop.modified_relu import Deconvolution, GuidedBackprop | ||||
| from ._attribution._perturbation.rise import RISE | |||||
| __all__ = [ | __all__ = [ | ||||
| 'Gradient', | 'Gradient', | ||||
| 'Deconvolution', | 'Deconvolution', | ||||
| 'GuidedBackprop', | 'GuidedBackprop', | ||||
| 'GradCAM', | 'GradCAM', | ||||
| 'RISE' | |||||
| ] | ] | ||||
| @@ -16,10 +16,12 @@ | |||||
| from ._backprop.gradcam import GradCAM | from ._backprop.gradcam import GradCAM | ||||
| from ._backprop.gradient import Gradient | from ._backprop.gradient import Gradient | ||||
| from ._backprop.modified_relu import Deconvolution, GuidedBackprop | from ._backprop.modified_relu import Deconvolution, GuidedBackprop | ||||
| from ._perturbation.rise import RISE | |||||
| __all__ = [ | __all__ = [ | ||||
| 'Gradient', | 'Gradient', | ||||
| 'Deconvolution', | 'Deconvolution', | ||||
| 'GuidedBackprop', | 'GuidedBackprop', | ||||
| 'GradCAM', | 'GradCAM', | ||||
| 'RISE' | |||||
| ] | ] | ||||
| @@ -16,33 +16,25 @@ | |||||
| from typing import Callable | from typing import Callable | ||||
| import mindspore as ms | |||||
| from mindspore.train._utils import check_value_type | |||||
| from mindspore.nn import Cell | |||||
| class Attribution: | class Attribution: | ||||
| r""" | |||||
| """ | |||||
| Basic class of attributing the salient score | Basic class of attributing the salient score | ||||
| The explainers which explanation through attributing the relevance scores | |||||
| should inherit this class. | |||||
| The explainers which explanation through attributing the relevance scores should inherit this class. | |||||
| Args: | Args: | ||||
| network (ms.nn.Cell): The black-box model to explanation. | |||||
| network (Cell): The black-box model to explain. | |||||
| """ | """ | ||||
| def __init__(self, network): | def __init__(self, network): | ||||
| self._verify_model(network) | |||||
| check_value_type("network", network, Cell) | |||||
| self._model = network | self._model = network | ||||
| self._model.set_train(False) | self._model.set_train(False) | ||||
| self._model.set_grad(False) | self._model.set_grad(False) | ||||
| @staticmethod | |||||
| def _verify_model(model): | |||||
| """ | |||||
| Verify the input `network` for __init__ function. | |||||
| """ | |||||
| if not isinstance(model, ms.nn.Cell): | |||||
| raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.") | |||||
| __call__: Callable | __call__: Callable | ||||
| """ | """ | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """ GradCAM and GuidedGradCAM. """ | |||||
| """GradCAM.""" | |||||
| from mindspore.ops import operations as op | from mindspore.ops import operations as op | ||||
| @@ -98,7 +98,7 @@ class GradCAM(IntermediateLayerAttribution): | |||||
| """ | """ | ||||
| Hook function to deal with the backward gradient. | Hook function to deal with the backward gradient. | ||||
| The arguments are set as required by Cell.register_back_hook | |||||
| The arguments are set as required by `Cell.register_backward_hook`. | |||||
| """ | """ | ||||
| self._intermediate_grad = grad_input | self._intermediate_grad = grad_input | ||||
| @@ -0,0 +1,19 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ Perturbation-based _attribution explainer. """ | |||||
| from .rise import RISE | |||||
| __all__ = ['RISE'] | |||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Base class `PerturbationAttribtuion`""" | |||||
| from mindspore.train._utils import check_value_type | |||||
| from mindspore.nn import Cell | |||||
| from .._attribution import Attribution | |||||
| from ...._operators import softmax | |||||
| class PerturbationAttribution(Attribution): | |||||
| """ | |||||
| Base class for perturbation-based attribution methods. | |||||
| All perturbation-based _attribution methods extend from this class. | |||||
| """ | |||||
| def __init__(self, | |||||
| network, | |||||
| activation_fn=softmax(), | |||||
| ): | |||||
| super(PerturbationAttribution, self).__init__(network) | |||||
| check_value_type("activation_fn", activation_fn, Cell) | |||||
| self._activation_fn = activation_fn | |||||
| @@ -0,0 +1,194 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """RISE.""" | |||||
| import math | |||||
| import numpy as np | |||||
| from mindspore import Tensor | |||||
| from mindspore import nn | |||||
| from mindspore.train._utils import check_value_type | |||||
| from .perturbation import PerturbationAttribution | |||||
| from .... import _operators as op | |||||
| from ...._utils import resize | |||||
| class RISE(PerturbationAttribution): | |||||
| r""" | |||||
| RISE: Randomized Input Sampling for Explanation of Black-box Model. | |||||
| RISE is a perturbation-based method that generates attribution maps by sampling on multiple random binary masks. | |||||
| The original image is randomly masked, and then fed into the black-box model to get predictions. The final | |||||
| attribution map is the weighted sum of these random masks, with the weights being the corresponding output on the | |||||
| node of interest: | |||||
| .. math:: | |||||
| E_{RISE}(I, f)_c = \sum_{i}f_c(I\odot M_i) M_i | |||||
| For more details, please refer to the original paper via: `RISE <https://arxiv.org/abs/1806.07421>`_. | |||||
| Args: | |||||
| network (Cell): The black-box model to be explained. | |||||
| activation_fn (Cell, optional): The activation layer that transforms logits to prediction probabilities. For | |||||
| single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks, | |||||
| `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as | |||||
| when combining this function with network, the final output is the probability of the input. | |||||
| Default: `nn.Softmax`. | |||||
| perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the | |||||
| perturbed samples. Default: 32. | |||||
| Examples: | |||||
| >>> from mindspore.explainer.explanation import RISE | |||||
| >>> net = resnet50(10) | |||||
| >>> param_dict = load_checkpoint("resnet50.ckpt") | |||||
| >>> load_param_into_net(net, param_dict) | |||||
| >>> # init RISE with specified activation function | |||||
| >>> rise = RISE(net, activation_fn=nn.layer.Sigmoid()) | |||||
| """ | |||||
| def __init__(self, | |||||
| network, | |||||
| activation_fn=nn.Softmax(), | |||||
| perturbation_per_eval=32): | |||||
| super(RISE, self).__init__(network, activation_fn) | |||||
| self._perturbation_per_eval = perturbation_per_eval | |||||
| self._num_masks = 6000 # number of masks to be sampled | |||||
| self._mask_probability = 0.2 # ratio of inputs to be masked | |||||
| self._down_sample_size = 10 # the original size of binary masks | |||||
| self._resize_mode = 'bilinear' # mode choice to resize the down-sized binary masks to size of the inputs | |||||
| self._perturbation_mode = 'constant' # setting the perturbed pixels to a constant value | |||||
| self._base_value = 0 # setting the perturbed pixels to this constant value | |||||
| self._num_classes = None # placeholder of self._num_classes just for future assignment in other methods | |||||
| def _generate_masks(self, data, batch_size): | |||||
| """Generate a batch of binary masks for data.""" | |||||
| height, width = data.shape[2], data.shape[3] | |||||
| mask_size = (self._down_sample_size, self._down_sample_size) | |||||
| up_size = (height + mask_size[0], width + mask_size[1]) | |||||
| mask = np.random.random((batch_size, 1) + mask_size) < self._mask_probability | |||||
| upsample = resize(op.Tensor(mask, data.dtype), up_size, | |||||
| self._resize_mode) | |||||
| # Pack operator not available for GPU, thus transfer to numpy first | |||||
| upsample_np = upsample.asnumpy() | |||||
| masks_lst = [] | |||||
| for sample in upsample_np: | |||||
| shift_x = np.random.randint(0, mask_size[0] + 1) | |||||
| shift_y = np.random.randint(0, mask_size[1] + 1) | |||||
| masks_lst.append(sample[:, shift_x: shift_x + height, shift_y:shift_y + width]) | |||||
| masks = op.Tensor(np.array(masks_lst), data.dtype) | |||||
| return masks | |||||
| def __call__(self, inputs, targets): | |||||
| """ | |||||
| Generates attribution maps for inputs. | |||||
| Args: | |||||
| inputs (Tensor): Input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. | |||||
| targets (int, Tensor): The labels of interest to be explained. When `targets` is an integer, | |||||
| all of the inputs will generates attribution map w.r.t this integer. When `targets` is a tensor, it | |||||
| should be of shape :math:`(N, ?)` or :math:`(N,)` :math:`()`. | |||||
| Returns: | |||||
| Tensor, a 4D tensor of shape :math:`(N, ?, H, W)` or :math:`(N, 1, H, W)`. | |||||
| Examples: | |||||
| >>> # given an instance of RISE, saliency map can be generate | |||||
| >>> inputs = ms.Tensor(np.random.rand([2, 3, 224, 224]), ms.float32) | |||||
| >>> # when `targets` is an integer | |||||
| >>> targets = 5 | |||||
| >>> saliency = rise(inputs, targets) | |||||
| >>> # `targets` can also be a tensor | |||||
| >>> targets = ms.Tensor([[5], [1]]) | |||||
| >>> saliency = rise(inputs, targets) | |||||
| >>> | |||||
| """ | |||||
| self._verify_data(inputs, targets) | |||||
| height, width = inputs.shape[2], inputs.shape[3] | |||||
| batch_size = inputs.shape[0] | |||||
| if self._num_classes is None: | |||||
| logits = self.model(inputs) | |||||
| num_classes = logits.shape[1] | |||||
| self._num_classes = num_classes | |||||
| # Due to the unsupported Op of slice assignment, we use numpy array here | |||||
| attr_np = np.zeros(shape=(batch_size, self._num_classes, height, width)) | |||||
| cal_times = math.ceil(self._num_masks / self._perturbation_per_eval) | |||||
| for idx, data in enumerate(inputs): | |||||
| bg_data = data * 0 + self._base_value | |||||
| for j in range(cal_times): | |||||
| bs = min(self._num_masks - j * self._perturbation_per_eval, | |||||
| self._perturbation_per_eval) | |||||
| data = op.reshape(data, (1, -1, height, width)) | |||||
| masks = self._generate_masks(data, bs) | |||||
| masked_input = masks * data + (1 - masks) * bg_data | |||||
| weights = self._activation_fn(self.model(masked_input)) | |||||
| while len(weights.shape) > 2: | |||||
| weights = op.mean(weights, axis=2) | |||||
| weights = op.reshape(weights, | |||||
| (bs, self._num_classes, 1, 1)) | |||||
| attr_np[idx] += op.summation(weights * masks, axis=0).asnumpy() | |||||
| attr_np = attr_np / self._num_masks | |||||
| targets = self._unify_targets(inputs, targets) | |||||
| attr_classes = [] | |||||
| for idx, target in enumerate(targets): | |||||
| dtype = inputs.dtype | |||||
| attr_np_idx = attr_np[idx] | |||||
| attr_idx = attr_np_idx[target] | |||||
| attr_classes.append(attr_idx) | |||||
| return op.Tensor(attr_classes, dtype=dtype) | |||||
| @staticmethod | |||||
| def _verify_data(inputs, targets): | |||||
| """Verify the validity of the parsed inputs.""" | |||||
| check_value_type('inputs', inputs, Tensor) | |||||
| if len(inputs.shape) != 4: | |||||
| raise ValueError('Argument inputs must be 4D Tensor') | |||||
| check_value_type('targets', targets, (Tensor, int, tuple, list)) | |||||
| if isinstance(targets, Tensor): | |||||
| if len(targets.shape) > 2: | |||||
| raise ValueError('Dimension invalid. If `targets` is a Tensor, it should be 0D, 1D or 2D. ' | |||||
| 'But got {}D.'.format(len(targets.shape))) | |||||
| if targets.shape and len(targets) != len(inputs): | |||||
| raise ValueError( | |||||
| 'If `targets` is a 2D, 1D Tensor, it should have the same length as inputs {}. But got {}'.format( | |||||
| len(inputs), len(targets))) | |||||
| @staticmethod | |||||
| def _unify_targets(inputs, targets): | |||||
| """To unify targets to be 2D numpy.ndarray.""" | |||||
| if isinstance(targets, int): | |||||
| return np.array([[targets] for i in inputs]).astype(np.int) | |||||
| if isinstance(targets, Tensor): | |||||
| if not targets.shape: | |||||
| return np.array([[targets.asnumpy()] for _ in inputs]).astype(np.int) | |||||
| if len(targets.shape) == 1: | |||||
| return np.array([[t.asnumpy()] for t in targets]).astype(np.int) | |||||
| if len(targets.shape) == 2: | |||||
| return np.array([t.asnumpy() for t in targets]).astype(np.int) | |||||
| return targets | |||||