Browse Source

Add HOC modules and support uncertainty in runner

fix pylint issues

deepcopy model inside uncertatiny instance

update summary.proto for hoc and uncertainty

update summary.proto tab to spaces

enhance code by review comments

fix comment format

add uncertainty and saliency cross registration checking.

check registered with is none

group constants togather, enhance runner data checking

update copyright year

enhance comment wordings
tags/v1.2.0-rc1
unknown 4 years ago
parent
commit
d621b9d4ea
5 changed files with 1240 additions and 29 deletions
  1. +19
    -6
      mindspore/ccsrc/utils/summary.proto
  2. +225
    -21
      mindspore/explainer/_image_classification_runner.py
  3. +15
    -0
      mindspore/explainer/explanation/_counterfactual/__init__.py
  4. +979
    -0
      mindspore/explainer/explanation/_counterfactual/hierarchical_occlusion.py
  5. +2
    -2
      mindspore/nn/probability/toolbox/uncertainty_evaluation.py

+ 19
- 6
mindspore/ccsrc/utils/summary.proto View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -121,7 +121,7 @@ message Explain {
optional string explain_method = 1;
optional int32 label = 2;
optional string heatmap_path = 3;
}
}

message Benchmark{
optional string benchmark_method = 1;
@@ -131,10 +131,21 @@ message Explain {
}

message Metadata{
repeated string label = 1;
repeated string explain_method = 2;
repeated string benchmark_method = 3;
}
repeated string label = 1;
repeated string explain_method = 2;
repeated string benchmark_method = 3;
}

message HocLayer {
optional float prob = 1;
repeated int32 box = 2; // List of repeated x, y, w, h
}

message Hoc {
optional int32 label = 1;
optional string mask = 2;
repeated HocLayer layer = 3;
}

optional int32 sample_id = 1;
optional string image_path = 2; // The Metadata and image path must have one fill in
@@ -146,4 +157,6 @@ message Explain {

optional Metadata metadata = 7;
optional string status = 8; // enum value: run, end

repeated Hoc hoc = 9; // hierarchical occlusion counterfactual
}

+ 225
- 21
mindspore/explainer/_image_classification_runner.py View File

@@ -15,9 +15,11 @@
"""Image Classification Runner."""
import os
import re
import json
from time import time

import numpy as np
from scipy.stats import beta
from PIL import Image

import mindspore as ms
@@ -30,10 +32,15 @@ from mindspore.train._utils import check_value_type
from mindspore.train.summary._summary_adapter import _convert_image_format
from mindspore.train.summary.summary_record import SummaryRecord
from mindspore.train.summary_pb2 import Explain
from .benchmark import Localization
from .explanation import RISE
from .benchmark._attribution.metric import AttributionMetric, LabelSensitiveMetric, LabelAgnosticMetric
from .explanation._attribution.attribution import Attribution
from mindspore.nn.probability.toolbox.uncertainty_evaluation import UncertaintyEvaluation
from mindspore.explainer.benchmark import Localization
from mindspore.explainer.benchmark._attribution.metric import AttributionMetric
from mindspore.explainer.benchmark._attribution.metric import LabelSensitiveMetric
from mindspore.explainer.benchmark._attribution.metric import LabelAgnosticMetric
from mindspore.explainer.explanation import RISE
from mindspore.explainer.explanation._attribution.attribution import Attribution
from mindspore.explainer.explanation._counterfactual import hierarchical_occlusion as hoc


_EXPAND_DIMS = ExpandDims()

@@ -97,6 +104,8 @@ class ImageClassificationRunner:
_DATAFILE_DIRNAME_PREFIX = "_explain_"
_ORIGINAL_IMAGE_DIRNAME = "origin_images"
_HEATMAP_DIRNAME = "heatmap"
# specfial filenames
_MANIFEST_FILENAME = "manifest.json"
# max. no. of sample per directory
_SAMPLE_PER_DIR = 1000
# seed for fixing the iterating order of the dataset
@@ -132,11 +141,15 @@ class ImageClassificationRunner:
self._network = network
self._explainers = None
self._benchmarkers = None
self._uncertainty = None
self._hoc_searcher = None
self._summary_timestamp = None
self._sample_index = -1

self._full_network = SequentialCell([self._network, activation_fn])

self._manifest = None

self._verify_data_n_settings(check_data_n_network=True)

def register_saliency(self,
@@ -185,6 +198,50 @@ class ImageClassificationRunner:
self._benchmarkers = None
raise

def register_hierarchical_occlusion(self):
"""
Register hierarchical occlusion instances.

Notes:
Input images are required to be in 3 channels formats and the length of side short must be equals to or
greater than 56 pixels.

Raises:
ValueError: Be raised for any data or settings' value problem.
RuntimeError: Be raised if the function was called already.
"""
if self._hoc_searcher is not None:
raise RuntimeError("Function register_hierarchical_occlusion() was invoked already.")

self._hoc_searcher = hoc.Searcher(self._full_network)

try:
self._verify_data_n_settings(check_hoc=True)
except ValueError:
self._hoc_searcher = None
raise

def register_uncertainty(self):
"""
Register uncertainty instance to compute the epistemic uncertainty base on the Bayes' theorem.

Notes:
Please refer to the documentation of mindspore.nn.probability.toolbox.uncertainty_evaluation for the
details. The actual output is standard deviation of the classification predictions and the corresponding
95% confidence intervals. Users have to invoke register_saliency() as well for the uncertainty results are
going to be shown on the saliency map page in MindInsight.

Raises:
RuntimeError: Be raised if the function was called already.
"""
if self._uncertainty is not None:
raise RuntimeError("Function register_uncertainty() was invoked already.")

self._uncertainty = UncertaintyEvaluation(model=self._full_network,
train_dataset=None,
task_type='classification',
num_classes=len(self._labels))

def run(self):
"""
Run the explain job and save the result as a summary in summary_dir.
@@ -198,7 +255,10 @@ class ImageClassificationRunner:
RuntimeError: Be raised for any runtime problem.
"""
self._verify_data_n_settings(check_all=True)

self._manifest = {"saliency_map": False,
"benchmark": False,
"uncertainty": False,
"hierarchical_occlusion": False}
with SummaryRecord(self._summary_dir, raise_exception=True) as summary:
print("Start running and writing......")
begin = time()
@@ -214,13 +274,25 @@ class ImageClassificationRunner:
if self._is_saliency_registered:
self._run_saliency(summary, imageid_labels)

self._save_manifest()

print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin))

@property
def _is_hoc_registered(self):
"""Check if HOC module is registered."""
return self._hoc_searcher is not None

@property
def _is_saliency_registered(self):
"""Check if saliency module is registered."""
return bool(self._explainers)

@property
def _is_uncertainty_registered(self):
"""Check if uncertainty module is registered."""
return self._uncertainty is not None

def _save_metadata(self, summary):
"""Save metadata of the explain job to summary."""
print("Start writing metadata......")
@@ -245,12 +317,13 @@ class ImageClassificationRunner:
Run inference for the dataset and write the inference related data into summary.

Args:
summary (SummaryRecord): The summary object to store the data
summary (SummaryRecord): The summary object to store the data.
threshold (float): The threshold for prediction.

Returns:
dict, The map of sample d to the union of its ground truth and predicted labels.
"""
has_uncertainty_rec = False
sample_id_labels = {}
self._sample_index = 0
ds.config.set_seed(self._DATASET_SEED)
@@ -259,10 +332,20 @@ class ImageClassificationRunner:
inputs, labels, _ = self._unpack_next_element(next_element)
prob = self._full_network(inputs).asnumpy()

if self._uncertainty is not None:
prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
else:
prob_var = None

for idx, inp in enumerate(inputs):
gt_labels = labels[idx]
gt_probs = [float(prob[idx][i]) for i in gt_labels]

if prob_var is not None:
gt_prob_vars = [float(prob_var[idx][i]) for i in gt_labels]
gt_itl_lows, gt_itl_his, gt_prob_sds = \
self._calc_beta_intervals(gt_probs, gt_prob_vars)

data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
original_image = _np_to_image(_normalize(data_np), mode='RGB')
original_image_path = self._save_original_image(self._sample_index, original_image)
@@ -270,6 +353,11 @@ class ImageClassificationRunner:
predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]]
predicted_probs = [float(prob[idx][i]) for i in predicted_labels]

if prob_var is not None:
predicted_prob_vars = [float(prob_var[idx][i]) for i in predicted_labels]
predicted_itl_lows, predicted_itl_his, predicted_prob_sds = \
self._calc_beta_intervals(predicted_probs, predicted_prob_vars)

union_labs = list(set(gt_labels + predicted_labels))
sample_id_labels[str(self._sample_index)] = union_labs

@@ -285,14 +373,29 @@ class ImageClassificationRunner:
explain.inference.predicted_label.extend(predicted_labels)
explain.inference.predicted_prob.extend(predicted_probs)

summary.add_value("explainer", "inference", explain)
if prob_var is not None:
explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
explain.inference.ground_truth_prob_itl95_low.extend(gt_itl_lows)
explain.inference.ground_truth_prob_itl95_hi.extend(gt_itl_his)
explain.inference.predicted_prob_sd.extend(predicted_prob_sds)
explain.inference.predicted_prob_itl95_low.extend(predicted_itl_lows)
explain.inference.predicted_prob_itl95_hi.extend(predicted_itl_his)

has_uncertainty_rec = True

summary.add_value("explainer", "inference", explain)
summary.record(1)

if self._is_hoc_registered:
self._run_hoc(summary, self._sample_index, inputs[idx], prob[idx])

self._sample_index += 1
self._spaced_print("Finish running and writing {}-th batch inference data."
" Time elapsed: {:.3f} s".format(j, time() - now),
end='')
" Time elapsed: {:.3f} s".format(j, time() - now))

if has_uncertainty_rec:
self._manifest["uncertainty"] = True

return sample_id_labels

def _run_saliency(self, summary, sample_id_labels):
@@ -306,10 +409,10 @@ class ImageClassificationRunner:
for idx, next_element in enumerate(self._dataset):
now = time()
self._spaced_print("Start running {}-th explanation data for {}......".format(
idx, exp.__class__.__name__), end='')
idx, exp.__class__.__name__))
self._run_exp_step(next_element, exp, sample_id_labels, summary)
self._spaced_print("Finish writing {}-th explanation data for {}. Time elapsed: "
"{:.3f} s".format(idx, exp.__class__.__name__, time() - now), end='')
"{:.3f} s".format(idx, exp.__class__.__name__, time() - now))
self._spaced_print(
"Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format(
exp.__class__.__name__, time() - start))
@@ -326,20 +429,20 @@ class ImageClassificationRunner:
for idx, next_element in enumerate(self._dataset):
now = time()
self._spaced_print("Start running {}-th explanation data for {}......".format(
idx, exp.__class__.__name__), end='')
idx, exp.__class__.__name__))
saliency_dict_lst = self._run_exp_step(next_element, exp, sample_id_labels, summary)
self._spaced_print(
"Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format(
idx, exp.__class__.__name__, time() - now), end='')
idx, exp.__class__.__name__, time() - now))
for bench in self._benchmarkers:
now = time()
self._spaced_print(
"Start running {}-th batch {} data for {}......".format(
idx, bench.__class__.__name__, exp.__class__.__name__), end='')
idx, bench.__class__.__name__, exp.__class__.__name__))
self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
self._spaced_print(
"Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format(
idx, bench.__class__.__name__, exp.__class__.__name__, time() - now), end='')
idx, bench.__class__.__name__, exp.__class__.__name__, time() - now))

for bench in self._benchmarkers:
benchmark = explain.benchmark.add()
@@ -355,6 +458,52 @@ class ImageClassificationRunner:
summary.add_value('explainer', 'benchmark', explain)
summary.record(1)

def _run_hoc(self, summary, sample_id, sample_input, prob):
"""
Run HOC search for a sample image, and then save the result to summary.

Args:
summary (SummaryRecord): The summary object to store the data.
sample_id (int): The sample ID.
sample_input (Union[Tensor, np.ndarray]): Sample image tensor in CHW or NCWH(N=1).
prob (Union[Tensor, np.ndarray]): List of sample's classification prediction output, HOC will run for
labels with prediction output strictly larger then HOC searcher's threshold(0.5 by default).
"""
if isinstance(sample_input, ms.Tensor):
sample_input = sample_input.asnumpy()
if len(sample_input.shape) == 3:
sample_input = np.expand_dims(sample_input, axis=0)
has_rec = False
explain = Explain()
explain.sample_id = sample_id
str_mask = hoc.auto_str_mask(sample_input)
compiled_mask = None
for label_idx, label_prob in enumerate(prob):
if label_prob > self._hoc_searcher.threshold:
if compiled_mask is None:
compiled_mask = hoc.compile_mask(str_mask, sample_input)
try:
edit_tree, layer_outputs = self._hoc_searcher.search(sample_input, label_idx, compiled_mask)
except hoc.NoValidResultError as ex:
log.error(f"HOC cannot find result for sample:{sample_id} error:{ex}")
continue
has_rec = True
hoc_rec = explain.hoc.add()
hoc_rec.label = label_idx
hoc_rec.mask = str_mask
layer_count = edit_tree.max_layer + 1
for layer in range(layer_count):
steps = edit_tree.get_layer_or_leaf_steps(layer)
layer_output = layer_outputs[layer]
hoc_layer = hoc_rec.layer.add()
hoc_layer.prob = layer_output
for step in steps:
hoc_layer.box.extend(list(step.box))
if has_rec:
summary.add_value("explainer", "hoc", explain)
summary.record(1)
self._manifest['hierarchical_occlusion'] = True

def _run_exp_step(self, next_element, explainer, sample_id_labels, summary):
"""
Run the explanation for each step and write explanation results into summary.
@@ -368,6 +517,7 @@ class ImageClassificationRunner:
Returns:
list, List of dict that maps label to its corresponding saliency map.
"""
has_saliency_rec = False
inputs, labels, _ = self._unpack_next_element(next_element)
sample_index = self._sample_index
unions = []
@@ -406,11 +556,17 @@ class ImageClassificationRunner:
explanation.heatmap_path = heatmap_path
explanation.label = lab

has_saliency_rec = True

summary.add_value("explainer", "explanation", explain)
summary.record(1)

self._sample_index += 1
saliency_dict_lst.append(saliency_dict)

if has_saliency_rec:
self._manifest['saliency_map'] = True

return saliency_dict_lst

def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
@@ -436,6 +592,26 @@ class ImageClassificationRunner:
else:
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
'receive {}'.format(type(benchmarker)))
self._manifest['benchmark'] = True

@staticmethod
def _calc_beta_intervals(means, variances, prob=0.95):
"""Calculate confidence interval of beta distributions."""
if not isinstance(means, np.ndarray):
means = np.array(means)
if not isinstance(variances, np.ndarray):
variances = np.array(variances)
with np.errstate(divide='ignore'):
coef_a = ((means ** 2) * (1 - means) / variances) - means
coef_b = (coef_a * (1 - means)) / means
itl_lows, itl_his = beta.interval(prob, coef_a, coef_b)
sds = np.sqrt(variances)
for i in range(itl_lows.shape[0]):
if not np.isfinite(sds[i]) or not np.isfinite(itl_lows[i]) or not np.isfinite(itl_his[i]):
itl_lows[i] = means[i]
itl_his[i] = means[i]
sds[i] = 0
return itl_lows, itl_his, sds

def _verify_data(self):
"""Verify dataset and labels."""
@@ -474,6 +650,17 @@ class ImageClassificationRunner:
"Labels shape {} is unrecognizable: outputs should not have more than two dimensions"
" with length greater than 1.".format(labels.shape))

if self._is_hoc_registered:
if inputs.shape[-3] != 3:
raise ValueError(
"Hierarchical occlusion is registered, images must be in 3 channels format, but "
"{} channels is encountered.".format(inputs.shape[-3]))
short_side = min(inputs.shape[-2:])
if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN:
raise ValueError(
"Hierarchical occlusion is registered, images' short side must be equals to or greater then "
"{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side))

def _verify_network(self):
"""Verify the network."""
label_set = set()
@@ -521,7 +708,8 @@ class ImageClassificationRunner:
check_all=False,
check_registration=False,
check_data_n_network=False,
check_saliency=False):
check_saliency=False,
check_hoc=False):
"""
Verify the validity of dataset and other settings.

@@ -530,6 +718,7 @@ class ImageClassificationRunner:
check_registration (bool): Set it True for checking registrations, check if it is enough to invoke run().
check_data_n_network (bool): Set it True for checking data and network.
check_saliency (bool): Set it True for checking saliency related settings.
check_hoc (bool): Set it True for checking HOC related settings.

Raises:
ValueError: Be raised for any data or settings' value problem.
@@ -539,13 +728,16 @@ class ImageClassificationRunner:
check_registration = True
check_data_n_network = True
check_saliency = True
check_hoc = True

if check_registration:
if not self._is_saliency_registered:
raise ValueError("No explanation module was registered, user should at least call register_saliency()"
" once with proper explanation instances")
if not self._is_saliency_registered and not self._is_hoc_registered:
raise ValueError("No explanation module was registered, user should at least call register_saliency() "
"or register_hierarchical_occlusion() once with proper arguments")
if self._is_uncertainty_registered and not self._is_saliency_registered:
raise ValueError("Function register_uncertainty() is invoked but register_saliency() is not.")

if check_data_n_network or check_saliency:
if check_data_n_network or check_saliency or check_hoc:
self._verify_data()

if check_data_n_network:
@@ -658,6 +850,18 @@ class ImageClassificationRunner:

return ms.Tensor(batch_labels, ms.int32)

def _save_manifest(self):
"""Save manifest.json underneath datafile directory."""
if self._manifest is None:
raise RuntimeError("Manifest not yet be initialized.")
path_tokens = [self._summary_dir,
self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp)]
abs_dir_path = self._create_subdir(*path_tokens)
save_path = os.path.join(abs_dir_path, self._MANIFEST_FILENAME)
with open(save_path, 'w') as file:
json.dump(self._manifest, file, indent=4)
os.chmod(save_path, self._FILE_MODE)

def _save_original_image(self, sample_id, image):
"""Save an image to summary directory."""
id_dirname = self._get_sample_dirname(sample_id)
@@ -720,7 +924,7 @@ class ImageClassificationRunner:
return None

@classmethod
def _spaced_print(cls, message, *args, **kwargs):
def _spaced_print(cls, message):
"""Spaced message printing."""
# workaround to print logs starting new line in case line width mismatch.
print(cls._SPACER.format(message))

+ 15
- 0
mindspore/explainer/explanation/_counterfactual/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Counterfactual modules."""

+ 979
- 0
mindspore/explainer/explanation/_counterfactual/hierarchical_occlusion.py View File

@@ -0,0 +1,979 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Hierarchical occlusion edit tree searcher."""
from enum import Enum
import copy
import re
import math

import numpy as np
from scipy.ndimage import gaussian_filter

from mindspore import nn
from mindspore import Tensor
from mindspore.ops import Squeeze
from mindspore.train._utils import check_value_type


AUTO_LAYER_MAX = 3 # maximum number of layer by auto settings
AUTO_WIN_SIZE_MIN = 28 # minimum window size by auto settings
AUTO_WIN_SIZE_DIV = 2 # denominator of windows size calculations by auto settings
AUTO_STRIDE_DIV = 5 # denominator of stride calculations by auto settings
AUTO_MASK_GAUSSIAN_RADIUS_DIV = 25 # denominator of gaussian mask radius calculations by auto settings
DEFAULT_THRESHOLD = 0.5 # default target prediction threshold
DEFAULT_BATCH_SIZE = 64 # default batch size for batch inference search
MASK_GAUSSIAN_RE = r'^gaussian:(\d+)$' # gaussian mask string pattern

# minimum length of input images' short side with auto settings
AUTO_IMAGE_SHORT_SIDE_MIN = AUTO_WIN_SIZE_MIN * AUTO_WIN_SIZE_DIV


def is_valid_str_mask(mask):
"""Check if it is a valid string mask."""
check_value_type('mask', mask, str)
match = re.match(MASK_GAUSSIAN_RE, mask)
return match and int(match.group(1)) > 0


def compile_mask(mask, image):
"""Compile mask to a ready to use object."""
if mask is None:
return compile_str_mask(auto_str_mask(image), image)
check_value_type('mask', mask, (str, tuple, float, np.ndarray))
if isinstance(mask, str):
return compile_str_mask(mask, image)

if isinstance(mask, tuple):
_check_iterable_type('mask', mask, tuple, float)
elif isinstance(mask, np.ndarray):
if len(image.shape) == 4 and len(mask.shape) == 3:
mask = np.expand_dims(mask, axis=0)
elif len(image.shape) == 3 and len(mask.shape) == 4 and mask.shape[0] == 1:
mask = mask.squeeze(0)
if image.shape != mask.shape:
raise ValueError("Image and mask is not match in shape.")
return mask


def auto_str_mask(image):
"""Generate auto string mask for the image."""
check_value_type('image', image, np.ndarray)
short_side = np.min(image.shape[-2:])
radius = int(round(short_side/AUTO_MASK_GAUSSIAN_RADIUS_DIV))
if radius == 0:
raise ValueError(f"Input image's short side:{short_side} is too small for auto mask, "
f"at least {AUTO_MASK_GAUSSIAN_RADIUS_DIV}pixels is required.")
return f'gaussian:{radius}'


def compile_str_mask(mask, image):
"""Concert string mask to numpy.ndarray."""
check_value_type('mask', mask, str)
check_value_type('image', image, np.ndarray)
match = re.match(MASK_GAUSSIAN_RE, mask)
if match:
radius = int(match.group(1))
if radius > 0:
sigma = [0] * len(image.shape)
sigma[-2] = radius
sigma[-1] = radius
return gaussian_filter(image, sigma=sigma, mode='nearest')
raise ValueError(f"Invalid string mask: '{mask}'.")


class EditStep:
"""
Edit step that describes a box region, also represents an edit tree.

Args:
layer (int): Layer number, -1 is root layer, 0 or above is normal edit layer.
box (tuple[int, int, int, int]): Tuple of x, y, width, height.
"""
def __init__(self, layer, box):
self.layer = layer
self.box = box
self.network_output = 0
self.step_change = 0
self.children = None

@property
def x(self):
"""X-coordinate of the box."""
return self.box[0]

@property
def y(self):
"""Y-coordinate of the box."""
return self.box[1]

@property
def width(self):
"""Width of the box."""
return self.box[2]

@property
def height(self):
"""Height of the box."""
return self.box[3]

@property
def is_leaf(self):
"""Returns True if no child edit step."""
return not self.children

@property
def leaf_steps(self):
"""Returns all leaf edit steps in the tree."""
if self.is_leaf:
return [self]
steps = []
for child in self.children:
steps.extend(child.leaf_steps)
return steps

@property
def max_layer(self):
"""Maximum layer number in the edit tree."""
if self.is_leaf:
return self.layer
layer = self.layer
for child in self.children:
child_max_layer = child.max_layer
if child_max_layer > layer:
layer = child_max_layer
return layer

def add_child(self, child):
"""Add a child edit step."""
if self.children is None:
self.children = [child]
else:
self.children.append(child)

def remove_all_children(self):
"""Remove all child steps."""
self.children = None

def get_layer_or_leaf_steps(self, layer):
"""Get all edit steps of the layer and all leaf edit steps above the layer."""
if self.layer == layer or (self.layer < layer and self.is_leaf):
return [self]
steps = []
if self.layer < layer and self.children:
for child in self.children:
steps.extend(child.get_layer_or_leaf_steps(layer))
return steps

def get_layer_steps(self, layer):
"""Get all edit steps of the layer."""
if self.layer == layer:
return [self]
steps = []
if self.layer < layer and self.children:
for child in self.children:
steps.extend(child.get_layer_steps(layer))
return steps

@classmethod
def apply(cls,
image,
mask,
edit_steps,
by_masking=False,
inplace=False):
"""
Apply edit steps.

Args:
image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be
str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9.
tuple[float, float, float]: RGB solid color mask,
float: Grey scale solid color mask.
numpy.ndarray: Image mask in CHW or NCHW(N=1) format.
edit_steps (list[EditStep], optional): Edit steps to be applied.
by_masking (bool): Whether it is masking mode.
inplace (bool): Whether the modification is going to take place in the input image tensor. False to
construct a new image tensor as result.

Returns:
numpy.ndarray, the result image tensor.

Raises:
TypeError: Be raised for any argument or data type problem.
ValueError: Be raised for any argument or data value problem.
"""
if by_masking:
return cls.apply_masking(image, mask, edit_steps, inplace)
return cls.apply_unmasking(image, mask, edit_steps, inplace)

@staticmethod
def apply_masking(image,
mask,
edit_steps,
inplace=False):
"""
Apply edit steps in masking mode.

Args:
image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be
str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9.
tuple[float, float, float]: RGB solid color mask,
float: Grey scale solid color mask.
numpy.ndarray: Image mask in CHW or NCHW(N=1) format.
edit_steps (list[EditStep], optional): Edit steps to be applied.
inplace (bool): Whether the modification is going to take place in the input image tensor. False to
construct a new image tensor as result.

Returns:
numpy.ndarray, the result image tensor.

Raises:
TypeError: Be raised for any argument or data type problem.
ValueError: Be raised for any argument or data value problem.
"""
check_value_type('image', image, np.ndarray)
check_value_type('mask', mask, (str, tuple, float, np.ndarray))
if isinstance(mask, tuple):
_check_iterable_type('mask', mask, tuple, float)

if edit_steps is not None:
_check_iterable_type('edit_steps', edit_steps, (tuple, list), EditStep)

mask = compile_mask(mask, image)

if inplace:
background = image
else:
background = np.copy(image)

if not edit_steps:
return background

for step in edit_steps:

x_max = step.x + step.width
y_max = step.y + step.height

if x_max > background.shape[-1]:
x_max = background.shape[-1]

if y_max > background.shape[-2]:
y_max = background.shape[-2]

if x_max <= step.x or y_max <= step.y:
continue

if isinstance(mask, np.ndarray):
background[..., step.y:y_max, step.x:x_max] = mask[..., step.y:y_max, step.x:x_max]
else:
if isinstance(mask, (int, float)):
mask = (mask, mask, mask)
for c in range(3):
background[..., c, step.y:y_max, step.x:x_max] = mask[c]
return background

@staticmethod
def apply_unmasking(image,
mask,
edit_steps,
inplace=False):
"""
Apply edit steps in unmasking mode.

Args:
image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be
str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9.
tuple[float, float, float]: RGB solid color mask,
float: Grey scale solid color mask.
numpy.ndarray: Image mask in CHW or NCHW(N=1) format.
edit_steps (list[EditStep]): Edit steps to be applied.
inplace (bool): Whether the modification is going to take place in the input mask tensor. False to
construct a new image tensor as result.

Returns:
numpy.ndarray, the result image tensor.

Raises:
TypeError: Be raised for any argument or data type problem.
ValueError: Be raised for any argument or data value problem.
"""
check_value_type('image', image, np.ndarray)
check_value_type('mask', mask, (str, tuple, float, np.ndarray))
if isinstance(mask, tuple):
_check_iterable_type('mask', mask, tuple, float)

if edit_steps is not None:
_check_iterable_type('edit_steps', edit_steps, (tuple, list), EditStep)

mask = compile_mask(mask, image)

if isinstance(mask, np.ndarray):
if inplace:
background = mask
else:
background = np.copy(mask)
else:
if inplace:
raise ValueError('Inplace cannot be True when mask is not a numpy.ndarray')

background = np.zeros_like(image)
if isinstance(mask, (int, float)):
background.fill(mask)
else:
for c in range(3):
background[..., c, :, :] = mask[c]

if not edit_steps:
return background

for step in edit_steps:

x_max = step.x + step.width
y_max = step.y + step.height

if x_max > background.shape[-1]:
x_max = background.shape[-1]

if y_max > background.shape[-2]:
y_max = background.shape[-2]

if x_max <= step.x or y_max <= step.y:
continue

background[..., step.y:y_max, step.x:x_max] = image[..., step.y:y_max, step.x:x_max]

return background


class NoValidResultError(RuntimeError):
"""Error for no edit step layer's network output meet the threshold."""


class OriginalOutputError(RuntimeError):
"""Error for network output of the original image is not strictly larger than the threshold."""


class Searcher:
"""
Edit step searcher.

Args:
network (Cell): Image tensor in CHW or NCHW(N=1) format.
win_sizes (Union(list[int], optional): Moving square window size (length of side) of layers,
None means by auto calcuation.
strides (Union(list[int], optional): Stride of layers, None means by auto calcuation.
threshold (float): Threshold network output value of the target class.
by_masking (bool): Whether it is masking mode.

Examples:
>>> from mindspore import nn
>>> from mindspore.explainer.explanation._counterfactual.hierarchical_occlusion import Searcher, EditStep
>>>
>>> from user_defined import load_network, load_sample_image
>>>
>>>
>>> network = nn.SequentialCell([load_network(), nn.Sigmoid()])
>>>
>>> # single image in CHW or NCHW(N=1) numpy.ndarray tensor, typical dimension is 224x224
>>> image = load_sample_image()
>>> # target class index
>>> class_idx = 5
>>>
>>> # by default, maximum 3 search layers, auto calculate window sizes and strides
>>> searcher = Searcher(network)
>>>
>>> edit_tree, layer_outputs = searcher.search(image, class_idx)
>>> # get the outcome image of the deepest layer in CHW(or NCHW(N=1) if input image is NCHW) format
>>> outcome = EditStep.apply(image, searcher.compiled_mask, edit_tree.leaf_steps)
"""

def __init__(self,
network,
win_sizes=None,
strides=None,
threshold=DEFAULT_THRESHOLD,
by_masking=False):

check_value_type('network', network, nn.Cell)

if win_sizes is not None:
_check_iterable_type('win_sizes', win_sizes, list, int)
if not win_sizes:
raise ValueError('Argument win_sizes is empty.')

for i in range(1, len(win_sizes)):
if win_sizes[i] >= win_sizes[i-1]:
raise ValueError('Argument win_sizes is not strictly descending.')

if win_sizes[-1] <= 0:
raise ValueError('Argument win_sizes has non-positive number.')
elif strides is not None:
raise ValueError('Argument win_sizes cannot be None if strides is not None.')

if strides is not None:
_check_iterable_type('strides', strides, list, int)
for i in range(1, len(strides)):
if strides[i] >= strides[i-1]:
raise ValueError('Argument win_sizes is not strictly descending.')

if strides[-1] <= 0:
raise ValueError('Argument strides has non-positive number.')

if len(strides) != len(win_sizes):
raise ValueError('Length of strides and win_sizes is not equal.')
elif win_sizes is not None:
raise ValueError('Argument strides cannot be None if win_sizes is not None.')

self._network = copy.deepcopy(network)
self._compiled_mask = None
self._threshold = threshold
self._win_sizes = copy.copy(win_sizes) if win_sizes else None
self._strides = copy.copy(strides) if strides else None
self._by_masking = by_masking

@property
def network(self):
"""Get the network."""
return self._network

@property
def by_masking(self):
"""Check if it is masking mode."""
return self._by_masking

@property
def threshold(self):
"""The network output threshold to stop the search."""
return self._threshold

@property
def win_sizes(self):
"""Windows sizes in pixels."""
return self._win_sizes

@property
def strides(self):
"""Strides in pixels."""
return self._strides

@property
def compiled_mask(self):
"""The compiled mask after a successful search() call."""
return self._compiled_mask

def search(self, image, class_idx, mask=None):
"""
Search smallest sufficient/destruction region on an image.

Args:
image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
class_idx (int): Target class index.
mask (Union[str, tuple[float, float, float], float], optional): The mask, type can be
str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9.
tuple[float, float, float]: RGB solid color mask,
float: Grey scale solid color mask.
None: By auto calculation.

Returns:
tuple[EditStep, list[float]], the root edit step and network output of each layer after applied the
layer steps.

Raise:
TypeError: Be raised for any argument or data type problem.
ValueError: Be raised for any argument or data value problem.
NoValidResultError: Be raised if no valid result was found.
OriginalOutputError: Be raised if network output of the original image is not strictly larger than
the threshold.
"""
check_value_type('image', image, (Tensor, np.ndarray))

if isinstance(image, Tensor):
image = image.asnumpy()

if len(image.shape) == 4:
if image.shape[0] != 1:
raise ValueError("Argument image's batch size is not 1.")
elif len(image.shape) == 3:
image = np.expand_dims(image, axis=0)
else:
raise ValueError("Argument image is not in CHW or NCHW(N=1) format.")

check_value_type('class_idx', class_idx, int)

if class_idx < 0:
raise ValueError("Argument class_idx is less then zero.")

self._compiled_mask = compile_mask(mask, image)

short_side = np.min(image.shape[-2:])
if self._win_sizes is None:
win_sizes, strides = self._auto_win_sizes_strides(short_side)
else:
win_sizes, strides = self._win_sizes, self._strides

if short_side <= win_sizes[0]:
raise ValueError(f"Input image's short side is shorter then or "
f"equals to the first window size:{win_sizes[0]}.")

self._network.set_train(False)

# the search result will be store as a edit tree that attached to the root step.
root_step = EditStep(-1, (0, 0, image.shape[-1], image.shape[-2]))
root_job = _SearchJob(by_masking=self._by_masking,
class_idx=class_idx,
win_sizes=win_sizes,
strides=strides,
layer=0,
search_field=root_step.box,
pre_edit_steps=None,
parent_step=root_step)
self._process_root_job(image, root_job)

# the leaf layer's network output may not meet the threshold,
# we have to cutoff the unqualified layers
layer_count = root_step.max_layer + 1
if layer_count == 0:
raise NoValidResultError("No edit step layer was found.")

# gather the network output of each layer
layer_outputs = [None] * layer_count
for layer in range(layer_count):
steps = root_step.get_layer_or_leaf_steps(layer)
if not steps:
continue
masked_image = EditStep.apply(image, self._compiled_mask, steps, by_masking=self._by_masking)
output = self._network(Tensor(masked_image))
output = output[0, class_idx].asnumpy().item()
layer_outputs[layer] = output

# determine which layer we have to cutoff
cutoff_layer = None
for layer in reversed(range(layer_count)):
if layer_outputs[layer] is not None and self._is_threshold_met(layer_outputs[layer]):
cutoff_layer = layer
break

if cutoff_layer is None or root_step.is_leaf:
raise NoValidResultError(f"No edit step layer's network output meet the threshold: {self._threshold}.")

# cutoff the layer by removing all children of the layer's steps.
steps = root_step.get_layer_steps(cutoff_layer)
for step in steps:
step.remove_all_children()
layer_outputs = layer_outputs[:cutoff_layer + 1]

return root_step, layer_outputs

def _process_root_job(self, sample_input, root_job):
"""
Process job queue.

Args:
sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format.
root_job (_SearchJob): Root search job.
"""
job_queue = [root_job]
while job_queue:
job = job_queue.pop(0)
sub_job_queue = []
job_edit_steps, stop_reason = self._process_job(job, sample_input, sub_job_queue)

if stop_reason in (self._StopReason.THRESHOLD_MET, self._StopReason.STEP_CHANGE_MET):
for step in job_edit_steps:
job.parent_step.add_child(step)
job_queue.extend(sub_job_queue)

def _process_job(self, job, sample_input, job_queue):
"""
Process a job.

Args:
job (_SearchJob): Search job to be processed.
sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format.
job_queue (list[_SearchJob]): Job queue.

Returns:
tuple[list[EditStep], _StopReason], result edit stop and the stop reason.
"""
edit_steps = []

# make the network output with the original image is strictly larger than the threshold
if job.layer == 0:
original_output = self._network(Tensor(sample_input))[0, job.class_idx].asnumpy().item()
if original_output <= self._threshold:
raise OriginalOutputError(f'The original output is not strictly larger the threshold: '
f'{self._threshold}')

# applying the pre-edit steps from the parent steps
if job.pre_edit_steps:
# use the latest leaf steps to increase the accuracy
leaf_steps = []
for step in job.pre_edit_steps:
leaf_steps.extend(step.leaf_steps)
pre_edit_steps = leaf_steps
else:
pre_edit_steps = None
workpiece = EditStep.apply(sample_input,
self._compiled_mask,
pre_edit_steps,
self._by_masking)

job.on_start(sample_input, workpiece, self._compiled_mask, self._network)
start_output = self._network(Tensor(workpiece))[0, job.class_idx].asnumpy().item()
last_output = start_output

# greedy search loop
while True:

if self._is_threshold_met(last_output):
return edit_steps, self._StopReason.THRESHOLD_MET

try:
best_edit = job.find_best_edit()
except _NoNewStepError:
return edit_steps, self._StopReason.NO_NEW_STEP
except _RepeatedStepError:
return edit_steps, self._StopReason.REPEATED_STEP

best_edit.step_change = best_edit.network_output - last_output

if job.layer < job.layer_count - 1 and self._is_greedy(best_edit.step_change):
# create net layer search job if new edit step is valid and not yet reaching
# the final layer
if job.pre_edit_steps:
pre_edit_steps = list(job.pre_edit_steps)
pre_edit_steps.extend(edit_steps)
else:
pre_edit_steps = list(edit_steps)

sub_job = job.create_sub_job(best_edit, pre_edit_steps)
job_queue.append(sub_job)

edit_steps.append(best_edit)

if job.layer > 0:
# stop if the step change meet the parent step change only after layer 0
change = best_edit.network_output - start_output
if self._is_step_change_met(job.parent_step.step_change, change):
return edit_steps, self._StopReason.STEP_CHANGE_MET

last_output = best_edit.network_output

def _is_threshold_met(self, network_output):
"""Check if the threshold was met."""
if self._by_masking:
return network_output <= self._threshold
return network_output >= self._threshold

def _is_step_change_met(self, target, step_change):
"""Check if the change target was met."""
if self._by_masking:
return step_change <= target
return step_change >= target

def _is_greedy(self, step_change):
"""Check if it is a greedy step."""
if self._by_masking:
return step_change < 0
return step_change > 0

@classmethod
def _auto_win_sizes_strides(cls, short_side):
"""
Calculate auto window sizes and strides.

Args:
short_side (int): Length of search space.

Returns:
tuple[list[int], list[int]], window sizes and strides.
"""
win_sizes = []
strides = []
cur_len = int(short_side/AUTO_WIN_SIZE_DIV)
while len(win_sizes) < AUTO_LAYER_MAX and cur_len >= AUTO_WIN_SIZE_MIN:
stride = int(cur_len/AUTO_STRIDE_DIV)
if stride <= 0:
break
win_sizes.append(cur_len)
strides.append(stride)
cur_len = int(cur_len/AUTO_WIN_SIZE_DIV)
if not win_sizes:
raise ValueError(f"Image's short side is less then {AUTO_IMAGE_SHORT_SIDE_MIN}, "
f"unable to calculates auto settings.")
return win_sizes, strides

class _StopReason(Enum):
"""Stop reason of search job."""
THRESHOLD_MET = 0 # threshold was met.
STEP_CHANGE_MET = 1 # parent step change was met.
NO_NEW_STEP = 2 # no new step was found.
REPEATED_STEP = 3 # repeated step was found.


def _check_iterable_type(arg_name, arg_value, container_type, elem_types):
"""Concert iterable argument data type."""
check_value_type(arg_name, arg_value, container_type)
for elem in arg_value:
check_value_type(arg_name + ' element', elem, elem_types)


class _NoNewStepError(Exception):
"""Error for no new step was found."""


class _RepeatedStepError(Exception):
"""Error for repeated step was found."""


class _SearchJob:
"""
Search job.

Args:
by_masking (bool): Whether it is masking mode.
class_idx (int): Target class index.
win_sizes (list[int]): Moving square window size (length of side) of layers.
strides (list[int]): Strides of layers.
layer (int): Layer number.
search_field (tuple[int, int, int, int]): Search field in x, y, width, height format.
pre_edit_steps (list[EditStep], optional): Edit steps to be applied before searching.
parent_step (EditStep): Parent edit step.
batch_size (int): Batch size of batched inferences.
"""

def __init__(self,
by_masking,
class_idx,
win_sizes,
strides,
layer,
search_field,
pre_edit_steps,
parent_step,
batch_size=DEFAULT_BATCH_SIZE):

if layer >= len(win_sizes):
raise ValueError('Layer is larger then number of window sizes.')

self.by_masking = by_masking
self.class_idx = class_idx
self.win_sizes = win_sizes
self.strides = strides
self.layer = layer
self.search_field = search_field
self.pre_edit_steps = pre_edit_steps
self.parent_step = parent_step
self.batch_size = batch_size
self.network = None
self.mask = None
self.original_input = None

self._workpiece = None
self._found_best_edits = None
self._found_uvs = None
self._u_pixels = None
self._v_pixels = None

@property
def layer_count(self):
"""Number of layers."""
return len(self.win_sizes)

def on_start(self, original_input, workpiece, mask, network):
"""
Notification of the start of the search job.

Args:
original_input (numpy.ndarray): The original image tensor in CHW or NCHW(N=1) format.
workpiece (numpy.ndarray): The intermediate image tensor in CHW or NCHW(N=1) format.
mask (Union[tuple[float, float, float], float, numpy.ndarray]): The mask, type can be
tuple[float, float, float]: RGB solid color mask,
float: Grey scale solid color mask.
numpy.ndarray: Image mask, has same format of original_input.
network (nn.Cell): Classification network.
"""
self.original_input = original_input
self.mask = mask
self.network = network

self._workpiece = workpiece
self._found_best_edits = []
self._found_uvs = []
self._u_pixels = self._calc_uv_pixels(self.search_field[0], self.search_field[2])
self._v_pixels = self._calc_uv_pixels(self.search_field[1], self.search_field[3])

def create_sub_job(self, parent_step, pre_edit_steps):
"""Create next layer search job."""
return self.__class__(by_masking=self.by_masking,
class_idx=self.class_idx,
win_sizes=self.win_sizes,
strides=self.strides,
layer=self.layer + 1,
search_field=copy.copy(parent_step.box),
pre_edit_steps=pre_edit_steps,
parent_step=parent_step,
batch_size=self.batch_size)

def find_best_edit(self):
"""
Find the next best edit step.

Returns:
EditStep, the next best edit step.
"""
workpiece = self._workpiece
if len(workpiece.shape) == 3:
workpiece = np.expand_dims(workpiece, axis=0)

# generate input tensors with shifted masked/unmasked region and pack into a batch
squeeze = Squeeze()
best_new_workpiece = None
best_output = None
best_edit = None
best_uv = None
batch = np.repeat(workpiece, repeats=self.batch_size, axis=0)
batch_uvs = []
batch_steps = []
batch_i = 0
win_size = self.win_sizes[self.layer]
for u, x in enumerate(self._u_pixels):
for v, y in enumerate(self._v_pixels):
if (u, v) in self._found_uvs:
continue

edit_step = EditStep(self.layer, (x, y, win_size, win_size))

if self.by_masking:
EditStep.apply(batch[batch_i],
self.mask,
[edit_step],
self.by_masking,
inplace=True)
else:
EditStep.apply(self.original_input,
batch[batch_i],
[edit_step],
self.by_masking,
inplace=True)

batch_i += 1
batch_uvs.append((u, v))
batch_steps.append(edit_step)
if batch_i == self.batch_size:
# the batch is full, inference and empty it
batch_output = self.network(Tensor(batch))
batch_output = batch_output[:, self.class_idx]
if len(batch_output.shape) > 1:
batch_output = squeeze(batch_output)
if self.by_masking:
batch_best_i = np.argmin(batch_output.asnumpy())
else:
batch_best_i = np.argmax(batch_output.asnumpy())
batch_best_output = batch_output[int(batch_best_i)].asnumpy().item()

if best_output is None or self._is_output0_better(batch_best_output, best_output):
best_output = batch_best_output
best_uv = batch_uvs[batch_best_i]
best_edit = batch_steps[batch_best_i]
best_new_workpiece = batch[batch_best_i]

batch = np.repeat(workpiece, repeats=self.batch_size, axis=0)
batch_uvs = []
batch_i = 0

if batch_i > 0:
# don't forget the last half full batch
batch_output = self.network(Tensor(batch))
batch_output = batch_output[:, self.class_idx]
if len(batch_output.shape) > 1:
batch_output = squeeze(batch_output)
if self.by_masking:
batch_best_i = np.argmin(batch_output.asnumpy()[:batch_i, ...])
else:
batch_best_i = np.argmax(batch_output.asnumpy()[:batch_i, ...])

batch_best_output = batch_output[int(batch_best_i)].asnumpy().item()
if best_output is None or self._is_output0_better(batch_best_output, best_output):
best_output = batch_best_output
best_uv = batch_uvs[batch_best_i]
best_edit = batch_steps[batch_best_i]
best_new_workpiece = batch[batch_best_i]

if best_edit is None:
raise _NoNewStepError

if best_uv in self._found_uvs:
raise _RepeatedStepError

self._found_uvs.append(best_uv)
self._found_best_edits.append(best_edit)
best_edit.network_output = best_output

# continue on the best workpiece in the next function call
self._workpiece = best_new_workpiece

return best_edit

def _is_output0_better(self, output0, output1):
"""Check if the network output0 is better."""
if self.by_masking:
return output0 < output1
return output0 > output1

def _calc_uv_pixels(self, begin, length):
"""
Calculate the pixel coordinate of shifts.

Args:
begin (int): The beginning pixel coordinate of search field.
length (int): The length of search field.

Returns:
list[int], pixel coordinate of shifts.
"""
win_size = self.win_sizes[self.layer]
stride = self.strides[self.layer]
shift_count = self._calc_shift_count(length, win_size, stride)
pixels = [0] * shift_count
for i in range(shift_count):
if i == shift_count - 1:
pixels[i] = begin + length - win_size
else:
pixels[i] = begin + i*stride
return pixels

@staticmethod
def _calc_shift_count(length, win_size, stride):
"""
Calculate the number of shifts in search field.

Args:
length (int): The length of search field.
win_size (int): The length of sides of moving window.
stride (int): The stride.

Returns:
int, number of shifts.
"""
if length <= win_size or win_size < stride or stride <= 0:
raise ValueError("Invalid length, win_size or stride.")
count = int(math.ceil((length - win_size)/stride))
if (count - 1)*stride + win_size < length:
return count + 1
return count

+ 2
- 2
mindspore/nn/probability/toolbox/uncertainty_evaluation.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -84,7 +84,7 @@ class UncertaintyEvaluation:
def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1,
epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False):
self.epi_model = model
self.epi_model = deepcopy(model)
self.ale_model = deepcopy(model)
self.epi_train_dataset = train_dataset
self.ale_train_dataset = train_dataset


Loading…
Cancel
Save