diff --git a/mindspore/ccsrc/utils/summary.proto b/mindspore/ccsrc/utils/summary.proto index b3cb0c6795..3b325bd3dd 100644 --- a/mindspore/ccsrc/utils/summary.proto +++ b/mindspore/ccsrc/utils/summary.proto @@ -109,12 +109,18 @@ message Explain { repeated float ground_truth_prob = 1; repeated int32 predicted_label = 2; repeated float predicted_prob = 3; + repeated float ground_truth_prob_sd = 4; + repeated float ground_truth_prob_itl95_low = 5; + repeated float ground_truth_prob_itl95_hi = 6; + repeated float predicted_prob_sd = 7; + repeated float predicted_prob_itl95_low = 8; + repeated float predicted_prob_itl95_hi = 9; } message Explanation{ optional string explain_method = 1; optional int32 label = 2; - optional bytes heatmap = 3; + optional string heatmap_path = 3; } message Benchmark{ @@ -130,11 +136,10 @@ message Explain { repeated string benchmark_method = 3; } - optional string image_id = 1; // The Metadata and image id and benchmark must have one fill in - optional bytes image_data = 2; + optional int32 sample_id = 1; + optional string image_path = 2; // The Metadata and image path must have one fill in repeated int32 ground_truth_label = 3; - optional Inference inference = 4; repeated Explanation explanation = 5; repeated Benchmark benchmark = 6; diff --git a/mindspore/explainer/_runner.py b/mindspore/explainer/_runner.py index a3b4899e3b..51080a6ac1 100644 --- a/mindspore/explainer/_runner.py +++ b/mindspore/explainer/_runner.py @@ -13,10 +13,15 @@ # limitations under the License. # ============================================================================ """Runner.""" +import os +import re +import traceback from time import time from typing import Tuple, List, Optional import numpy as np +from scipy.stats import beta +from PIL import Image from mindspore.train._utils import check_value_type from mindspore.train.summary_pb2 import Explain @@ -24,34 +29,75 @@ import mindspore as ms import mindspore.dataset as ds from mindspore import log from mindspore.ops.operations import ExpandDims -from mindspore.train.summary._summary_adapter import _convert_image_format, _make_image +from mindspore.train.summary._summary_adapter import _convert_image_format from mindspore.train.summary.summary_record import SummaryRecord +from mindspore.nn.probability.toolbox import UncertaintyEvaluation from .benchmark import Localization from .benchmark._attribution.metric import AttributionMetric from .explanation._attribution._attribution import Attribution +# datafile directory names +_DATAFILE_DIRNAME_PREFIX = "_explain_" +_ORIGINAL_IMAGE_DIRNAME = "origin_images" +_HEATMAP_DIRNAME = "heatmap" +# max. no. of sample per directory +_SAMPLE_PER_DIR = 1000 + + _EXPAND_DIMS = ExpandDims() -_CMAP_0 = np.reshape(np.array([55, 25, 86, 255]), [1, 1, 4]) / 255 -_CMAP_1 = np.reshape(np.array([255, 255, 0, 255]), [1, 1, 4]) / 255 def _normalize(img_np): - """Normalize the image in the numpy array to be in [0, 255]. """ + """Normalize the numpy image to the range of [0, 1]. """ max_ = img_np.max() min_ = img_np.min() normed = (img_np - min_) / (max_ - min_).clip(min=1e-10) - return (normed * 255).astype(np.uint8) + return normed + + +def _np_to_image(img_np, mode): + """Convert numpy array to PIL image.""" + return Image.fromarray(np.uint8(img_np*255), mode=mode) + + +def _calc_prob_interval(volume, probs, prob_vars): + """Compute the confidence interval of probability.""" + if not isinstance(probs, np.ndarray): + probs = np.asarray(probs) + if not isinstance(prob_vars, np.ndarray): + prob_vars = np.asarray(prob_vars) + one_minus_probs = 1 - probs + alpha_coef = (np.square(probs) * one_minus_probs / prob_vars) - probs + beta_coef = alpha_coef * one_minus_probs / probs + intervals = beta.interval(volume, alpha_coef, beta_coef) + # avoid invalid result due to extreme small value of prob_vars + lows = [] + highs = [] + for i, low in enumerate(intervals[0]): + high = intervals[1][i] + if prob_vars[i] <= 0 or \ + not np.isfinite(low) or low > probs[i] or \ + not np.isfinite(high) or high < probs[i]: + low = probs[i] + high = probs[i] + lows.append(low) + highs.append(high) -def _make_rgba(saliency): - """Make rgba image for saliency map.""" - saliency = saliency.asnumpy().squeeze() - saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min()).clip(1e-10) - rgba = np.empty((saliency.shape[0], saliency.shape[1], 4)) - rgba[:, :, :] = np.expand_dims(saliency, 2) - rgba = rgba * _CMAP_1 + (1 - rgba) * _CMAP_0 - rgba[:, :, -1] = saliency * 1 - return rgba + return lows, highs + + +def _get_id_dirname(sample_id: int): + """Get the name of parent directory of the image id.""" + return str(int(sample_id/_SAMPLE_PER_DIR)*_SAMPLE_PER_DIR) + + +def _extract_timestamp(filename: str): + """Extract timestamp from summary filename.""" + matched = re.search(r"summary\.(\d+)", filename) + if matched: + return int(matched.group(1)) + return None class ExplainRunner: @@ -78,11 +124,14 @@ class ExplainRunner: self._count = 0 self._classes = None self._model = None + self._uncertainty = None + self._summary_timestamp = None def run(self, dataset: Tuple, explainers: List, - benchmarkers: Optional[List] = None): + benchmarkers: Optional[List] = None, + uncertainty: Optional[UncertaintyEvaluation] = None): """ Genereates results and writes results into the summary files in `summary_dir` specified during the object initialization. @@ -98,7 +147,8 @@ class ExplainRunner: `mindspore.explainer.explanation`. benchmarkers (list[Benchmark], optional): A list of benchmark objects to generate evaluation results. Default: None - + uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference + uncertainty of samples. Examples: >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient >>> # obtain dataset object @@ -145,15 +195,31 @@ class ExplainRunner: inputs, _, _ = self._unpack_next_element(next_element) prop_test = self._model(inputs) check_value_type("output of model im explainer", prop_test, ms.Tensor) - if prop_test.shape[1] > len(self._classes): - raise ValueError("The dimension of model output should not exceed the length of dataset classes. Please " + if prop_test.shape[1] != len(self._classes): + raise ValueError("The dimension of model output does not match the length of dataset classes. Please " "check dataset classes or the black-box model in the explainer again.") + if uncertainty is not None: + check_value_type("uncertainty", uncertainty, UncertaintyEvaluation) + prop_var_test = uncertainty.eval_epistemic_uncertainty(inputs) + check_value_type("output of uncertainty", prop_var_test, np.ndarray) + if prop_var_test.shape[1] != len(self._classes): + raise ValueError("The dimension of uncertainty output does not match the length of dataset classes" + "classes. Please check dataset classes or the black-box model in the explainer again.") + self._uncertainty = uncertainty + else: + self._uncertainty = None + with SummaryRecord(self._summary_dir) as summary: print("Start running and writing......") begin = time() print("Start writing metadata.") + self._summary_timestamp = _extract_timestamp(summary.event_file_name) + if self._summary_timestamp is None: + raise RuntimeError("Cannot extract timestamp from summary filename!" + " It should contains a timestamp of 10 digits.") + explain = Explain() explain.metadata.label.extend(classes) exp_names = [exp.__class__.__name__ for exp in explainers] @@ -341,7 +407,7 @@ class ExplainRunner: bboxes = None else: inputs = next_element[0] - labels = [[] for x in inputs] + labels = [[] for _ in inputs] bboxes = None inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox) return inputs, labels, bboxes @@ -359,7 +425,7 @@ class ExplainRunner: 2D Tensor. """ - max_len = max([len(l) for l in labels]) + max_len = max([len(label) for label in labels]) batch_labels = np.zeros((len(labels), max_len)) for idx, _ in enumerate(batch_labels): @@ -368,7 +434,7 @@ class ExplainRunner: return ms.Tensor(batch_labels, ms.int32) - def _run_inference(self, dataset, summary, threshod=0.5): + def _run_inference(self, dataset, summary, threshold=0.5): """ Run inference for the dataset and write the inference related data into summary. @@ -387,30 +453,64 @@ class ExplainRunner: now = time() inputs, labels, _ = self._unpack_next_element(next_element) prob = self._model(inputs).asnumpy() + if self._uncertainty is not None: + prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs) + prob_sd = np.sqrt(prob_var) + else: + prob_var = prob_sd = None + for idx, inp in enumerate(inputs): gt_labels = labels[idx] gt_probs = [float(prob[idx][i]) for i in gt_labels] data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW') - _, _, _, image_string = _make_image(_normalize(data_np)) + original_image = _np_to_image(_normalize(data_np), mode='RGB') + original_image_path = self._save_original_image(self._count, original_image) - predicted_labels = [int(i) for i in (prob[idx] > threshod).nonzero()[0]] + predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]] predicted_probs = [float(prob[idx][i]) for i in predicted_labels] + has_uncertainty = False + gt_prob_sds = gt_prob_itl95_lows = gt_prob_itl95_his = None + predicted_prob_sds = predicted_prob_itl95_lows = predicted_prob_itl95_his = None + if prob_var is not None: + gt_prob_sds = [float(prob_sd[idx][i]) for i in gt_labels] + predicted_prob_sds = [float(prob_sd[idx][i]) for i in predicted_labels] + try: + gt_prob_itl95_lows, gt_prob_itl95_his = \ + _calc_prob_interval(0.95, gt_probs, [float(prob_var[idx][i]) for i in gt_labels]) + predicted_prob_itl95_lows, predicted_prob_itl95_his = \ + _calc_prob_interval(0.95, predicted_probs, [float(prob_var[idx][i]) + for i in predicted_labels]) + has_uncertainty = True + except ValueError: + log.error(traceback.format_exc()) + log.error("Error on calculating uncertainty") + union_labs = list(set(gt_labels + predicted_labels)) imageid_labels[str(self._count)] = union_labs explain = Explain() - explain.image_id = str(self._count) - explain.image_data = image_string - summary.add_value("explainer", "image", explain) + explain.sample_id = self._count + explain.image_path = original_image_path + summary.add_value("explainer", "sample", explain) explain = Explain() - explain.image_id = str(self._count) + explain.sample_id = self._count explain.ground_truth_label.extend(gt_labels) explain.inference.ground_truth_prob.extend(gt_probs) explain.inference.predicted_label.extend(predicted_labels) explain.inference.predicted_prob.extend(predicted_probs) + + if has_uncertainty: + explain.inference.ground_truth_prob_sd.extend(gt_prob_sds) + explain.inference.ground_truth_prob_itl95_low.extend(gt_prob_itl95_lows) + explain.inference.ground_truth_prob_itl95_hi.extend(gt_prob_itl95_his) + + explain.inference.predicted_prob_sd.extend(predicted_prob_sds) + explain.inference.predicted_prob_itl95_low.extend(predicted_prob_itl95_lows) + explain.inference.predicted_prob_itl95_hi.extend(predicted_prob_itl95_his) + summary.add_value("explainer", "inference", explain) summary.record(1) @@ -451,20 +551,20 @@ class ExplainRunner: for idx, union in enumerate(unions): saliency_dict = {} explain = Explain() - explain.image_id = str(self._count) + explain.sample_id = self._count for k, lab in enumerate(union): saliency = batch_saliency_full[k][idx:idx + 1] saliency_dict[lab] = saliency - saliency_np = _make_rgba(saliency) - _, _, _, saliency_string = _make_image(_normalize(saliency_np)) + saliency_np = _normalize(saliency.asnumpy().squeeze()) + saliency_image = _np_to_image(saliency_np, mode='L') + heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab, self._count, saliency_image) explanation = explain.explanation.add() explanation.explain_method = explainer.__class__.__name__ - + explanation.heatmap_path = heatmap_path explanation.label = lab - explanation.heatmap = saliency_string summary.add_value("explainer", "explanation", explain) summary.record(1) @@ -496,3 +596,30 @@ class ExplainRunner: else: res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) benchmarker.aggregate(res, label) + + def _save_original_image(self, sample_id: int, image): + """Save an image to summary directory.""" + id_dirname = _get_id_dirname(sample_id) + relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp), + _ORIGINAL_IMAGE_DIRNAME, + id_dirname) + os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True) + relative_path = os.path.join(relative_dir, f"{sample_id}.jpg") + save_path = os.path.join(self._summary_dir, relative_path) + with open(save_path, "wb") as file: + image.save(file) + return relative_path + + def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image): + """Save heatmap image to summary directory.""" + id_dirname = _get_id_dirname(sample_id) + relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp), + _HEATMAP_DIRNAME, + explain_method, + id_dirname) + os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True) + relative_path = os.path.join(relative_dir, f"{sample_id}_{class_id}.jpg") + save_path = os.path.join(self._summary_dir, relative_path) + with open(save_path, "wb") as file: + image.save(file, optimize=True) + return relative_path diff --git a/mindspore/train/summary/_explain_adapter.py b/mindspore/train/summary/_explain_adapter.py index 156aae530a..8812f82a9a 100644 --- a/mindspore/train/summary/_explain_adapter.py +++ b/mindspore/train/summary/_explain_adapter.py @@ -28,8 +28,8 @@ def check_explain_proto(explain): if not isinstance(explain, Explain): raise TypeError(f'Plugin explainer expects a {Explain.__name__} value.') - if not explain.image_id and not explain.metadata.label and not explain.benchmark: - raise ValueError(f'The Metadata and image id and benchmark must have one fill in.') + if not explain.image_path and not explain.inference and not explain.metadata.label and not explain.benchmark: + raise ValueError('One of metadata, image path, inference or benchmark has to be filled in.') def package_explain_event(explain_str):