From 33e63fae2b23c74231ef1dcbe8ac18fd056620fc Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 29 Jan 2021 10:42:42 +0800 Subject: [PATCH] add envirnoment checkings and change HOC error log to warning log fix typo enhance error message rearrange import statements swap checking order fix typo add no result checking enhance error message breakdown verfiy_data() fix comments typo and add set_train(False) --- .../explainer/_image_classification_runner.py | 118 ++++++++++++------ 1 file changed, 83 insertions(+), 35 deletions(-) diff --git a/mindspore/explainer/_image_classification_runner.py b/mindspore/explainer/_image_classification_runner.py index 1f856e5aa4..d61f719811 100644 --- a/mindspore/explainer/_image_classification_runner.py +++ b/mindspore/explainer/_image_classification_runner.py @@ -23,8 +23,9 @@ from scipy.stats import beta from PIL import Image import mindspore as ms -import mindspore.dataset as ds +from mindspore import context from mindspore import log +import mindspore.dataset as ds from mindspore.dataset import Dataset from mindspore.nn import Cell, SequentialCell from mindspore.ops.operations import ExpandDims @@ -147,10 +148,12 @@ class ImageClassificationRunner: self._sample_index = -1 self._full_network = SequentialCell([self._network, activation_fn]) + self._full_network.set_train(False) self._manifest = None - self._verify_data_n_settings(check_data_n_network=True) + self._verify_data_n_settings(check_data_n_network=True, + check_environment=True) def register_saliency(self, explainers, @@ -159,7 +162,7 @@ class ImageClassificationRunner: Register saliency explanation instances. Note: - This function call not be invoked more then once on each runner. + This function can not be invoked more than once on each runner. Args: explainers (list[Attribution]): The explainers to be evaluated, @@ -192,7 +195,7 @@ class ImageClassificationRunner: self._benchmarkers = benchmarkers try: - self._verify_data_n_settings(check_saliency=True) + self._verify_data_n_settings(check_saliency=True, check_environment=True) except (ValueError, TypeError): self._explainers = None self._benchmarkers = None @@ -204,7 +207,7 @@ class ImageClassificationRunner: Notes: Input images are required to be in 3 channels formats and the length of side short must be equals to or - greater than 56 pixels. + greater than 56 pixels. This function can not be invoked more than once on each runner. Raises: ValueError: Be raised for any data or settings' value problem. @@ -216,7 +219,7 @@ class ImageClassificationRunner: self._hoc_searcher = hoc.Searcher(self._full_network) try: - self._verify_data_n_settings(check_hoc=True) + self._verify_data_n_settings(check_hoc=True, check_environment=True) except ValueError: self._hoc_searcher = None raise @@ -229,7 +232,8 @@ class ImageClassificationRunner: Please refer to the documentation of mindspore.nn.probability.toolbox.uncertainty_evaluation for the details. The actual output is standard deviation of the classification predictions and the corresponding 95% confidence intervals. Users have to invoke register_saliency() as well for the uncertainty results are - going to be shown on the saliency map page in MindInsight. + going to be shown on the saliency map page in MindInsight. This function can not be invoked more then once + on each runner. Raises: RuntimeError: Be raised if the function was called already. @@ -271,8 +275,20 @@ class ImageClassificationRunner: self._save_metadata(summary) imageid_labels = self._run_inference(summary) + sample_count = self._sample_index if self._is_saliency_registered: self._run_saliency(summary, imageid_labels) + if not self._manifest["saliency_map"]: + raise RuntimeError( + f"No saliency map was generated in {sample_count} samples. " + f"Please make sure the dataset, labels, activation function and network are properly trained " + f"and configured.") + + if self._is_hoc_registered and not self._manifest["hierarchical_occlusion"]: + raise RuntimeError( + f"No Hierarchical Occlusion result was found in {sample_count} samples. " + f"Please make sure the dataset, labels, activation function and network are properly trained " + f"and configured.") self._save_manifest() @@ -484,8 +500,9 @@ class ImageClassificationRunner: compiled_mask = hoc.compile_mask(str_mask, sample_input) try: edit_tree, layer_outputs = self._hoc_searcher.search(sample_input, label_idx, compiled_mask) - except hoc.NoValidResultError as ex: - log.error(f"HOC cannot find result for sample:{sample_id} error:{ex}") + except hoc.NoValidResultError: + log.warning(f"No Hierarchical Occlusion result was found in sample#{sample_id} " + f"label:{self._labels[label_idx]}, skipped.") continue has_rec = True hoc_rec = explain.hoc.add() @@ -512,7 +529,7 @@ class ImageClassificationRunner: next_element (Tuple): Data of one step explainer (_Attribution): An Attribution object to generate saliency maps. sample_id_labels (dict): A dict that maps the sample id and its union labels. - summary (SummaryRecord): The summary object to store the data + summary (SummaryRecord): The summary object to store the data. Returns: list, List of dict that maps label to its corresponding saliency map. @@ -613,16 +630,27 @@ class ImageClassificationRunner: sds[i] = 0 return itl_lows, itl_his, sds - def _verify_data(self): - """Verify dataset and labels.""" - next_element = next(self._dataset.create_tuple_iterator()) + def _verify_labels(self): + """Verify labels.""" + label_set = set() + if not self._labels: + raise ValueError(f"The label list provided is empty.") + for i, label in enumerate(self._labels): + if label.strip() == "": + raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is " + f"no empty label.") + if label in label_set: + raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.") + label_set.add(label) - if len(next_element) not in [1, 2, 3]: + def _verify_ds_sample(self, sample): + """Verify a dataset sample.""" + if len(sample) not in [1, 2, 3]: raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]" " as columns.") - if len(next_element) == 3: - inputs, labels, bboxes = next_element + if len(sample) == 3: + inputs, labels, bboxes = sample if bboxes.shape[-1] != 4: raise ValueError("The third element of dataset should be bounding boxes with shape of " "[batch_size, num_ground_truth, 4].") @@ -631,10 +659,10 @@ class ImageClassificationRunner: if any([isinstance(bench, Localization) for bench in self._benchmarkers]): raise ValueError("The dataset must provide bboxes if Localization is to be computed.") - if len(next_element) == 2: - inputs, labels = next_element - if len(next_element) == 1: - inputs = next_element[0] + if len(sample) == 2: + inputs, labels = sample + if len(sample) == 1: + inputs = sample[0] if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]: raise ValueError( @@ -644,7 +672,7 @@ class ImageClassificationRunner: log.warning( "Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th" " dimension as batch data.".format(inputs.shape)) - if len(next_element) > 1: + if len(sample) > 1: if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1: raise ValueError( "Labels shape {} is unrecognizable: outputs should not have more than two dimensions" @@ -654,24 +682,26 @@ class ImageClassificationRunner: if inputs.shape[-3] != 3: raise ValueError( "Hierarchical occlusion is registered, images must be in 3 channels format, but " - "{} channels is encountered.".format(inputs.shape[-3])) + "{} channel(s) is(are) encountered.".format(inputs.shape[-3])) short_side = min(inputs.shape[-2:]) if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN: raise ValueError( "Hierarchical occlusion is registered, images' short side must be equals to or greater then " "{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side)) + def _verify_data(self): + """Verify dataset and labels.""" + self._verify_labels() + + try: + sample = next(self._dataset.create_tuple_iterator()) + except StopIteration: + raise ValueError("The dataset provided is empty.") + + self._verify_ds_sample(sample) + def _verify_network(self): """Verify the network.""" - label_set = set() - for i, label in enumerate(self._labels): - if label.strip() == "": - raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is " - f"no empty label.") - if label in label_set: - raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.") - label_set.add(label) - next_element = next(self._dataset.create_tuple_iterator()) inputs, _, _ = self._unpack_next_element(next_element) prop_test = self._full_network(inputs) @@ -709,7 +739,8 @@ class ImageClassificationRunner: check_registration=False, check_data_n_network=False, check_saliency=False, - check_hoc=False): + check_hoc=False, + check_environment=False): """ Verify the validity of dataset and other settings. @@ -719,23 +750,40 @@ class ImageClassificationRunner: check_data_n_network (bool): Set it True for checking data and network. check_saliency (bool): Set it True for checking saliency related settings. check_hoc (bool): Set it True for checking HOC related settings. + check_environment (bool): Set it True for checking environment conditions. Raises: ValueError: Be raised for any data or settings' value problem. TypeError: Be raised for any data or settings' type problem. + RuntimeError: Be raised for any runtime problem. """ if check_all: check_registration = True check_data_n_network = True check_saliency = True check_hoc = True + check_environment = True + + if check_environment: + device_target = context.get_context('device_target') + if device_target not in ("Ascend", "GPU"): + raise RuntimeError(f"Unsupported device_target: '{device_target}', " + f"only 'Ascend' or 'GPU' is supported. " + f"Please call context.set_context(device_target='Ascend') or " + f"context.set_context(device_target='GPU').") + if check_environment or check_saliency: + if self._is_saliency_registered: + mode = context.get_context('mode') + if mode != context.PYNATIVE_MODE: + raise RuntimeError("Context mode: GRAPH_MODE is not supported, " + "please call context.set_context(mode=context.PYNATIVE_MODE).") if check_registration: + if self._is_uncertainty_registered and not self._is_saliency_registered: + raise ValueError("Function register_uncertainty() is called but register_saliency() is not.") if not self._is_saliency_registered and not self._is_hoc_registered: raise ValueError("No explanation module was registered, user should at least call register_saliency() " - "or register_hierarchical_occlusion() once with proper arguments") - if self._is_uncertainty_registered and not self._is_saliency_registered: - raise ValueError("Function register_uncertainty() is invoked but register_saliency() is not.") + "or register_hierarchical_occlusion() once with proper arguments.") if check_data_n_network or check_saliency or check_hoc: self._verify_data()