|
|
|
@@ -23,8 +23,9 @@ from scipy.stats import beta |
|
|
|
from PIL import Image |
|
|
|
|
|
|
|
import mindspore as ms |
|
|
|
import mindspore.dataset as ds |
|
|
|
from mindspore import context |
|
|
|
from mindspore import log |
|
|
|
import mindspore.dataset as ds |
|
|
|
from mindspore.dataset import Dataset |
|
|
|
from mindspore.nn import Cell, SequentialCell |
|
|
|
from mindspore.ops.operations import ExpandDims |
|
|
|
@@ -147,10 +148,12 @@ class ImageClassificationRunner: |
|
|
|
self._sample_index = -1 |
|
|
|
|
|
|
|
self._full_network = SequentialCell([self._network, activation_fn]) |
|
|
|
self._full_network.set_train(False) |
|
|
|
|
|
|
|
self._manifest = None |
|
|
|
|
|
|
|
self._verify_data_n_settings(check_data_n_network=True) |
|
|
|
self._verify_data_n_settings(check_data_n_network=True, |
|
|
|
check_environment=True) |
|
|
|
|
|
|
|
def register_saliency(self, |
|
|
|
explainers, |
|
|
|
@@ -159,7 +162,7 @@ class ImageClassificationRunner: |
|
|
|
Register saliency explanation instances. |
|
|
|
|
|
|
|
Note: |
|
|
|
This function call not be invoked more then once on each runner. |
|
|
|
This function can not be invoked more than once on each runner. |
|
|
|
|
|
|
|
Args: |
|
|
|
explainers (list[Attribution]): The explainers to be evaluated, |
|
|
|
@@ -192,7 +195,7 @@ class ImageClassificationRunner: |
|
|
|
self._benchmarkers = benchmarkers |
|
|
|
|
|
|
|
try: |
|
|
|
self._verify_data_n_settings(check_saliency=True) |
|
|
|
self._verify_data_n_settings(check_saliency=True, check_environment=True) |
|
|
|
except (ValueError, TypeError): |
|
|
|
self._explainers = None |
|
|
|
self._benchmarkers = None |
|
|
|
@@ -204,7 +207,7 @@ class ImageClassificationRunner: |
|
|
|
|
|
|
|
Notes: |
|
|
|
Input images are required to be in 3 channels formats and the length of side short must be equals to or |
|
|
|
greater than 56 pixels. |
|
|
|
greater than 56 pixels. This function can not be invoked more than once on each runner. |
|
|
|
|
|
|
|
Raises: |
|
|
|
ValueError: Be raised for any data or settings' value problem. |
|
|
|
@@ -216,7 +219,7 @@ class ImageClassificationRunner: |
|
|
|
self._hoc_searcher = hoc.Searcher(self._full_network) |
|
|
|
|
|
|
|
try: |
|
|
|
self._verify_data_n_settings(check_hoc=True) |
|
|
|
self._verify_data_n_settings(check_hoc=True, check_environment=True) |
|
|
|
except ValueError: |
|
|
|
self._hoc_searcher = None |
|
|
|
raise |
|
|
|
@@ -229,7 +232,8 @@ class ImageClassificationRunner: |
|
|
|
Please refer to the documentation of mindspore.nn.probability.toolbox.uncertainty_evaluation for the |
|
|
|
details. The actual output is standard deviation of the classification predictions and the corresponding |
|
|
|
95% confidence intervals. Users have to invoke register_saliency() as well for the uncertainty results are |
|
|
|
going to be shown on the saliency map page in MindInsight. |
|
|
|
going to be shown on the saliency map page in MindInsight. This function can not be invoked more then once |
|
|
|
on each runner. |
|
|
|
|
|
|
|
Raises: |
|
|
|
RuntimeError: Be raised if the function was called already. |
|
|
|
@@ -271,8 +275,20 @@ class ImageClassificationRunner: |
|
|
|
self._save_metadata(summary) |
|
|
|
|
|
|
|
imageid_labels = self._run_inference(summary) |
|
|
|
sample_count = self._sample_index |
|
|
|
if self._is_saliency_registered: |
|
|
|
self._run_saliency(summary, imageid_labels) |
|
|
|
if not self._manifest["saliency_map"]: |
|
|
|
raise RuntimeError( |
|
|
|
f"No saliency map was generated in {sample_count} samples. " |
|
|
|
f"Please make sure the dataset, labels, activation function and network are properly trained " |
|
|
|
f"and configured.") |
|
|
|
|
|
|
|
if self._is_hoc_registered and not self._manifest["hierarchical_occlusion"]: |
|
|
|
raise RuntimeError( |
|
|
|
f"No Hierarchical Occlusion result was found in {sample_count} samples. " |
|
|
|
f"Please make sure the dataset, labels, activation function and network are properly trained " |
|
|
|
f"and configured.") |
|
|
|
|
|
|
|
self._save_manifest() |
|
|
|
|
|
|
|
@@ -484,8 +500,9 @@ class ImageClassificationRunner: |
|
|
|
compiled_mask = hoc.compile_mask(str_mask, sample_input) |
|
|
|
try: |
|
|
|
edit_tree, layer_outputs = self._hoc_searcher.search(sample_input, label_idx, compiled_mask) |
|
|
|
except hoc.NoValidResultError as ex: |
|
|
|
log.error(f"HOC cannot find result for sample:{sample_id} error:{ex}") |
|
|
|
except hoc.NoValidResultError: |
|
|
|
log.warning(f"No Hierarchical Occlusion result was found in sample#{sample_id} " |
|
|
|
f"label:{self._labels[label_idx]}, skipped.") |
|
|
|
continue |
|
|
|
has_rec = True |
|
|
|
hoc_rec = explain.hoc.add() |
|
|
|
@@ -512,7 +529,7 @@ class ImageClassificationRunner: |
|
|
|
next_element (Tuple): Data of one step |
|
|
|
explainer (_Attribution): An Attribution object to generate saliency maps. |
|
|
|
sample_id_labels (dict): A dict that maps the sample id and its union labels. |
|
|
|
summary (SummaryRecord): The summary object to store the data |
|
|
|
summary (SummaryRecord): The summary object to store the data. |
|
|
|
|
|
|
|
Returns: |
|
|
|
list, List of dict that maps label to its corresponding saliency map. |
|
|
|
@@ -613,16 +630,27 @@ class ImageClassificationRunner: |
|
|
|
sds[i] = 0 |
|
|
|
return itl_lows, itl_his, sds |
|
|
|
|
|
|
|
def _verify_data(self): |
|
|
|
"""Verify dataset and labels.""" |
|
|
|
next_element = next(self._dataset.create_tuple_iterator()) |
|
|
|
def _verify_labels(self): |
|
|
|
"""Verify labels.""" |
|
|
|
label_set = set() |
|
|
|
if not self._labels: |
|
|
|
raise ValueError(f"The label list provided is empty.") |
|
|
|
for i, label in enumerate(self._labels): |
|
|
|
if label.strip() == "": |
|
|
|
raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is " |
|
|
|
f"no empty label.") |
|
|
|
if label in label_set: |
|
|
|
raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.") |
|
|
|
label_set.add(label) |
|
|
|
|
|
|
|
if len(next_element) not in [1, 2, 3]: |
|
|
|
def _verify_ds_sample(self, sample): |
|
|
|
"""Verify a dataset sample.""" |
|
|
|
if len(sample) not in [1, 2, 3]: |
|
|
|
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]" |
|
|
|
" as columns.") |
|
|
|
|
|
|
|
if len(next_element) == 3: |
|
|
|
inputs, labels, bboxes = next_element |
|
|
|
if len(sample) == 3: |
|
|
|
inputs, labels, bboxes = sample |
|
|
|
if bboxes.shape[-1] != 4: |
|
|
|
raise ValueError("The third element of dataset should be bounding boxes with shape of " |
|
|
|
"[batch_size, num_ground_truth, 4].") |
|
|
|
@@ -631,10 +659,10 @@ class ImageClassificationRunner: |
|
|
|
if any([isinstance(bench, Localization) for bench in self._benchmarkers]): |
|
|
|
raise ValueError("The dataset must provide bboxes if Localization is to be computed.") |
|
|
|
|
|
|
|
if len(next_element) == 2: |
|
|
|
inputs, labels = next_element |
|
|
|
if len(next_element) == 1: |
|
|
|
inputs = next_element[0] |
|
|
|
if len(sample) == 2: |
|
|
|
inputs, labels = sample |
|
|
|
if len(sample) == 1: |
|
|
|
inputs = sample[0] |
|
|
|
|
|
|
|
if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]: |
|
|
|
raise ValueError( |
|
|
|
@@ -644,7 +672,7 @@ class ImageClassificationRunner: |
|
|
|
log.warning( |
|
|
|
"Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th" |
|
|
|
" dimension as batch data.".format(inputs.shape)) |
|
|
|
if len(next_element) > 1: |
|
|
|
if len(sample) > 1: |
|
|
|
if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1: |
|
|
|
raise ValueError( |
|
|
|
"Labels shape {} is unrecognizable: outputs should not have more than two dimensions" |
|
|
|
@@ -654,24 +682,26 @@ class ImageClassificationRunner: |
|
|
|
if inputs.shape[-3] != 3: |
|
|
|
raise ValueError( |
|
|
|
"Hierarchical occlusion is registered, images must be in 3 channels format, but " |
|
|
|
"{} channels is encountered.".format(inputs.shape[-3])) |
|
|
|
"{} channel(s) is(are) encountered.".format(inputs.shape[-3])) |
|
|
|
short_side = min(inputs.shape[-2:]) |
|
|
|
if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN: |
|
|
|
raise ValueError( |
|
|
|
"Hierarchical occlusion is registered, images' short side must be equals to or greater then " |
|
|
|
"{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side)) |
|
|
|
|
|
|
|
def _verify_data(self): |
|
|
|
"""Verify dataset and labels.""" |
|
|
|
self._verify_labels() |
|
|
|
|
|
|
|
try: |
|
|
|
sample = next(self._dataset.create_tuple_iterator()) |
|
|
|
except StopIteration: |
|
|
|
raise ValueError("The dataset provided is empty.") |
|
|
|
|
|
|
|
self._verify_ds_sample(sample) |
|
|
|
|
|
|
|
def _verify_network(self): |
|
|
|
"""Verify the network.""" |
|
|
|
label_set = set() |
|
|
|
for i, label in enumerate(self._labels): |
|
|
|
if label.strip() == "": |
|
|
|
raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is " |
|
|
|
f"no empty label.") |
|
|
|
if label in label_set: |
|
|
|
raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.") |
|
|
|
label_set.add(label) |
|
|
|
|
|
|
|
next_element = next(self._dataset.create_tuple_iterator()) |
|
|
|
inputs, _, _ = self._unpack_next_element(next_element) |
|
|
|
prop_test = self._full_network(inputs) |
|
|
|
@@ -709,7 +739,8 @@ class ImageClassificationRunner: |
|
|
|
check_registration=False, |
|
|
|
check_data_n_network=False, |
|
|
|
check_saliency=False, |
|
|
|
check_hoc=False): |
|
|
|
check_hoc=False, |
|
|
|
check_environment=False): |
|
|
|
""" |
|
|
|
Verify the validity of dataset and other settings. |
|
|
|
|
|
|
|
@@ -719,23 +750,40 @@ class ImageClassificationRunner: |
|
|
|
check_data_n_network (bool): Set it True for checking data and network. |
|
|
|
check_saliency (bool): Set it True for checking saliency related settings. |
|
|
|
check_hoc (bool): Set it True for checking HOC related settings. |
|
|
|
check_environment (bool): Set it True for checking environment conditions. |
|
|
|
|
|
|
|
Raises: |
|
|
|
ValueError: Be raised for any data or settings' value problem. |
|
|
|
TypeError: Be raised for any data or settings' type problem. |
|
|
|
RuntimeError: Be raised for any runtime problem. |
|
|
|
""" |
|
|
|
if check_all: |
|
|
|
check_registration = True |
|
|
|
check_data_n_network = True |
|
|
|
check_saliency = True |
|
|
|
check_hoc = True |
|
|
|
check_environment = True |
|
|
|
|
|
|
|
if check_environment: |
|
|
|
device_target = context.get_context('device_target') |
|
|
|
if device_target not in ("Ascend", "GPU"): |
|
|
|
raise RuntimeError(f"Unsupported device_target: '{device_target}', " |
|
|
|
f"only 'Ascend' or 'GPU' is supported. " |
|
|
|
f"Please call context.set_context(device_target='Ascend') or " |
|
|
|
f"context.set_context(device_target='GPU').") |
|
|
|
if check_environment or check_saliency: |
|
|
|
if self._is_saliency_registered: |
|
|
|
mode = context.get_context('mode') |
|
|
|
if mode != context.PYNATIVE_MODE: |
|
|
|
raise RuntimeError("Context mode: GRAPH_MODE is not supported, " |
|
|
|
"please call context.set_context(mode=context.PYNATIVE_MODE).") |
|
|
|
|
|
|
|
if check_registration: |
|
|
|
if self._is_uncertainty_registered and not self._is_saliency_registered: |
|
|
|
raise ValueError("Function register_uncertainty() is called but register_saliency() is not.") |
|
|
|
if not self._is_saliency_registered and not self._is_hoc_registered: |
|
|
|
raise ValueError("No explanation module was registered, user should at least call register_saliency() " |
|
|
|
"or register_hierarchical_occlusion() once with proper arguments") |
|
|
|
if self._is_uncertainty_registered and not self._is_saliency_registered: |
|
|
|
raise ValueError("Function register_uncertainty() is invoked but register_saliency() is not.") |
|
|
|
"or register_hierarchical_occlusion() once with proper arguments.") |
|
|
|
|
|
|
|
if check_data_n_network or check_saliency or check_hoc: |
|
|
|
self._verify_data() |
|
|
|
|