diff --git a/mindspore/explainer/_runner.py b/mindspore/explainer/_runner.py index 1bc062a46d..abe7ba8a66 100644 --- a/mindspore/explainer/_runner.py +++ b/mindspore/explainer/_runner.py @@ -159,13 +159,14 @@ class ExplainRunner: label probability distribution :math:`P(y|x)`. Default: Softmax(). Examples: + >>> from mindspore.explainer import ExplainRunner >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient >>> from mindspore.nn import Sigmoid >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net - >>> # obtain dataset object - >>> dataset = get_dataset() - >>> classes = ["cat", "dog", ...] - >>> # load checkpoint to a network, e.g. resnet50 + >>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10 + >>> dataset = get_dataset('/path/to/Cifar10_dataset') + >>> classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'turck'] + >>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10 >>> param_dict = load_checkpoint("checkpoint.ckpt") >>> net = resnet50(len(classes)) >>> load_param_into_net(net, param_dict) @@ -204,7 +205,7 @@ class ExplainRunner: check_value_type("activation_fn", activation_fn, Cell) self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn]) - next_element = dataset.create_tuple_iterator().get_next() + next_element = next(dataset.create_tuple_iterator()) inputs, _, _ = self._unpack_next_element(next_element) prop_test = self._model(inputs) check_value_type("output of model im explainer", prop_test, ms.Tensor) @@ -314,7 +315,7 @@ class ExplainRunner: dataset (`ds`): the user parsed dataset. benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers. """ - next_element = dataset.create_tuple_iterator().get_next() + next_element = next(dataset.create_tuple_iterator()) if len(next_element) not in [1, 2, 3]: raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]" @@ -611,29 +612,27 @@ class ExplainRunner: inputs, labels, _ = self._unpack_next_element(next_element) for idx, inp in enumerate(inputs): inp = _EXPAND_DIMS(inp, 0) - saliency_dict = saliency_dict_lst[idx] - for label, saliency in saliency_dict.items(): - if isinstance(benchmarker, Localization): - _, _, bboxes = self._unpack_next_element(next_element, True) - if label in labels[idx]: - res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], - saliency=saliency) - if np.any(res == np.nan): - res = np.zeros_like(res) + if isinstance(benchmarker, LabelAgnosticMetric): + res = benchmarker.evaluate(explainer, inp) + res[np.isnan(res)] = 0.0 + benchmarker.aggregate(res) + else: + saliency_dict = saliency_dict_lst[idx] + for label, saliency in saliency_dict.items(): + if isinstance(benchmarker, Localization): + _, _, bboxes = self._unpack_next_element(next_element, True) + if label in labels[idx]: + res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], + saliency=saliency) + res[np.isnan(res)] = 0.0 + benchmarker.aggregate(res, label) + elif isinstance(benchmarker, LabelSensitiveMetric): + res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) + res[np.isnan(res)] = 0.0 benchmarker.aggregate(res, label) - elif isinstance(benchmarker, LabelSensitiveMetric): - res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) - if np.any(res == np.nan): - res = np.zeros_like(res) - benchmarker.aggregate(res, label) - elif isinstance(benchmarker, LabelAgnosticMetric): - res = benchmarker.evaluate(explainer, inp) - if np.any(res == np.nan): - res = np.zeros_like(res) - benchmarker.aggregate(res) - else: - raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but' - 'receive {}'.format(type(benchmarker))) + else: + raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but' + 'receive {}'.format(type(benchmarker))) def _save_original_image(self, sample_id: int, image): """Save an image to summary directory."""