| @@ -161,13 +161,14 @@ class ExplainRunner: | |||||
| Examples: | Examples: | ||||
| >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient | >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient | ||||
| >>> from mindspore.nn import Sigmoid | >>> from mindspore.nn import Sigmoid | ||||
| >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| >>> # obtain dataset object | >>> # obtain dataset object | ||||
| >>> dataset = get_dataset() | >>> dataset = get_dataset() | ||||
| >>> classes = ["cat", "dog", ...] | >>> classes = ["cat", "dog", ...] | ||||
| >>> # load checkpoint to a network, e.g. resnet50 | >>> # load checkpoint to a network, e.g. resnet50 | ||||
| >>> param_dict = load_checkpoint("checkpoint.ckpt") | >>> param_dict = load_checkpoint("checkpoint.ckpt") | ||||
| >>> net = resnet50(len(classes)) | >>> net = resnet50(len(classes)) | ||||
| >>> load_parama_into_net(net, param_dict) | |||||
| >>> load_param_into_net(net, param_dict) | |||||
| >>> gbp = GuidedBackprop(net) | >>> gbp = GuidedBackprop(net) | ||||
| >>> gradient = Gradient(net) | >>> gradient = Gradient(net) | ||||
| >>> runner = ExplainRunner("./") | >>> runner = ExplainRunner("./") | ||||
| @@ -180,6 +181,9 @@ class ExplainRunner: | |||||
| raise ValueError("Argument `dataset` should be a tuple with length = 2.") | raise ValueError("Argument `dataset` should be a tuple with length = 2.") | ||||
| dataset, classes = dataset | dataset, classes = dataset | ||||
| if benchmarkers is None: | |||||
| benchmarkers = [] | |||||
| self._verify_data_form(dataset, benchmarkers) | self._verify_data_form(dataset, benchmarkers) | ||||
| self._classes = classes | self._classes = classes | ||||
| @@ -191,7 +195,7 @@ class ExplainRunner: | |||||
| if not isinstance(exp, Attribution): | if not isinstance(exp, Attribution): | ||||
| raise TypeError("Argument `explainers` should be a list of objects of classes in " | raise TypeError("Argument `explainers` should be a list of objects of classes in " | ||||
| "`mindspore.explainer.explanation`.") | "`mindspore.explainer.explanation`.") | ||||
| if benchmarkers is not None: | |||||
| if benchmarkers: | |||||
| check_value_type("benchmarkers", benchmarkers, list) | check_value_type("benchmarkers", benchmarkers, list) | ||||
| for bench in benchmarkers: | for bench in benchmarkers: | ||||
| if not isinstance(bench, AttributionMetric): | if not isinstance(bench, AttributionMetric): | ||||
| @@ -234,7 +238,7 @@ class ExplainRunner: | |||||
| explain.metadata.label.extend(classes) | explain.metadata.label.extend(classes) | ||||
| exp_names = [exp.__class__.__name__ for exp in explainers] | exp_names = [exp.__class__.__name__ for exp in explainers] | ||||
| explain.metadata.explain_method.extend(exp_names) | explain.metadata.explain_method.extend(exp_names) | ||||
| if benchmarkers is not None: | |||||
| if benchmarkers: | |||||
| bench_names = [bench.__class__.__name__ for bench in benchmarkers] | bench_names = [bench.__class__.__name__ for bench in benchmarkers] | ||||
| explain.metadata.benchmark_method.extend(bench_names) | explain.metadata.benchmark_method.extend(bench_names) | ||||
| @@ -249,7 +253,7 @@ class ExplainRunner: | |||||
| print(spacer.format("Finish running and writing inference data. " | print(spacer.format("Finish running and writing inference data. " | ||||
| "Time elapsed: {:.3f} s".format(time() - now))) | "Time elapsed: {:.3f} s".format(time() - now))) | ||||
| if benchmarkers is None or not benchmarkers: | |||||
| if not benchmarkers: | |||||
| for exp in explainers: | for exp in explainers: | ||||
| start = time() | start = time() | ||||
| print("Start running and writing explanation data for {}......".format(exp.__class__.__name__)) | print("Start running and writing explanation data for {}......".format(exp.__class__.__name__)) | ||||
| @@ -322,9 +326,8 @@ class ExplainRunner: | |||||
| raise ValueError("The third element of dataset should be bounding boxes with shape of " | raise ValueError("The third element of dataset should be bounding boxes with shape of " | ||||
| "[batch_size, num_ground_truth, 4].") | "[batch_size, num_ground_truth, 4].") | ||||
| else: | else: | ||||
| if benchmarkers is not None: | |||||
| if True in [isinstance(bench, Localization) for bench in benchmarkers]: | |||||
| raise ValueError("The dataset must provide bboxes if Localization is to be computed.") | |||||
| if any(map(lambda benchmarker: isinstance(benchmarker, Localization), benchmarkers)): | |||||
| raise ValueError("The dataset must provide bboxes if Localization is to be computed.") | |||||
| if len(next_element) == 2: | if len(next_element) == 2: | ||||
| inputs, labels = next_element | inputs, labels = next_element | ||||