From 1632d44711507638b51f494b2947939144b4a633 Mon Sep 17 00:00:00 2001 From: lixiaohui Date: Thu, 29 Oct 2020 15:56:07 +0800 Subject: [PATCH] fix benchmark=[] bug --- mindspore/explainer/_runner.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mindspore/explainer/_runner.py b/mindspore/explainer/_runner.py index 6e715dcb6a..4defc8364c 100644 --- a/mindspore/explainer/_runner.py +++ b/mindspore/explainer/_runner.py @@ -110,8 +110,7 @@ class ExplainRunner: >>> runner.run((dataset, classes), explainers) """ - if not isinstance(dataset, tuple): - raise TypeError("Argument `dataset` must be a tuple.") + check_value_type("dataset", dataset, tuple) if len(dataset) != 2: raise ValueError("Argument `dataset` should be a tuple with length = 2.") @@ -119,16 +118,18 @@ class ExplainRunner: self._verify_data_form(dataset, benchmarkers) self._classes = classes - if explainers is None or not explainers: - raise ValueError("Argument `explainers` can neither be None nor empty.") + check_value_type("explainers", explainers, list) + if not explainers: + raise ValueError("Argument `explainers` must be a non-empty list") for exp in explainers: - if not isinstance(exp, Attribution) or not isinstance(explainers, list): + if not isinstance(exp, Attribution): raise TypeError("Argument explainers should be a list of objects of classes in " "`mindspore.explainer.explanation`.") if benchmarkers is not None: + check_value_type("benchmarkers", benchmarkers, list) for bench in benchmarkers: - if not isinstance(bench, AttributionMetric) or not isinstance(explainers, list): + if not isinstance(bench, AttributionMetric): raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation" "`mindspore.explainer.benchmark`.") @@ -164,7 +165,7 @@ class ExplainRunner: imageid_labels = self._run_inference(dataset, summary) print("Finish running and writing inference data. Time elapsed: {}s".format(time() - now)) - if benchmarkers is None: + if benchmarkers is None or not benchmarkers: for exp in explainers: start = time() print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))