Browse Source

correct mistakes in docstring and fix bugs when no benchmark is passed.

tags/v1.1.0
YuhanShi53 5 years ago
parent
commit
c24fabcd8e
1 changed files with 10 additions and 7 deletions
  1. +10
    -7
      mindspore/explainer/_runner.py

+ 10
- 7
mindspore/explainer/_runner.py View File

@@ -161,13 +161,14 @@ class ExplainRunner:
Examples: Examples:
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
>>> from mindspore.nn import Sigmoid >>> from mindspore.nn import Sigmoid
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
>>> # obtain dataset object >>> # obtain dataset object
>>> dataset = get_dataset() >>> dataset = get_dataset()
>>> classes = ["cat", "dog", ...] >>> classes = ["cat", "dog", ...]
>>> # load checkpoint to a network, e.g. resnet50 >>> # load checkpoint to a network, e.g. resnet50
>>> param_dict = load_checkpoint("checkpoint.ckpt") >>> param_dict = load_checkpoint("checkpoint.ckpt")
>>> net = resnet50(len(classes)) >>> net = resnet50(len(classes))
>>> load_parama_into_net(net, param_dict)
>>> load_param_into_net(net, param_dict)
>>> gbp = GuidedBackprop(net) >>> gbp = GuidedBackprop(net)
>>> gradient = Gradient(net) >>> gradient = Gradient(net)
>>> runner = ExplainRunner("./") >>> runner = ExplainRunner("./")
@@ -180,6 +181,9 @@ class ExplainRunner:
raise ValueError("Argument `dataset` should be a tuple with length = 2.") raise ValueError("Argument `dataset` should be a tuple with length = 2.")


dataset, classes = dataset dataset, classes = dataset
if benchmarkers is None:
benchmarkers = []

self._verify_data_form(dataset, benchmarkers) self._verify_data_form(dataset, benchmarkers)
self._classes = classes self._classes = classes


@@ -191,7 +195,7 @@ class ExplainRunner:
if not isinstance(exp, Attribution): if not isinstance(exp, Attribution):
raise TypeError("Argument `explainers` should be a list of objects of classes in " raise TypeError("Argument `explainers` should be a list of objects of classes in "
"`mindspore.explainer.explanation`.") "`mindspore.explainer.explanation`.")
if benchmarkers is not None:
if benchmarkers:
check_value_type("benchmarkers", benchmarkers, list) check_value_type("benchmarkers", benchmarkers, list)
for bench in benchmarkers: for bench in benchmarkers:
if not isinstance(bench, AttributionMetric): if not isinstance(bench, AttributionMetric):
@@ -234,7 +238,7 @@ class ExplainRunner:
explain.metadata.label.extend(classes) explain.metadata.label.extend(classes)
exp_names = [exp.__class__.__name__ for exp in explainers] exp_names = [exp.__class__.__name__ for exp in explainers]
explain.metadata.explain_method.extend(exp_names) explain.metadata.explain_method.extend(exp_names)
if benchmarkers is not None:
if benchmarkers:
bench_names = [bench.__class__.__name__ for bench in benchmarkers] bench_names = [bench.__class__.__name__ for bench in benchmarkers]
explain.metadata.benchmark_method.extend(bench_names) explain.metadata.benchmark_method.extend(bench_names)


@@ -249,7 +253,7 @@ class ExplainRunner:
print(spacer.format("Finish running and writing inference data. " print(spacer.format("Finish running and writing inference data. "
"Time elapsed: {:.3f} s".format(time() - now))) "Time elapsed: {:.3f} s".format(time() - now)))


if benchmarkers is None or not benchmarkers:
if not benchmarkers:
for exp in explainers: for exp in explainers:
start = time() start = time()
print("Start running and writing explanation data for {}......".format(exp.__class__.__name__)) print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
@@ -322,9 +326,8 @@ class ExplainRunner:
raise ValueError("The third element of dataset should be bounding boxes with shape of " raise ValueError("The third element of dataset should be bounding boxes with shape of "
"[batch_size, num_ground_truth, 4].") "[batch_size, num_ground_truth, 4].")
else: else:
if benchmarkers is not None:
if True in [isinstance(bench, Localization) for bench in benchmarkers]:
raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
if any(map(lambda benchmarker: isinstance(benchmarker, Localization), benchmarkers)):
raise ValueError("The dataset must provide bboxes if Localization is to be computed.")


if len(next_element) == 2: if len(next_element) == 2:
inputs, labels = next_element inputs, labels = next_element


Loading…
Cancel
Save