|
|
|
@@ -13,10 +13,15 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
"""Runner.""" |
|
|
|
import os |
|
|
|
import re |
|
|
|
import traceback |
|
|
|
from time import time |
|
|
|
from typing import Tuple, List, Optional |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
from scipy.stats import beta |
|
|
|
from PIL import Image |
|
|
|
|
|
|
|
from mindspore.train._utils import check_value_type |
|
|
|
from mindspore.train.summary_pb2 import Explain |
|
|
|
@@ -24,34 +29,75 @@ import mindspore as ms |
|
|
|
import mindspore.dataset as ds |
|
|
|
from mindspore import log |
|
|
|
from mindspore.ops.operations import ExpandDims |
|
|
|
from mindspore.train.summary._summary_adapter import _convert_image_format, _make_image |
|
|
|
from mindspore.train.summary._summary_adapter import _convert_image_format |
|
|
|
from mindspore.train.summary.summary_record import SummaryRecord |
|
|
|
from mindspore.nn.probability.toolbox import UncertaintyEvaluation |
|
|
|
from .benchmark import Localization |
|
|
|
from .benchmark._attribution.metric import AttributionMetric |
|
|
|
from .explanation._attribution._attribution import Attribution |
|
|
|
|
|
|
|
# datafile directory names |
|
|
|
_DATAFILE_DIRNAME_PREFIX = "_explain_" |
|
|
|
_ORIGINAL_IMAGE_DIRNAME = "origin_images" |
|
|
|
_HEATMAP_DIRNAME = "heatmap" |
|
|
|
# max. no. of sample per directory |
|
|
|
_SAMPLE_PER_DIR = 1000 |
|
|
|
|
|
|
|
|
|
|
|
_EXPAND_DIMS = ExpandDims() |
|
|
|
_CMAP_0 = np.reshape(np.array([55, 25, 86, 255]), [1, 1, 4]) / 255 |
|
|
|
_CMAP_1 = np.reshape(np.array([255, 255, 0, 255]), [1, 1, 4]) / 255 |
|
|
|
|
|
|
|
|
|
|
|
def _normalize(img_np): |
|
|
|
"""Normalize the image in the numpy array to be in [0, 255]. """ |
|
|
|
"""Normalize the numpy image to the range of [0, 1]. """ |
|
|
|
max_ = img_np.max() |
|
|
|
min_ = img_np.min() |
|
|
|
normed = (img_np - min_) / (max_ - min_).clip(min=1e-10) |
|
|
|
return (normed * 255).astype(np.uint8) |
|
|
|
return normed |
|
|
|
|
|
|
|
|
|
|
|
def _np_to_image(img_np, mode): |
|
|
|
"""Convert numpy array to PIL image.""" |
|
|
|
return Image.fromarray(np.uint8(img_np*255), mode=mode) |
|
|
|
|
|
|
|
|
|
|
|
def _calc_prob_interval(volume, probs, prob_vars): |
|
|
|
"""Compute the confidence interval of probability.""" |
|
|
|
if not isinstance(probs, np.ndarray): |
|
|
|
probs = np.asarray(probs) |
|
|
|
if not isinstance(prob_vars, np.ndarray): |
|
|
|
prob_vars = np.asarray(prob_vars) |
|
|
|
one_minus_probs = 1 - probs |
|
|
|
alpha_coef = (np.square(probs) * one_minus_probs / prob_vars) - probs |
|
|
|
beta_coef = alpha_coef * one_minus_probs / probs |
|
|
|
intervals = beta.interval(volume, alpha_coef, beta_coef) |
|
|
|
|
|
|
|
# avoid invalid result due to extreme small value of prob_vars |
|
|
|
lows = [] |
|
|
|
highs = [] |
|
|
|
for i, low in enumerate(intervals[0]): |
|
|
|
high = intervals[1][i] |
|
|
|
if prob_vars[i] <= 0 or \ |
|
|
|
not np.isfinite(low) or low > probs[i] or \ |
|
|
|
not np.isfinite(high) or high < probs[i]: |
|
|
|
low = probs[i] |
|
|
|
high = probs[i] |
|
|
|
lows.append(low) |
|
|
|
highs.append(high) |
|
|
|
|
|
|
|
def _make_rgba(saliency): |
|
|
|
"""Make rgba image for saliency map.""" |
|
|
|
saliency = saliency.asnumpy().squeeze() |
|
|
|
saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min()).clip(1e-10) |
|
|
|
rgba = np.empty((saliency.shape[0], saliency.shape[1], 4)) |
|
|
|
rgba[:, :, :] = np.expand_dims(saliency, 2) |
|
|
|
rgba = rgba * _CMAP_1 + (1 - rgba) * _CMAP_0 |
|
|
|
rgba[:, :, -1] = saliency * 1 |
|
|
|
return rgba |
|
|
|
return lows, highs |
|
|
|
|
|
|
|
|
|
|
|
def _get_id_dirname(sample_id: int): |
|
|
|
"""Get the name of parent directory of the image id.""" |
|
|
|
return str(int(sample_id/_SAMPLE_PER_DIR)*_SAMPLE_PER_DIR) |
|
|
|
|
|
|
|
|
|
|
|
def _extract_timestamp(filename: str): |
|
|
|
"""Extract timestamp from summary filename.""" |
|
|
|
matched = re.search(r"summary\.(\d+)", filename) |
|
|
|
if matched: |
|
|
|
return int(matched.group(1)) |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
class ExplainRunner: |
|
|
|
@@ -78,11 +124,14 @@ class ExplainRunner: |
|
|
|
self._count = 0 |
|
|
|
self._classes = None |
|
|
|
self._model = None |
|
|
|
self._uncertainty = None |
|
|
|
self._summary_timestamp = None |
|
|
|
|
|
|
|
def run(self, |
|
|
|
dataset: Tuple, |
|
|
|
explainers: List, |
|
|
|
benchmarkers: Optional[List] = None): |
|
|
|
benchmarkers: Optional[List] = None, |
|
|
|
uncertainty: Optional[UncertaintyEvaluation] = None): |
|
|
|
""" |
|
|
|
Genereates results and writes results into the summary files in `summary_dir` specified during the object |
|
|
|
initialization. |
|
|
|
@@ -98,7 +147,8 @@ class ExplainRunner: |
|
|
|
`mindspore.explainer.explanation`. |
|
|
|
benchmarkers (list[Benchmark], optional): A list of benchmark objects to generate evaluation results. |
|
|
|
Default: None |
|
|
|
|
|
|
|
uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference |
|
|
|
uncertainty of samples. |
|
|
|
Examples: |
|
|
|
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient |
|
|
|
>>> # obtain dataset object |
|
|
|
@@ -145,15 +195,31 @@ class ExplainRunner: |
|
|
|
inputs, _, _ = self._unpack_next_element(next_element) |
|
|
|
prop_test = self._model(inputs) |
|
|
|
check_value_type("output of model im explainer", prop_test, ms.Tensor) |
|
|
|
if prop_test.shape[1] > len(self._classes): |
|
|
|
raise ValueError("The dimension of model output should not exceed the length of dataset classes. Please " |
|
|
|
if prop_test.shape[1] != len(self._classes): |
|
|
|
raise ValueError("The dimension of model output does not match the length of dataset classes. Please " |
|
|
|
"check dataset classes or the black-box model in the explainer again.") |
|
|
|
|
|
|
|
if uncertainty is not None: |
|
|
|
check_value_type("uncertainty", uncertainty, UncertaintyEvaluation) |
|
|
|
prop_var_test = uncertainty.eval_epistemic_uncertainty(inputs) |
|
|
|
check_value_type("output of uncertainty", prop_var_test, np.ndarray) |
|
|
|
if prop_var_test.shape[1] != len(self._classes): |
|
|
|
raise ValueError("The dimension of uncertainty output does not match the length of dataset classes" |
|
|
|
"classes. Please check dataset classes or the black-box model in the explainer again.") |
|
|
|
self._uncertainty = uncertainty |
|
|
|
else: |
|
|
|
self._uncertainty = None |
|
|
|
|
|
|
|
with SummaryRecord(self._summary_dir) as summary: |
|
|
|
print("Start running and writing......") |
|
|
|
begin = time() |
|
|
|
print("Start writing metadata.") |
|
|
|
|
|
|
|
self._summary_timestamp = _extract_timestamp(summary.event_file_name) |
|
|
|
if self._summary_timestamp is None: |
|
|
|
raise RuntimeError("Cannot extract timestamp from summary filename!" |
|
|
|
" It should contains a timestamp of 10 digits.") |
|
|
|
|
|
|
|
explain = Explain() |
|
|
|
explain.metadata.label.extend(classes) |
|
|
|
exp_names = [exp.__class__.__name__ for exp in explainers] |
|
|
|
@@ -341,7 +407,7 @@ class ExplainRunner: |
|
|
|
bboxes = None |
|
|
|
else: |
|
|
|
inputs = next_element[0] |
|
|
|
labels = [[] for x in inputs] |
|
|
|
labels = [[] for _ in inputs] |
|
|
|
bboxes = None |
|
|
|
inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox) |
|
|
|
return inputs, labels, bboxes |
|
|
|
@@ -359,7 +425,7 @@ class ExplainRunner: |
|
|
|
2D Tensor. |
|
|
|
""" |
|
|
|
|
|
|
|
max_len = max([len(l) for l in labels]) |
|
|
|
max_len = max([len(label) for label in labels]) |
|
|
|
batch_labels = np.zeros((len(labels), max_len)) |
|
|
|
|
|
|
|
for idx, _ in enumerate(batch_labels): |
|
|
|
@@ -368,7 +434,7 @@ class ExplainRunner: |
|
|
|
|
|
|
|
return ms.Tensor(batch_labels, ms.int32) |
|
|
|
|
|
|
|
def _run_inference(self, dataset, summary, threshod=0.5): |
|
|
|
def _run_inference(self, dataset, summary, threshold=0.5): |
|
|
|
""" |
|
|
|
Run inference for the dataset and write the inference related data into summary. |
|
|
|
|
|
|
|
@@ -387,30 +453,64 @@ class ExplainRunner: |
|
|
|
now = time() |
|
|
|
inputs, labels, _ = self._unpack_next_element(next_element) |
|
|
|
prob = self._model(inputs).asnumpy() |
|
|
|
if self._uncertainty is not None: |
|
|
|
prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs) |
|
|
|
prob_sd = np.sqrt(prob_var) |
|
|
|
else: |
|
|
|
prob_var = prob_sd = None |
|
|
|
|
|
|
|
for idx, inp in enumerate(inputs): |
|
|
|
gt_labels = labels[idx] |
|
|
|
gt_probs = [float(prob[idx][i]) for i in gt_labels] |
|
|
|
|
|
|
|
data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW') |
|
|
|
_, _, _, image_string = _make_image(_normalize(data_np)) |
|
|
|
original_image = _np_to_image(_normalize(data_np), mode='RGB') |
|
|
|
original_image_path = self._save_original_image(self._count, original_image) |
|
|
|
|
|
|
|
predicted_labels = [int(i) for i in (prob[idx] > threshod).nonzero()[0]] |
|
|
|
predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]] |
|
|
|
predicted_probs = [float(prob[idx][i]) for i in predicted_labels] |
|
|
|
|
|
|
|
has_uncertainty = False |
|
|
|
gt_prob_sds = gt_prob_itl95_lows = gt_prob_itl95_his = None |
|
|
|
predicted_prob_sds = predicted_prob_itl95_lows = predicted_prob_itl95_his = None |
|
|
|
if prob_var is not None: |
|
|
|
gt_prob_sds = [float(prob_sd[idx][i]) for i in gt_labels] |
|
|
|
predicted_prob_sds = [float(prob_sd[idx][i]) for i in predicted_labels] |
|
|
|
try: |
|
|
|
gt_prob_itl95_lows, gt_prob_itl95_his = \ |
|
|
|
_calc_prob_interval(0.95, gt_probs, [float(prob_var[idx][i]) for i in gt_labels]) |
|
|
|
predicted_prob_itl95_lows, predicted_prob_itl95_his = \ |
|
|
|
_calc_prob_interval(0.95, predicted_probs, [float(prob_var[idx][i]) |
|
|
|
for i in predicted_labels]) |
|
|
|
has_uncertainty = True |
|
|
|
except ValueError: |
|
|
|
log.error(traceback.format_exc()) |
|
|
|
log.error("Error on calculating uncertainty") |
|
|
|
|
|
|
|
union_labs = list(set(gt_labels + predicted_labels)) |
|
|
|
imageid_labels[str(self._count)] = union_labs |
|
|
|
|
|
|
|
explain = Explain() |
|
|
|
explain.image_id = str(self._count) |
|
|
|
explain.image_data = image_string |
|
|
|
summary.add_value("explainer", "image", explain) |
|
|
|
explain.sample_id = self._count |
|
|
|
explain.image_path = original_image_path |
|
|
|
summary.add_value("explainer", "sample", explain) |
|
|
|
|
|
|
|
explain = Explain() |
|
|
|
explain.image_id = str(self._count) |
|
|
|
explain.sample_id = self._count |
|
|
|
explain.ground_truth_label.extend(gt_labels) |
|
|
|
explain.inference.ground_truth_prob.extend(gt_probs) |
|
|
|
explain.inference.predicted_label.extend(predicted_labels) |
|
|
|
explain.inference.predicted_prob.extend(predicted_probs) |
|
|
|
|
|
|
|
if has_uncertainty: |
|
|
|
explain.inference.ground_truth_prob_sd.extend(gt_prob_sds) |
|
|
|
explain.inference.ground_truth_prob_itl95_low.extend(gt_prob_itl95_lows) |
|
|
|
explain.inference.ground_truth_prob_itl95_hi.extend(gt_prob_itl95_his) |
|
|
|
|
|
|
|
explain.inference.predicted_prob_sd.extend(predicted_prob_sds) |
|
|
|
explain.inference.predicted_prob_itl95_low.extend(predicted_prob_itl95_lows) |
|
|
|
explain.inference.predicted_prob_itl95_hi.extend(predicted_prob_itl95_his) |
|
|
|
|
|
|
|
summary.add_value("explainer", "inference", explain) |
|
|
|
|
|
|
|
summary.record(1) |
|
|
|
@@ -451,20 +551,20 @@ class ExplainRunner: |
|
|
|
for idx, union in enumerate(unions): |
|
|
|
saliency_dict = {} |
|
|
|
explain = Explain() |
|
|
|
explain.image_id = str(self._count) |
|
|
|
explain.sample_id = self._count |
|
|
|
for k, lab in enumerate(union): |
|
|
|
saliency = batch_saliency_full[k][idx:idx + 1] |
|
|
|
|
|
|
|
saliency_dict[lab] = saliency |
|
|
|
|
|
|
|
saliency_np = _make_rgba(saliency) |
|
|
|
_, _, _, saliency_string = _make_image(_normalize(saliency_np)) |
|
|
|
saliency_np = _normalize(saliency.asnumpy().squeeze()) |
|
|
|
saliency_image = _np_to_image(saliency_np, mode='L') |
|
|
|
heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab, self._count, saliency_image) |
|
|
|
|
|
|
|
explanation = explain.explanation.add() |
|
|
|
explanation.explain_method = explainer.__class__.__name__ |
|
|
|
|
|
|
|
explanation.heatmap_path = heatmap_path |
|
|
|
explanation.label = lab |
|
|
|
explanation.heatmap = saliency_string |
|
|
|
|
|
|
|
summary.add_value("explainer", "explanation", explain) |
|
|
|
summary.record(1) |
|
|
|
@@ -496,3 +596,30 @@ class ExplainRunner: |
|
|
|
else: |
|
|
|
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) |
|
|
|
benchmarker.aggregate(res, label) |
|
|
|
|
|
|
|
def _save_original_image(self, sample_id: int, image): |
|
|
|
"""Save an image to summary directory.""" |
|
|
|
id_dirname = _get_id_dirname(sample_id) |
|
|
|
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp), |
|
|
|
_ORIGINAL_IMAGE_DIRNAME, |
|
|
|
id_dirname) |
|
|
|
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True) |
|
|
|
relative_path = os.path.join(relative_dir, f"{sample_id}.jpg") |
|
|
|
save_path = os.path.join(self._summary_dir, relative_path) |
|
|
|
with open(save_path, "wb") as file: |
|
|
|
image.save(file) |
|
|
|
return relative_path |
|
|
|
|
|
|
|
def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image): |
|
|
|
"""Save heatmap image to summary directory.""" |
|
|
|
id_dirname = _get_id_dirname(sample_id) |
|
|
|
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp), |
|
|
|
_HEATMAP_DIRNAME, |
|
|
|
explain_method, |
|
|
|
id_dirname) |
|
|
|
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True) |
|
|
|
relative_path = os.path.join(relative_dir, f"{sample_id}_{class_id}.jpg") |
|
|
|
save_path = os.path.join(self._summary_dir, relative_path) |
|
|
|
with open(save_path, "wb") as file: |
|
|
|
image.save(file, optimize=True) |
|
|
|
return relative_path |