diff --git a/mindspore/explainer/_image_classification_runner.py b/mindspore/explainer/_image_classification_runner.py index 1f856e5aa4..d61f719811 100644 --- a/mindspore/explainer/_image_classification_runner.py +++ b/mindspore/explainer/_image_classification_runner.py @@ -23,8 +23,9 @@ from scipy.stats import beta from PIL import Image import mindspore as ms -import mindspore.dataset as ds +from mindspore import context from mindspore import log +import mindspore.dataset as ds from mindspore.dataset import Dataset from mindspore.nn import Cell, SequentialCell from mindspore.ops.operations import ExpandDims @@ -147,10 +148,12 @@ class ImageClassificationRunner: self._sample_index = -1 self._full_network = SequentialCell([self._network, activation_fn]) + self._full_network.set_train(False) self._manifest = None - self._verify_data_n_settings(check_data_n_network=True) + self._verify_data_n_settings(check_data_n_network=True, + check_environment=True) def register_saliency(self, explainers, @@ -159,7 +162,7 @@ class ImageClassificationRunner: Register saliency explanation instances. Note: - This function call not be invoked more then once on each runner. + This function can not be invoked more than once on each runner. Args: explainers (list[Attribution]): The explainers to be evaluated, @@ -192,7 +195,7 @@ class ImageClassificationRunner: self._benchmarkers = benchmarkers try: - self._verify_data_n_settings(check_saliency=True) + self._verify_data_n_settings(check_saliency=True, check_environment=True) except (ValueError, TypeError): self._explainers = None self._benchmarkers = None @@ -204,7 +207,7 @@ class ImageClassificationRunner: Notes: Input images are required to be in 3 channels formats and the length of side short must be equals to or - greater than 56 pixels. + greater than 56 pixels. This function can not be invoked more than once on each runner. Raises: ValueError: Be raised for any data or settings' value problem. @@ -216,7 +219,7 @@ class ImageClassificationRunner: self._hoc_searcher = hoc.Searcher(self._full_network) try: - self._verify_data_n_settings(check_hoc=True) + self._verify_data_n_settings(check_hoc=True, check_environment=True) except ValueError: self._hoc_searcher = None raise @@ -229,7 +232,8 @@ class ImageClassificationRunner: Please refer to the documentation of mindspore.nn.probability.toolbox.uncertainty_evaluation for the details. The actual output is standard deviation of the classification predictions and the corresponding 95% confidence intervals. Users have to invoke register_saliency() as well for the uncertainty results are - going to be shown on the saliency map page in MindInsight. + going to be shown on the saliency map page in MindInsight. This function can not be invoked more then once + on each runner. Raises: RuntimeError: Be raised if the function was called already. @@ -271,8 +275,20 @@ class ImageClassificationRunner: self._save_metadata(summary) imageid_labels = self._run_inference(summary) + sample_count = self._sample_index if self._is_saliency_registered: self._run_saliency(summary, imageid_labels) + if not self._manifest["saliency_map"]: + raise RuntimeError( + f"No saliency map was generated in {sample_count} samples. " + f"Please make sure the dataset, labels, activation function and network are properly trained " + f"and configured.") + + if self._is_hoc_registered and not self._manifest["hierarchical_occlusion"]: + raise RuntimeError( + f"No Hierarchical Occlusion result was found in {sample_count} samples. " + f"Please make sure the dataset, labels, activation function and network are properly trained " + f"and configured.") self._save_manifest() @@ -484,8 +500,9 @@ class ImageClassificationRunner: compiled_mask = hoc.compile_mask(str_mask, sample_input) try: edit_tree, layer_outputs = self._hoc_searcher.search(sample_input, label_idx, compiled_mask) - except hoc.NoValidResultError as ex: - log.error(f"HOC cannot find result for sample:{sample_id} error:{ex}") + except hoc.NoValidResultError: + log.warning(f"No Hierarchical Occlusion result was found in sample#{sample_id} " + f"label:{self._labels[label_idx]}, skipped.") continue has_rec = True hoc_rec = explain.hoc.add() @@ -512,7 +529,7 @@ class ImageClassificationRunner: next_element (Tuple): Data of one step explainer (_Attribution): An Attribution object to generate saliency maps. sample_id_labels (dict): A dict that maps the sample id and its union labels. - summary (SummaryRecord): The summary object to store the data + summary (SummaryRecord): The summary object to store the data. Returns: list, List of dict that maps label to its corresponding saliency map. @@ -613,16 +630,27 @@ class ImageClassificationRunner: sds[i] = 0 return itl_lows, itl_his, sds - def _verify_data(self): - """Verify dataset and labels.""" - next_element = next(self._dataset.create_tuple_iterator()) + def _verify_labels(self): + """Verify labels.""" + label_set = set() + if not self._labels: + raise ValueError(f"The label list provided is empty.") + for i, label in enumerate(self._labels): + if label.strip() == "": + raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is " + f"no empty label.") + if label in label_set: + raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.") + label_set.add(label) - if len(next_element) not in [1, 2, 3]: + def _verify_ds_sample(self, sample): + """Verify a dataset sample.""" + if len(sample) not in [1, 2, 3]: raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]" " as columns.") - if len(next_element) == 3: - inputs, labels, bboxes = next_element + if len(sample) == 3: + inputs, labels, bboxes = sample if bboxes.shape[-1] != 4: raise ValueError("The third element of dataset should be bounding boxes with shape of " "[batch_size, num_ground_truth, 4].") @@ -631,10 +659,10 @@ class ImageClassificationRunner: if any([isinstance(bench, Localization) for bench in self._benchmarkers]): raise ValueError("The dataset must provide bboxes if Localization is to be computed.") - if len(next_element) == 2: - inputs, labels = next_element - if len(next_element) == 1: - inputs = next_element[0] + if len(sample) == 2: + inputs, labels = sample + if len(sample) == 1: + inputs = sample[0] if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]: raise ValueError( @@ -644,7 +672,7 @@ class ImageClassificationRunner: log.warning( "Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th" " dimension as batch data.".format(inputs.shape)) - if len(next_element) > 1: + if len(sample) > 1: if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1: raise ValueError( "Labels shape {} is unrecognizable: outputs should not have more than two dimensions" @@ -654,24 +682,26 @@ class ImageClassificationRunner: if inputs.shape[-3] != 3: raise ValueError( "Hierarchical occlusion is registered, images must be in 3 channels format, but " - "{} channels is encountered.".format(inputs.shape[-3])) + "{} channel(s) is(are) encountered.".format(inputs.shape[-3])) short_side = min(inputs.shape[-2:]) if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN: raise ValueError( "Hierarchical occlusion is registered, images' short side must be equals to or greater then " "{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side)) + def _verify_data(self): + """Verify dataset and labels.""" + self._verify_labels() + + try: + sample = next(self._dataset.create_tuple_iterator()) + except StopIteration: + raise ValueError("The dataset provided is empty.") + + self._verify_ds_sample(sample) + def _verify_network(self): """Verify the network.""" - label_set = set() - for i, label in enumerate(self._labels): - if label.strip() == "": - raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is " - f"no empty label.") - if label in label_set: - raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.") - label_set.add(label) - next_element = next(self._dataset.create_tuple_iterator()) inputs, _, _ = self._unpack_next_element(next_element) prop_test = self._full_network(inputs) @@ -709,7 +739,8 @@ class ImageClassificationRunner: check_registration=False, check_data_n_network=False, check_saliency=False, - check_hoc=False): + check_hoc=False, + check_environment=False): """ Verify the validity of dataset and other settings. @@ -719,23 +750,40 @@ class ImageClassificationRunner: check_data_n_network (bool): Set it True for checking data and network. check_saliency (bool): Set it True for checking saliency related settings. check_hoc (bool): Set it True for checking HOC related settings. + check_environment (bool): Set it True for checking environment conditions. Raises: ValueError: Be raised for any data or settings' value problem. TypeError: Be raised for any data or settings' type problem. + RuntimeError: Be raised for any runtime problem. """ if check_all: check_registration = True check_data_n_network = True check_saliency = True check_hoc = True + check_environment = True + + if check_environment: + device_target = context.get_context('device_target') + if device_target not in ("Ascend", "GPU"): + raise RuntimeError(f"Unsupported device_target: '{device_target}', " + f"only 'Ascend' or 'GPU' is supported. " + f"Please call context.set_context(device_target='Ascend') or " + f"context.set_context(device_target='GPU').") + if check_environment or check_saliency: + if self._is_saliency_registered: + mode = context.get_context('mode') + if mode != context.PYNATIVE_MODE: + raise RuntimeError("Context mode: GRAPH_MODE is not supported, " + "please call context.set_context(mode=context.PYNATIVE_MODE).") if check_registration: + if self._is_uncertainty_registered and not self._is_saliency_registered: + raise ValueError("Function register_uncertainty() is called but register_saliency() is not.") if not self._is_saliency_registered and not self._is_hoc_registered: raise ValueError("No explanation module was registered, user should at least call register_saliency() " - "or register_hierarchical_occlusion() once with proper arguments") - if self._is_uncertainty_registered and not self._is_saliency_registered: - raise ValueError("Function register_uncertainty() is invoked but register_saliency() is not.") + "or register_hierarchical_occlusion() once with proper arguments.") if check_data_n_network or check_saliency or check_hoc: self._verify_data()