|
|
|
@@ -159,13 +159,14 @@ class ExplainRunner: |
|
|
|
label probability distribution :math:`P(y|x)`. Default: Softmax(). |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> from mindspore.explainer import ExplainRunner |
|
|
|
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient |
|
|
|
>>> from mindspore.nn import Sigmoid |
|
|
|
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
|
>>> # obtain dataset object |
|
|
|
>>> dataset = get_dataset() |
|
|
|
>>> classes = ["cat", "dog", ...] |
|
|
|
>>> # load checkpoint to a network, e.g. resnet50 |
|
|
|
>>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10 |
|
|
|
>>> dataset = get_dataset('/path/to/Cifar10_dataset') |
|
|
|
>>> classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'turck'] |
|
|
|
>>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10 |
|
|
|
>>> param_dict = load_checkpoint("checkpoint.ckpt") |
|
|
|
>>> net = resnet50(len(classes)) |
|
|
|
>>> load_param_into_net(net, param_dict) |
|
|
|
@@ -204,7 +205,7 @@ class ExplainRunner: |
|
|
|
check_value_type("activation_fn", activation_fn, Cell) |
|
|
|
|
|
|
|
self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn]) |
|
|
|
next_element = dataset.create_tuple_iterator().get_next() |
|
|
|
next_element = next(dataset.create_tuple_iterator()) |
|
|
|
inputs, _, _ = self._unpack_next_element(next_element) |
|
|
|
prop_test = self._model(inputs) |
|
|
|
check_value_type("output of model im explainer", prop_test, ms.Tensor) |
|
|
|
@@ -314,7 +315,7 @@ class ExplainRunner: |
|
|
|
dataset (`ds`): the user parsed dataset. |
|
|
|
benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers. |
|
|
|
""" |
|
|
|
next_element = dataset.create_tuple_iterator().get_next() |
|
|
|
next_element = next(dataset.create_tuple_iterator()) |
|
|
|
|
|
|
|
if len(next_element) not in [1, 2, 3]: |
|
|
|
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]" |
|
|
|
@@ -611,29 +612,27 @@ class ExplainRunner: |
|
|
|
inputs, labels, _ = self._unpack_next_element(next_element) |
|
|
|
for idx, inp in enumerate(inputs): |
|
|
|
inp = _EXPAND_DIMS(inp, 0) |
|
|
|
saliency_dict = saliency_dict_lst[idx] |
|
|
|
for label, saliency in saliency_dict.items(): |
|
|
|
if isinstance(benchmarker, Localization): |
|
|
|
_, _, bboxes = self._unpack_next_element(next_element, True) |
|
|
|
if label in labels[idx]: |
|
|
|
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], |
|
|
|
saliency=saliency) |
|
|
|
if np.any(res == np.nan): |
|
|
|
res = np.zeros_like(res) |
|
|
|
if isinstance(benchmarker, LabelAgnosticMetric): |
|
|
|
res = benchmarker.evaluate(explainer, inp) |
|
|
|
res[np.isnan(res)] = 0.0 |
|
|
|
benchmarker.aggregate(res) |
|
|
|
else: |
|
|
|
saliency_dict = saliency_dict_lst[idx] |
|
|
|
for label, saliency in saliency_dict.items(): |
|
|
|
if isinstance(benchmarker, Localization): |
|
|
|
_, _, bboxes = self._unpack_next_element(next_element, True) |
|
|
|
if label in labels[idx]: |
|
|
|
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], |
|
|
|
saliency=saliency) |
|
|
|
res[np.isnan(res)] = 0.0 |
|
|
|
benchmarker.aggregate(res, label) |
|
|
|
elif isinstance(benchmarker, LabelSensitiveMetric): |
|
|
|
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) |
|
|
|
res[np.isnan(res)] = 0.0 |
|
|
|
benchmarker.aggregate(res, label) |
|
|
|
elif isinstance(benchmarker, LabelSensitiveMetric): |
|
|
|
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) |
|
|
|
if np.any(res == np.nan): |
|
|
|
res = np.zeros_like(res) |
|
|
|
benchmarker.aggregate(res, label) |
|
|
|
elif isinstance(benchmarker, LabelAgnosticMetric): |
|
|
|
res = benchmarker.evaluate(explainer, inp) |
|
|
|
if np.any(res == np.nan): |
|
|
|
res = np.zeros_like(res) |
|
|
|
benchmarker.aggregate(res) |
|
|
|
else: |
|
|
|
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but' |
|
|
|
'receive {}'.format(type(benchmarker))) |
|
|
|
else: |
|
|
|
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but' |
|
|
|
'receive {}'.format(type(benchmarker))) |
|
|
|
|
|
|
|
def _save_original_image(self, sample_id: int, image): |
|
|
|
"""Save an image to summary directory.""" |
|
|
|
|