@@ -13,10 +13,15 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""Runner."""
"""Runner."""
import os
import re
import traceback
from time import time
from time import time
from typing import Tuple, List, Optional
from typing import Tuple, List, Optional
import numpy as np
import numpy as np
from scipy.stats import beta
from PIL import Image
from mindspore.train._utils import check_value_type
from mindspore.train._utils import check_value_type
from mindspore.train.summary_pb2 import Explain
from mindspore.train.summary_pb2 import Explain
@@ -24,34 +29,75 @@ import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset as ds
from mindspore import log
from mindspore import log
from mindspore.ops.operations import ExpandDims
from mindspore.ops.operations import ExpandDims
from mindspore.train.summary._summary_adapter import _convert_image_format, _make_image
from mindspore.train.summary._summary_adapter import _convert_image_format
from mindspore.train.summary.summary_record import SummaryRecord
from mindspore.train.summary.summary_record import SummaryRecord
from mindspore.nn.probability.toolbox import UncertaintyEvaluation
from .benchmark import Localization
from .benchmark import Localization
from .benchmark._attribution.metric import AttributionMetric
from .benchmark._attribution.metric import AttributionMetric
from .explanation._attribution._attribution import Attribution
from .explanation._attribution._attribution import Attribution
# datafile directory names
_DATAFILE_DIRNAME_PREFIX = "_explain_"
_ORIGINAL_IMAGE_DIRNAME = "origin_images"
_HEATMAP_DIRNAME = "heatmap"
# max. no. of sample per directory
_SAMPLE_PER_DIR = 1000
_EXPAND_DIMS = ExpandDims()
_EXPAND_DIMS = ExpandDims()
_CMAP_0 = np.reshape(np.array([55, 25, 86, 255]), [1, 1, 4]) / 255
_CMAP_1 = np.reshape(np.array([255, 255, 0, 255]), [1, 1, 4]) / 255
def _normalize(img_np):
def _normalize(img_np):
"""Normalize the image in the numpy array to be in [0, 255]. """
"""Normalize the numpy image to the range of [0, 1 ]. """
max_ = img_np.max()
max_ = img_np.max()
min_ = img_np.min()
min_ = img_np.min()
normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
return (normed * 255).astype(np.uint8)
return normed
def _np_to_image(img_np, mode):
"""Convert numpy array to PIL image."""
return Image.fromarray(np.uint8(img_np*255), mode=mode)
def _calc_prob_interval(volume, probs, prob_vars):
"""Compute the confidence interval of probability."""
if not isinstance(probs, np.ndarray):
probs = np.asarray(probs)
if not isinstance(prob_vars, np.ndarray):
prob_vars = np.asarray(prob_vars)
one_minus_probs = 1 - probs
alpha_coef = (np.square(probs) * one_minus_probs / prob_vars) - probs
beta_coef = alpha_coef * one_minus_probs / probs
intervals = beta.interval(volume, alpha_coef, beta_coef)
# avoid invalid result due to extreme small value of prob_vars
lows = []
highs = []
for i, low in enumerate(intervals[0]):
high = intervals[1][i]
if prob_vars[i] <= 0 or \
not np.isfinite(low) or low > probs[i] or \
not np.isfinite(high) or high < probs[i]:
low = probs[i]
high = probs[i]
lows.append(low)
highs.append(high)
def _make_rgba(saliency):
"""Make rgba image for saliency map."""
saliency = saliency.asnumpy().squeeze()
saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min()).clip(1e-10)
rgba = np.empty((saliency.shape[0], saliency.shape[1], 4))
rgba[:, :, :] = np.expand_dims(saliency, 2)
rgba = rgba * _CMAP_1 + (1 - rgba) * _CMAP_0
rgba[:, :, -1] = saliency * 1
return rgba
return lows, highs
def _get_id_dirname(sample_id: int):
"""Get the name of parent directory of the image id."""
return str(int(sample_id/_SAMPLE_PER_DIR)*_SAMPLE_PER_DIR)
def _extract_timestamp(filename: str):
"""Extract timestamp from summary filename."""
matched = re.search(r"summary\.(\d+)", filename)
if matched:
return int(matched.group(1))
return None
class ExplainRunner:
class ExplainRunner:
@@ -78,11 +124,14 @@ class ExplainRunner:
self._count = 0
self._count = 0
self._classes = None
self._classes = None
self._model = None
self._model = None
self._uncertainty = None
self._summary_timestamp = None
def run(self,
def run(self,
dataset: Tuple,
dataset: Tuple,
explainers: List,
explainers: List,
benchmarkers: Optional[List] = None):
benchmarkers: Optional[List] = None,
uncertainty: Optional[UncertaintyEvaluation] = None):
"""
"""
Genereates results and writes results into the summary files in `summary_dir` specified during the object
Genereates results and writes results into the summary files in `summary_dir` specified during the object
initialization.
initialization.
@@ -98,7 +147,8 @@ class ExplainRunner:
`mindspore.explainer.explanation`.
`mindspore.explainer.explanation`.
benchmarkers (list[Benchmark], optional): A list of benchmark objects to generate evaluation results.
benchmarkers (list[Benchmark], optional): A list of benchmark objects to generate evaluation results.
Default: None
Default: None
uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference
uncertainty of samples.
Examples:
Examples:
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
>>> # obtain dataset object
>>> # obtain dataset object
@@ -145,15 +195,31 @@ class ExplainRunner:
inputs, _, _ = self._unpack_next_element(next_element)
inputs, _, _ = self._unpack_next_element(next_element)
prop_test = self._model(inputs)
prop_test = self._model(inputs)
check_value_type("output of model im explainer", prop_test, ms.Tensor)
check_value_type("output of model im explainer", prop_test, ms.Tensor)
if prop_test.shape[1] > len(self._classes):
raise ValueError("The dimension of model output should not exceed the length of dataset classes. Please "
if prop_test.shape[1] != len(self._classes):
raise ValueError("The dimension of model output does not match the length of dataset classes. Please "
"check dataset classes or the black-box model in the explainer again.")
"check dataset classes or the black-box model in the explainer again.")
if uncertainty is not None:
check_value_type("uncertainty", uncertainty, UncertaintyEvaluation)
prop_var_test = uncertainty.eval_epistemic_uncertainty(inputs)
check_value_type("output of uncertainty", prop_var_test, np.ndarray)
if prop_var_test.shape[1] != len(self._classes):
raise ValueError("The dimension of uncertainty output does not match the length of dataset classes"
"classes. Please check dataset classes or the black-box model in the explainer again.")
self._uncertainty = uncertainty
else:
self._uncertainty = None
with SummaryRecord(self._summary_dir) as summary:
with SummaryRecord(self._summary_dir) as summary:
print("Start running and writing......")
print("Start running and writing......")
begin = time()
begin = time()
print("Start writing metadata.")
print("Start writing metadata.")
self._summary_timestamp = _extract_timestamp(summary.event_file_name)
if self._summary_timestamp is None:
raise RuntimeError("Cannot extract timestamp from summary filename!"
" It should contains a timestamp of 10 digits.")
explain = Explain()
explain = Explain()
explain.metadata.label.extend(classes)
explain.metadata.label.extend(classes)
exp_names = [exp.__class__.__name__ for exp in explainers]
exp_names = [exp.__class__.__name__ for exp in explainers]
@@ -341,7 +407,7 @@ class ExplainRunner:
bboxes = None
bboxes = None
else:
else:
inputs = next_element[0]
inputs = next_element[0]
labels = [[] for x in inputs]
labels = [[] for _ in inputs]
bboxes = None
bboxes = None
inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
return inputs, labels, bboxes
return inputs, labels, bboxes
@@ -359,7 +425,7 @@ class ExplainRunner:
2D Tensor.
2D Tensor.
"""
"""
max_len = max([len(l) for l in labels])
max_len = max([len(label ) for labe l in labels])
batch_labels = np.zeros((len(labels), max_len))
batch_labels = np.zeros((len(labels), max_len))
for idx, _ in enumerate(batch_labels):
for idx, _ in enumerate(batch_labels):
@@ -368,7 +434,7 @@ class ExplainRunner:
return ms.Tensor(batch_labels, ms.int32)
return ms.Tensor(batch_labels, ms.int32)
def _run_inference(self, dataset, summary, threshod=0.5):
def _run_inference(self, dataset, summary, threshol d=0.5):
"""
"""
Run inference for the dataset and write the inference related data into summary.
Run inference for the dataset and write the inference related data into summary.
@@ -387,30 +453,64 @@ class ExplainRunner:
now = time()
now = time()
inputs, labels, _ = self._unpack_next_element(next_element)
inputs, labels, _ = self._unpack_next_element(next_element)
prob = self._model(inputs).asnumpy()
prob = self._model(inputs).asnumpy()
if self._uncertainty is not None:
prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
prob_sd = np.sqrt(prob_var)
else:
prob_var = prob_sd = None
for idx, inp in enumerate(inputs):
for idx, inp in enumerate(inputs):
gt_labels = labels[idx]
gt_labels = labels[idx]
gt_probs = [float(prob[idx][i]) for i in gt_labels]
gt_probs = [float(prob[idx][i]) for i in gt_labels]
data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
_, _, _, image_string = _make_image(_normalize(data_np))
original_image = _np_to_image(_normalize(data_np), mode='RGB')
original_image_path = self._save_original_image(self._count, original_image)
predicted_labels = [int(i) for i in (prob[idx] > threshod).nonzero()[0]]
predicted_labels = [int(i) for i in (prob[idx] > threshol d).nonzero()[0]]
predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
has_uncertainty = False
gt_prob_sds = gt_prob_itl95_lows = gt_prob_itl95_his = None
predicted_prob_sds = predicted_prob_itl95_lows = predicted_prob_itl95_his = None
if prob_var is not None:
gt_prob_sds = [float(prob_sd[idx][i]) for i in gt_labels]
predicted_prob_sds = [float(prob_sd[idx][i]) for i in predicted_labels]
try:
gt_prob_itl95_lows, gt_prob_itl95_his = \
_calc_prob_interval(0.95, gt_probs, [float(prob_var[idx][i]) for i in gt_labels])
predicted_prob_itl95_lows, predicted_prob_itl95_his = \
_calc_prob_interval(0.95, predicted_probs, [float(prob_var[idx][i])
for i in predicted_labels])
has_uncertainty = True
except ValueError:
log.error(traceback.format_exc())
log.error("Error on calculating uncertainty")
union_labs = list(set(gt_labels + predicted_labels))
union_labs = list(set(gt_labels + predicted_labels))
imageid_labels[str(self._count)] = union_labs
imageid_labels[str(self._count)] = union_labs
explain = Explain()
explain = Explain()
explain.image_id = str(self._count)
explain.image_data = image_string
summary.add_value("explainer", "image", explain)
explain.sample_id = self._count
explain.image_path = original_image_path
summary.add_value("explainer", "sampl e", explain)
explain = Explain()
explain = Explain()
explain.image_id = str(self._count)
explain.sample_id = self._count
explain.ground_truth_label.extend(gt_labels)
explain.ground_truth_label.extend(gt_labels)
explain.inference.ground_truth_prob.extend(gt_probs)
explain.inference.ground_truth_prob.extend(gt_probs)
explain.inference.predicted_label.extend(predicted_labels)
explain.inference.predicted_label.extend(predicted_labels)
explain.inference.predicted_prob.extend(predicted_probs)
explain.inference.predicted_prob.extend(predicted_probs)
if has_uncertainty:
explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
explain.inference.ground_truth_prob_itl95_low.extend(gt_prob_itl95_lows)
explain.inference.ground_truth_prob_itl95_hi.extend(gt_prob_itl95_his)
explain.inference.predicted_prob_sd.extend(predicted_prob_sds)
explain.inference.predicted_prob_itl95_low.extend(predicted_prob_itl95_lows)
explain.inference.predicted_prob_itl95_hi.extend(predicted_prob_itl95_his)
summary.add_value("explainer", "inference", explain)
summary.add_value("explainer", "inference", explain)
summary.record(1)
summary.record(1)
@@ -451,20 +551,20 @@ class ExplainRunner:
for idx, union in enumerate(unions):
for idx, union in enumerate(unions):
saliency_dict = {}
saliency_dict = {}
explain = Explain()
explain = Explain()
explain.image_id = str(self._count)
explain.sample_id = self._count
for k, lab in enumerate(union):
for k, lab in enumerate(union):
saliency = batch_saliency_full[k][idx:idx + 1]
saliency = batch_saliency_full[k][idx:idx + 1]
saliency_dict[lab] = saliency
saliency_dict[lab] = saliency
saliency_np = _make_rgba(saliency)
_, _, _, saliency_string = _make_image(_normalize(saliency_np))
saliency_np = _normalize(saliency.asnumpy().squeeze())
saliency_image = _np_to_image(saliency_np, mode='L')
heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab, self._count, saliency_image)
explanation = explain.explanation.add()
explanation = explain.explanation.add()
explanation.explain_method = explainer.__class__.__name__
explanation.explain_method = explainer.__class__.__name__
explanation.heatmap_path = heatmap_path
explanation.label = lab
explanation.label = lab
explanation.heatmap = saliency_string
summary.add_value("explainer", "explanation", explain)
summary.add_value("explainer", "explanation", explain)
summary.record(1)
summary.record(1)
@@ -496,3 +596,30 @@ class ExplainRunner:
else:
else:
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
benchmarker.aggregate(res, label)
benchmarker.aggregate(res, label)
def _save_original_image(self, sample_id: int, image):
"""Save an image to summary directory."""
id_dirname = _get_id_dirname(sample_id)
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp),
_ORIGINAL_IMAGE_DIRNAME,
id_dirname)
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
relative_path = os.path.join(relative_dir, f"{sample_id}.jpg")
save_path = os.path.join(self._summary_dir, relative_path)
with open(save_path, "wb") as file:
image.save(file)
return relative_path
def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image):
"""Save heatmap image to summary directory."""
id_dirname = _get_id_dirname(sample_id)
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp),
_HEATMAP_DIRNAME,
explain_method,
id_dirname)
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
relative_path = os.path.join(relative_dir, f"{sample_id}_{class_id}.jpg")
save_path = os.path.join(self._summary_dir, relative_path)
with open(save_path, "wb") as file:
image.save(file, optimize=True)
return relative_path