Browse Source

add envirnoment checkings and change HOC error log to warning log

fix typo

enhance error message

rearrange import statements

swap checking order

fix typo

add no result checking

enhance error message

breakdown verfiy_data()

fix comments typo and add set_train(False)
tags/v1.2.0-rc1
unknown 4 years ago
parent
commit
33e63fae2b
1 changed files with 83 additions and 35 deletions
  1. +83
    -35
      mindspore/explainer/_image_classification_runner.py

+ 83
- 35
mindspore/explainer/_image_classification_runner.py View File

@@ -23,8 +23,9 @@ from scipy.stats import beta
from PIL import Image

import mindspore as ms
import mindspore.dataset as ds
from mindspore import context
from mindspore import log
import mindspore.dataset as ds
from mindspore.dataset import Dataset
from mindspore.nn import Cell, SequentialCell
from mindspore.ops.operations import ExpandDims
@@ -147,10 +148,12 @@ class ImageClassificationRunner:
self._sample_index = -1

self._full_network = SequentialCell([self._network, activation_fn])
self._full_network.set_train(False)

self._manifest = None

self._verify_data_n_settings(check_data_n_network=True)
self._verify_data_n_settings(check_data_n_network=True,
check_environment=True)

def register_saliency(self,
explainers,
@@ -159,7 +162,7 @@ class ImageClassificationRunner:
Register saliency explanation instances.

Note:
This function call not be invoked more then once on each runner.
This function can not be invoked more than once on each runner.

Args:
explainers (list[Attribution]): The explainers to be evaluated,
@@ -192,7 +195,7 @@ class ImageClassificationRunner:
self._benchmarkers = benchmarkers

try:
self._verify_data_n_settings(check_saliency=True)
self._verify_data_n_settings(check_saliency=True, check_environment=True)
except (ValueError, TypeError):
self._explainers = None
self._benchmarkers = None
@@ -204,7 +207,7 @@ class ImageClassificationRunner:

Notes:
Input images are required to be in 3 channels formats and the length of side short must be equals to or
greater than 56 pixels.
greater than 56 pixels. This function can not be invoked more than once on each runner.

Raises:
ValueError: Be raised for any data or settings' value problem.
@@ -216,7 +219,7 @@ class ImageClassificationRunner:
self._hoc_searcher = hoc.Searcher(self._full_network)

try:
self._verify_data_n_settings(check_hoc=True)
self._verify_data_n_settings(check_hoc=True, check_environment=True)
except ValueError:
self._hoc_searcher = None
raise
@@ -229,7 +232,8 @@ class ImageClassificationRunner:
Please refer to the documentation of mindspore.nn.probability.toolbox.uncertainty_evaluation for the
details. The actual output is standard deviation of the classification predictions and the corresponding
95% confidence intervals. Users have to invoke register_saliency() as well for the uncertainty results are
going to be shown on the saliency map page in MindInsight.
going to be shown on the saliency map page in MindInsight. This function can not be invoked more then once
on each runner.

Raises:
RuntimeError: Be raised if the function was called already.
@@ -271,8 +275,20 @@ class ImageClassificationRunner:
self._save_metadata(summary)

imageid_labels = self._run_inference(summary)
sample_count = self._sample_index
if self._is_saliency_registered:
self._run_saliency(summary, imageid_labels)
if not self._manifest["saliency_map"]:
raise RuntimeError(
f"No saliency map was generated in {sample_count} samples. "
f"Please make sure the dataset, labels, activation function and network are properly trained "
f"and configured.")

if self._is_hoc_registered and not self._manifest["hierarchical_occlusion"]:
raise RuntimeError(
f"No Hierarchical Occlusion result was found in {sample_count} samples. "
f"Please make sure the dataset, labels, activation function and network are properly trained "
f"and configured.")

self._save_manifest()

@@ -484,8 +500,9 @@ class ImageClassificationRunner:
compiled_mask = hoc.compile_mask(str_mask, sample_input)
try:
edit_tree, layer_outputs = self._hoc_searcher.search(sample_input, label_idx, compiled_mask)
except hoc.NoValidResultError as ex:
log.error(f"HOC cannot find result for sample:{sample_id} error:{ex}")
except hoc.NoValidResultError:
log.warning(f"No Hierarchical Occlusion result was found in sample#{sample_id} "
f"label:{self._labels[label_idx]}, skipped.")
continue
has_rec = True
hoc_rec = explain.hoc.add()
@@ -512,7 +529,7 @@ class ImageClassificationRunner:
next_element (Tuple): Data of one step
explainer (_Attribution): An Attribution object to generate saliency maps.
sample_id_labels (dict): A dict that maps the sample id and its union labels.
summary (SummaryRecord): The summary object to store the data
summary (SummaryRecord): The summary object to store the data.

Returns:
list, List of dict that maps label to its corresponding saliency map.
@@ -613,16 +630,27 @@ class ImageClassificationRunner:
sds[i] = 0
return itl_lows, itl_his, sds

def _verify_data(self):
"""Verify dataset and labels."""
next_element = next(self._dataset.create_tuple_iterator())
def _verify_labels(self):
"""Verify labels."""
label_set = set()
if not self._labels:
raise ValueError(f"The label list provided is empty.")
for i, label in enumerate(self._labels):
if label.strip() == "":
raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is "
f"no empty label.")
if label in label_set:
raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.")
label_set.add(label)

if len(next_element) not in [1, 2, 3]:
def _verify_ds_sample(self, sample):
"""Verify a dataset sample."""
if len(sample) not in [1, 2, 3]:
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
" as columns.")

if len(next_element) == 3:
inputs, labels, bboxes = next_element
if len(sample) == 3:
inputs, labels, bboxes = sample
if bboxes.shape[-1] != 4:
raise ValueError("The third element of dataset should be bounding boxes with shape of "
"[batch_size, num_ground_truth, 4].")
@@ -631,10 +659,10 @@ class ImageClassificationRunner:
if any([isinstance(bench, Localization) for bench in self._benchmarkers]):
raise ValueError("The dataset must provide bboxes if Localization is to be computed.")

if len(next_element) == 2:
inputs, labels = next_element
if len(next_element) == 1:
inputs = next_element[0]
if len(sample) == 2:
inputs, labels = sample
if len(sample) == 1:
inputs = sample[0]

if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
raise ValueError(
@@ -644,7 +672,7 @@ class ImageClassificationRunner:
log.warning(
"Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
" dimension as batch data.".format(inputs.shape))
if len(next_element) > 1:
if len(sample) > 1:
if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
raise ValueError(
"Labels shape {} is unrecognizable: outputs should not have more than two dimensions"
@@ -654,24 +682,26 @@ class ImageClassificationRunner:
if inputs.shape[-3] != 3:
raise ValueError(
"Hierarchical occlusion is registered, images must be in 3 channels format, but "
"{} channels is encountered.".format(inputs.shape[-3]))
"{} channel(s) is(are) encountered.".format(inputs.shape[-3]))
short_side = min(inputs.shape[-2:])
if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN:
raise ValueError(
"Hierarchical occlusion is registered, images' short side must be equals to or greater then "
"{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side))

def _verify_data(self):
"""Verify dataset and labels."""
self._verify_labels()

try:
sample = next(self._dataset.create_tuple_iterator())
except StopIteration:
raise ValueError("The dataset provided is empty.")

self._verify_ds_sample(sample)

def _verify_network(self):
"""Verify the network."""
label_set = set()
for i, label in enumerate(self._labels):
if label.strip() == "":
raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is "
f"no empty label.")
if label in label_set:
raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.")
label_set.add(label)

next_element = next(self._dataset.create_tuple_iterator())
inputs, _, _ = self._unpack_next_element(next_element)
prop_test = self._full_network(inputs)
@@ -709,7 +739,8 @@ class ImageClassificationRunner:
check_registration=False,
check_data_n_network=False,
check_saliency=False,
check_hoc=False):
check_hoc=False,
check_environment=False):
"""
Verify the validity of dataset and other settings.

@@ -719,23 +750,40 @@ class ImageClassificationRunner:
check_data_n_network (bool): Set it True for checking data and network.
check_saliency (bool): Set it True for checking saliency related settings.
check_hoc (bool): Set it True for checking HOC related settings.
check_environment (bool): Set it True for checking environment conditions.

Raises:
ValueError: Be raised for any data or settings' value problem.
TypeError: Be raised for any data or settings' type problem.
RuntimeError: Be raised for any runtime problem.
"""
if check_all:
check_registration = True
check_data_n_network = True
check_saliency = True
check_hoc = True
check_environment = True

if check_environment:
device_target = context.get_context('device_target')
if device_target not in ("Ascend", "GPU"):
raise RuntimeError(f"Unsupported device_target: '{device_target}', "
f"only 'Ascend' or 'GPU' is supported. "
f"Please call context.set_context(device_target='Ascend') or "
f"context.set_context(device_target='GPU').")
if check_environment or check_saliency:
if self._is_saliency_registered:
mode = context.get_context('mode')
if mode != context.PYNATIVE_MODE:
raise RuntimeError("Context mode: GRAPH_MODE is not supported, "
"please call context.set_context(mode=context.PYNATIVE_MODE).")

if check_registration:
if self._is_uncertainty_registered and not self._is_saliency_registered:
raise ValueError("Function register_uncertainty() is called but register_saliency() is not.")
if not self._is_saliency_registered and not self._is_hoc_registered:
raise ValueError("No explanation module was registered, user should at least call register_saliency() "
"or register_hierarchical_occlusion() once with proper arguments")
if self._is_uncertainty_registered and not self._is_saliency_registered:
raise ValueError("Function register_uncertainty() is invoked but register_saliency() is not.")
"or register_hierarchical_occlusion() once with proper arguments.")

if check_data_n_network or check_saliency or check_hoc:
self._verify_data()


Loading…
Cancel
Save