Browse Source

!8692 Support uncertatiny, single channel heatmap, separated image datafiles

From: @ngtony
Reviewed-by: @jonyguo
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
d773b54fe7
3 changed files with 170 additions and 38 deletions
  1. +9
    -4
      mindspore/ccsrc/utils/summary.proto
  2. +159
    -32
      mindspore/explainer/_runner.py
  3. +2
    -2
      mindspore/train/summary/_explain_adapter.py

+ 9
- 4
mindspore/ccsrc/utils/summary.proto View File

@@ -109,12 +109,18 @@ message Explain {
repeated float ground_truth_prob = 1; repeated float ground_truth_prob = 1;
repeated int32 predicted_label = 2; repeated int32 predicted_label = 2;
repeated float predicted_prob = 3; repeated float predicted_prob = 3;
repeated float ground_truth_prob_sd = 4;
repeated float ground_truth_prob_itl95_low = 5;
repeated float ground_truth_prob_itl95_hi = 6;
repeated float predicted_prob_sd = 7;
repeated float predicted_prob_itl95_low = 8;
repeated float predicted_prob_itl95_hi = 9;
} }


message Explanation{ message Explanation{
optional string explain_method = 1; optional string explain_method = 1;
optional int32 label = 2; optional int32 label = 2;
optional bytes heatmap = 3;
optional string heatmap_path = 3;
} }


message Benchmark{ message Benchmark{
@@ -130,11 +136,10 @@ message Explain {
repeated string benchmark_method = 3; repeated string benchmark_method = 3;
} }


optional string image_id = 1; // The Metadata and image id and benchmark must have one fill in
optional bytes image_data = 2;
optional int32 sample_id = 1;
optional string image_path = 2; // The Metadata and image path must have one fill in
repeated int32 ground_truth_label = 3; repeated int32 ground_truth_label = 3;



optional Inference inference = 4; optional Inference inference = 4;
repeated Explanation explanation = 5; repeated Explanation explanation = 5;
repeated Benchmark benchmark = 6; repeated Benchmark benchmark = 6;


+ 159
- 32
mindspore/explainer/_runner.py View File

@@ -13,10 +13,15 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Runner.""" """Runner."""
import os
import re
import traceback
from time import time from time import time
from typing import Tuple, List, Optional from typing import Tuple, List, Optional


import numpy as np import numpy as np
from scipy.stats import beta
from PIL import Image


from mindspore.train._utils import check_value_type from mindspore.train._utils import check_value_type
from mindspore.train.summary_pb2 import Explain from mindspore.train.summary_pb2 import Explain
@@ -24,34 +29,75 @@ import mindspore as ms
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log from mindspore import log
from mindspore.ops.operations import ExpandDims from mindspore.ops.operations import ExpandDims
from mindspore.train.summary._summary_adapter import _convert_image_format, _make_image
from mindspore.train.summary._summary_adapter import _convert_image_format
from mindspore.train.summary.summary_record import SummaryRecord from mindspore.train.summary.summary_record import SummaryRecord
from mindspore.nn.probability.toolbox import UncertaintyEvaluation
from .benchmark import Localization from .benchmark import Localization
from .benchmark._attribution.metric import AttributionMetric from .benchmark._attribution.metric import AttributionMetric
from .explanation._attribution._attribution import Attribution from .explanation._attribution._attribution import Attribution


# datafile directory names
_DATAFILE_DIRNAME_PREFIX = "_explain_"
_ORIGINAL_IMAGE_DIRNAME = "origin_images"
_HEATMAP_DIRNAME = "heatmap"
# max. no. of sample per directory
_SAMPLE_PER_DIR = 1000


_EXPAND_DIMS = ExpandDims() _EXPAND_DIMS = ExpandDims()
_CMAP_0 = np.reshape(np.array([55, 25, 86, 255]), [1, 1, 4]) / 255
_CMAP_1 = np.reshape(np.array([255, 255, 0, 255]), [1, 1, 4]) / 255




def _normalize(img_np): def _normalize(img_np):
"""Normalize the image in the numpy array to be in [0, 255]. """
"""Normalize the numpy image to the range of [0, 1]. """
max_ = img_np.max() max_ = img_np.max()
min_ = img_np.min() min_ = img_np.min()
normed = (img_np - min_) / (max_ - min_).clip(min=1e-10) normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
return (normed * 255).astype(np.uint8)
return normed


def _np_to_image(img_np, mode):
"""Convert numpy array to PIL image."""
return Image.fromarray(np.uint8(img_np*255), mode=mode)


def _calc_prob_interval(volume, probs, prob_vars):
"""Compute the confidence interval of probability."""
if not isinstance(probs, np.ndarray):
probs = np.asarray(probs)
if not isinstance(prob_vars, np.ndarray):
prob_vars = np.asarray(prob_vars)
one_minus_probs = 1 - probs
alpha_coef = (np.square(probs) * one_minus_probs / prob_vars) - probs
beta_coef = alpha_coef * one_minus_probs / probs
intervals = beta.interval(volume, alpha_coef, beta_coef)


# avoid invalid result due to extreme small value of prob_vars
lows = []
highs = []
for i, low in enumerate(intervals[0]):
high = intervals[1][i]
if prob_vars[i] <= 0 or \
not np.isfinite(low) or low > probs[i] or \
not np.isfinite(high) or high < probs[i]:
low = probs[i]
high = probs[i]
lows.append(low)
highs.append(high)


def _make_rgba(saliency):
"""Make rgba image for saliency map."""
saliency = saliency.asnumpy().squeeze()
saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min()).clip(1e-10)
rgba = np.empty((saliency.shape[0], saliency.shape[1], 4))
rgba[:, :, :] = np.expand_dims(saliency, 2)
rgba = rgba * _CMAP_1 + (1 - rgba) * _CMAP_0
rgba[:, :, -1] = saliency * 1
return rgba
return lows, highs


def _get_id_dirname(sample_id: int):
"""Get the name of parent directory of the image id."""
return str(int(sample_id/_SAMPLE_PER_DIR)*_SAMPLE_PER_DIR)


def _extract_timestamp(filename: str):
"""Extract timestamp from summary filename."""
matched = re.search(r"summary\.(\d+)", filename)
if matched:
return int(matched.group(1))
return None




class ExplainRunner: class ExplainRunner:
@@ -78,11 +124,14 @@ class ExplainRunner:
self._count = 0 self._count = 0
self._classes = None self._classes = None
self._model = None self._model = None
self._uncertainty = None
self._summary_timestamp = None


def run(self, def run(self,
dataset: Tuple, dataset: Tuple,
explainers: List, explainers: List,
benchmarkers: Optional[List] = None):
benchmarkers: Optional[List] = None,
uncertainty: Optional[UncertaintyEvaluation] = None):
""" """
Genereates results and writes results into the summary files in `summary_dir` specified during the object Genereates results and writes results into the summary files in `summary_dir` specified during the object
initialization. initialization.
@@ -98,7 +147,8 @@ class ExplainRunner:
`mindspore.explainer.explanation`. `mindspore.explainer.explanation`.
benchmarkers (list[Benchmark], optional): A list of benchmark objects to generate evaluation results. benchmarkers (list[Benchmark], optional): A list of benchmark objects to generate evaluation results.
Default: None Default: None

uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference
uncertainty of samples.
Examples: Examples:
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
>>> # obtain dataset object >>> # obtain dataset object
@@ -145,15 +195,31 @@ class ExplainRunner:
inputs, _, _ = self._unpack_next_element(next_element) inputs, _, _ = self._unpack_next_element(next_element)
prop_test = self._model(inputs) prop_test = self._model(inputs)
check_value_type("output of model im explainer", prop_test, ms.Tensor) check_value_type("output of model im explainer", prop_test, ms.Tensor)
if prop_test.shape[1] > len(self._classes):
raise ValueError("The dimension of model output should not exceed the length of dataset classes. Please "
if prop_test.shape[1] != len(self._classes):
raise ValueError("The dimension of model output does not match the length of dataset classes. Please "
"check dataset classes or the black-box model in the explainer again.") "check dataset classes or the black-box model in the explainer again.")


if uncertainty is not None:
check_value_type("uncertainty", uncertainty, UncertaintyEvaluation)
prop_var_test = uncertainty.eval_epistemic_uncertainty(inputs)
check_value_type("output of uncertainty", prop_var_test, np.ndarray)
if prop_var_test.shape[1] != len(self._classes):
raise ValueError("The dimension of uncertainty output does not match the length of dataset classes"
"classes. Please check dataset classes or the black-box model in the explainer again.")
self._uncertainty = uncertainty
else:
self._uncertainty = None

with SummaryRecord(self._summary_dir) as summary: with SummaryRecord(self._summary_dir) as summary:
print("Start running and writing......") print("Start running and writing......")
begin = time() begin = time()
print("Start writing metadata.") print("Start writing metadata.")


self._summary_timestamp = _extract_timestamp(summary.event_file_name)
if self._summary_timestamp is None:
raise RuntimeError("Cannot extract timestamp from summary filename!"
" It should contains a timestamp of 10 digits.")

explain = Explain() explain = Explain()
explain.metadata.label.extend(classes) explain.metadata.label.extend(classes)
exp_names = [exp.__class__.__name__ for exp in explainers] exp_names = [exp.__class__.__name__ for exp in explainers]
@@ -341,7 +407,7 @@ class ExplainRunner:
bboxes = None bboxes = None
else: else:
inputs = next_element[0] inputs = next_element[0]
labels = [[] for x in inputs]
labels = [[] for _ in inputs]
bboxes = None bboxes = None
inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox) inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
return inputs, labels, bboxes return inputs, labels, bboxes
@@ -359,7 +425,7 @@ class ExplainRunner:
2D Tensor. 2D Tensor.
""" """


max_len = max([len(l) for l in labels])
max_len = max([len(label) for label in labels])
batch_labels = np.zeros((len(labels), max_len)) batch_labels = np.zeros((len(labels), max_len))


for idx, _ in enumerate(batch_labels): for idx, _ in enumerate(batch_labels):
@@ -368,7 +434,7 @@ class ExplainRunner:


return ms.Tensor(batch_labels, ms.int32) return ms.Tensor(batch_labels, ms.int32)


def _run_inference(self, dataset, summary, threshod=0.5):
def _run_inference(self, dataset, summary, threshold=0.5):
""" """
Run inference for the dataset and write the inference related data into summary. Run inference for the dataset and write the inference related data into summary.


@@ -387,30 +453,64 @@ class ExplainRunner:
now = time() now = time()
inputs, labels, _ = self._unpack_next_element(next_element) inputs, labels, _ = self._unpack_next_element(next_element)
prob = self._model(inputs).asnumpy() prob = self._model(inputs).asnumpy()
if self._uncertainty is not None:
prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
prob_sd = np.sqrt(prob_var)
else:
prob_var = prob_sd = None

for idx, inp in enumerate(inputs): for idx, inp in enumerate(inputs):
gt_labels = labels[idx] gt_labels = labels[idx]
gt_probs = [float(prob[idx][i]) for i in gt_labels] gt_probs = [float(prob[idx][i]) for i in gt_labels]


data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW') data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
_, _, _, image_string = _make_image(_normalize(data_np))
original_image = _np_to_image(_normalize(data_np), mode='RGB')
original_image_path = self._save_original_image(self._count, original_image)


predicted_labels = [int(i) for i in (prob[idx] > threshod).nonzero()[0]]
predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]]
predicted_probs = [float(prob[idx][i]) for i in predicted_labels] predicted_probs = [float(prob[idx][i]) for i in predicted_labels]


has_uncertainty = False
gt_prob_sds = gt_prob_itl95_lows = gt_prob_itl95_his = None
predicted_prob_sds = predicted_prob_itl95_lows = predicted_prob_itl95_his = None
if prob_var is not None:
gt_prob_sds = [float(prob_sd[idx][i]) for i in gt_labels]
predicted_prob_sds = [float(prob_sd[idx][i]) for i in predicted_labels]
try:
gt_prob_itl95_lows, gt_prob_itl95_his = \
_calc_prob_interval(0.95, gt_probs, [float(prob_var[idx][i]) for i in gt_labels])
predicted_prob_itl95_lows, predicted_prob_itl95_his = \
_calc_prob_interval(0.95, predicted_probs, [float(prob_var[idx][i])
for i in predicted_labels])
has_uncertainty = True
except ValueError:
log.error(traceback.format_exc())
log.error("Error on calculating uncertainty")

union_labs = list(set(gt_labels + predicted_labels)) union_labs = list(set(gt_labels + predicted_labels))
imageid_labels[str(self._count)] = union_labs imageid_labels[str(self._count)] = union_labs


explain = Explain() explain = Explain()
explain.image_id = str(self._count)
explain.image_data = image_string
summary.add_value("explainer", "image", explain)
explain.sample_id = self._count
explain.image_path = original_image_path
summary.add_value("explainer", "sample", explain)


explain = Explain() explain = Explain()
explain.image_id = str(self._count)
explain.sample_id = self._count
explain.ground_truth_label.extend(gt_labels) explain.ground_truth_label.extend(gt_labels)
explain.inference.ground_truth_prob.extend(gt_probs) explain.inference.ground_truth_prob.extend(gt_probs)
explain.inference.predicted_label.extend(predicted_labels) explain.inference.predicted_label.extend(predicted_labels)
explain.inference.predicted_prob.extend(predicted_probs) explain.inference.predicted_prob.extend(predicted_probs)

if has_uncertainty:
explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
explain.inference.ground_truth_prob_itl95_low.extend(gt_prob_itl95_lows)
explain.inference.ground_truth_prob_itl95_hi.extend(gt_prob_itl95_his)

explain.inference.predicted_prob_sd.extend(predicted_prob_sds)
explain.inference.predicted_prob_itl95_low.extend(predicted_prob_itl95_lows)
explain.inference.predicted_prob_itl95_hi.extend(predicted_prob_itl95_his)

summary.add_value("explainer", "inference", explain) summary.add_value("explainer", "inference", explain)


summary.record(1) summary.record(1)
@@ -451,20 +551,20 @@ class ExplainRunner:
for idx, union in enumerate(unions): for idx, union in enumerate(unions):
saliency_dict = {} saliency_dict = {}
explain = Explain() explain = Explain()
explain.image_id = str(self._count)
explain.sample_id = self._count
for k, lab in enumerate(union): for k, lab in enumerate(union):
saliency = batch_saliency_full[k][idx:idx + 1] saliency = batch_saliency_full[k][idx:idx + 1]


saliency_dict[lab] = saliency saliency_dict[lab] = saliency


saliency_np = _make_rgba(saliency)
_, _, _, saliency_string = _make_image(_normalize(saliency_np))
saliency_np = _normalize(saliency.asnumpy().squeeze())
saliency_image = _np_to_image(saliency_np, mode='L')
heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab, self._count, saliency_image)


explanation = explain.explanation.add() explanation = explain.explanation.add()
explanation.explain_method = explainer.__class__.__name__ explanation.explain_method = explainer.__class__.__name__
explanation.heatmap_path = heatmap_path
explanation.label = lab explanation.label = lab
explanation.heatmap = saliency_string


summary.add_value("explainer", "explanation", explain) summary.add_value("explainer", "explanation", explain)
summary.record(1) summary.record(1)
@@ -496,3 +596,30 @@ class ExplainRunner:
else: else:
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
benchmarker.aggregate(res, label) benchmarker.aggregate(res, label)

def _save_original_image(self, sample_id: int, image):
"""Save an image to summary directory."""
id_dirname = _get_id_dirname(sample_id)
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp),
_ORIGINAL_IMAGE_DIRNAME,
id_dirname)
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
relative_path = os.path.join(relative_dir, f"{sample_id}.jpg")
save_path = os.path.join(self._summary_dir, relative_path)
with open(save_path, "wb") as file:
image.save(file)
return relative_path

def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image):
"""Save heatmap image to summary directory."""
id_dirname = _get_id_dirname(sample_id)
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp),
_HEATMAP_DIRNAME,
explain_method,
id_dirname)
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
relative_path = os.path.join(relative_dir, f"{sample_id}_{class_id}.jpg")
save_path = os.path.join(self._summary_dir, relative_path)
with open(save_path, "wb") as file:
image.save(file, optimize=True)
return relative_path

+ 2
- 2
mindspore/train/summary/_explain_adapter.py View File

@@ -28,8 +28,8 @@ def check_explain_proto(explain):
if not isinstance(explain, Explain): if not isinstance(explain, Explain):
raise TypeError(f'Plugin explainer expects a {Explain.__name__} value.') raise TypeError(f'Plugin explainer expects a {Explain.__name__} value.')


if not explain.image_id and not explain.metadata.label and not explain.benchmark:
raise ValueError(f'The Metadata and image id and benchmark must have one fill in.')
if not explain.image_path and not explain.inference and not explain.metadata.label and not explain.benchmark:
raise ValueError('One of metadata, image path, inference or benchmark has to be filled in.')




def package_explain_event(explain_str): def package_explain_event(explain_str):


Loading…
Cancel
Save