| @@ -159,13 +159,14 @@ class ExplainRunner: | |||
| label probability distribution :math:`P(y|x)`. Default: Softmax(). | |||
| Examples: | |||
| >>> from mindspore.explainer import ExplainRunner | |||
| >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient | |||
| >>> from mindspore.nn import Sigmoid | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # obtain dataset object | |||
| >>> dataset = get_dataset() | |||
| >>> classes = ["cat", "dog", ...] | |||
| >>> # load checkpoint to a network, e.g. resnet50 | |||
| >>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10 | |||
| >>> dataset = get_dataset('/path/to/Cifar10_dataset') | |||
| >>> classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'turck'] | |||
| >>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10 | |||
| >>> param_dict = load_checkpoint("checkpoint.ckpt") | |||
| >>> net = resnet50(len(classes)) | |||
| >>> load_param_into_net(net, param_dict) | |||
| @@ -204,7 +205,7 @@ class ExplainRunner: | |||
| check_value_type("activation_fn", activation_fn, Cell) | |||
| self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn]) | |||
| next_element = dataset.create_tuple_iterator().get_next() | |||
| next_element = next(dataset.create_tuple_iterator()) | |||
| inputs, _, _ = self._unpack_next_element(next_element) | |||
| prop_test = self._model(inputs) | |||
| check_value_type("output of model im explainer", prop_test, ms.Tensor) | |||
| @@ -314,7 +315,7 @@ class ExplainRunner: | |||
| dataset (`ds`): the user parsed dataset. | |||
| benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers. | |||
| """ | |||
| next_element = dataset.create_tuple_iterator().get_next() | |||
| next_element = next(dataset.create_tuple_iterator()) | |||
| if len(next_element) not in [1, 2, 3]: | |||
| raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]" | |||
| @@ -611,29 +612,27 @@ class ExplainRunner: | |||
| inputs, labels, _ = self._unpack_next_element(next_element) | |||
| for idx, inp in enumerate(inputs): | |||
| inp = _EXPAND_DIMS(inp, 0) | |||
| saliency_dict = saliency_dict_lst[idx] | |||
| for label, saliency in saliency_dict.items(): | |||
| if isinstance(benchmarker, Localization): | |||
| _, _, bboxes = self._unpack_next_element(next_element, True) | |||
| if label in labels[idx]: | |||
| res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], | |||
| saliency=saliency) | |||
| if np.any(res == np.nan): | |||
| res = np.zeros_like(res) | |||
| if isinstance(benchmarker, LabelAgnosticMetric): | |||
| res = benchmarker.evaluate(explainer, inp) | |||
| res[np.isnan(res)] = 0.0 | |||
| benchmarker.aggregate(res) | |||
| else: | |||
| saliency_dict = saliency_dict_lst[idx] | |||
| for label, saliency in saliency_dict.items(): | |||
| if isinstance(benchmarker, Localization): | |||
| _, _, bboxes = self._unpack_next_element(next_element, True) | |||
| if label in labels[idx]: | |||
| res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], | |||
| saliency=saliency) | |||
| res[np.isnan(res)] = 0.0 | |||
| benchmarker.aggregate(res, label) | |||
| elif isinstance(benchmarker, LabelSensitiveMetric): | |||
| res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) | |||
| res[np.isnan(res)] = 0.0 | |||
| benchmarker.aggregate(res, label) | |||
| elif isinstance(benchmarker, LabelSensitiveMetric): | |||
| res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) | |||
| if np.any(res == np.nan): | |||
| res = np.zeros_like(res) | |||
| benchmarker.aggregate(res, label) | |||
| elif isinstance(benchmarker, LabelAgnosticMetric): | |||
| res = benchmarker.evaluate(explainer, inp) | |||
| if np.any(res == np.nan): | |||
| res = np.zeros_like(res) | |||
| benchmarker.aggregate(res) | |||
| else: | |||
| raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but' | |||
| 'receive {}'.format(type(benchmarker))) | |||
| else: | |||
| raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but' | |||
| 'receive {}'.format(type(benchmarker))) | |||
| def _save_original_image(self, sample_id: int, image): | |||
| """Save an image to summary directory.""" | |||