|
|
@@ -161,13 +161,14 @@ class ExplainRunner: |
|
|
Examples: |
|
|
Examples: |
|
|
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient |
|
|
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient |
|
|
>>> from mindspore.nn import Sigmoid |
|
|
>>> from mindspore.nn import Sigmoid |
|
|
|
|
|
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
>>> # obtain dataset object |
|
|
>>> # obtain dataset object |
|
|
>>> dataset = get_dataset() |
|
|
>>> dataset = get_dataset() |
|
|
>>> classes = ["cat", "dog", ...] |
|
|
>>> classes = ["cat", "dog", ...] |
|
|
>>> # load checkpoint to a network, e.g. resnet50 |
|
|
>>> # load checkpoint to a network, e.g. resnet50 |
|
|
>>> param_dict = load_checkpoint("checkpoint.ckpt") |
|
|
>>> param_dict = load_checkpoint("checkpoint.ckpt") |
|
|
>>> net = resnet50(len(classes)) |
|
|
>>> net = resnet50(len(classes)) |
|
|
>>> load_parama_into_net(net, param_dict) |
|
|
|
|
|
|
|
|
>>> load_param_into_net(net, param_dict) |
|
|
>>> gbp = GuidedBackprop(net) |
|
|
>>> gbp = GuidedBackprop(net) |
|
|
>>> gradient = Gradient(net) |
|
|
>>> gradient = Gradient(net) |
|
|
>>> runner = ExplainRunner("./") |
|
|
>>> runner = ExplainRunner("./") |
|
|
@@ -180,6 +181,9 @@ class ExplainRunner: |
|
|
raise ValueError("Argument `dataset` should be a tuple with length = 2.") |
|
|
raise ValueError("Argument `dataset` should be a tuple with length = 2.") |
|
|
|
|
|
|
|
|
dataset, classes = dataset |
|
|
dataset, classes = dataset |
|
|
|
|
|
if benchmarkers is None: |
|
|
|
|
|
benchmarkers = [] |
|
|
|
|
|
|
|
|
self._verify_data_form(dataset, benchmarkers) |
|
|
self._verify_data_form(dataset, benchmarkers) |
|
|
self._classes = classes |
|
|
self._classes = classes |
|
|
|
|
|
|
|
|
@@ -191,7 +195,7 @@ class ExplainRunner: |
|
|
if not isinstance(exp, Attribution): |
|
|
if not isinstance(exp, Attribution): |
|
|
raise TypeError("Argument `explainers` should be a list of objects of classes in " |
|
|
raise TypeError("Argument `explainers` should be a list of objects of classes in " |
|
|
"`mindspore.explainer.explanation`.") |
|
|
"`mindspore.explainer.explanation`.") |
|
|
if benchmarkers is not None: |
|
|
|
|
|
|
|
|
if benchmarkers: |
|
|
check_value_type("benchmarkers", benchmarkers, list) |
|
|
check_value_type("benchmarkers", benchmarkers, list) |
|
|
for bench in benchmarkers: |
|
|
for bench in benchmarkers: |
|
|
if not isinstance(bench, AttributionMetric): |
|
|
if not isinstance(bench, AttributionMetric): |
|
|
@@ -234,7 +238,7 @@ class ExplainRunner: |
|
|
explain.metadata.label.extend(classes) |
|
|
explain.metadata.label.extend(classes) |
|
|
exp_names = [exp.__class__.__name__ for exp in explainers] |
|
|
exp_names = [exp.__class__.__name__ for exp in explainers] |
|
|
explain.metadata.explain_method.extend(exp_names) |
|
|
explain.metadata.explain_method.extend(exp_names) |
|
|
if benchmarkers is not None: |
|
|
|
|
|
|
|
|
if benchmarkers: |
|
|
bench_names = [bench.__class__.__name__ for bench in benchmarkers] |
|
|
bench_names = [bench.__class__.__name__ for bench in benchmarkers] |
|
|
explain.metadata.benchmark_method.extend(bench_names) |
|
|
explain.metadata.benchmark_method.extend(bench_names) |
|
|
|
|
|
|
|
|
@@ -249,7 +253,7 @@ class ExplainRunner: |
|
|
print(spacer.format("Finish running and writing inference data. " |
|
|
print(spacer.format("Finish running and writing inference data. " |
|
|
"Time elapsed: {:.3f} s".format(time() - now))) |
|
|
"Time elapsed: {:.3f} s".format(time() - now))) |
|
|
|
|
|
|
|
|
if benchmarkers is None or not benchmarkers: |
|
|
|
|
|
|
|
|
if not benchmarkers: |
|
|
for exp in explainers: |
|
|
for exp in explainers: |
|
|
start = time() |
|
|
start = time() |
|
|
print("Start running and writing explanation data for {}......".format(exp.__class__.__name__)) |
|
|
print("Start running and writing explanation data for {}......".format(exp.__class__.__name__)) |
|
|
@@ -322,9 +326,8 @@ class ExplainRunner: |
|
|
raise ValueError("The third element of dataset should be bounding boxes with shape of " |
|
|
raise ValueError("The third element of dataset should be bounding boxes with shape of " |
|
|
"[batch_size, num_ground_truth, 4].") |
|
|
"[batch_size, num_ground_truth, 4].") |
|
|
else: |
|
|
else: |
|
|
if benchmarkers is not None: |
|
|
|
|
|
if True in [isinstance(bench, Localization) for bench in benchmarkers]: |
|
|
|
|
|
raise ValueError("The dataset must provide bboxes if Localization is to be computed.") |
|
|
|
|
|
|
|
|
if any(map(lambda benchmarker: isinstance(benchmarker, Localization), benchmarkers)): |
|
|
|
|
|
raise ValueError("The dataset must provide bboxes if Localization is to be computed.") |
|
|
|
|
|
|
|
|
if len(next_element) == 2: |
|
|
if len(next_element) == 2: |
|
|
inputs, labels = next_element |
|
|
inputs, labels = next_element |
|
|
|