Browse Source

correct mistakes in docstring and fix bugs when no benchmark is passed.

tags/v1.1.0
YuhanShi53 5 years ago
parent
commit
c24fabcd8e
1 changed files with 10 additions and 7 deletions
  1. +10
    -7
      mindspore/explainer/_runner.py

+ 10
- 7
mindspore/explainer/_runner.py View File

@@ -161,13 +161,14 @@ class ExplainRunner:
Examples:
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
>>> from mindspore.nn import Sigmoid
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
>>> # obtain dataset object
>>> dataset = get_dataset()
>>> classes = ["cat", "dog", ...]
>>> # load checkpoint to a network, e.g. resnet50
>>> param_dict = load_checkpoint("checkpoint.ckpt")
>>> net = resnet50(len(classes))
>>> load_parama_into_net(net, param_dict)
>>> load_param_into_net(net, param_dict)
>>> gbp = GuidedBackprop(net)
>>> gradient = Gradient(net)
>>> runner = ExplainRunner("./")
@@ -180,6 +181,9 @@ class ExplainRunner:
raise ValueError("Argument `dataset` should be a tuple with length = 2.")

dataset, classes = dataset
if benchmarkers is None:
benchmarkers = []

self._verify_data_form(dataset, benchmarkers)
self._classes = classes

@@ -191,7 +195,7 @@ class ExplainRunner:
if not isinstance(exp, Attribution):
raise TypeError("Argument `explainers` should be a list of objects of classes in "
"`mindspore.explainer.explanation`.")
if benchmarkers is not None:
if benchmarkers:
check_value_type("benchmarkers", benchmarkers, list)
for bench in benchmarkers:
if not isinstance(bench, AttributionMetric):
@@ -234,7 +238,7 @@ class ExplainRunner:
explain.metadata.label.extend(classes)
exp_names = [exp.__class__.__name__ for exp in explainers]
explain.metadata.explain_method.extend(exp_names)
if benchmarkers is not None:
if benchmarkers:
bench_names = [bench.__class__.__name__ for bench in benchmarkers]
explain.metadata.benchmark_method.extend(bench_names)

@@ -249,7 +253,7 @@ class ExplainRunner:
print(spacer.format("Finish running and writing inference data. "
"Time elapsed: {:.3f} s".format(time() - now)))

if benchmarkers is None or not benchmarkers:
if not benchmarkers:
for exp in explainers:
start = time()
print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
@@ -322,9 +326,8 @@ class ExplainRunner:
raise ValueError("The third element of dataset should be bounding boxes with shape of "
"[batch_size, num_ground_truth, 4].")
else:
if benchmarkers is not None:
if True in [isinstance(bench, Localization) for bench in benchmarkers]:
raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
if any(map(lambda benchmarker: isinstance(benchmarker, Localization), benchmarkers)):
raise ValueError("The dataset must provide bboxes if Localization is to be computed.")

if len(next_element) == 2:
inputs, labels = next_element


Loading…
Cancel
Save