Merge pull request !7778 from lixiaohui33/feature_explain_coretags/v1.1.0
| @@ -17,8 +17,9 @@ from time import time | |||||
| from typing import Tuple, List, Optional | from typing import Tuple, List, Optional | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore.train.summary_pb2 import Explain | |||||
| from mindspore.train._utils import check_value_type | |||||
| from mindspore.train.summary_pb2 import Explain | |||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import log | from mindspore import log | ||||
| @@ -71,6 +72,7 @@ class ExplainRunner: | |||||
| """ | """ | ||||
| def __init__(self, summary_dir: Optional[str] = "./"): | def __init__(self, summary_dir: Optional[str] = "./"): | ||||
| check_value_type("summary_dir", summary_dir, str) | |||||
| self._summary_dir = summary_dir | self._summary_dir = summary_dir | ||||
| self._count = 0 | self._count = 0 | ||||
| self._classes = None | self._classes = None | ||||
| @@ -123,14 +125,21 @@ class ExplainRunner: | |||||
| for exp in explainers: | for exp in explainers: | ||||
| if not isinstance(exp, Attribution) or not isinstance(explainers, list): | if not isinstance(exp, Attribution) or not isinstance(explainers, list): | ||||
| raise TypeError("Argument explainers should be a list of objects of classes in " | raise TypeError("Argument explainers should be a list of objects of classes in " | ||||
| "`mindspore.explainer.explanation._attribution`.") | |||||
| "`mindspore.explainer.explanation`.") | |||||
| if benchmarkers is not None: | if benchmarkers is not None: | ||||
| for bench in benchmarkers: | for bench in benchmarkers: | ||||
| if not isinstance(bench, AttributionMetric) or not isinstance(explainers, list): | if not isinstance(bench, AttributionMetric) or not isinstance(explainers, list): | ||||
| raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation" | raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation" | ||||
| "`mindspore.explainer.benchmark._attribution`.") | |||||
| "`mindspore.explainer.benchmark`.") | |||||
| self._model = explainers[0].model | self._model = explainers[0].model | ||||
| next_element = dataset.create_tuple_iterator().get_next() | |||||
| inputs, _, _ = self._unpack_next_element(next_element) | |||||
| prop_test = self._model(inputs) | |||||
| check_value_type("output of model im explainer", prop_test, ms.Tensor) | |||||
| if prop_test.shape[1] > len(self._classes): | |||||
| raise ValueError("The dimension of model output should not exceed the length of dataset classes. Please " | |||||
| "check dataset classes or the black-box model in the explainer again.") | |||||
| with SummaryRecord(self._summary_dir) as summary: | with SummaryRecord(self._summary_dir) as summary: | ||||
| print("Start running and writing......") | print("Start running and writing......") | ||||
| @@ -29,6 +29,7 @@ __all__ = [ | |||||
| ] | ] | ||||
| from typing import Tuple, Union | from typing import Tuple, Union | ||||
| import math | |||||
| import numpy as np | import numpy as np | ||||
| from PIL import Image | from PIL import Image | ||||
| @@ -204,7 +205,8 @@ def calc_correlation(x: Union[ms.Tensor, np.ndarray], | |||||
| x = format_tensor_to_ndarray(x) | x = format_tensor_to_ndarray(x) | ||||
| y = format_tensor_to_ndarray(y) | y = format_tensor_to_ndarray(y) | ||||
| faithfulness = -np.corrcoef(x, y)[0, 1] | faithfulness = -np.corrcoef(x, y)[0, 1] | ||||
| if math.isnan(faithfulness): | |||||
| return np.float(0) | |||||
| return faithfulness | return faithfulness | ||||
| @@ -232,7 +234,6 @@ def rank_pixels(inputs: _Array, descending: bool = True) -> _Array: | |||||
| >> np.array([[2, 3, 4], [1, 0, 5]]) | >> np.array([[2, 3, 4], [1, 0, 5]]) | ||||
| rank_pixels(x, descending=False) | rank_pixels(x, descending=False) | ||||
| >> np.array([[3, 2, 0], [4, 5, 1]]) | >> np.array([[3, 2, 0], [4, 5, 1]]) | ||||
| """ | """ | ||||
| if len(inputs.shape) != 2: | if len(inputs.shape) != 2: | ||||
| raise ValueError('Only support 2D array currently') | raise ValueError('Only support 2D array currently') | ||||
| @@ -339,6 +339,7 @@ class NaiveFaithfulness(_FaithfulnessHelper): | |||||
| perturbations = ms.Tensor(perturbations, dtype=ms.float32) | perturbations = ms.Tensor(perturbations, dtype=ms.float32) | ||||
| predictions = model(perturbations).asnumpy()[:, targets] | predictions = model(perturbations).asnumpy()[:, targets] | ||||
| faithfulness = calc_correlation(feature_importance, predictions) | faithfulness = calc_correlation(feature_importance, predictions) | ||||
| normalized_faithfulness = (faithfulness + 1) / 2 | normalized_faithfulness = (faithfulness + 1) / 2 | ||||
| return np.array([normalized_faithfulness], np.float) | return np.array([normalized_faithfulness], np.float) | ||||
| @@ -90,7 +90,7 @@ class Localization(AttributionMetric): | |||||
| Evaluate localization on a single data sample. | Evaluate localization on a single data sample. | ||||
| Args: | Args: | ||||
| explainer (Explanation): The explainer to be evaluated, see `mindspore/explainer/explanation`. | |||||
| explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`. | |||||
| inputs (Tensor): data sample. Currently only support single sample at each call. | inputs (Tensor): data sample. Currently only support single sample at each call. | ||||
| targets (int): target label to evaluate on. | targets (int): target label to evaluate on. | ||||
| saliency (Tensor): A saliency tensor. | saliency (Tensor): A saliency tensor. | ||||
| @@ -113,7 +113,7 @@ class Localization(AttributionMetric): | |||||
| >>> saliency = gradient(inputs, targets) | >>> saliency = gradient(inputs, targets) | ||||
| >>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks) | >>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks) | ||||
| """ | """ | ||||
| self._check_evaluate_param(explainer, inputs, targets, saliency) | |||||
| self._check_evaluate_param_with_mask(explainer, inputs, targets, saliency, mask) | |||||
| mask_np = format_tensor_to_ndarray(mask)[0] | mask_np = format_tensor_to_ndarray(mask)[0] | ||||
| @@ -141,6 +141,10 @@ class Localization(AttributionMetric): | |||||
| def _check_evaluate_param_with_mask(self, explainer, inputs, targets, saliency, mask): | def _check_evaluate_param_with_mask(self, explainer, inputs, targets, saliency, mask): | ||||
| self._check_evaluate_param(explainer, inputs, targets, saliency) | self._check_evaluate_param(explainer, inputs, targets, saliency) | ||||
| check_value_type('mask', mask, (Tensor, np.ndarray)) | |||||
| if len(inputs.shape) != 4: | if len(inputs.shape) != 4: | ||||
| raise ValueError('Argument mask must be 4D Tensor') | raise ValueError('Argument mask must be 4D Tensor') | ||||
| if mask is None: | |||||
| raise ValueError('To compute localization, mask must be provided.') | |||||
| check_value_type('mask', mask, (Tensor, np.ndarray)) | |||||
| if len(mask.shape) != 4 or len(mask) != len(inputs): | |||||
| raise ValueError("The input mask must be 4-dimensional (1, 1, h, w) with same length of inputs.") | |||||
| @@ -47,9 +47,16 @@ class AttributionMetric: | |||||
| """Super class of XAI metric class used in classification scenarios.""" | """Super class of XAI metric class used in classification scenarios.""" | ||||
| def __init__(self, num_labels=None): | def __init__(self, num_labels=None): | ||||
| self._verify_params(num_labels) | |||||
| self._num_labels = num_labels | self._num_labels = num_labels | ||||
| self._global_results = {i: [] for i in range(num_labels)} | self._global_results = {i: [] for i in range(num_labels)} | ||||
| @staticmethod | |||||
| def _verify_params(num_labels): | |||||
| check_value_type("num_labels", num_labels, int) | |||||
| if num_labels < 1: | |||||
| raise ValueError("Argument num_labels must be parsed with a integer > 0.") | |||||
| def evaluate(self, explainer, inputs, targets, saliency=None): | def evaluate(self, explainer, inputs, targets, saliency=None): | ||||
| """This function evaluates on a single sample and return the result.""" | """This function evaluates on a single sample and return the result.""" | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -119,5 +126,11 @@ class AttributionMetric: | |||||
| """Check the evaluate parameters.""" | """Check the evaluate parameters.""" | ||||
| check_value_type('explainer', explainer, Attribution) | check_value_type('explainer', explainer, Attribution) | ||||
| verify_argument(inputs, 'inputs') | verify_argument(inputs, 'inputs') | ||||
| output = explainer.model(inputs) | |||||
| check_value_type("output of explainer model", output, Tensor) | |||||
| output_dim = explainer.model(inputs).shape[1] | |||||
| if output_dim > self._num_labels: | |||||
| raise ValueError("The output dimension of of black-box model in explainer should not exceed the dimension " | |||||
| "of num_labels set in the __init__, please set num_labels larger.") | |||||
| verify_targets(targets, self._num_labels) | verify_targets(targets, self._num_labels) | ||||
| check_value_type('saliency', saliency, (Tensor, type(None))) | check_value_type('saliency', saliency, (Tensor, type(None))) | ||||