Browse Source

!11309 Add Hierarchical Occlusion Counterfactual

From: @ngtony
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
910772cea8
5 changed files with 1240 additions and 29 deletions
  1. +19
    -6
      mindspore/ccsrc/utils/summary.proto
  2. +225
    -21
      mindspore/explainer/_image_classification_runner.py
  3. +15
    -0
      mindspore/explainer/explanation/_counterfactual/__init__.py
  4. +979
    -0
      mindspore/explainer/explanation/_counterfactual/hierarchical_occlusion.py
  5. +2
    -2
      mindspore/nn/probability/toolbox/uncertainty_evaluation.py

+ 19
- 6
mindspore/ccsrc/utils/summary.proto View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -121,7 +121,7 @@ message Explain {
optional string explain_method = 1;
optional int32 label = 2;
optional string heatmap_path = 3;
}
}

message Benchmark{
optional string benchmark_method = 1;
@@ -131,10 +131,21 @@ message Explain {
}

message Metadata{
repeated string label = 1;
repeated string explain_method = 2;
repeated string benchmark_method = 3;
}
repeated string label = 1;
repeated string explain_method = 2;
repeated string benchmark_method = 3;
}

message HocLayer {
optional float prob = 1;
repeated int32 box = 2; // List of repeated x, y, w, h
}

message Hoc {
optional int32 label = 1;
optional string mask = 2;
repeated HocLayer layer = 3;
}

optional int32 sample_id = 1;
optional string image_path = 2; // The Metadata and image path must have one fill in
@@ -146,4 +157,6 @@ message Explain {

optional Metadata metadata = 7;
optional string status = 8; // enum value: run, end

repeated Hoc hoc = 9; // hierarchical occlusion counterfactual
}

+ 225
- 21
mindspore/explainer/_image_classification_runner.py View File

@@ -15,9 +15,11 @@
"""Image Classification Runner."""
import os
import re
import json
from time import time

import numpy as np
from scipy.stats import beta
from PIL import Image

import mindspore as ms
@@ -30,10 +32,15 @@ from mindspore.train._utils import check_value_type
from mindspore.train.summary._summary_adapter import _convert_image_format
from mindspore.train.summary.summary_record import SummaryRecord
from mindspore.train.summary_pb2 import Explain
from .benchmark import Localization
from .explanation import RISE
from .benchmark._attribution.metric import AttributionMetric, LabelSensitiveMetric, LabelAgnosticMetric
from .explanation._attribution.attribution import Attribution
from mindspore.nn.probability.toolbox.uncertainty_evaluation import UncertaintyEvaluation
from mindspore.explainer.benchmark import Localization
from mindspore.explainer.benchmark._attribution.metric import AttributionMetric
from mindspore.explainer.benchmark._attribution.metric import LabelSensitiveMetric
from mindspore.explainer.benchmark._attribution.metric import LabelAgnosticMetric
from mindspore.explainer.explanation import RISE
from mindspore.explainer.explanation._attribution.attribution import Attribution
from mindspore.explainer.explanation._counterfactual import hierarchical_occlusion as hoc


_EXPAND_DIMS = ExpandDims()

@@ -97,6 +104,8 @@ class ImageClassificationRunner:
_DATAFILE_DIRNAME_PREFIX = "_explain_"
_ORIGINAL_IMAGE_DIRNAME = "origin_images"
_HEATMAP_DIRNAME = "heatmap"
# specfial filenames
_MANIFEST_FILENAME = "manifest.json"
# max. no. of sample per directory
_SAMPLE_PER_DIR = 1000
# seed for fixing the iterating order of the dataset
@@ -132,11 +141,15 @@ class ImageClassificationRunner:
self._network = network
self._explainers = None
self._benchmarkers = None
self._uncertainty = None
self._hoc_searcher = None
self._summary_timestamp = None
self._sample_index = -1

self._full_network = SequentialCell([self._network, activation_fn])

self._manifest = None

self._verify_data_n_settings(check_data_n_network=True)

def register_saliency(self,
@@ -185,6 +198,50 @@ class ImageClassificationRunner:
self._benchmarkers = None
raise

def register_hierarchical_occlusion(self):
"""
Register hierarchical occlusion instances.

Notes:
Input images are required to be in 3 channels formats and the length of side short must be equals to or
greater than 56 pixels.

Raises:
ValueError: Be raised for any data or settings' value problem.
RuntimeError: Be raised if the function was called already.
"""
if self._hoc_searcher is not None:
raise RuntimeError("Function register_hierarchical_occlusion() was invoked already.")

self._hoc_searcher = hoc.Searcher(self._full_network)

try:
self._verify_data_n_settings(check_hoc=True)
except ValueError:
self._hoc_searcher = None
raise

def register_uncertainty(self):
"""
Register uncertainty instance to compute the epistemic uncertainty base on the Bayes' theorem.

Notes:
Please refer to the documentation of mindspore.nn.probability.toolbox.uncertainty_evaluation for the
details. The actual output is standard deviation of the classification predictions and the corresponding
95% confidence intervals. Users have to invoke register_saliency() as well for the uncertainty results are
going to be shown on the saliency map page in MindInsight.

Raises:
RuntimeError: Be raised if the function was called already.
"""
if self._uncertainty is not None:
raise RuntimeError("Function register_uncertainty() was invoked already.")

self._uncertainty = UncertaintyEvaluation(model=self._full_network,
train_dataset=None,
task_type='classification',
num_classes=len(self._labels))

def run(self):
"""
Run the explain job and save the result as a summary in summary_dir.
@@ -198,7 +255,10 @@ class ImageClassificationRunner:
RuntimeError: Be raised for any runtime problem.
"""
self._verify_data_n_settings(check_all=True)

self._manifest = {"saliency_map": False,
"benchmark": False,
"uncertainty": False,
"hierarchical_occlusion": False}
with SummaryRecord(self._summary_dir, raise_exception=True) as summary:
print("Start running and writing......")
begin = time()
@@ -214,13 +274,25 @@ class ImageClassificationRunner:
if self._is_saliency_registered:
self._run_saliency(summary, imageid_labels)

self._save_manifest()

print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin))

@property
def _is_hoc_registered(self):
"""Check if HOC module is registered."""
return self._hoc_searcher is not None

@property
def _is_saliency_registered(self):
"""Check if saliency module is registered."""
return bool(self._explainers)

@property
def _is_uncertainty_registered(self):
"""Check if uncertainty module is registered."""
return self._uncertainty is not None

def _save_metadata(self, summary):
"""Save metadata of the explain job to summary."""
print("Start writing metadata......")
@@ -245,12 +317,13 @@ class ImageClassificationRunner:
Run inference for the dataset and write the inference related data into summary.

Args:
summary (SummaryRecord): The summary object to store the data
summary (SummaryRecord): The summary object to store the data.
threshold (float): The threshold for prediction.

Returns:
dict, The map of sample d to the union of its ground truth and predicted labels.
"""
has_uncertainty_rec = False
sample_id_labels = {}
self._sample_index = 0
ds.config.set_seed(self._DATASET_SEED)
@@ -259,10 +332,20 @@ class ImageClassificationRunner:
inputs, labels, _ = self._unpack_next_element(next_element)
prob = self._full_network(inputs).asnumpy()

if self._uncertainty is not None:
prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
else:
prob_var = None

for idx, inp in enumerate(inputs):
gt_labels = labels[idx]
gt_probs = [float(prob[idx][i]) for i in gt_labels]

if prob_var is not None:
gt_prob_vars = [float(prob_var[idx][i]) for i in gt_labels]
gt_itl_lows, gt_itl_his, gt_prob_sds = \
self._calc_beta_intervals(gt_probs, gt_prob_vars)

data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
original_image = _np_to_image(_normalize(data_np), mode='RGB')
original_image_path = self._save_original_image(self._sample_index, original_image)
@@ -270,6 +353,11 @@ class ImageClassificationRunner:
predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]]
predicted_probs = [float(prob[idx][i]) for i in predicted_labels]

if prob_var is not None:
predicted_prob_vars = [float(prob_var[idx][i]) for i in predicted_labels]
predicted_itl_lows, predicted_itl_his, predicted_prob_sds = \
self._calc_beta_intervals(predicted_probs, predicted_prob_vars)

union_labs = list(set(gt_labels + predicted_labels))
sample_id_labels[str(self._sample_index)] = union_labs

@@ -285,14 +373,29 @@ class ImageClassificationRunner:
explain.inference.predicted_label.extend(predicted_labels)
explain.inference.predicted_prob.extend(predicted_probs)

summary.add_value("explainer", "inference", explain)
if prob_var is not None:
explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
explain.inference.ground_truth_prob_itl95_low.extend(gt_itl_lows)
explain.inference.ground_truth_prob_itl95_hi.extend(gt_itl_his)
explain.inference.predicted_prob_sd.extend(predicted_prob_sds)
explain.inference.predicted_prob_itl95_low.extend(predicted_itl_lows)
explain.inference.predicted_prob_itl95_hi.extend(predicted_itl_his)

has_uncertainty_rec = True

summary.add_value("explainer", "inference", explain)
summary.record(1)

if self._is_hoc_registered:
self._run_hoc(summary, self._sample_index, inputs[idx], prob[idx])

self._sample_index += 1
self._spaced_print("Finish running and writing {}-th batch inference data."
" Time elapsed: {:.3f} s".format(j, time() - now),
end='')
" Time elapsed: {:.3f} s".format(j, time() - now))

if has_uncertainty_rec:
self._manifest["uncertainty"] = True

return sample_id_labels

def _run_saliency(self, summary, sample_id_labels):
@@ -306,10 +409,10 @@ class ImageClassificationRunner:
for idx, next_element in enumerate(self._dataset):
now = time()
self._spaced_print("Start running {}-th explanation data for {}......".format(
idx, exp.__class__.__name__), end='')
idx, exp.__class__.__name__))
self._run_exp_step(next_element, exp, sample_id_labels, summary)
self._spaced_print("Finish writing {}-th explanation data for {}. Time elapsed: "
"{:.3f} s".format(idx, exp.__class__.__name__, time() - now), end='')
"{:.3f} s".format(idx, exp.__class__.__name__, time() - now))
self._spaced_print(
"Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format(
exp.__class__.__name__, time() - start))
@@ -326,20 +429,20 @@ class ImageClassificationRunner:
for idx, next_element in enumerate(self._dataset):
now = time()
self._spaced_print("Start running {}-th explanation data for {}......".format(
idx, exp.__class__.__name__), end='')
idx, exp.__class__.__name__))
saliency_dict_lst = self._run_exp_step(next_element, exp, sample_id_labels, summary)
self._spaced_print(
"Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format(
idx, exp.__class__.__name__, time() - now), end='')
idx, exp.__class__.__name__, time() - now))
for bench in self._benchmarkers:
now = time()
self._spaced_print(
"Start running {}-th batch {} data for {}......".format(
idx, bench.__class__.__name__, exp.__class__.__name__), end='')
idx, bench.__class__.__name__, exp.__class__.__name__))
self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
self._spaced_print(
"Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format(
idx, bench.__class__.__name__, exp.__class__.__name__, time() - now), end='')
idx, bench.__class__.__name__, exp.__class__.__name__, time() - now))

for bench in self._benchmarkers:
benchmark = explain.benchmark.add()
@@ -355,6 +458,52 @@ class ImageClassificationRunner:
summary.add_value('explainer', 'benchmark', explain)
summary.record(1)

def _run_hoc(self, summary, sample_id, sample_input, prob):
"""
Run HOC search for a sample image, and then save the result to summary.

Args:
summary (SummaryRecord): The summary object to store the data.
sample_id (int): The sample ID.
sample_input (Union[Tensor, np.ndarray]): Sample image tensor in CHW or NCWH(N=1).
prob (Union[Tensor, np.ndarray]): List of sample's classification prediction output, HOC will run for
labels with prediction output strictly larger then HOC searcher's threshold(0.5 by default).
"""
if isinstance(sample_input, ms.Tensor):
sample_input = sample_input.asnumpy()
if len(sample_input.shape) == 3:
sample_input = np.expand_dims(sample_input, axis=0)
has_rec = False
explain = Explain()
explain.sample_id = sample_id
str_mask = hoc.auto_str_mask(sample_input)
compiled_mask = None
for label_idx, label_prob in enumerate(prob):
if label_prob > self._hoc_searcher.threshold:
if compiled_mask is None:
compiled_mask = hoc.compile_mask(str_mask, sample_input)
try:
edit_tree, layer_outputs = self._hoc_searcher.search(sample_input, label_idx, compiled_mask)
except hoc.NoValidResultError as ex:
log.error(f"HOC cannot find result for sample:{sample_id} error:{ex}")
continue
has_rec = True
hoc_rec = explain.hoc.add()
hoc_rec.label = label_idx
hoc_rec.mask = str_mask
layer_count = edit_tree.max_layer + 1
for layer in range(layer_count):
steps = edit_tree.get_layer_or_leaf_steps(layer)
layer_output = layer_outputs[layer]
hoc_layer = hoc_rec.layer.add()
hoc_layer.prob = layer_output
for step in steps:
hoc_layer.box.extend(list(step.box))
if has_rec:
summary.add_value("explainer", "hoc", explain)
summary.record(1)
self._manifest['hierarchical_occlusion'] = True

def _run_exp_step(self, next_element, explainer, sample_id_labels, summary):
"""
Run the explanation for each step and write explanation results into summary.
@@ -368,6 +517,7 @@ class ImageClassificationRunner:
Returns:
list, List of dict that maps label to its corresponding saliency map.
"""
has_saliency_rec = False
inputs, labels, _ = self._unpack_next_element(next_element)
sample_index = self._sample_index
unions = []
@@ -406,11 +556,17 @@ class ImageClassificationRunner:
explanation.heatmap_path = heatmap_path
explanation.label = lab

has_saliency_rec = True

summary.add_value("explainer", "explanation", explain)
summary.record(1)

self._sample_index += 1
saliency_dict_lst.append(saliency_dict)

if has_saliency_rec:
self._manifest['saliency_map'] = True

return saliency_dict_lst

def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
@@ -436,6 +592,26 @@ class ImageClassificationRunner:
else:
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
'receive {}'.format(type(benchmarker)))
self._manifest['benchmark'] = True

@staticmethod
def _calc_beta_intervals(means, variances, prob=0.95):
"""Calculate confidence interval of beta distributions."""
if not isinstance(means, np.ndarray):
means = np.array(means)
if not isinstance(variances, np.ndarray):
variances = np.array(variances)
with np.errstate(divide='ignore'):
coef_a = ((means ** 2) * (1 - means) / variances) - means
coef_b = (coef_a * (1 - means)) / means
itl_lows, itl_his = beta.interval(prob, coef_a, coef_b)
sds = np.sqrt(variances)
for i in range(itl_lows.shape[0]):
if not np.isfinite(sds[i]) or not np.isfinite(itl_lows[i]) or not np.isfinite(itl_his[i]):
itl_lows[i] = means[i]
itl_his[i] = means[i]
sds[i] = 0
return itl_lows, itl_his, sds

def _verify_data(self):
"""Verify dataset and labels."""
@@ -474,6 +650,17 @@ class ImageClassificationRunner:
"Labels shape {} is unrecognizable: outputs should not have more than two dimensions"
" with length greater than 1.".format(labels.shape))

if self._is_hoc_registered:
if inputs.shape[-3] != 3:
raise ValueError(
"Hierarchical occlusion is registered, images must be in 3 channels format, but "
"{} channels is encountered.".format(inputs.shape[-3]))
short_side = min(inputs.shape[-2:])
if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN:
raise ValueError(
"Hierarchical occlusion is registered, images' short side must be equals to or greater then "
"{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side))

def _verify_network(self):
"""Verify the network."""
label_set = set()
@@ -521,7 +708,8 @@ class ImageClassificationRunner:
check_all=False,
check_registration=False,
check_data_n_network=False,
check_saliency=False):
check_saliency=False,
check_hoc=False):
"""
Verify the validity of dataset and other settings.

@@ -530,6 +718,7 @@ class ImageClassificationRunner:
check_registration (bool): Set it True for checking registrations, check if it is enough to invoke run().
check_data_n_network (bool): Set it True for checking data and network.
check_saliency (bool): Set it True for checking saliency related settings.
check_hoc (bool): Set it True for checking HOC related settings.

Raises:
ValueError: Be raised for any data or settings' value problem.
@@ -539,13 +728,16 @@ class ImageClassificationRunner:
check_registration = True
check_data_n_network = True
check_saliency = True
check_hoc = True

if check_registration:
if not self._is_saliency_registered:
raise ValueError("No explanation module was registered, user should at least call register_saliency()"
" once with proper explanation instances")
if not self._is_saliency_registered and not self._is_hoc_registered:
raise ValueError("No explanation module was registered, user should at least call register_saliency() "
"or register_hierarchical_occlusion() once with proper arguments")
if self._is_uncertainty_registered and not self._is_saliency_registered:
raise ValueError("Function register_uncertainty() is invoked but register_saliency() is not.")

if check_data_n_network or check_saliency:
if check_data_n_network or check_saliency or check_hoc:
self._verify_data()

if check_data_n_network:
@@ -658,6 +850,18 @@ class ImageClassificationRunner:

return ms.Tensor(batch_labels, ms.int32)

def _save_manifest(self):
"""Save manifest.json underneath datafile directory."""
if self._manifest is None:
raise RuntimeError("Manifest not yet be initialized.")
path_tokens = [self._summary_dir,
self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp)]
abs_dir_path = self._create_subdir(*path_tokens)
save_path = os.path.join(abs_dir_path, self._MANIFEST_FILENAME)
with open(save_path, 'w') as file:
json.dump(self._manifest, file, indent=4)
os.chmod(save_path, self._FILE_MODE)

def _save_original_image(self, sample_id, image):
"""Save an image to summary directory."""
id_dirname = self._get_sample_dirname(sample_id)
@@ -720,7 +924,7 @@ class ImageClassificationRunner:
return None

@classmethod
def _spaced_print(cls, message, *args, **kwargs):
def _spaced_print(cls, message):
"""Spaced message printing."""
# workaround to print logs starting new line in case line width mismatch.
print(cls._SPACER.format(message))

+ 15
- 0
mindspore/explainer/explanation/_counterfactual/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Counterfactual modules."""

+ 979
- 0
mindspore/explainer/explanation/_counterfactual/hierarchical_occlusion.py View File

@@ -0,0 +1,979 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Hierarchical occlusion edit tree searcher."""
from enum import Enum
import copy
import re
import math

import numpy as np
from scipy.ndimage import gaussian_filter

from mindspore import nn
from mindspore import Tensor
from mindspore.ops import Squeeze
from mindspore.train._utils import check_value_type


AUTO_LAYER_MAX = 3 # maximum number of layer by auto settings
AUTO_WIN_SIZE_MIN = 28 # minimum window size by auto settings
AUTO_WIN_SIZE_DIV = 2 # denominator of windows size calculations by auto settings
AUTO_STRIDE_DIV = 5 # denominator of stride calculations by auto settings
AUTO_MASK_GAUSSIAN_RADIUS_DIV = 25 # denominator of gaussian mask radius calculations by auto settings
DEFAULT_THRESHOLD = 0.5 # default target prediction threshold
DEFAULT_BATCH_SIZE = 64 # default batch size for batch inference search
MASK_GAUSSIAN_RE = r'^gaussian:(\d+)$' # gaussian mask string pattern

# minimum length of input images' short side with auto settings
AUTO_IMAGE_SHORT_SIDE_MIN = AUTO_WIN_SIZE_MIN * AUTO_WIN_SIZE_DIV


def is_valid_str_mask(mask):
"""Check if it is a valid string mask."""
check_value_type('mask', mask, str)
match = re.match(MASK_GAUSSIAN_RE, mask)
return match and int(match.group(1)) > 0


def compile_mask(mask, image):
"""Compile mask to a ready to use object."""
if mask is None:
return compile_str_mask(auto_str_mask(image), image)
check_value_type('mask', mask, (str, tuple, float, np.ndarray))
if isinstance(mask, str):
return compile_str_mask(mask, image)

if isinstance(mask, tuple):
_check_iterable_type('mask', mask, tuple, float)
elif isinstance(mask, np.ndarray):
if len(image.shape) == 4 and len(mask.shape) == 3:
mask = np.expand_dims(mask, axis=0)
elif len(image.shape) == 3 and len(mask.shape) == 4 and mask.shape[0] == 1:
mask = mask.squeeze(0)
if image.shape != mask.shape:
raise ValueError("Image and mask is not match in shape.")
return mask


def auto_str_mask(image):
"""Generate auto string mask for the image."""
check_value_type('image', image, np.ndarray)
short_side = np.min(image.shape[-2:])
radius = int(round(short_side/AUTO_MASK_GAUSSIAN_RADIUS_DIV))
if radius == 0:
raise ValueError(f"Input image's short side:{short_side} is too small for auto mask, "
f"at least {AUTO_MASK_GAUSSIAN_RADIUS_DIV}pixels is required.")
return f'gaussian:{radius}'


def compile_str_mask(mask, image):
"""Concert string mask to numpy.ndarray."""
check_value_type('mask', mask, str)
check_value_type('image', image, np.ndarray)
match = re.match(MASK_GAUSSIAN_RE, mask)
if match:
radius = int(match.group(1))
if radius > 0:
sigma = [0] * len(image.shape)
sigma[-2] = radius
sigma[-1] = radius
return gaussian_filter(image, sigma=sigma, mode='nearest')
raise ValueError(f"Invalid string mask: '{mask}'.")


class EditStep:
"""
Edit step that describes a box region, also represents an edit tree.

Args:
layer (int): Layer number, -1 is root layer, 0 or above is normal edit layer.
box (tuple[int, int, int, int]): Tuple of x, y, width, height.
"""
def __init__(self, layer, box):
self.layer = layer
self.box = box
self.network_output = 0
self.step_change = 0
self.children = None

@property
def x(self):
"""X-coordinate of the box."""
return self.box[0]

@property
def y(self):
"""Y-coordinate of the box."""
return self.box[1]

@property
def width(self):
"""Width of the box."""
return self.box[2]

@property
def height(self):
"""Height of the box."""
return self.box[3]

@property
def is_leaf(self):
"""Returns True if no child edit step."""
return not self.children

@property
def leaf_steps(self):
"""Returns all leaf edit steps in the tree."""
if self.is_leaf:
return [self]
steps = []
for child in self.children:
steps.extend(child.leaf_steps)
return steps

@property
def max_layer(self):
"""Maximum layer number in the edit tree."""
if self.is_leaf:
return self.layer
layer = self.layer
for child in self.children:
child_max_layer = child.max_layer
if child_max_layer > layer:
layer = child_max_layer
return layer

def add_child(self, child):
"""Add a child edit step."""
if self.children is None:
self.children = [child]
else:
self.children.append(child)

def remove_all_children(self):
"""Remove all child steps."""
self.children = None

def get_layer_or_leaf_steps(self, layer):
"""Get all edit steps of the layer and all leaf edit steps above the layer."""
if self.layer == layer or (self.layer < layer and self.is_leaf):
return [self]
steps = []
if self.layer < layer and self.children:
for child in self.children:
steps.extend(child.get_layer_or_leaf_steps(layer))
return steps

def get_layer_steps(self, layer):
"""Get all edit steps of the layer."""
if self.layer == layer:
return [self]
steps = []
if self.layer < layer and self.children:
for child in self.children:
steps.extend(child.get_layer_steps(layer))
return steps

@classmethod
def apply(cls,
image,
mask,
edit_steps,
by_masking=False,
inplace=False):
"""
Apply edit steps.

Args:
image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be
str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9.
tuple[float, float, float]: RGB solid color mask,
float: Grey scale solid color mask.
numpy.ndarray: Image mask in CHW or NCHW(N=1) format.
edit_steps (list[EditStep], optional): Edit steps to be applied.
by_masking (bool): Whether it is masking mode.
inplace (bool): Whether the modification is going to take place in the input image tensor. False to
construct a new image tensor as result.

Returns:
numpy.ndarray, the result image tensor.

Raises:
TypeError: Be raised for any argument or data type problem.
ValueError: Be raised for any argument or data value problem.
"""
if by_masking:
return cls.apply_masking(image, mask, edit_steps, inplace)
return cls.apply_unmasking(image, mask, edit_steps, inplace)

@staticmethod
def apply_masking(image,
mask,
edit_steps,
inplace=False):
"""
Apply edit steps in masking mode.

Args:
image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be
str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9.
tuple[float, float, float]: RGB solid color mask,
float: Grey scale solid color mask.
numpy.ndarray: Image mask in CHW or NCHW(N=1) format.
edit_steps (list[EditStep], optional): Edit steps to be applied.
inplace (bool): Whether the modification is going to take place in the input image tensor. False to
construct a new image tensor as result.

Returns:
numpy.ndarray, the result image tensor.

Raises:
TypeError: Be raised for any argument or data type problem.
ValueError: Be raised for any argument or data value problem.
"""
check_value_type('image', image, np.ndarray)
check_value_type('mask', mask, (str, tuple, float, np.ndarray))
if isinstance(mask, tuple):
_check_iterable_type('mask', mask, tuple, float)

if edit_steps is not None:
_check_iterable_type('edit_steps', edit_steps, (tuple, list), EditStep)

mask = compile_mask(mask, image)

if inplace:
background = image
else:
background = np.copy(image)

if not edit_steps:
return background

for step in edit_steps:

x_max = step.x + step.width
y_max = step.y + step.height

if x_max > background.shape[-1]:
x_max = background.shape[-1]

if y_max > background.shape[-2]:
y_max = background.shape[-2]

if x_max <= step.x or y_max <= step.y:
continue

if isinstance(mask, np.ndarray):
background[..., step.y:y_max, step.x:x_max] = mask[..., step.y:y_max, step.x:x_max]
else:
if isinstance(mask, (int, float)):
mask = (mask, mask, mask)
for c in range(3):
background[..., c, step.y:y_max, step.x:x_max] = mask[c]
return background

@staticmethod
def apply_unmasking(image,
mask,
edit_steps,
inplace=False):
"""
Apply edit steps in unmasking mode.

Args:
image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be
str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9.
tuple[float, float, float]: RGB solid color mask,
float: Grey scale solid color mask.
numpy.ndarray: Image mask in CHW or NCHW(N=1) format.
edit_steps (list[EditStep]): Edit steps to be applied.
inplace (bool): Whether the modification is going to take place in the input mask tensor. False to
construct a new image tensor as result.

Returns:
numpy.ndarray, the result image tensor.

Raises:
TypeError: Be raised for any argument or data type problem.
ValueError: Be raised for any argument or data value problem.
"""
check_value_type('image', image, np.ndarray)
check_value_type('mask', mask, (str, tuple, float, np.ndarray))
if isinstance(mask, tuple):
_check_iterable_type('mask', mask, tuple, float)

if edit_steps is not None:
_check_iterable_type('edit_steps', edit_steps, (tuple, list), EditStep)

mask = compile_mask(mask, image)

if isinstance(mask, np.ndarray):
if inplace:
background = mask
else:
background = np.copy(mask)
else:
if inplace:
raise ValueError('Inplace cannot be True when mask is not a numpy.ndarray')

background = np.zeros_like(image)
if isinstance(mask, (int, float)):
background.fill(mask)
else:
for c in range(3):
background[..., c, :, :] = mask[c]

if not edit_steps:
return background

for step in edit_steps:

x_max = step.x + step.width
y_max = step.y + step.height

if x_max > background.shape[-1]:
x_max = background.shape[-1]

if y_max > background.shape[-2]:
y_max = background.shape[-2]

if x_max <= step.x or y_max <= step.y:
continue

background[..., step.y:y_max, step.x:x_max] = image[..., step.y:y_max, step.x:x_max]

return background


class NoValidResultError(RuntimeError):
"""Error for no edit step layer's network output meet the threshold."""


class OriginalOutputError(RuntimeError):
"""Error for network output of the original image is not strictly larger than the threshold."""


class Searcher:
"""
Edit step searcher.

Args:
network (Cell): Image tensor in CHW or NCHW(N=1) format.
win_sizes (Union(list[int], optional): Moving square window size (length of side) of layers,
None means by auto calcuation.
strides (Union(list[int], optional): Stride of layers, None means by auto calcuation.
threshold (float): Threshold network output value of the target class.
by_masking (bool): Whether it is masking mode.

Examples:
>>> from mindspore import nn
>>> from mindspore.explainer.explanation._counterfactual.hierarchical_occlusion import Searcher, EditStep
>>>
>>> from user_defined import load_network, load_sample_image
>>>
>>>
>>> network = nn.SequentialCell([load_network(), nn.Sigmoid()])
>>>
>>> # single image in CHW or NCHW(N=1) numpy.ndarray tensor, typical dimension is 224x224
>>> image = load_sample_image()
>>> # target class index
>>> class_idx = 5
>>>
>>> # by default, maximum 3 search layers, auto calculate window sizes and strides
>>> searcher = Searcher(network)
>>>
>>> edit_tree, layer_outputs = searcher.search(image, class_idx)
>>> # get the outcome image of the deepest layer in CHW(or NCHW(N=1) if input image is NCHW) format
>>> outcome = EditStep.apply(image, searcher.compiled_mask, edit_tree.leaf_steps)
"""

def __init__(self,
network,
win_sizes=None,
strides=None,
threshold=DEFAULT_THRESHOLD,
by_masking=False):

check_value_type('network', network, nn.Cell)

if win_sizes is not None:
_check_iterable_type('win_sizes', win_sizes, list, int)
if not win_sizes:
raise ValueError('Argument win_sizes is empty.')

for i in range(1, len(win_sizes)):
if win_sizes[i] >= win_sizes[i-1]:
raise ValueError('Argument win_sizes is not strictly descending.')

if win_sizes[-1] <= 0:
raise ValueError('Argument win_sizes has non-positive number.')
elif strides is not None:
raise ValueError('Argument win_sizes cannot be None if strides is not None.')

if strides is not None:
_check_iterable_type('strides', strides, list, int)
for i in range(1, len(strides)):
if strides[i] >= strides[i-1]:
raise ValueError('Argument win_sizes is not strictly descending.')

if strides[-1] <= 0:
raise ValueError('Argument strides has non-positive number.')

if len(strides) != len(win_sizes):
raise ValueError('Length of strides and win_sizes is not equal.')
elif win_sizes is not None:
raise ValueError('Argument strides cannot be None if win_sizes is not None.')

self._network = copy.deepcopy(network)
self._compiled_mask = None
self._threshold = threshold
self._win_sizes = copy.copy(win_sizes) if win_sizes else None
self._strides = copy.copy(strides) if strides else None
self._by_masking = by_masking

@property
def network(self):
"""Get the network."""
return self._network

@property
def by_masking(self):
"""Check if it is masking mode."""
return self._by_masking

@property
def threshold(self):
"""The network output threshold to stop the search."""
return self._threshold

@property
def win_sizes(self):
"""Windows sizes in pixels."""
return self._win_sizes

@property
def strides(self):
"""Strides in pixels."""
return self._strides

@property
def compiled_mask(self):
"""The compiled mask after a successful search() call."""
return self._compiled_mask

def search(self, image, class_idx, mask=None):
"""
Search smallest sufficient/destruction region on an image.

Args:
image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
class_idx (int): Target class index.
mask (Union[str, tuple[float, float, float], float], optional): The mask, type can be
str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9.
tuple[float, float, float]: RGB solid color mask,
float: Grey scale solid color mask.
None: By auto calculation.

Returns:
tuple[EditStep, list[float]], the root edit step and network output of each layer after applied the
layer steps.

Raise:
TypeError: Be raised for any argument or data type problem.
ValueError: Be raised for any argument or data value problem.
NoValidResultError: Be raised if no valid result was found.
OriginalOutputError: Be raised if network output of the original image is not strictly larger than
the threshold.
"""
check_value_type('image', image, (Tensor, np.ndarray))

if isinstance(image, Tensor):
image = image.asnumpy()

if len(image.shape) == 4:
if image.shape[0] != 1:
raise ValueError("Argument image's batch size is not 1.")
elif len(image.shape) == 3:
image = np.expand_dims(image, axis=0)
else:
raise ValueError("Argument image is not in CHW or NCHW(N=1) format.")

check_value_type('class_idx', class_idx, int)

if class_idx < 0:
raise ValueError("Argument class_idx is less then zero.")

self._compiled_mask = compile_mask(mask, image)

short_side = np.min(image.shape[-2:])
if self._win_sizes is None:
win_sizes, strides = self._auto_win_sizes_strides(short_side)
else:
win_sizes, strides = self._win_sizes, self._strides

if short_side <= win_sizes[0]:
raise ValueError(f"Input image's short side is shorter then or "
f"equals to the first window size:{win_sizes[0]}.")

self._network.set_train(False)

# the search result will be store as a edit tree that attached to the root step.
root_step = EditStep(-1, (0, 0, image.shape[-1], image.shape[-2]))
root_job = _SearchJob(by_masking=self._by_masking,
class_idx=class_idx,
win_sizes=win_sizes,
strides=strides,
layer=0,
search_field=root_step.box,
pre_edit_steps=None,
parent_step=root_step)
self._process_root_job(image, root_job)

# the leaf layer's network output may not meet the threshold,
# we have to cutoff the unqualified layers
layer_count = root_step.max_layer + 1
if layer_count == 0:
raise NoValidResultError("No edit step layer was found.")

# gather the network output of each layer
layer_outputs = [None] * layer_count
for layer in range(layer_count):
steps = root_step.get_layer_or_leaf_steps(layer)
if not steps:
continue
masked_image = EditStep.apply(image, self._compiled_mask, steps, by_masking=self._by_masking)
output = self._network(Tensor(masked_image))
output = output[0, class_idx].asnumpy().item()
layer_outputs[layer] = output

# determine which layer we have to cutoff
cutoff_layer = None
for layer in reversed(range(layer_count)):
if layer_outputs[layer] is not None and self._is_threshold_met(layer_outputs[layer]):
cutoff_layer = layer
break

if cutoff_layer is None or root_step.is_leaf:
raise NoValidResultError(f"No edit step layer's network output meet the threshold: {self._threshold}.")

# cutoff the layer by removing all children of the layer's steps.
steps = root_step.get_layer_steps(cutoff_layer)
for step in steps:
step.remove_all_children()
layer_outputs = layer_outputs[:cutoff_layer + 1]

return root_step, layer_outputs

def _process_root_job(self, sample_input, root_job):
"""
Process job queue.

Args:
sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format.
root_job (_SearchJob): Root search job.
"""
job_queue = [root_job]
while job_queue:
job = job_queue.pop(0)
sub_job_queue = []
job_edit_steps, stop_reason = self._process_job(job, sample_input, sub_job_queue)

if stop_reason in (self._StopReason.THRESHOLD_MET, self._StopReason.STEP_CHANGE_MET):
for step in job_edit_steps:
job.parent_step.add_child(step)
job_queue.extend(sub_job_queue)

def _process_job(self, job, sample_input, job_queue):
"""
Process a job.

Args:
job (_SearchJob): Search job to be processed.
sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format.
job_queue (list[_SearchJob]): Job queue.

Returns:
tuple[list[EditStep], _StopReason], result edit stop and the stop reason.
"""
edit_steps = []

# make the network output with the original image is strictly larger than the threshold
if job.layer == 0:
original_output = self._network(Tensor(sample_input))[0, job.class_idx].asnumpy().item()
if original_output <= self._threshold:
raise OriginalOutputError(f'The original output is not strictly larger the threshold: '
f'{self._threshold}')

# applying the pre-edit steps from the parent steps
if job.pre_edit_steps:
# use the latest leaf steps to increase the accuracy
leaf_steps = []
for step in job.pre_edit_steps:
leaf_steps.extend(step.leaf_steps)
pre_edit_steps = leaf_steps
else:
pre_edit_steps = None
workpiece = EditStep.apply(sample_input,
self._compiled_mask,
pre_edit_steps,
self._by_masking)

job.on_start(sample_input, workpiece, self._compiled_mask, self._network)
start_output = self._network(Tensor(workpiece))[0, job.class_idx].asnumpy().item()
last_output = start_output

# greedy search loop
while True:

if self._is_threshold_met(last_output):
return edit_steps, self._StopReason.THRESHOLD_MET

try:
best_edit = job.find_best_edit()
except _NoNewStepError:
return edit_steps, self._StopReason.NO_NEW_STEP
except _RepeatedStepError:
return edit_steps, self._StopReason.REPEATED_STEP

best_edit.step_change = best_edit.network_output - last_output

if job.layer < job.layer_count - 1 and self._is_greedy(best_edit.step_change):
# create net layer search job if new edit step is valid and not yet reaching
# the final layer
if job.pre_edit_steps:
pre_edit_steps = list(job.pre_edit_steps)
pre_edit_steps.extend(edit_steps)
else:
pre_edit_steps = list(edit_steps)

sub_job = job.create_sub_job(best_edit, pre_edit_steps)
job_queue.append(sub_job)

edit_steps.append(best_edit)

if job.layer > 0:
# stop if the step change meet the parent step change only after layer 0
change = best_edit.network_output - start_output
if self._is_step_change_met(job.parent_step.step_change, change):
return edit_steps, self._StopReason.STEP_CHANGE_MET

last_output = best_edit.network_output

def _is_threshold_met(self, network_output):
"""Check if the threshold was met."""
if self._by_masking:
return network_output <= self._threshold
return network_output >= self._threshold

def _is_step_change_met(self, target, step_change):
"""Check if the change target was met."""
if self._by_masking:
return step_change <= target
return step_change >= target

def _is_greedy(self, step_change):
"""Check if it is a greedy step."""
if self._by_masking:
return step_change < 0
return step_change > 0

@classmethod
def _auto_win_sizes_strides(cls, short_side):
"""
Calculate auto window sizes and strides.

Args:
short_side (int): Length of search space.

Returns:
tuple[list[int], list[int]], window sizes and strides.
"""
win_sizes = []
strides = []
cur_len = int(short_side/AUTO_WIN_SIZE_DIV)
while len(win_sizes) < AUTO_LAYER_MAX and cur_len >= AUTO_WIN_SIZE_MIN:
stride = int(cur_len/AUTO_STRIDE_DIV)
if stride <= 0:
break
win_sizes.append(cur_len)
strides.append(stride)
cur_len = int(cur_len/AUTO_WIN_SIZE_DIV)
if not win_sizes:
raise ValueError(f"Image's short side is less then {AUTO_IMAGE_SHORT_SIDE_MIN}, "
f"unable to calculates auto settings.")
return win_sizes, strides

class _StopReason(Enum):
"""Stop reason of search job."""
THRESHOLD_MET = 0 # threshold was met.
STEP_CHANGE_MET = 1 # parent step change was met.
NO_NEW_STEP = 2 # no new step was found.
REPEATED_STEP = 3 # repeated step was found.


def _check_iterable_type(arg_name, arg_value, container_type, elem_types):
"""Concert iterable argument data type."""
check_value_type(arg_name, arg_value, container_type)
for elem in arg_value:
check_value_type(arg_name + ' element', elem, elem_types)


class _NoNewStepError(Exception):
"""Error for no new step was found."""


class _RepeatedStepError(Exception):
"""Error for repeated step was found."""


class _SearchJob:
"""
Search job.

Args:
by_masking (bool): Whether it is masking mode.
class_idx (int): Target class index.
win_sizes (list[int]): Moving square window size (length of side) of layers.
strides (list[int]): Strides of layers.
layer (int): Layer number.
search_field (tuple[int, int, int, int]): Search field in x, y, width, height format.
pre_edit_steps (list[EditStep], optional): Edit steps to be applied before searching.
parent_step (EditStep): Parent edit step.
batch_size (int): Batch size of batched inferences.
"""

def __init__(self,
by_masking,
class_idx,
win_sizes,
strides,
layer,
search_field,
pre_edit_steps,
parent_step,
batch_size=DEFAULT_BATCH_SIZE):

if layer >= len(win_sizes):
raise ValueError('Layer is larger then number of window sizes.')

self.by_masking = by_masking
self.class_idx = class_idx
self.win_sizes = win_sizes
self.strides = strides
self.layer = layer
self.search_field = search_field
self.pre_edit_steps = pre_edit_steps
self.parent_step = parent_step
self.batch_size = batch_size
self.network = None
self.mask = None
self.original_input = None

self._workpiece = None
self._found_best_edits = None
self._found_uvs = None
self._u_pixels = None
self._v_pixels = None

@property
def layer_count(self):
"""Number of layers."""
return len(self.win_sizes)

def on_start(self, original_input, workpiece, mask, network):
"""
Notification of the start of the search job.

Args:
original_input (numpy.ndarray): The original image tensor in CHW or NCHW(N=1) format.
workpiece (numpy.ndarray): The intermediate image tensor in CHW or NCHW(N=1) format.
mask (Union[tuple[float, float, float], float, numpy.ndarray]): The mask, type can be
tuple[float, float, float]: RGB solid color mask,
float: Grey scale solid color mask.
numpy.ndarray: Image mask, has same format of original_input.
network (nn.Cell): Classification network.
"""
self.original_input = original_input
self.mask = mask
self.network = network

self._workpiece = workpiece
self._found_best_edits = []
self._found_uvs = []
self._u_pixels = self._calc_uv_pixels(self.search_field[0], self.search_field[2])
self._v_pixels = self._calc_uv_pixels(self.search_field[1], self.search_field[3])

def create_sub_job(self, parent_step, pre_edit_steps):
"""Create next layer search job."""
return self.__class__(by_masking=self.by_masking,
class_idx=self.class_idx,
win_sizes=self.win_sizes,
strides=self.strides,
layer=self.layer + 1,
search_field=copy.copy(parent_step.box),
pre_edit_steps=pre_edit_steps,
parent_step=parent_step,
batch_size=self.batch_size)

def find_best_edit(self):
"""
Find the next best edit step.

Returns:
EditStep, the next best edit step.
"""
workpiece = self._workpiece
if len(workpiece.shape) == 3:
workpiece = np.expand_dims(workpiece, axis=0)

# generate input tensors with shifted masked/unmasked region and pack into a batch
squeeze = Squeeze()
best_new_workpiece = None
best_output = None
best_edit = None
best_uv = None
batch = np.repeat(workpiece, repeats=self.batch_size, axis=0)
batch_uvs = []
batch_steps = []
batch_i = 0
win_size = self.win_sizes[self.layer]
for u, x in enumerate(self._u_pixels):
for v, y in enumerate(self._v_pixels):
if (u, v) in self._found_uvs:
continue

edit_step = EditStep(self.layer, (x, y, win_size, win_size))

if self.by_masking:
EditStep.apply(batch[batch_i],
self.mask,
[edit_step],
self.by_masking,
inplace=True)
else:
EditStep.apply(self.original_input,
batch[batch_i],
[edit_step],
self.by_masking,
inplace=True)

batch_i += 1
batch_uvs.append((u, v))
batch_steps.append(edit_step)
if batch_i == self.batch_size:
# the batch is full, inference and empty it
batch_output = self.network(Tensor(batch))
batch_output = batch_output[:, self.class_idx]
if len(batch_output.shape) > 1:
batch_output = squeeze(batch_output)
if self.by_masking:
batch_best_i = np.argmin(batch_output.asnumpy())
else:
batch_best_i = np.argmax(batch_output.asnumpy())
batch_best_output = batch_output[int(batch_best_i)].asnumpy().item()

if best_output is None or self._is_output0_better(batch_best_output, best_output):
best_output = batch_best_output
best_uv = batch_uvs[batch_best_i]
best_edit = batch_steps[batch_best_i]
best_new_workpiece = batch[batch_best_i]

batch = np.repeat(workpiece, repeats=self.batch_size, axis=0)
batch_uvs = []
batch_i = 0

if batch_i > 0:
# don't forget the last half full batch
batch_output = self.network(Tensor(batch))
batch_output = batch_output[:, self.class_idx]
if len(batch_output.shape) > 1:
batch_output = squeeze(batch_output)
if self.by_masking:
batch_best_i = np.argmin(batch_output.asnumpy()[:batch_i, ...])
else:
batch_best_i = np.argmax(batch_output.asnumpy()[:batch_i, ...])

batch_best_output = batch_output[int(batch_best_i)].asnumpy().item()
if best_output is None or self._is_output0_better(batch_best_output, best_output):
best_output = batch_best_output
best_uv = batch_uvs[batch_best_i]
best_edit = batch_steps[batch_best_i]
best_new_workpiece = batch[batch_best_i]

if best_edit is None:
raise _NoNewStepError

if best_uv in self._found_uvs:
raise _RepeatedStepError

self._found_uvs.append(best_uv)
self._found_best_edits.append(best_edit)
best_edit.network_output = best_output

# continue on the best workpiece in the next function call
self._workpiece = best_new_workpiece

return best_edit

def _is_output0_better(self, output0, output1):
"""Check if the network output0 is better."""
if self.by_masking:
return output0 < output1
return output0 > output1

def _calc_uv_pixels(self, begin, length):
"""
Calculate the pixel coordinate of shifts.

Args:
begin (int): The beginning pixel coordinate of search field.
length (int): The length of search field.

Returns:
list[int], pixel coordinate of shifts.
"""
win_size = self.win_sizes[self.layer]
stride = self.strides[self.layer]
shift_count = self._calc_shift_count(length, win_size, stride)
pixels = [0] * shift_count
for i in range(shift_count):
if i == shift_count - 1:
pixels[i] = begin + length - win_size
else:
pixels[i] = begin + i*stride
return pixels

@staticmethod
def _calc_shift_count(length, win_size, stride):
"""
Calculate the number of shifts in search field.

Args:
length (int): The length of search field.
win_size (int): The length of sides of moving window.
stride (int): The stride.

Returns:
int, number of shifts.
"""
if length <= win_size or win_size < stride or stride <= 0:
raise ValueError("Invalid length, win_size or stride.")
count = int(math.ceil((length - win_size)/stride))
if (count - 1)*stride + win_size < length:
return count + 1
return count

+ 2
- 2
mindspore/nn/probability/toolbox/uncertainty_evaluation.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -84,7 +84,7 @@ class UncertaintyEvaluation:
def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1,
epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False):
self.epi_model = model
self.epi_model = deepcopy(model)
self.ale_model = deepcopy(model)
self.epi_train_dataset = train_dataset
self.ale_train_dataset = train_dataset


Loading…
Cancel
Save