@@ -13,10 +13,15 @@
# limitations under the License.
# ============================================================================
"""Runner."""
import os
import re
import traceback
from time import time
from typing import Tuple, List, Optional
import numpy as np
from scipy.stats import beta
from PIL import Image
from mindspore.train._utils import check_value_type
from mindspore.train.summary_pb2 import Explain
@@ -24,34 +29,75 @@ import mindspore as ms
import mindspore.dataset as ds
from mindspore import log
from mindspore.ops.operations import ExpandDims
from mindspore.train.summary._summary_adapter import _convert_image_format, _make_image
from mindspore.train.summary._summary_adapter import _convert_image_format
from mindspore.train.summary.summary_record import SummaryRecord
from mindspore.nn.probability.toolbox import UncertaintyEvaluation
from .benchmark import Localization
from .benchmark._attribution.metric import AttributionMetric
from .explanation._attribution._attribution import Attribution
# datafile directory names
_DATAFILE_DIRNAME_PREFIX = "_explain_"
_ORIGINAL_IMAGE_DIRNAME = "origin_images"
_HEATMAP_DIRNAME = "heatmap"
# max. no. of sample per directory
_SAMPLE_PER_DIR = 1000
_EXPAND_DIMS = ExpandDims()
_CMAP_0 = np.reshape(np.array([55, 25, 86, 255]), [1, 1, 4]) / 255
_CMAP_1 = np.reshape(np.array([255, 255, 0, 255]), [1, 1, 4]) / 255
def _normalize(img_np):
"""Normalize the image in the numpy array to be in [0, 255]. """
"""Normalize the numpy image to the range of [0, 1 ]. """
max_ = img_np.max()
min_ = img_np.min()
normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
return (normed * 255).astype(np.uint8)
return normed
def _np_to_image(img_np, mode):
"""Convert numpy array to PIL image."""
return Image.fromarray(np.uint8(img_np*255), mode=mode)
def _calc_prob_interval(volume, probs, prob_vars):
"""Compute the confidence interval of probability."""
if not isinstance(probs, np.ndarray):
probs = np.asarray(probs)
if not isinstance(prob_vars, np.ndarray):
prob_vars = np.asarray(prob_vars)
one_minus_probs = 1 - probs
alpha_coef = (np.square(probs) * one_minus_probs / prob_vars) - probs
beta_coef = alpha_coef * one_minus_probs / probs
intervals = beta.interval(volume, alpha_coef, beta_coef)
# avoid invalid result due to extreme small value of prob_vars
lows = []
highs = []
for i, low in enumerate(intervals[0]):
high = intervals[1][i]
if prob_vars[i] <= 0 or \
not np.isfinite(low) or low > probs[i] or \
not np.isfinite(high) or high < probs[i]:
low = probs[i]
high = probs[i]
lows.append(low)
highs.append(high)
def _make_rgba(saliency):
"""Make rgba image for saliency map."""
saliency = saliency.asnumpy().squeeze()
saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min()).clip(1e-10)
rgba = np.empty((saliency.shape[0], saliency.shape[1], 4))
rgba[:, :, :] = np.expand_dims(saliency, 2)
rgba = rgba * _CMAP_1 + (1 - rgba) * _CMAP_0
rgba[:, :, -1] = saliency * 1
return rgba
return lows, highs
def _get_id_dirname(sample_id: int):
"""Get the name of parent directory of the image id."""
return str(int(sample_id/_SAMPLE_PER_DIR)*_SAMPLE_PER_DIR)
def _extract_timestamp(filename: str):
"""Extract timestamp from summary filename."""
matched = re.search(r"summary\.(\d+)", filename)
if matched:
return int(matched.group(1))
return None
class ExplainRunner:
@@ -78,11 +124,14 @@ class ExplainRunner:
self._count = 0
self._classes = None
self._model = None
self._uncertainty = None
self._summary_timestamp = None
def run(self,
dataset: Tuple,
explainers: List,
benchmarkers: Optional[List] = None):
benchmarkers: Optional[List] = None,
uncertainty: Optional[UncertaintyEvaluation] = None):
"""
Genereates results and writes results into the summary files in `summary_dir` specified during the object
initialization.
@@ -98,7 +147,8 @@ class ExplainRunner:
`mindspore.explainer.explanation`.
benchmarkers (list[Benchmark], optional): A list of benchmark objects to generate evaluation results.
Default: None
uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference
uncertainty of samples.
Examples:
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
>>> # obtain dataset object
@@ -145,15 +195,31 @@ class ExplainRunner:
inputs, _, _ = self._unpack_next_element(next_element)
prop_test = self._model(inputs)
check_value_type("output of model im explainer", prop_test, ms.Tensor)
if prop_test.shape[1] > len(self._classes):
raise ValueError("The dimension of model output should not exceed the length of dataset classes. Please "
if prop_test.shape[1] != len(self._classes):
raise ValueError("The dimension of model output does not match the length of dataset classes. Please "
"check dataset classes or the black-box model in the explainer again.")
if uncertainty is not None:
check_value_type("uncertainty", uncertainty, UncertaintyEvaluation)
prop_var_test = uncertainty.eval_epistemic_uncertainty(inputs)
check_value_type("output of uncertainty", prop_var_test, np.ndarray)
if prop_var_test.shape[1] != len(self._classes):
raise ValueError("The dimension of uncertainty output does not match the length of dataset classes"
"classes. Please check dataset classes or the black-box model in the explainer again.")
self._uncertainty = uncertainty
else:
self._uncertainty = None
with SummaryRecord(self._summary_dir) as summary:
print("Start running and writing......")
begin = time()
print("Start writing metadata.")
self._summary_timestamp = _extract_timestamp(summary.event_file_name)
if self._summary_timestamp is None:
raise RuntimeError("Cannot extract timestamp from summary filename!"
" It should contains a timestamp of 10 digits.")
explain = Explain()
explain.metadata.label.extend(classes)
exp_names = [exp.__class__.__name__ for exp in explainers]
@@ -341,7 +407,7 @@ class ExplainRunner:
bboxes = None
else:
inputs = next_element[0]
labels = [[] for x in inputs]
labels = [[] for _ in inputs]
bboxes = None
inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
return inputs, labels, bboxes
@@ -359,7 +425,7 @@ class ExplainRunner:
2D Tensor.
"""
max_len = max([len(l) for l in labels])
max_len = max([len(label ) for labe l in labels])
batch_labels = np.zeros((len(labels), max_len))
for idx, _ in enumerate(batch_labels):
@@ -368,7 +434,7 @@ class ExplainRunner:
return ms.Tensor(batch_labels, ms.int32)
def _run_inference(self, dataset, summary, threshod=0.5):
def _run_inference(self, dataset, summary, threshol d=0.5):
"""
Run inference for the dataset and write the inference related data into summary.
@@ -387,30 +453,64 @@ class ExplainRunner:
now = time()
inputs, labels, _ = self._unpack_next_element(next_element)
prob = self._model(inputs).asnumpy()
if self._uncertainty is not None:
prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
prob_sd = np.sqrt(prob_var)
else:
prob_var = prob_sd = None
for idx, inp in enumerate(inputs):
gt_labels = labels[idx]
gt_probs = [float(prob[idx][i]) for i in gt_labels]
data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
_, _, _, image_string = _make_image(_normalize(data_np))
original_image = _np_to_image(_normalize(data_np), mode='RGB')
original_image_path = self._save_original_image(self._count, original_image)
predicted_labels = [int(i) for i in (prob[idx] > threshod).nonzero()[0]]
predicted_labels = [int(i) for i in (prob[idx] > threshol d).nonzero()[0]]
predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
has_uncertainty = False
gt_prob_sds = gt_prob_itl95_lows = gt_prob_itl95_his = None
predicted_prob_sds = predicted_prob_itl95_lows = predicted_prob_itl95_his = None
if prob_var is not None:
gt_prob_sds = [float(prob_sd[idx][i]) for i in gt_labels]
predicted_prob_sds = [float(prob_sd[idx][i]) for i in predicted_labels]
try:
gt_prob_itl95_lows, gt_prob_itl95_his = \
_calc_prob_interval(0.95, gt_probs, [float(prob_var[idx][i]) for i in gt_labels])
predicted_prob_itl95_lows, predicted_prob_itl95_his = \
_calc_prob_interval(0.95, predicted_probs, [float(prob_var[idx][i])
for i in predicted_labels])
has_uncertainty = True
except ValueError:
log.error(traceback.format_exc())
log.error("Error on calculating uncertainty")
union_labs = list(set(gt_labels + predicted_labels))
imageid_labels[str(self._count)] = union_labs
explain = Explain()
explain.image_id = str(self._count)
explain.image_data = image_string
summary.add_value("explainer", "image", explain)
explain.sample_id = self._count
explain.image_path = original_image_path
summary.add_value("explainer", "sampl e", explain)
explain = Explain()
explain.image_id = str(self._count)
explain.sample_id = self._count
explain.ground_truth_label.extend(gt_labels)
explain.inference.ground_truth_prob.extend(gt_probs)
explain.inference.predicted_label.extend(predicted_labels)
explain.inference.predicted_prob.extend(predicted_probs)
if has_uncertainty:
explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
explain.inference.ground_truth_prob_itl95_low.extend(gt_prob_itl95_lows)
explain.inference.ground_truth_prob_itl95_hi.extend(gt_prob_itl95_his)
explain.inference.predicted_prob_sd.extend(predicted_prob_sds)
explain.inference.predicted_prob_itl95_low.extend(predicted_prob_itl95_lows)
explain.inference.predicted_prob_itl95_hi.extend(predicted_prob_itl95_his)
summary.add_value("explainer", "inference", explain)
summary.record(1)
@@ -451,20 +551,20 @@ class ExplainRunner:
for idx, union in enumerate(unions):
saliency_dict = {}
explain = Explain()
explain.image_id = str(self._count)
explain.sample_id = self._count
for k, lab in enumerate(union):
saliency = batch_saliency_full[k][idx:idx + 1]
saliency_dict[lab] = saliency
saliency_np = _make_rgba(saliency)
_, _, _, saliency_string = _make_image(_normalize(saliency_np))
saliency_np = _normalize(saliency.asnumpy().squeeze())
saliency_image = _np_to_image(saliency_np, mode='L')
heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab, self._count, saliency_image)
explanation = explain.explanation.add()
explanation.explain_method = explainer.__class__.__name__
explanation.heatmap_path = heatmap_path
explanation.label = lab
explanation.heatmap = saliency_string
summary.add_value("explainer", "explanation", explain)
summary.record(1)
@@ -496,3 +596,30 @@ class ExplainRunner:
else:
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
benchmarker.aggregate(res, label)
def _save_original_image(self, sample_id: int, image):
"""Save an image to summary directory."""
id_dirname = _get_id_dirname(sample_id)
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp),
_ORIGINAL_IMAGE_DIRNAME,
id_dirname)
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
relative_path = os.path.join(relative_dir, f"{sample_id}.jpg")
save_path = os.path.join(self._summary_dir, relative_path)
with open(save_path, "wb") as file:
image.save(file)
return relative_path
def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image):
"""Save heatmap image to summary directory."""
id_dirname = _get_id_dirname(sample_id)
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp),
_HEATMAP_DIRNAME,
explain_method,
id_dirname)
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
relative_path = os.path.join(relative_dir, f"{sample_id}_{class_id}.jpg")
save_path = os.path.join(self._summary_dir, relative_path)
with open(save_path, "wb") as file:
image.save(file, optimize=True)
return relative_path