| @@ -161,13 +161,14 @@ class ExplainRunner: | |||
| Examples: | |||
| >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient | |||
| >>> from mindspore.nn import Sigmoid | |||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| >>> # obtain dataset object | |||
| >>> dataset = get_dataset() | |||
| >>> classes = ["cat", "dog", ...] | |||
| >>> # load checkpoint to a network, e.g. resnet50 | |||
| >>> param_dict = load_checkpoint("checkpoint.ckpt") | |||
| >>> net = resnet50(len(classes)) | |||
| >>> load_parama_into_net(net, param_dict) | |||
| >>> load_param_into_net(net, param_dict) | |||
| >>> gbp = GuidedBackprop(net) | |||
| >>> gradient = Gradient(net) | |||
| >>> runner = ExplainRunner("./") | |||
| @@ -180,6 +181,9 @@ class ExplainRunner: | |||
| raise ValueError("Argument `dataset` should be a tuple with length = 2.") | |||
| dataset, classes = dataset | |||
| if benchmarkers is None: | |||
| benchmarkers = [] | |||
| self._verify_data_form(dataset, benchmarkers) | |||
| self._classes = classes | |||
| @@ -191,7 +195,7 @@ class ExplainRunner: | |||
| if not isinstance(exp, Attribution): | |||
| raise TypeError("Argument `explainers` should be a list of objects of classes in " | |||
| "`mindspore.explainer.explanation`.") | |||
| if benchmarkers is not None: | |||
| if benchmarkers: | |||
| check_value_type("benchmarkers", benchmarkers, list) | |||
| for bench in benchmarkers: | |||
| if not isinstance(bench, AttributionMetric): | |||
| @@ -234,7 +238,7 @@ class ExplainRunner: | |||
| explain.metadata.label.extend(classes) | |||
| exp_names = [exp.__class__.__name__ for exp in explainers] | |||
| explain.metadata.explain_method.extend(exp_names) | |||
| if benchmarkers is not None: | |||
| if benchmarkers: | |||
| bench_names = [bench.__class__.__name__ for bench in benchmarkers] | |||
| explain.metadata.benchmark_method.extend(bench_names) | |||
| @@ -249,7 +253,7 @@ class ExplainRunner: | |||
| print(spacer.format("Finish running and writing inference data. " | |||
| "Time elapsed: {:.3f} s".format(time() - now))) | |||
| if benchmarkers is None or not benchmarkers: | |||
| if not benchmarkers: | |||
| for exp in explainers: | |||
| start = time() | |||
| print("Start running and writing explanation data for {}......".format(exp.__class__.__name__)) | |||
| @@ -322,9 +326,8 @@ class ExplainRunner: | |||
| raise ValueError("The third element of dataset should be bounding boxes with shape of " | |||
| "[batch_size, num_ground_truth, 4].") | |||
| else: | |||
| if benchmarkers is not None: | |||
| if True in [isinstance(bench, Localization) for bench in benchmarkers]: | |||
| raise ValueError("The dataset must provide bboxes if Localization is to be computed.") | |||
| if any(map(lambda benchmarker: isinstance(benchmarker, Localization), benchmarkers)): | |||
| raise ValueError("The dataset must provide bboxes if Localization is to be computed.") | |||
| if len(next_element) == 2: | |||
| inputs, labels = next_element | |||