|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Runner."""
- from time import time
- from typing import Tuple, List, Optional
-
- import numpy as np
-
- from mindspore.train._utils import check_value_type
- from mindspore.train.summary_pb2 import Explain
- import mindspore as ms
- import mindspore.dataset as ds
- from mindspore import log
- from mindspore.ops.operations import ExpandDims
- from mindspore.train.summary._summary_adapter import _convert_image_format, _make_image
- from mindspore.train.summary.summary_record import SummaryRecord
- from .benchmark import Localization
- from .benchmark._attribution.metric import AttributionMetric
- from .explanation._attribution._attribution import Attribution
-
- _EXPAND_DIMS = ExpandDims()
- _CMAP_0 = np.reshape(np.array([55, 25, 86, 255]), [1, 1, 4]) / 255
- _CMAP_1 = np.reshape(np.array([255, 255, 0, 255]), [1, 1, 4]) / 255
-
-
- def _normalize(img_np):
- """Normalize the image in the numpy array to be in [0, 255]. """
- max_ = img_np.max()
- min_ = img_np.min()
- normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
- return (normed * 255).astype(np.uint8)
-
-
- def _make_rgba(saliency):
- """Make rgba image for saliency map."""
- saliency = saliency.asnumpy().squeeze()
- saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min()).clip(1e-10)
- rgba = np.empty((saliency.shape[0], saliency.shape[1], 4))
- rgba[:, :, :] = np.expand_dims(saliency, 2)
- rgba = rgba * _CMAP_1 + (1 - rgba) * _CMAP_0
- rgba[:, :, -1] = saliency * 1
- return rgba
-
-
- class ExplainRunner:
- """
- A high-level API for users to generate and store results of the explanation methods and the evaluation methods.
-
- After generating results with the explanation methods and the evaluation methods, the results will be written into
- a specified file with `mindspore.summary.SummaryRecord`. The stored content can be viewed using MindInsight.
-
- Args:
- summary_dir (str, optional): The directory path to save the summary files which store the generated results.
- Default: "./"
-
- Examples:
- >>> from mindspore.explainer import ExplainRunner
- >>> # init a runner with a specified directory
- >>> summary_dir = "summary_dir"
- >>> runner = ExplainRunner(summary_dir)
- """
-
- def __init__(self, summary_dir: Optional[str] = "./"):
- check_value_type("summary_dir", summary_dir, str)
- self._summary_dir = summary_dir
- self._count = 0
- self._classes = None
- self._model = None
-
- def run(self,
- dataset: Tuple,
- explainers: List,
- benchmarkers: Optional[List] = None):
- """
- Genereates results and writes results into the summary files in `summary_dir` specified during the object
- initialization.
-
- Args:
- dataset (tuple): A tuple that contains `mindspore.dataset` object for iteration and its labels.
-
- - dataset[0]: A `mindspore.dataset` object to provide data to explain.
- - dataset[1]: A list of string that specifies the label names of the dataset.
-
- explainers (list[Explanation]): A list of explanation objects to generate attribution results. Explanation
- object is an instance initialized with the explanation methods in module
- `mindspore.explainer.explanation`.
- benchmarkers (list[Benchmark], optional): A list of benchmark objects to generate evaluation results.
- Default: None
-
- Examples:
- >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
- >>> # obtain dataset object
- >>> dataset = get_dataset()
- >>> classes = ["cat", "dog", ...]
- >>> # load checkpoint to a network, e.g. resnet50
- >>> param_dict = load_checkpoint("checkpoint.ckpt")
- >>> net = resnet50(len(classes))
- >>> load_parama_into_net(net, param_dict)
- >>> # bind net with its output activation
- >>> model = nn.SequentialCell([net, nn.Sigmoid()])
- >>> gbp = GuidedBackprop(model)
- >>> gradient = Gradient(model)
- >>> runner = ExplainRunner("./")
- >>> explainers = [gbp, gradient]
- >>> runner.run((dataset, classes), explainers)
- """
-
- check_value_type("dataset", dataset, tuple)
- if len(dataset) != 2:
- raise ValueError("Argument `dataset` should be a tuple with length = 2.")
-
- dataset, classes = dataset
- self._verify_data_form(dataset, benchmarkers)
- self._classes = classes
-
- check_value_type("explainers", explainers, list)
- if not explainers:
- raise ValueError("Argument `explainers` must be a non-empty list")
-
- for exp in explainers:
- if not isinstance(exp, Attribution):
- raise TypeError("Argument explainers should be a list of objects of classes in "
- "`mindspore.explainer.explanation`.")
- if benchmarkers is not None:
- check_value_type("benchmarkers", benchmarkers, list)
- for bench in benchmarkers:
- if not isinstance(bench, AttributionMetric):
- raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation"
- "`mindspore.explainer.benchmark`.")
-
- self._model = explainers[0].model
- next_element = dataset.create_tuple_iterator().get_next()
- inputs, _, _ = self._unpack_next_element(next_element)
- prop_test = self._model(inputs)
- check_value_type("output of model im explainer", prop_test, ms.Tensor)
- if prop_test.shape[1] > len(self._classes):
- raise ValueError("The dimension of model output should not exceed the length of dataset classes. Please "
- "check dataset classes or the black-box model in the explainer again.")
-
- with SummaryRecord(self._summary_dir) as summary:
- print("Start running and writing......")
- begin = time()
- print("Start writing metadata.")
-
- explain = Explain()
- explain.metadata.label.extend(classes)
- exp_names = [exp.__class__.__name__ for exp in explainers]
- explain.metadata.explain_method.extend(exp_names)
- if benchmarkers is not None:
- bench_names = [bench.__class__.__name__ for bench in benchmarkers]
- explain.metadata.benchmark_method.extend(bench_names)
-
- summary.add_value("explainer", "metadata", explain)
- summary.record(1)
-
- print("Finish writing metadata.")
-
- now = time()
- print("Start running and writing inference data......")
- imageid_labels = self._run_inference(dataset, summary)
- print("Finish running and writing inference data. Time elapsed: {}s".format(time() - now))
-
- if benchmarkers is None or not benchmarkers:
- for exp in explainers:
- start = time()
- print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
- self._count = 0
- ds.config.set_seed(58)
- for idx, next_element in enumerate(dataset):
- now = time()
- self._run_exp_step(next_element, exp, imageid_labels, summary)
- print("Finish writing {}-th explanation data. Time elapsed: {}".format(
- idx, time() - now))
- print("Finish running and writing explanation data for {}. Time elapsed: {}".format(
- exp.__class__.__name__, time() - start))
- else:
- for exp in explainers:
- explain = Explain()
- for bench in benchmarkers:
- bench.reset()
- print(f"Start running and writing explanation and benchmark data for {exp.__class__.__name__}.")
- self._count = 0
- start = time()
- ds.config.set_seed(58)
- for idx, next_element in enumerate(dataset):
- now = time()
- saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary)
- print("Finish writing {}-th batch explanation data. Time elapsed: {}s".format(
- idx, time() - now))
- for bench in benchmarkers:
- now = time()
- self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
- print("Finish running {}-th batch benchmark data for {}. Time elapsed: {}s".format(
- idx, bench.__class__.__name__, time() - now))
-
- for bench in benchmarkers:
- benchmark = explain.benchmark.add()
- benchmark.explain_method = exp.__class__.__name__
- benchmark.benchmark_method = bench.__class__.__name__
-
- benchmark.total_score = bench.performance
- benchmark.label_score.extend(bench.class_performances)
-
- print("Finish running and writing explanation and benchmark data for {}. "
- "Time elapsed: {}s".format(exp.__class__.__name__, time() - start))
- summary.add_value('explainer', 'benchmark', explain)
- summary.record(1)
- print("Finish running and writing. Total time elapsed: {}s".format(time() - begin))
-
- @staticmethod
- def _verify_data_form(dataset, benchmarkers):
- """
- Verify the validity of dataset.
-
- Args:
- dataset (`ds`): the user parsed dataset.
- benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers.
- """
- next_element = dataset.create_tuple_iterator().get_next()
-
- if len(next_element) not in [1, 2, 3]:
- raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
- " as columns.")
-
- if len(next_element) == 3:
- inputs, labels, bboxes = next_element
- if bboxes.shape[-1] != 4:
- raise ValueError("The third element of dataset should be bounding boxes with shape of "
- "[batch_size, num_ground_truth, 4].")
- else:
- if True in [isinstance(bench, Localization) for bench in benchmarkers]:
- raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
-
- if len(next_element) == 2:
- inputs, labels = next_element
- if len(next_element) == 1:
- inputs = next_element[0]
-
- if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
- raise ValueError(
- "Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
- inputs.shape))
- if len(inputs.shape) == 3:
- log.warning(
- "Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
- " dimension as batch data.".format(inputs.shape))
-
- if len(next_element) > 1:
- if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
- raise ValueError(
- "Labels shape {} is unrecognizable: labels should not have more than two dimensions"
- " with length greater than 1.".format(labels.shape))
-
- def _transform_data(self, inputs, labels, bboxes, ifbbox):
- """
- Transform the data from one iteration of dataset to a unifying form for the follow-up operations.
-
- Args:
- inputs (Tensor): the image data
- labels (Tensor): the labels
- bboxes (Tensor): the boudnding boxes data
- ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t label
- id will be returned. If False, the returned bboxes is the the parsed bboxes.
-
- Returns:
- inputs (Tensor): the image data, unified to a 4D Tensor.
- labels (List[List[int]]): the ground truth labels.
- bboxes (Union[List[Dict], None, Tensor]): the bounding boxes
- """
- inputs = ms.Tensor(inputs, ms.float32)
- if len(inputs.shape) == 3:
- inputs = _EXPAND_DIMS(inputs, 0)
- if isinstance(labels, ms.Tensor):
- labels = ms.Tensor(labels, ms.int32)
- labels = _EXPAND_DIMS(labels, 0)
- if isinstance(bboxes, ms.Tensor):
- bboxes = ms.Tensor(bboxes, ms.int32)
- bboxes = _EXPAND_DIMS(bboxes, 0)
-
- input_len = len(inputs)
- if bboxes is not None and ifbbox:
- bboxes = ms.Tensor(bboxes, ms.int32)
- masks_lst = []
- labels = labels.asnumpy().reshape([input_len, -1])
- bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
- for idx, label in enumerate(labels):
- height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
- masks = {}
- for j, label_item in enumerate(label):
- target = int(label_item)
- if -1 < target < len(self._classes):
- if target not in masks:
- mask = np.zeros((1, 1, height, width))
- else:
- mask = masks[target]
- x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
- mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
- masks[target] = mask
-
- masks_lst.append(masks)
- bboxes = masks_lst
-
- labels = ms.Tensor(labels, ms.int32)
- if len(labels.shape) == 1:
- labels_lst = [[int(i)] for i in labels.asnumpy()]
- else:
- labels = labels.asnumpy().reshape([input_len, -1])
- labels_lst = []
- for item in labels:
- labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._classes))))
- labels = labels_lst
- return inputs, labels, bboxes
-
- def _unpack_next_element(self, next_element, ifbbox=False):
- """
- Unpack a single iteration of dataset.
-
- Args:
- next_element (Tuple): a single element iterated from dataset object.
- ifbbox (bool): whether to preprocess bboxes in self._transform_data.
-
- Returns:
- Tuple, a unified Tuple contains image_data, labels, and bounding boxes.
- """
- if len(next_element) == 3:
- inputs, labels, bboxes = next_element
- elif len(next_element) == 2:
- inputs, labels = next_element
- bboxes = None
- else:
- inputs = next_element[0]
- labels = [[] for x in inputs]
- bboxes = None
- inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
- return inputs, labels, bboxes
-
- @staticmethod
- def _make_label_batch(labels):
- """
- Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max
- length of all the rows in labels.
-
- Args:
- labels (List[List]): the union labels of a data batch.
-
- Returns:
- 2D Tensor.
- """
-
- max_len = max([len(l) for l in labels])
- batch_labels = np.zeros((len(labels), max_len))
-
- for idx, _ in enumerate(batch_labels):
- length = len(labels[idx])
- batch_labels[idx, :length] = np.array(labels[idx])
-
- return ms.Tensor(batch_labels, ms.int32)
-
- def _run_inference(self, dataset, summary, threshod=0.5):
- """
- Run inference for the dataset and write the inference related data into summary.
-
- Args:
- dataset (`ds`): the parsed dataset
- summary (`SummaryRecord`): the summary object to store the data
- threshold (float): the threshold for prediction.
-
- Returns:
- imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels.
- """
- imageid_labels = {}
- ds.config.set_seed(58)
- self._count = 0
- for j, next_element in enumerate(dataset):
- now = time()
- inputs, labels, _ = self._unpack_next_element(next_element)
- prob = self._model(inputs).asnumpy()
- for idx, inp in enumerate(inputs):
- gt_labels = labels[idx]
- gt_probs = [float(prob[idx][i]) for i in gt_labels]
-
- data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
- _, _, _, image_string = _make_image(_normalize(data_np))
-
- predicted_labels = [int(i) for i in (prob[idx] > threshod).nonzero()[0]]
- predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
-
- union_labs = list(set(gt_labels + predicted_labels))
- imageid_labels[str(self._count)] = union_labs
-
- explain = Explain()
- explain.image_id = str(self._count)
- explain.image_data = image_string
- summary.add_value("explainer", "image", explain)
-
- explain = Explain()
- explain.image_id = str(self._count)
- explain.ground_truth_label.extend(gt_labels)
- explain.inference.ground_truth_prob.extend(gt_probs)
- explain.inference.predicted_label.extend(predicted_labels)
- explain.inference.predicted_prob.extend(predicted_probs)
- summary.add_value("explainer", "inference", explain)
-
- summary.record(1)
-
- self._count += 1
- print("Finish running and writing {}-th batch inference data. Time elapsed: {}s".format(j, time() - now))
- return imageid_labels
-
- def _run_exp_step(self, next_element, explainer, imageid_labels, summary):
- """
- Run the explanation for each step and write explanation results into summary.
-
- Args:
- next_element (Tuple): data of one step
- explainer (_Attribution): an Attribution object to generate saliency maps.
- imageid_labels (dict): a dict that maps the image_id and its union labels.
- summary (SummaryRecord): the summary object to store the data
-
- Returns:
- List of dict that maps label to its corresponding saliency map.
- """
- inputs, labels, _ = self._unpack_next_element(next_element)
- count = self._count
- unions = []
- for _ in range(len(labels)):
- unions_labels = imageid_labels[str(count)]
- unions.append(unions_labels)
- count += 1
-
- batch_unions = self._make_label_batch(unions)
- saliency_dict_lst = []
-
- batch_saliency_full = []
- for i in range(len(batch_unions[0])):
- batch_saliency = explainer(inputs, batch_unions[:, i])
- batch_saliency_full.append(batch_saliency)
-
- for idx, union in enumerate(unions):
- saliency_dict = {}
- explain = Explain()
- explain.image_id = str(self._count)
- for k, lab in enumerate(union):
- saliency = batch_saliency_full[k][idx:idx + 1]
-
- saliency_dict[lab] = saliency
-
- saliency_np = _make_rgba(saliency)
- _, _, _, saliency_string = _make_image(_normalize(saliency_np))
-
- explanation = explain.explanation.add()
- explanation.explain_method = explainer.__class__.__name__
-
- explanation.label = lab
- explanation.heatmap = saliency_string
-
- summary.add_value("explainer", "explanation", explain)
- summary.record(1)
-
- self._count += 1
- saliency_dict_lst.append(saliency_dict)
- return saliency_dict_lst
-
- def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
- """
- Run the explanation and evaluation for each step and write explanation results into summary.
-
- Args:
- next_element (Tuple): Data of one step
- explainer (`_Attribution`): An Attribution object to generate saliency maps.
- imageid_labels (dict): A dict that maps the image_id and its union labels.
- """
- inputs, labels, _ = self._unpack_next_element(next_element)
- for idx, inp in enumerate(inputs):
- inp = _EXPAND_DIMS(inp, 0)
- saliency_dict = saliency_dict_lst[idx]
- for label, saliency in saliency_dict.items():
- if isinstance(benchmarker, Localization):
- _, _, bboxes = self._unpack_next_element(next_element, True)
- if label in labels[idx]:
- res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
- saliency=saliency)
- benchmarker.aggregate(res, label)
- else:
- res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
- benchmarker.aggregate(res, label)
|