Merge pull request !7778 from lixiaohui33/feature_explain_coretags/v1.1.0
| @@ -17,8 +17,9 @@ from time import time | |||
| from typing import Tuple, List, Optional | |||
| import numpy as np | |||
| from mindspore.train.summary_pb2 import Explain | |||
| from mindspore.train._utils import check_value_type | |||
| from mindspore.train.summary_pb2 import Explain | |||
| import mindspore as ms | |||
| import mindspore.dataset as ds | |||
| from mindspore import log | |||
| @@ -71,6 +72,7 @@ class ExplainRunner: | |||
| """ | |||
| def __init__(self, summary_dir: Optional[str] = "./"): | |||
| check_value_type("summary_dir", summary_dir, str) | |||
| self._summary_dir = summary_dir | |||
| self._count = 0 | |||
| self._classes = None | |||
| @@ -123,14 +125,21 @@ class ExplainRunner: | |||
| for exp in explainers: | |||
| if not isinstance(exp, Attribution) or not isinstance(explainers, list): | |||
| raise TypeError("Argument explainers should be a list of objects of classes in " | |||
| "`mindspore.explainer.explanation._attribution`.") | |||
| "`mindspore.explainer.explanation`.") | |||
| if benchmarkers is not None: | |||
| for bench in benchmarkers: | |||
| if not isinstance(bench, AttributionMetric) or not isinstance(explainers, list): | |||
| raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation" | |||
| "`mindspore.explainer.benchmark._attribution`.") | |||
| "`mindspore.explainer.benchmark`.") | |||
| self._model = explainers[0].model | |||
| next_element = dataset.create_tuple_iterator().get_next() | |||
| inputs, _, _ = self._unpack_next_element(next_element) | |||
| prop_test = self._model(inputs) | |||
| check_value_type("output of model im explainer", prop_test, ms.Tensor) | |||
| if prop_test.shape[1] > len(self._classes): | |||
| raise ValueError("The dimension of model output should not exceed the length of dataset classes. Please " | |||
| "check dataset classes or the black-box model in the explainer again.") | |||
| with SummaryRecord(self._summary_dir) as summary: | |||
| print("Start running and writing......") | |||
| @@ -29,6 +29,7 @@ __all__ = [ | |||
| ] | |||
| from typing import Tuple, Union | |||
| import math | |||
| import numpy as np | |||
| from PIL import Image | |||
| @@ -204,7 +205,8 @@ def calc_correlation(x: Union[ms.Tensor, np.ndarray], | |||
| x = format_tensor_to_ndarray(x) | |||
| y = format_tensor_to_ndarray(y) | |||
| faithfulness = -np.corrcoef(x, y)[0, 1] | |||
| if math.isnan(faithfulness): | |||
| return np.float(0) | |||
| return faithfulness | |||
| @@ -232,7 +234,6 @@ def rank_pixels(inputs: _Array, descending: bool = True) -> _Array: | |||
| >> np.array([[2, 3, 4], [1, 0, 5]]) | |||
| rank_pixels(x, descending=False) | |||
| >> np.array([[3, 2, 0], [4, 5, 1]]) | |||
| """ | |||
| if len(inputs.shape) != 2: | |||
| raise ValueError('Only support 2D array currently') | |||
| @@ -339,6 +339,7 @@ class NaiveFaithfulness(_FaithfulnessHelper): | |||
| perturbations = ms.Tensor(perturbations, dtype=ms.float32) | |||
| predictions = model(perturbations).asnumpy()[:, targets] | |||
| faithfulness = calc_correlation(feature_importance, predictions) | |||
| normalized_faithfulness = (faithfulness + 1) / 2 | |||
| return np.array([normalized_faithfulness], np.float) | |||
| @@ -90,7 +90,7 @@ class Localization(AttributionMetric): | |||
| Evaluate localization on a single data sample. | |||
| Args: | |||
| explainer (Explanation): The explainer to be evaluated, see `mindspore/explainer/explanation`. | |||
| explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`. | |||
| inputs (Tensor): data sample. Currently only support single sample at each call. | |||
| targets (int): target label to evaluate on. | |||
| saliency (Tensor): A saliency tensor. | |||
| @@ -113,7 +113,7 @@ class Localization(AttributionMetric): | |||
| >>> saliency = gradient(inputs, targets) | |||
| >>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks) | |||
| """ | |||
| self._check_evaluate_param(explainer, inputs, targets, saliency) | |||
| self._check_evaluate_param_with_mask(explainer, inputs, targets, saliency, mask) | |||
| mask_np = format_tensor_to_ndarray(mask)[0] | |||
| @@ -141,6 +141,10 @@ class Localization(AttributionMetric): | |||
| def _check_evaluate_param_with_mask(self, explainer, inputs, targets, saliency, mask): | |||
| self._check_evaluate_param(explainer, inputs, targets, saliency) | |||
| check_value_type('mask', mask, (Tensor, np.ndarray)) | |||
| if len(inputs.shape) != 4: | |||
| raise ValueError('Argument mask must be 4D Tensor') | |||
| if mask is None: | |||
| raise ValueError('To compute localization, mask must be provided.') | |||
| check_value_type('mask', mask, (Tensor, np.ndarray)) | |||
| if len(mask.shape) != 4 or len(mask) != len(inputs): | |||
| raise ValueError("The input mask must be 4-dimensional (1, 1, h, w) with same length of inputs.") | |||
| @@ -47,9 +47,16 @@ class AttributionMetric: | |||
| """Super class of XAI metric class used in classification scenarios.""" | |||
| def __init__(self, num_labels=None): | |||
| self._verify_params(num_labels) | |||
| self._num_labels = num_labels | |||
| self._global_results = {i: [] for i in range(num_labels)} | |||
| @staticmethod | |||
| def _verify_params(num_labels): | |||
| check_value_type("num_labels", num_labels, int) | |||
| if num_labels < 1: | |||
| raise ValueError("Argument num_labels must be parsed with a integer > 0.") | |||
| def evaluate(self, explainer, inputs, targets, saliency=None): | |||
| """This function evaluates on a single sample and return the result.""" | |||
| raise NotImplementedError | |||
| @@ -119,5 +126,11 @@ class AttributionMetric: | |||
| """Check the evaluate parameters.""" | |||
| check_value_type('explainer', explainer, Attribution) | |||
| verify_argument(inputs, 'inputs') | |||
| output = explainer.model(inputs) | |||
| check_value_type("output of explainer model", output, Tensor) | |||
| output_dim = explainer.model(inputs).shape[1] | |||
| if output_dim > self._num_labels: | |||
| raise ValueError("The output dimension of of black-box model in explainer should not exceed the dimension " | |||
| "of num_labels set in the __init__, please set num_labels larger.") | |||
| verify_targets(targets, self._num_labels) | |||
| check_value_type('saliency', saliency, (Tensor, type(None))) | |||