diff --git a/mindspore/explainer/_runner.py b/mindspore/explainer/_runner.py index 639973a92d..1bc062a46d 100644 --- a/mindspore/explainer/_runner.py +++ b/mindspore/explainer/_runner.py @@ -161,13 +161,14 @@ class ExplainRunner: Examples: >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient >>> from mindspore.nn import Sigmoid + >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net >>> # obtain dataset object >>> dataset = get_dataset() >>> classes = ["cat", "dog", ...] >>> # load checkpoint to a network, e.g. resnet50 >>> param_dict = load_checkpoint("checkpoint.ckpt") >>> net = resnet50(len(classes)) - >>> load_parama_into_net(net, param_dict) + >>> load_param_into_net(net, param_dict) >>> gbp = GuidedBackprop(net) >>> gradient = Gradient(net) >>> runner = ExplainRunner("./") @@ -180,6 +181,9 @@ class ExplainRunner: raise ValueError("Argument `dataset` should be a tuple with length = 2.") dataset, classes = dataset + if benchmarkers is None: + benchmarkers = [] + self._verify_data_form(dataset, benchmarkers) self._classes = classes @@ -191,7 +195,7 @@ class ExplainRunner: if not isinstance(exp, Attribution): raise TypeError("Argument `explainers` should be a list of objects of classes in " "`mindspore.explainer.explanation`.") - if benchmarkers is not None: + if benchmarkers: check_value_type("benchmarkers", benchmarkers, list) for bench in benchmarkers: if not isinstance(bench, AttributionMetric): @@ -234,7 +238,7 @@ class ExplainRunner: explain.metadata.label.extend(classes) exp_names = [exp.__class__.__name__ for exp in explainers] explain.metadata.explain_method.extend(exp_names) - if benchmarkers is not None: + if benchmarkers: bench_names = [bench.__class__.__name__ for bench in benchmarkers] explain.metadata.benchmark_method.extend(bench_names) @@ -249,7 +253,7 @@ class ExplainRunner: print(spacer.format("Finish running and writing inference data. " "Time elapsed: {:.3f} s".format(time() - now))) - if benchmarkers is None or not benchmarkers: + if not benchmarkers: for exp in explainers: start = time() print("Start running and writing explanation data for {}......".format(exp.__class__.__name__)) @@ -322,9 +326,8 @@ class ExplainRunner: raise ValueError("The third element of dataset should be bounding boxes with shape of " "[batch_size, num_ground_truth, 4].") else: - if benchmarkers is not None: - if True in [isinstance(bench, Localization) for bench in benchmarkers]: - raise ValueError("The dataset must provide bboxes if Localization is to be computed.") + if any(map(lambda benchmarker: isinstance(benchmarker, Localization), benchmarkers)): + raise ValueError("The dataset must provide bboxes if Localization is to be computed.") if len(next_element) == 2: inputs, labels = next_element