diff --git a/mindspore/ccsrc/utils/summary.proto b/mindspore/ccsrc/utils/summary.proto index 3b325bd3dd..32602978eb 100644 --- a/mindspore/ccsrc/utils/summary.proto +++ b/mindspore/ccsrc/utils/summary.proto @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -121,7 +121,7 @@ message Explain { optional string explain_method = 1; optional int32 label = 2; optional string heatmap_path = 3; - } + } message Benchmark{ optional string benchmark_method = 1; @@ -131,10 +131,21 @@ message Explain { } message Metadata{ - repeated string label = 1; - repeated string explain_method = 2; - repeated string benchmark_method = 3; - } + repeated string label = 1; + repeated string explain_method = 2; + repeated string benchmark_method = 3; + } + + message HocLayer { + optional float prob = 1; + repeated int32 box = 2; // List of repeated x, y, w, h + } + + message Hoc { + optional int32 label = 1; + optional string mask = 2; + repeated HocLayer layer = 3; + } optional int32 sample_id = 1; optional string image_path = 2; // The Metadata and image path must have one fill in @@ -146,4 +157,6 @@ message Explain { optional Metadata metadata = 7; optional string status = 8; // enum value: run, end + + repeated Hoc hoc = 9; // hierarchical occlusion counterfactual } \ No newline at end of file diff --git a/mindspore/explainer/_image_classification_runner.py b/mindspore/explainer/_image_classification_runner.py index 0163ae989f..36a0094164 100644 --- a/mindspore/explainer/_image_classification_runner.py +++ b/mindspore/explainer/_image_classification_runner.py @@ -15,9 +15,11 @@ """Image Classification Runner.""" import os import re +import json from time import time import numpy as np +from scipy.stats import beta from PIL import Image import mindspore as ms @@ -30,10 +32,15 @@ from mindspore.train._utils import check_value_type from mindspore.train.summary._summary_adapter import _convert_image_format from mindspore.train.summary.summary_record import SummaryRecord from mindspore.train.summary_pb2 import Explain -from .benchmark import Localization -from .explanation import RISE -from .benchmark._attribution.metric import AttributionMetric, LabelSensitiveMetric, LabelAgnosticMetric -from .explanation._attribution.attribution import Attribution +from mindspore.nn.probability.toolbox.uncertainty_evaluation import UncertaintyEvaluation +from mindspore.explainer.benchmark import Localization +from mindspore.explainer.benchmark._attribution.metric import AttributionMetric +from mindspore.explainer.benchmark._attribution.metric import LabelSensitiveMetric +from mindspore.explainer.benchmark._attribution.metric import LabelAgnosticMetric +from mindspore.explainer.explanation import RISE +from mindspore.explainer.explanation._attribution.attribution import Attribution +from mindspore.explainer.explanation._counterfactual import hierarchical_occlusion as hoc + _EXPAND_DIMS = ExpandDims() @@ -97,6 +104,8 @@ class ImageClassificationRunner: _DATAFILE_DIRNAME_PREFIX = "_explain_" _ORIGINAL_IMAGE_DIRNAME = "origin_images" _HEATMAP_DIRNAME = "heatmap" + # specfial filenames + _MANIFEST_FILENAME = "manifest.json" # max. no. of sample per directory _SAMPLE_PER_DIR = 1000 # seed for fixing the iterating order of the dataset @@ -132,11 +141,15 @@ class ImageClassificationRunner: self._network = network self._explainers = None self._benchmarkers = None + self._uncertainty = None + self._hoc_searcher = None self._summary_timestamp = None self._sample_index = -1 self._full_network = SequentialCell([self._network, activation_fn]) + self._manifest = None + self._verify_data_n_settings(check_data_n_network=True) def register_saliency(self, @@ -185,6 +198,50 @@ class ImageClassificationRunner: self._benchmarkers = None raise + def register_hierarchical_occlusion(self): + """ + Register hierarchical occlusion instances. + + Notes: + Input images are required to be in 3 channels formats and the length of side short must be equals to or + greater than 56 pixels. + + Raises: + ValueError: Be raised for any data or settings' value problem. + RuntimeError: Be raised if the function was called already. + """ + if self._hoc_searcher is not None: + raise RuntimeError("Function register_hierarchical_occlusion() was invoked already.") + + self._hoc_searcher = hoc.Searcher(self._full_network) + + try: + self._verify_data_n_settings(check_hoc=True) + except ValueError: + self._hoc_searcher = None + raise + + def register_uncertainty(self): + """ + Register uncertainty instance to compute the epistemic uncertainty base on the Bayes' theorem. + + Notes: + Please refer to the documentation of mindspore.nn.probability.toolbox.uncertainty_evaluation for the + details. The actual output is standard deviation of the classification predictions and the corresponding + 95% confidence intervals. Users have to invoke register_saliency() as well for the uncertainty results are + going to be shown on the saliency map page in MindInsight. + + Raises: + RuntimeError: Be raised if the function was called already. + """ + if self._uncertainty is not None: + raise RuntimeError("Function register_uncertainty() was invoked already.") + + self._uncertainty = UncertaintyEvaluation(model=self._full_network, + train_dataset=None, + task_type='classification', + num_classes=len(self._labels)) + def run(self): """ Run the explain job and save the result as a summary in summary_dir. @@ -198,7 +255,10 @@ class ImageClassificationRunner: RuntimeError: Be raised for any runtime problem. """ self._verify_data_n_settings(check_all=True) - + self._manifest = {"saliency_map": False, + "benchmark": False, + "uncertainty": False, + "hierarchical_occlusion": False} with SummaryRecord(self._summary_dir, raise_exception=True) as summary: print("Start running and writing......") begin = time() @@ -214,13 +274,25 @@ class ImageClassificationRunner: if self._is_saliency_registered: self._run_saliency(summary, imageid_labels) + self._save_manifest() + print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin)) + @property + def _is_hoc_registered(self): + """Check if HOC module is registered.""" + return self._hoc_searcher is not None + @property def _is_saliency_registered(self): """Check if saliency module is registered.""" return bool(self._explainers) + @property + def _is_uncertainty_registered(self): + """Check if uncertainty module is registered.""" + return self._uncertainty is not None + def _save_metadata(self, summary): """Save metadata of the explain job to summary.""" print("Start writing metadata......") @@ -245,12 +317,13 @@ class ImageClassificationRunner: Run inference for the dataset and write the inference related data into summary. Args: - summary (SummaryRecord): The summary object to store the data + summary (SummaryRecord): The summary object to store the data. threshold (float): The threshold for prediction. Returns: dict, The map of sample d to the union of its ground truth and predicted labels. """ + has_uncertainty_rec = False sample_id_labels = {} self._sample_index = 0 ds.config.set_seed(self._DATASET_SEED) @@ -259,10 +332,20 @@ class ImageClassificationRunner: inputs, labels, _ = self._unpack_next_element(next_element) prob = self._full_network(inputs).asnumpy() + if self._uncertainty is not None: + prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs) + else: + prob_var = None + for idx, inp in enumerate(inputs): gt_labels = labels[idx] gt_probs = [float(prob[idx][i]) for i in gt_labels] + if prob_var is not None: + gt_prob_vars = [float(prob_var[idx][i]) for i in gt_labels] + gt_itl_lows, gt_itl_his, gt_prob_sds = \ + self._calc_beta_intervals(gt_probs, gt_prob_vars) + data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW') original_image = _np_to_image(_normalize(data_np), mode='RGB') original_image_path = self._save_original_image(self._sample_index, original_image) @@ -270,6 +353,11 @@ class ImageClassificationRunner: predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]] predicted_probs = [float(prob[idx][i]) for i in predicted_labels] + if prob_var is not None: + predicted_prob_vars = [float(prob_var[idx][i]) for i in predicted_labels] + predicted_itl_lows, predicted_itl_his, predicted_prob_sds = \ + self._calc_beta_intervals(predicted_probs, predicted_prob_vars) + union_labs = list(set(gt_labels + predicted_labels)) sample_id_labels[str(self._sample_index)] = union_labs @@ -285,14 +373,29 @@ class ImageClassificationRunner: explain.inference.predicted_label.extend(predicted_labels) explain.inference.predicted_prob.extend(predicted_probs) - summary.add_value("explainer", "inference", explain) + if prob_var is not None: + explain.inference.ground_truth_prob_sd.extend(gt_prob_sds) + explain.inference.ground_truth_prob_itl95_low.extend(gt_itl_lows) + explain.inference.ground_truth_prob_itl95_hi.extend(gt_itl_his) + explain.inference.predicted_prob_sd.extend(predicted_prob_sds) + explain.inference.predicted_prob_itl95_low.extend(predicted_itl_lows) + explain.inference.predicted_prob_itl95_hi.extend(predicted_itl_his) + has_uncertainty_rec = True + + summary.add_value("explainer", "inference", explain) summary.record(1) + if self._is_hoc_registered: + self._run_hoc(summary, self._sample_index, inputs[idx], prob[idx]) + self._sample_index += 1 self._spaced_print("Finish running and writing {}-th batch inference data." - " Time elapsed: {:.3f} s".format(j, time() - now), - end='') + " Time elapsed: {:.3f} s".format(j, time() - now)) + + if has_uncertainty_rec: + self._manifest["uncertainty"] = True + return sample_id_labels def _run_saliency(self, summary, sample_id_labels): @@ -306,10 +409,10 @@ class ImageClassificationRunner: for idx, next_element in enumerate(self._dataset): now = time() self._spaced_print("Start running {}-th explanation data for {}......".format( - idx, exp.__class__.__name__), end='') + idx, exp.__class__.__name__)) self._run_exp_step(next_element, exp, sample_id_labels, summary) self._spaced_print("Finish writing {}-th explanation data for {}. Time elapsed: " - "{:.3f} s".format(idx, exp.__class__.__name__, time() - now), end='') + "{:.3f} s".format(idx, exp.__class__.__name__, time() - now)) self._spaced_print( "Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format( exp.__class__.__name__, time() - start)) @@ -326,20 +429,20 @@ class ImageClassificationRunner: for idx, next_element in enumerate(self._dataset): now = time() self._spaced_print("Start running {}-th explanation data for {}......".format( - idx, exp.__class__.__name__), end='') + idx, exp.__class__.__name__)) saliency_dict_lst = self._run_exp_step(next_element, exp, sample_id_labels, summary) self._spaced_print( "Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format( - idx, exp.__class__.__name__, time() - now), end='') + idx, exp.__class__.__name__, time() - now)) for bench in self._benchmarkers: now = time() self._spaced_print( "Start running {}-th batch {} data for {}......".format( - idx, bench.__class__.__name__, exp.__class__.__name__), end='') + idx, bench.__class__.__name__, exp.__class__.__name__)) self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst) self._spaced_print( "Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format( - idx, bench.__class__.__name__, exp.__class__.__name__, time() - now), end='') + idx, bench.__class__.__name__, exp.__class__.__name__, time() - now)) for bench in self._benchmarkers: benchmark = explain.benchmark.add() @@ -355,6 +458,52 @@ class ImageClassificationRunner: summary.add_value('explainer', 'benchmark', explain) summary.record(1) + def _run_hoc(self, summary, sample_id, sample_input, prob): + """ + Run HOC search for a sample image, and then save the result to summary. + + Args: + summary (SummaryRecord): The summary object to store the data. + sample_id (int): The sample ID. + sample_input (Union[Tensor, np.ndarray]): Sample image tensor in CHW or NCWH(N=1). + prob (Union[Tensor, np.ndarray]): List of sample's classification prediction output, HOC will run for + labels with prediction output strictly larger then HOC searcher's threshold(0.5 by default). + """ + if isinstance(sample_input, ms.Tensor): + sample_input = sample_input.asnumpy() + if len(sample_input.shape) == 3: + sample_input = np.expand_dims(sample_input, axis=0) + has_rec = False + explain = Explain() + explain.sample_id = sample_id + str_mask = hoc.auto_str_mask(sample_input) + compiled_mask = None + for label_idx, label_prob in enumerate(prob): + if label_prob > self._hoc_searcher.threshold: + if compiled_mask is None: + compiled_mask = hoc.compile_mask(str_mask, sample_input) + try: + edit_tree, layer_outputs = self._hoc_searcher.search(sample_input, label_idx, compiled_mask) + except hoc.NoValidResultError as ex: + log.error(f"HOC cannot find result for sample:{sample_id} error:{ex}") + continue + has_rec = True + hoc_rec = explain.hoc.add() + hoc_rec.label = label_idx + hoc_rec.mask = str_mask + layer_count = edit_tree.max_layer + 1 + for layer in range(layer_count): + steps = edit_tree.get_layer_or_leaf_steps(layer) + layer_output = layer_outputs[layer] + hoc_layer = hoc_rec.layer.add() + hoc_layer.prob = layer_output + for step in steps: + hoc_layer.box.extend(list(step.box)) + if has_rec: + summary.add_value("explainer", "hoc", explain) + summary.record(1) + self._manifest['hierarchical_occlusion'] = True + def _run_exp_step(self, next_element, explainer, sample_id_labels, summary): """ Run the explanation for each step and write explanation results into summary. @@ -368,6 +517,7 @@ class ImageClassificationRunner: Returns: list, List of dict that maps label to its corresponding saliency map. """ + has_saliency_rec = False inputs, labels, _ = self._unpack_next_element(next_element) sample_index = self._sample_index unions = [] @@ -406,11 +556,17 @@ class ImageClassificationRunner: explanation.heatmap_path = heatmap_path explanation.label = lab + has_saliency_rec = True + summary.add_value("explainer", "explanation", explain) summary.record(1) self._sample_index += 1 saliency_dict_lst.append(saliency_dict) + + if has_saliency_rec: + self._manifest['saliency_map'] = True + return saliency_dict_lst def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst): @@ -436,6 +592,26 @@ class ImageClassificationRunner: else: raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but' 'receive {}'.format(type(benchmarker))) + self._manifest['benchmark'] = True + + @staticmethod + def _calc_beta_intervals(means, variances, prob=0.95): + """Calculate confidence interval of beta distributions.""" + if not isinstance(means, np.ndarray): + means = np.array(means) + if not isinstance(variances, np.ndarray): + variances = np.array(variances) + with np.errstate(divide='ignore'): + coef_a = ((means ** 2) * (1 - means) / variances) - means + coef_b = (coef_a * (1 - means)) / means + itl_lows, itl_his = beta.interval(prob, coef_a, coef_b) + sds = np.sqrt(variances) + for i in range(itl_lows.shape[0]): + if not np.isfinite(sds[i]) or not np.isfinite(itl_lows[i]) or not np.isfinite(itl_his[i]): + itl_lows[i] = means[i] + itl_his[i] = means[i] + sds[i] = 0 + return itl_lows, itl_his, sds def _verify_data(self): """Verify dataset and labels.""" @@ -474,6 +650,17 @@ class ImageClassificationRunner: "Labels shape {} is unrecognizable: outputs should not have more than two dimensions" " with length greater than 1.".format(labels.shape)) + if self._is_hoc_registered: + if inputs.shape[-3] != 3: + raise ValueError( + "Hierarchical occlusion is registered, images must be in 3 channels format, but " + "{} channels is encountered.".format(inputs.shape[-3])) + short_side = min(inputs.shape[-2:]) + if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN: + raise ValueError( + "Hierarchical occlusion is registered, images' short side must be equals to or greater then " + "{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side)) + def _verify_network(self): """Verify the network.""" label_set = set() @@ -521,7 +708,8 @@ class ImageClassificationRunner: check_all=False, check_registration=False, check_data_n_network=False, - check_saliency=False): + check_saliency=False, + check_hoc=False): """ Verify the validity of dataset and other settings. @@ -530,6 +718,7 @@ class ImageClassificationRunner: check_registration (bool): Set it True for checking registrations, check if it is enough to invoke run(). check_data_n_network (bool): Set it True for checking data and network. check_saliency (bool): Set it True for checking saliency related settings. + check_hoc (bool): Set it True for checking HOC related settings. Raises: ValueError: Be raised for any data or settings' value problem. @@ -539,13 +728,16 @@ class ImageClassificationRunner: check_registration = True check_data_n_network = True check_saliency = True + check_hoc = True if check_registration: - if not self._is_saliency_registered: - raise ValueError("No explanation module was registered, user should at least call register_saliency()" - " once with proper explanation instances") + if not self._is_saliency_registered and not self._is_hoc_registered: + raise ValueError("No explanation module was registered, user should at least call register_saliency() " + "or register_hierarchical_occlusion() once with proper arguments") + if self._is_uncertainty_registered and not self._is_saliency_registered: + raise ValueError("Function register_uncertainty() is invoked but register_saliency() is not.") - if check_data_n_network or check_saliency: + if check_data_n_network or check_saliency or check_hoc: self._verify_data() if check_data_n_network: @@ -658,6 +850,18 @@ class ImageClassificationRunner: return ms.Tensor(batch_labels, ms.int32) + def _save_manifest(self): + """Save manifest.json underneath datafile directory.""" + if self._manifest is None: + raise RuntimeError("Manifest not yet be initialized.") + path_tokens = [self._summary_dir, + self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp)] + abs_dir_path = self._create_subdir(*path_tokens) + save_path = os.path.join(abs_dir_path, self._MANIFEST_FILENAME) + with open(save_path, 'w') as file: + json.dump(self._manifest, file, indent=4) + os.chmod(save_path, self._FILE_MODE) + def _save_original_image(self, sample_id, image): """Save an image to summary directory.""" id_dirname = self._get_sample_dirname(sample_id) @@ -720,7 +924,7 @@ class ImageClassificationRunner: return None @classmethod - def _spaced_print(cls, message, *args, **kwargs): + def _spaced_print(cls, message): """Spaced message printing.""" # workaround to print logs starting new line in case line width mismatch. print(cls._SPACER.format(message)) diff --git a/mindspore/explainer/explanation/_counterfactual/__init__.py b/mindspore/explainer/explanation/_counterfactual/__init__.py new file mode 100644 index 0000000000..4b8d0b727b --- /dev/null +++ b/mindspore/explainer/explanation/_counterfactual/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Counterfactual modules.""" diff --git a/mindspore/explainer/explanation/_counterfactual/hierarchical_occlusion.py b/mindspore/explainer/explanation/_counterfactual/hierarchical_occlusion.py new file mode 100644 index 0000000000..0a5880f11e --- /dev/null +++ b/mindspore/explainer/explanation/_counterfactual/hierarchical_occlusion.py @@ -0,0 +1,979 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Hierarchical occlusion edit tree searcher.""" +from enum import Enum +import copy +import re +import math + +import numpy as np +from scipy.ndimage import gaussian_filter + +from mindspore import nn +from mindspore import Tensor +from mindspore.ops import Squeeze +from mindspore.train._utils import check_value_type + + +AUTO_LAYER_MAX = 3 # maximum number of layer by auto settings +AUTO_WIN_SIZE_MIN = 28 # minimum window size by auto settings +AUTO_WIN_SIZE_DIV = 2 # denominator of windows size calculations by auto settings +AUTO_STRIDE_DIV = 5 # denominator of stride calculations by auto settings +AUTO_MASK_GAUSSIAN_RADIUS_DIV = 25 # denominator of gaussian mask radius calculations by auto settings +DEFAULT_THRESHOLD = 0.5 # default target prediction threshold +DEFAULT_BATCH_SIZE = 64 # default batch size for batch inference search +MASK_GAUSSIAN_RE = r'^gaussian:(\d+)$' # gaussian mask string pattern + +# minimum length of input images' short side with auto settings +AUTO_IMAGE_SHORT_SIDE_MIN = AUTO_WIN_SIZE_MIN * AUTO_WIN_SIZE_DIV + + +def is_valid_str_mask(mask): + """Check if it is a valid string mask.""" + check_value_type('mask', mask, str) + match = re.match(MASK_GAUSSIAN_RE, mask) + return match and int(match.group(1)) > 0 + + +def compile_mask(mask, image): + """Compile mask to a ready to use object.""" + if mask is None: + return compile_str_mask(auto_str_mask(image), image) + check_value_type('mask', mask, (str, tuple, float, np.ndarray)) + if isinstance(mask, str): + return compile_str_mask(mask, image) + + if isinstance(mask, tuple): + _check_iterable_type('mask', mask, tuple, float) + elif isinstance(mask, np.ndarray): + if len(image.shape) == 4 and len(mask.shape) == 3: + mask = np.expand_dims(mask, axis=0) + elif len(image.shape) == 3 and len(mask.shape) == 4 and mask.shape[0] == 1: + mask = mask.squeeze(0) + if image.shape != mask.shape: + raise ValueError("Image and mask is not match in shape.") + return mask + + +def auto_str_mask(image): + """Generate auto string mask for the image.""" + check_value_type('image', image, np.ndarray) + short_side = np.min(image.shape[-2:]) + radius = int(round(short_side/AUTO_MASK_GAUSSIAN_RADIUS_DIV)) + if radius == 0: + raise ValueError(f"Input image's short side:{short_side} is too small for auto mask, " + f"at least {AUTO_MASK_GAUSSIAN_RADIUS_DIV}pixels is required.") + return f'gaussian:{radius}' + + +def compile_str_mask(mask, image): + """Concert string mask to numpy.ndarray.""" + check_value_type('mask', mask, str) + check_value_type('image', image, np.ndarray) + match = re.match(MASK_GAUSSIAN_RE, mask) + if match: + radius = int(match.group(1)) + if radius > 0: + sigma = [0] * len(image.shape) + sigma[-2] = radius + sigma[-1] = radius + return gaussian_filter(image, sigma=sigma, mode='nearest') + raise ValueError(f"Invalid string mask: '{mask}'.") + + +class EditStep: + """ + Edit step that describes a box region, also represents an edit tree. + + Args: + layer (int): Layer number, -1 is root layer, 0 or above is normal edit layer. + box (tuple[int, int, int, int]): Tuple of x, y, width, height. + """ + def __init__(self, layer, box): + self.layer = layer + self.box = box + self.network_output = 0 + self.step_change = 0 + self.children = None + + @property + def x(self): + """X-coordinate of the box.""" + return self.box[0] + + @property + def y(self): + """Y-coordinate of the box.""" + return self.box[1] + + @property + def width(self): + """Width of the box.""" + return self.box[2] + + @property + def height(self): + """Height of the box.""" + return self.box[3] + + @property + def is_leaf(self): + """Returns True if no child edit step.""" + return not self.children + + @property + def leaf_steps(self): + """Returns all leaf edit steps in the tree.""" + if self.is_leaf: + return [self] + steps = [] + for child in self.children: + steps.extend(child.leaf_steps) + return steps + + @property + def max_layer(self): + """Maximum layer number in the edit tree.""" + if self.is_leaf: + return self.layer + layer = self.layer + for child in self.children: + child_max_layer = child.max_layer + if child_max_layer > layer: + layer = child_max_layer + return layer + + def add_child(self, child): + """Add a child edit step.""" + if self.children is None: + self.children = [child] + else: + self.children.append(child) + + def remove_all_children(self): + """Remove all child steps.""" + self.children = None + + def get_layer_or_leaf_steps(self, layer): + """Get all edit steps of the layer and all leaf edit steps above the layer.""" + if self.layer == layer or (self.layer < layer and self.is_leaf): + return [self] + steps = [] + if self.layer < layer and self.children: + for child in self.children: + steps.extend(child.get_layer_or_leaf_steps(layer)) + return steps + + def get_layer_steps(self, layer): + """Get all edit steps of the layer.""" + if self.layer == layer: + return [self] + steps = [] + if self.layer < layer and self.children: + for child in self.children: + steps.extend(child.get_layer_steps(layer)) + return steps + + @classmethod + def apply(cls, + image, + mask, + edit_steps, + by_masking=False, + inplace=False): + """ + Apply edit steps. + + Args: + image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. + mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be + str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. + tuple[float, float, float]: RGB solid color mask, + float: Grey scale solid color mask. + numpy.ndarray: Image mask in CHW or NCHW(N=1) format. + edit_steps (list[EditStep], optional): Edit steps to be applied. + by_masking (bool): Whether it is masking mode. + inplace (bool): Whether the modification is going to take place in the input image tensor. False to + construct a new image tensor as result. + + Returns: + numpy.ndarray, the result image tensor. + + Raises: + TypeError: Be raised for any argument or data type problem. + ValueError: Be raised for any argument or data value problem. + """ + if by_masking: + return cls.apply_masking(image, mask, edit_steps, inplace) + return cls.apply_unmasking(image, mask, edit_steps, inplace) + + @staticmethod + def apply_masking(image, + mask, + edit_steps, + inplace=False): + """ + Apply edit steps in masking mode. + + Args: + image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. + mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be + str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. + tuple[float, float, float]: RGB solid color mask, + float: Grey scale solid color mask. + numpy.ndarray: Image mask in CHW or NCHW(N=1) format. + edit_steps (list[EditStep], optional): Edit steps to be applied. + inplace (bool): Whether the modification is going to take place in the input image tensor. False to + construct a new image tensor as result. + + Returns: + numpy.ndarray, the result image tensor. + + Raises: + TypeError: Be raised for any argument or data type problem. + ValueError: Be raised for any argument or data value problem. + """ + check_value_type('image', image, np.ndarray) + check_value_type('mask', mask, (str, tuple, float, np.ndarray)) + if isinstance(mask, tuple): + _check_iterable_type('mask', mask, tuple, float) + + if edit_steps is not None: + _check_iterable_type('edit_steps', edit_steps, (tuple, list), EditStep) + + mask = compile_mask(mask, image) + + if inplace: + background = image + else: + background = np.copy(image) + + if not edit_steps: + return background + + for step in edit_steps: + + x_max = step.x + step.width + y_max = step.y + step.height + + if x_max > background.shape[-1]: + x_max = background.shape[-1] + + if y_max > background.shape[-2]: + y_max = background.shape[-2] + + if x_max <= step.x or y_max <= step.y: + continue + + if isinstance(mask, np.ndarray): + background[..., step.y:y_max, step.x:x_max] = mask[..., step.y:y_max, step.x:x_max] + else: + if isinstance(mask, (int, float)): + mask = (mask, mask, mask) + for c in range(3): + background[..., c, step.y:y_max, step.x:x_max] = mask[c] + return background + + @staticmethod + def apply_unmasking(image, + mask, + edit_steps, + inplace=False): + """ + Apply edit steps in unmasking mode. + + Args: + image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. + mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be + str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. + tuple[float, float, float]: RGB solid color mask, + float: Grey scale solid color mask. + numpy.ndarray: Image mask in CHW or NCHW(N=1) format. + edit_steps (list[EditStep]): Edit steps to be applied. + inplace (bool): Whether the modification is going to take place in the input mask tensor. False to + construct a new image tensor as result. + + Returns: + numpy.ndarray, the result image tensor. + + Raises: + TypeError: Be raised for any argument or data type problem. + ValueError: Be raised for any argument or data value problem. + """ + check_value_type('image', image, np.ndarray) + check_value_type('mask', mask, (str, tuple, float, np.ndarray)) + if isinstance(mask, tuple): + _check_iterable_type('mask', mask, tuple, float) + + if edit_steps is not None: + _check_iterable_type('edit_steps', edit_steps, (tuple, list), EditStep) + + mask = compile_mask(mask, image) + + if isinstance(mask, np.ndarray): + if inplace: + background = mask + else: + background = np.copy(mask) + else: + if inplace: + raise ValueError('Inplace cannot be True when mask is not a numpy.ndarray') + + background = np.zeros_like(image) + if isinstance(mask, (int, float)): + background.fill(mask) + else: + for c in range(3): + background[..., c, :, :] = mask[c] + + if not edit_steps: + return background + + for step in edit_steps: + + x_max = step.x + step.width + y_max = step.y + step.height + + if x_max > background.shape[-1]: + x_max = background.shape[-1] + + if y_max > background.shape[-2]: + y_max = background.shape[-2] + + if x_max <= step.x or y_max <= step.y: + continue + + background[..., step.y:y_max, step.x:x_max] = image[..., step.y:y_max, step.x:x_max] + + return background + + +class NoValidResultError(RuntimeError): + """Error for no edit step layer's network output meet the threshold.""" + + +class OriginalOutputError(RuntimeError): + """Error for network output of the original image is not strictly larger than the threshold.""" + + +class Searcher: + """ + Edit step searcher. + + Args: + network (Cell): Image tensor in CHW or NCHW(N=1) format. + win_sizes (Union(list[int], optional): Moving square window size (length of side) of layers, + None means by auto calcuation. + strides (Union(list[int], optional): Stride of layers, None means by auto calcuation. + threshold (float): Threshold network output value of the target class. + by_masking (bool): Whether it is masking mode. + + Examples: + >>> from mindspore import nn + >>> from mindspore.explainer.explanation._counterfactual.hierarchical_occlusion import Searcher, EditStep + >>> + >>> from user_defined import load_network, load_sample_image + >>> + >>> + >>> network = nn.SequentialCell([load_network(), nn.Sigmoid()]) + >>> + >>> # single image in CHW or NCHW(N=1) numpy.ndarray tensor, typical dimension is 224x224 + >>> image = load_sample_image() + >>> # target class index + >>> class_idx = 5 + >>> + >>> # by default, maximum 3 search layers, auto calculate window sizes and strides + >>> searcher = Searcher(network) + >>> + >>> edit_tree, layer_outputs = searcher.search(image, class_idx) + >>> # get the outcome image of the deepest layer in CHW(or NCHW(N=1) if input image is NCHW) format + >>> outcome = EditStep.apply(image, searcher.compiled_mask, edit_tree.leaf_steps) + """ + + def __init__(self, + network, + win_sizes=None, + strides=None, + threshold=DEFAULT_THRESHOLD, + by_masking=False): + + check_value_type('network', network, nn.Cell) + + if win_sizes is not None: + _check_iterable_type('win_sizes', win_sizes, list, int) + if not win_sizes: + raise ValueError('Argument win_sizes is empty.') + + for i in range(1, len(win_sizes)): + if win_sizes[i] >= win_sizes[i-1]: + raise ValueError('Argument win_sizes is not strictly descending.') + + if win_sizes[-1] <= 0: + raise ValueError('Argument win_sizes has non-positive number.') + elif strides is not None: + raise ValueError('Argument win_sizes cannot be None if strides is not None.') + + if strides is not None: + _check_iterable_type('strides', strides, list, int) + for i in range(1, len(strides)): + if strides[i] >= strides[i-1]: + raise ValueError('Argument win_sizes is not strictly descending.') + + if strides[-1] <= 0: + raise ValueError('Argument strides has non-positive number.') + + if len(strides) != len(win_sizes): + raise ValueError('Length of strides and win_sizes is not equal.') + elif win_sizes is not None: + raise ValueError('Argument strides cannot be None if win_sizes is not None.') + + self._network = copy.deepcopy(network) + self._compiled_mask = None + self._threshold = threshold + self._win_sizes = copy.copy(win_sizes) if win_sizes else None + self._strides = copy.copy(strides) if strides else None + self._by_masking = by_masking + + @property + def network(self): + """Get the network.""" + return self._network + + @property + def by_masking(self): + """Check if it is masking mode.""" + return self._by_masking + + @property + def threshold(self): + """The network output threshold to stop the search.""" + return self._threshold + + @property + def win_sizes(self): + """Windows sizes in pixels.""" + return self._win_sizes + + @property + def strides(self): + """Strides in pixels.""" + return self._strides + + @property + def compiled_mask(self): + """The compiled mask after a successful search() call.""" + return self._compiled_mask + + def search(self, image, class_idx, mask=None): + """ + Search smallest sufficient/destruction region on an image. + + Args: + image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. + class_idx (int): Target class index. + mask (Union[str, tuple[float, float, float], float], optional): The mask, type can be + str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. + tuple[float, float, float]: RGB solid color mask, + float: Grey scale solid color mask. + None: By auto calculation. + + Returns: + tuple[EditStep, list[float]], the root edit step and network output of each layer after applied the + layer steps. + + Raise: + TypeError: Be raised for any argument or data type problem. + ValueError: Be raised for any argument or data value problem. + NoValidResultError: Be raised if no valid result was found. + OriginalOutputError: Be raised if network output of the original image is not strictly larger than + the threshold. + """ + check_value_type('image', image, (Tensor, np.ndarray)) + + if isinstance(image, Tensor): + image = image.asnumpy() + + if len(image.shape) == 4: + if image.shape[0] != 1: + raise ValueError("Argument image's batch size is not 1.") + elif len(image.shape) == 3: + image = np.expand_dims(image, axis=0) + else: + raise ValueError("Argument image is not in CHW or NCHW(N=1) format.") + + check_value_type('class_idx', class_idx, int) + + if class_idx < 0: + raise ValueError("Argument class_idx is less then zero.") + + self._compiled_mask = compile_mask(mask, image) + + short_side = np.min(image.shape[-2:]) + if self._win_sizes is None: + win_sizes, strides = self._auto_win_sizes_strides(short_side) + else: + win_sizes, strides = self._win_sizes, self._strides + + if short_side <= win_sizes[0]: + raise ValueError(f"Input image's short side is shorter then or " + f"equals to the first window size:{win_sizes[0]}.") + + self._network.set_train(False) + + # the search result will be store as a edit tree that attached to the root step. + root_step = EditStep(-1, (0, 0, image.shape[-1], image.shape[-2])) + root_job = _SearchJob(by_masking=self._by_masking, + class_idx=class_idx, + win_sizes=win_sizes, + strides=strides, + layer=0, + search_field=root_step.box, + pre_edit_steps=None, + parent_step=root_step) + self._process_root_job(image, root_job) + + # the leaf layer's network output may not meet the threshold, + # we have to cutoff the unqualified layers + layer_count = root_step.max_layer + 1 + if layer_count == 0: + raise NoValidResultError("No edit step layer was found.") + + # gather the network output of each layer + layer_outputs = [None] * layer_count + for layer in range(layer_count): + steps = root_step.get_layer_or_leaf_steps(layer) + if not steps: + continue + masked_image = EditStep.apply(image, self._compiled_mask, steps, by_masking=self._by_masking) + output = self._network(Tensor(masked_image)) + output = output[0, class_idx].asnumpy().item() + layer_outputs[layer] = output + + # determine which layer we have to cutoff + cutoff_layer = None + for layer in reversed(range(layer_count)): + if layer_outputs[layer] is not None and self._is_threshold_met(layer_outputs[layer]): + cutoff_layer = layer + break + + if cutoff_layer is None or root_step.is_leaf: + raise NoValidResultError(f"No edit step layer's network output meet the threshold: {self._threshold}.") + + # cutoff the layer by removing all children of the layer's steps. + steps = root_step.get_layer_steps(cutoff_layer) + for step in steps: + step.remove_all_children() + layer_outputs = layer_outputs[:cutoff_layer + 1] + + return root_step, layer_outputs + + def _process_root_job(self, sample_input, root_job): + """ + Process job queue. + + Args: + sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format. + root_job (_SearchJob): Root search job. + """ + job_queue = [root_job] + while job_queue: + job = job_queue.pop(0) + sub_job_queue = [] + job_edit_steps, stop_reason = self._process_job(job, sample_input, sub_job_queue) + + if stop_reason in (self._StopReason.THRESHOLD_MET, self._StopReason.STEP_CHANGE_MET): + for step in job_edit_steps: + job.parent_step.add_child(step) + job_queue.extend(sub_job_queue) + + def _process_job(self, job, sample_input, job_queue): + """ + Process a job. + + Args: + job (_SearchJob): Search job to be processed. + sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format. + job_queue (list[_SearchJob]): Job queue. + + Returns: + tuple[list[EditStep], _StopReason], result edit stop and the stop reason. + """ + edit_steps = [] + + # make the network output with the original image is strictly larger than the threshold + if job.layer == 0: + original_output = self._network(Tensor(sample_input))[0, job.class_idx].asnumpy().item() + if original_output <= self._threshold: + raise OriginalOutputError(f'The original output is not strictly larger the threshold: ' + f'{self._threshold}') + + # applying the pre-edit steps from the parent steps + if job.pre_edit_steps: + # use the latest leaf steps to increase the accuracy + leaf_steps = [] + for step in job.pre_edit_steps: + leaf_steps.extend(step.leaf_steps) + pre_edit_steps = leaf_steps + else: + pre_edit_steps = None + workpiece = EditStep.apply(sample_input, + self._compiled_mask, + pre_edit_steps, + self._by_masking) + + job.on_start(sample_input, workpiece, self._compiled_mask, self._network) + start_output = self._network(Tensor(workpiece))[0, job.class_idx].asnumpy().item() + last_output = start_output + + # greedy search loop + while True: + + if self._is_threshold_met(last_output): + return edit_steps, self._StopReason.THRESHOLD_MET + + try: + best_edit = job.find_best_edit() + except _NoNewStepError: + return edit_steps, self._StopReason.NO_NEW_STEP + except _RepeatedStepError: + return edit_steps, self._StopReason.REPEATED_STEP + + best_edit.step_change = best_edit.network_output - last_output + + if job.layer < job.layer_count - 1 and self._is_greedy(best_edit.step_change): + # create net layer search job if new edit step is valid and not yet reaching + # the final layer + if job.pre_edit_steps: + pre_edit_steps = list(job.pre_edit_steps) + pre_edit_steps.extend(edit_steps) + else: + pre_edit_steps = list(edit_steps) + + sub_job = job.create_sub_job(best_edit, pre_edit_steps) + job_queue.append(sub_job) + + edit_steps.append(best_edit) + + if job.layer > 0: + # stop if the step change meet the parent step change only after layer 0 + change = best_edit.network_output - start_output + if self._is_step_change_met(job.parent_step.step_change, change): + return edit_steps, self._StopReason.STEP_CHANGE_MET + + last_output = best_edit.network_output + + def _is_threshold_met(self, network_output): + """Check if the threshold was met.""" + if self._by_masking: + return network_output <= self._threshold + return network_output >= self._threshold + + def _is_step_change_met(self, target, step_change): + """Check if the change target was met.""" + if self._by_masking: + return step_change <= target + return step_change >= target + + def _is_greedy(self, step_change): + """Check if it is a greedy step.""" + if self._by_masking: + return step_change < 0 + return step_change > 0 + + @classmethod + def _auto_win_sizes_strides(cls, short_side): + """ + Calculate auto window sizes and strides. + + Args: + short_side (int): Length of search space. + + Returns: + tuple[list[int], list[int]], window sizes and strides. + """ + win_sizes = [] + strides = [] + cur_len = int(short_side/AUTO_WIN_SIZE_DIV) + while len(win_sizes) < AUTO_LAYER_MAX and cur_len >= AUTO_WIN_SIZE_MIN: + stride = int(cur_len/AUTO_STRIDE_DIV) + if stride <= 0: + break + win_sizes.append(cur_len) + strides.append(stride) + cur_len = int(cur_len/AUTO_WIN_SIZE_DIV) + if not win_sizes: + raise ValueError(f"Image's short side is less then {AUTO_IMAGE_SHORT_SIDE_MIN}, " + f"unable to calculates auto settings.") + return win_sizes, strides + + class _StopReason(Enum): + """Stop reason of search job.""" + THRESHOLD_MET = 0 # threshold was met. + STEP_CHANGE_MET = 1 # parent step change was met. + NO_NEW_STEP = 2 # no new step was found. + REPEATED_STEP = 3 # repeated step was found. + + +def _check_iterable_type(arg_name, arg_value, container_type, elem_types): + """Concert iterable argument data type.""" + check_value_type(arg_name, arg_value, container_type) + for elem in arg_value: + check_value_type(arg_name + ' element', elem, elem_types) + + +class _NoNewStepError(Exception): + """Error for no new step was found.""" + + +class _RepeatedStepError(Exception): + """Error for repeated step was found.""" + + +class _SearchJob: + """ + Search job. + + Args: + by_masking (bool): Whether it is masking mode. + class_idx (int): Target class index. + win_sizes (list[int]): Moving square window size (length of side) of layers. + strides (list[int]): Strides of layers. + layer (int): Layer number. + search_field (tuple[int, int, int, int]): Search field in x, y, width, height format. + pre_edit_steps (list[EditStep], optional): Edit steps to be applied before searching. + parent_step (EditStep): Parent edit step. + batch_size (int): Batch size of batched inferences. + """ + + def __init__(self, + by_masking, + class_idx, + win_sizes, + strides, + layer, + search_field, + pre_edit_steps, + parent_step, + batch_size=DEFAULT_BATCH_SIZE): + + if layer >= len(win_sizes): + raise ValueError('Layer is larger then number of window sizes.') + + self.by_masking = by_masking + self.class_idx = class_idx + self.win_sizes = win_sizes + self.strides = strides + self.layer = layer + self.search_field = search_field + self.pre_edit_steps = pre_edit_steps + self.parent_step = parent_step + self.batch_size = batch_size + self.network = None + self.mask = None + self.original_input = None + + self._workpiece = None + self._found_best_edits = None + self._found_uvs = None + self._u_pixels = None + self._v_pixels = None + + @property + def layer_count(self): + """Number of layers.""" + return len(self.win_sizes) + + def on_start(self, original_input, workpiece, mask, network): + """ + Notification of the start of the search job. + + Args: + original_input (numpy.ndarray): The original image tensor in CHW or NCHW(N=1) format. + workpiece (numpy.ndarray): The intermediate image tensor in CHW or NCHW(N=1) format. + mask (Union[tuple[float, float, float], float, numpy.ndarray]): The mask, type can be + tuple[float, float, float]: RGB solid color mask, + float: Grey scale solid color mask. + numpy.ndarray: Image mask, has same format of original_input. + network (nn.Cell): Classification network. + """ + self.original_input = original_input + self.mask = mask + self.network = network + + self._workpiece = workpiece + self._found_best_edits = [] + self._found_uvs = [] + self._u_pixels = self._calc_uv_pixels(self.search_field[0], self.search_field[2]) + self._v_pixels = self._calc_uv_pixels(self.search_field[1], self.search_field[3]) + + def create_sub_job(self, parent_step, pre_edit_steps): + """Create next layer search job.""" + return self.__class__(by_masking=self.by_masking, + class_idx=self.class_idx, + win_sizes=self.win_sizes, + strides=self.strides, + layer=self.layer + 1, + search_field=copy.copy(parent_step.box), + pre_edit_steps=pre_edit_steps, + parent_step=parent_step, + batch_size=self.batch_size) + + def find_best_edit(self): + """ + Find the next best edit step. + + Returns: + EditStep, the next best edit step. + """ + workpiece = self._workpiece + if len(workpiece.shape) == 3: + workpiece = np.expand_dims(workpiece, axis=0) + + # generate input tensors with shifted masked/unmasked region and pack into a batch + squeeze = Squeeze() + best_new_workpiece = None + best_output = None + best_edit = None + best_uv = None + batch = np.repeat(workpiece, repeats=self.batch_size, axis=0) + batch_uvs = [] + batch_steps = [] + batch_i = 0 + win_size = self.win_sizes[self.layer] + for u, x in enumerate(self._u_pixels): + for v, y in enumerate(self._v_pixels): + if (u, v) in self._found_uvs: + continue + + edit_step = EditStep(self.layer, (x, y, win_size, win_size)) + + if self.by_masking: + EditStep.apply(batch[batch_i], + self.mask, + [edit_step], + self.by_masking, + inplace=True) + else: + EditStep.apply(self.original_input, + batch[batch_i], + [edit_step], + self.by_masking, + inplace=True) + + batch_i += 1 + batch_uvs.append((u, v)) + batch_steps.append(edit_step) + if batch_i == self.batch_size: + # the batch is full, inference and empty it + batch_output = self.network(Tensor(batch)) + batch_output = batch_output[:, self.class_idx] + if len(batch_output.shape) > 1: + batch_output = squeeze(batch_output) + if self.by_masking: + batch_best_i = np.argmin(batch_output.asnumpy()) + else: + batch_best_i = np.argmax(batch_output.asnumpy()) + batch_best_output = batch_output[int(batch_best_i)].asnumpy().item() + + if best_output is None or self._is_output0_better(batch_best_output, best_output): + best_output = batch_best_output + best_uv = batch_uvs[batch_best_i] + best_edit = batch_steps[batch_best_i] + best_new_workpiece = batch[batch_best_i] + + batch = np.repeat(workpiece, repeats=self.batch_size, axis=0) + batch_uvs = [] + batch_i = 0 + + if batch_i > 0: + # don't forget the last half full batch + batch_output = self.network(Tensor(batch)) + batch_output = batch_output[:, self.class_idx] + if len(batch_output.shape) > 1: + batch_output = squeeze(batch_output) + if self.by_masking: + batch_best_i = np.argmin(batch_output.asnumpy()[:batch_i, ...]) + else: + batch_best_i = np.argmax(batch_output.asnumpy()[:batch_i, ...]) + + batch_best_output = batch_output[int(batch_best_i)].asnumpy().item() + if best_output is None or self._is_output0_better(batch_best_output, best_output): + best_output = batch_best_output + best_uv = batch_uvs[batch_best_i] + best_edit = batch_steps[batch_best_i] + best_new_workpiece = batch[batch_best_i] + + if best_edit is None: + raise _NoNewStepError + + if best_uv in self._found_uvs: + raise _RepeatedStepError + + self._found_uvs.append(best_uv) + self._found_best_edits.append(best_edit) + best_edit.network_output = best_output + + # continue on the best workpiece in the next function call + self._workpiece = best_new_workpiece + + return best_edit + + def _is_output0_better(self, output0, output1): + """Check if the network output0 is better.""" + if self.by_masking: + return output0 < output1 + return output0 > output1 + + def _calc_uv_pixels(self, begin, length): + """ + Calculate the pixel coordinate of shifts. + + Args: + begin (int): The beginning pixel coordinate of search field. + length (int): The length of search field. + + Returns: + list[int], pixel coordinate of shifts. + """ + win_size = self.win_sizes[self.layer] + stride = self.strides[self.layer] + shift_count = self._calc_shift_count(length, win_size, stride) + pixels = [0] * shift_count + for i in range(shift_count): + if i == shift_count - 1: + pixels[i] = begin + length - win_size + else: + pixels[i] = begin + i*stride + return pixels + + @staticmethod + def _calc_shift_count(length, win_size, stride): + """ + Calculate the number of shifts in search field. + + Args: + length (int): The length of search field. + win_size (int): The length of sides of moving window. + stride (int): The stride. + + Returns: + int, number of shifts. + """ + if length <= win_size or win_size < stride or stride <= 0: + raise ValueError("Invalid length, win_size or stride.") + count = int(math.ceil((length - win_size)/stride)) + if (count - 1)*stride + win_size < length: + return count + 1 + return count diff --git a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py index f0ecef7f8a..eb6e24c1d5 100644 --- a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py +++ b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -84,7 +84,7 @@ class UncertaintyEvaluation: def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1, epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False): - self.epi_model = model + self.epi_model = deepcopy(model) self.ale_model = deepcopy(model) self.epi_train_dataset = train_dataset self.ale_train_dataset = train_dataset