Browse Source

clean code

fix sample count bug

change file permission

close fd

parens

add lixiahohui33 as approver
tags/v1.3.0
unknown 4 years ago
parent
commit
38f87a86f3
3 changed files with 230 additions and 205 deletions
  1. +1
    -0
      mindspore/explainer/OWNERS
  2. +201
    -198
      mindspore/explainer/_image_classification_runner.py
  3. +28
    -7
      mindspore/explainer/explanation/_counterfactual/hierarchical_occlusion.py

+ 1
- 0
mindspore/explainer/OWNERS View File

@@ -3,3 +3,4 @@ approvers:
- wangyue01
- wenkai_dist
- lilongfei15
- lixiaohui33

+ 201
- 198
mindspore/explainer/_image_classification_runner.py View File

@@ -59,8 +59,8 @@ def _np_to_image(img_np, mode):
return Image.fromarray(np.uint8(img_np * 255), mode=mode)


class _VerifyFlag:
"""Verification flags of dataset and settings of ImageClassificationRunner."""
class _Verifier:
"""Verification of dataset and settings of ImageClassificationRunner."""
ALL = 0xFFFFFFFF
REGISTRATION = 1
DATA_N_NETWORK = 1 << 1
@@ -68,8 +68,168 @@ class _VerifyFlag:
HOC = 1 << 3
ENVIRONMENT = 1 << 4

def _verify(self, flags):
"""
Verify datasets and settings.

Args:
flags (int): Verification flags, use bitwise or '|' to combine multiple flags.
Possible bitwise flags are shown as follow.

- ALL: Verify everything.
- REGISTRATION: Verify explainer module registration.
- DATA_N_NETWORK: Verify dataset and network.
- SALIENCY: Verify saliency related settings.
- HOC: Verify HOC related settings.
- ENVIRONMENT: Verify the runtime environment.

Raises:
ValueError: Be raised for any data or settings' value problem.
TypeError: Be raised for any data or settings' type problem.
RuntimeError: Be raised for any runtime problem.
"""
if flags & self.ENVIRONMENT:
device_target = context.get_context('device_target')
if device_target not in ("Ascend", "GPU"):
raise RuntimeError(f"Unsupported device_target: '{device_target}', "
f"only 'Ascend' or 'GPU' is supported. "
f"Please call context.set_context(device_target='Ascend') or "
f"context.set_context(device_target='GPU').")
if flags & (self.ENVIRONMENT | self.SALIENCY):
if self._is_saliency_registered:
mode = context.get_context('mode')
if mode != context.PYNATIVE_MODE:
raise RuntimeError("Context mode: GRAPH_MODE is not supported, "
"please call context.set_context(mode=context.PYNATIVE_MODE).")

if flags & self.REGISTRATION:
if self._is_uncertainty_registered and not self._is_saliency_registered:
raise ValueError("Function register_uncertainty() is called but register_saliency() is not.")
if not self._is_saliency_registered and not self._is_hoc_registered:
raise ValueError(
"No explanation module was registered, user should at least call register_saliency() "
"or register_hierarchical_occlusion() once with proper arguments.")

if flags & (self.DATA_N_NETWORK | self.SALIENCY | self.HOC):
self._verify_data()

if flags & self.DATA_N_NETWORK:
self._verify_network()

if flags & self.SALIENCY:
self._verify_saliency()

class ImageClassificationRunner:
def _verify_labels(self):
"""Verify labels."""
label_set = set()
if not self._labels:
raise ValueError(f"The label list provided is empty.")
for i, label in enumerate(self._labels):
if label.strip() == "":
raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is "
f"no empty label.")
if label in label_set:
raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.")
label_set.add(label)

def _verify_ds_inputs_shape(self, sample, inputs, labels):
"""Verify a dataset sample's input shape."""

if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
raise ValueError(
"Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
inputs.shape))
if len(inputs.shape) == 3:
log.warning(
"Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
" dimension as batch data.".format(inputs.shape))
if len(sample) > 1:
if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
raise ValueError(
"Labels shape {} is unrecognizable: outputs should not have more than two dimensions"
" with length greater than 1.".format(labels.shape))

if self._is_hoc_registered:
if inputs.shape[-3] != 3:
raise ValueError(
"Hierarchical occlusion is registered, images must be in 3 channels format, but "
"{} channel(s) is(are) encountered.".format(inputs.shape[-3]))
short_side = min(inputs.shape[-2:])
if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN:
raise ValueError(
"Hierarchical occlusion is registered, images' short side must be equals to or greater then "
"{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side))

def _verify_ds_sample(self, sample):
"""Verify a dataset sample."""
if len(sample) not in [1, 2, 3]:
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
" as columns.")

if len(sample) == 3:
inputs, labels, bboxes = sample
if bboxes.shape[-1] != 4:
raise ValueError("The third element of dataset should be bounding boxes with shape of "
"[batch_size, num_ground_truth, 4].")
else:
if self._benchmarkers is not None:
if any([isinstance(bench, Localization) for bench in self._benchmarkers]):
raise ValueError("The dataset must provide bboxes if Localization is to be computed.")

if len(sample) == 2:
inputs, labels = sample
if len(sample) == 1:
inputs = sample[0]

self._verify_ds_inputs_shape(sample, inputs, labels)

def _verify_data(self):
"""Verify dataset and labels."""
self._verify_labels()

try:
sample = next(self._dataset.create_tuple_iterator())
except StopIteration:
raise ValueError("The dataset provided is empty.")

self._verify_ds_sample(sample)

def _verify_network(self):
"""Verify the network."""
next_element = next(self._dataset.create_tuple_iterator())
inputs, _, _ = self._unpack_next_element(next_element)
prop_test = self._full_network(inputs)
check_value_type("output of network in explainer", prop_test, ms.Tensor)
if prop_test.shape[1] != len(self._labels):
raise ValueError("The dimension of network output does not match the no. of classes. Please "
"check labels or the network in the explainer again.")

def _verify_saliency(self):
"""Verify the saliency settings."""
if self._explainers:
explainer_classes = []
for explainer in self._explainers:
if explainer.__class__ in explainer_classes:
raise ValueError(f"Repeated {explainer.__class__.__name__} explainer! "
"Please make sure all explainers' class is distinct.")
if explainer.network is not self._network:
raise ValueError(f"The network of {explainer.__class__.__name__} explainer is different "
"instance from network of runner. Please make sure they are the same "
"instance.")
explainer_classes.append(explainer.__class__)
if self._benchmarkers:
benchmarker_classes = []
for benchmarker in self._benchmarkers:
if benchmarker.__class__ in benchmarker_classes:
raise ValueError(f"Repeated {benchmarker.__class__.__name__} benchmarker! "
"Please make sure all benchmarkers' class is distinct.")
if isinstance(benchmarker, LabelSensitiveMetric) and benchmarker.num_labels != len(self._labels):
raise ValueError(f"The num_labels of {benchmarker.__class__.__name__} benchmarker is different "
"from no. of labels of runner. Please make them are the same.")
benchmarker_classes.append(benchmarker.__class__)


class ImageClassificationRunner(_Verifier):
"""
A high-level API for users to generate and store results of the explanation methods and the evaluation methods.

@@ -132,9 +292,9 @@ class ImageClassificationRunner:
# printing spacer
_SPACER = "{:120}\r"
# datafile directory's permission
_DIR_MODE = 0o750
_DIR_MODE = 0o700
# datafile's permission
_FILE_MODE = 0o600
_FILE_MODE = 0o400

def __init__(self,
summary_dir,
@@ -169,7 +329,7 @@ class ImageClassificationRunner:
self._full_network = SequentialCell([self._network, activation_fn])
self._full_network.set_train(False)

self._verify(_VerifyFlag.DATA_N_NETWORK | _VerifyFlag.ENVIRONMENT)
self._verify(_Verifier.DATA_N_NETWORK | _Verifier.ENVIRONMENT)

def register_saliency(self,
explainers,
@@ -211,7 +371,7 @@ class ImageClassificationRunner:
self._benchmarkers = benchmarkers

try:
self._verify(_VerifyFlag.SALIENCY | _VerifyFlag.ENVIRONMENT)
self._verify(_Verifier.SALIENCY | _Verifier.ENVIRONMENT)
except (ValueError, TypeError):
self._explainers = None
self._benchmarkers = None
@@ -235,7 +395,7 @@ class ImageClassificationRunner:
self._hoc_searcher = hoc.Searcher(self._full_network)

try:
self._verify(_VerifyFlag.HOC | _VerifyFlag.ENVIRONMENT)
self._verify(_Verifier.HOC | _Verifier.ENVIRONMENT)
except ValueError:
self._hoc_searcher = None
raise
@@ -274,7 +434,7 @@ class ImageClassificationRunner:
TypeError: Be raised for any data or settings' type problem.
RuntimeError: Be raised for any runtime problem.
"""
self._verify(_VerifyFlag.ALL)
self._verify(_Verifier.ALL)
self._manifest = {"saliency_map": False,
"benchmark": False,
"uncertainty": False,
@@ -358,25 +518,24 @@ class ImageClassificationRunner:
sample_id_labels = {}
self._sample_index = 0
ds.config.set_seed(self._DATASET_SEED)
for j, next_element in enumerate(self._dataset):
for j, batch in enumerate(self._dataset):
now = time()
self._run_sample(summary, next_element, sample_id_labels, threshold)
self._sample_index += 1
self._infer_batch(summary, batch, sample_id_labels, threshold)
self._spaced_print("Finish running and writing {}-th batch inference data."
" Time elapsed: {:.3f} s".format(j, time() - now))
return sample_id_labels

def _run_sample(self, summary, next_element, sample_id_labels, threshold):
def _infer_batch(self, summary, batch, sample_id_labels, threshold):
"""
Run inference for a sample.
Infer a batch.

Args:
summary (SummaryRecord): The summary object to store the data.
next_element (tuple): The next dataset sample.
batch (tuple): The next dataset sample.
sample_id_labels (dict): The sample id to labels dictionary.
threshold (float): The threshold for prediction.
"""
inputs, labels, _ = self._unpack_next_element(next_element)
inputs, labels, _ = self._unpack_next_element(batch)
prob = self._full_network(inputs).asnumpy()

if self._uncertainty is not None:
@@ -436,6 +595,8 @@ class ImageClassificationRunner:
if self._is_hoc_registered:
self._run_hoc(summary, self._sample_index, inputs[idx], prob[idx])

self._sample_index += 1

def _run_explainer(self, summary, sample_id_labels, explainer):
"""
Run the explainer.
@@ -689,166 +850,6 @@ class ImageClassificationRunner:
sds[i] = 0
return itl_lows, itl_his, sds

def _verify(self, flags):
"""
Verify datasets and settings.

Args:
flags (int): Verification flags, use bitwise or '|' to combine multiple flags.
Possible bitwise flags are shown as follow.

- _VerifyFlag.ALL: Verify everything.
- _VerifyFlag.REGISTRATION: Verify explainer module registration.
- _VerifyFlag.DATA_N_NETWORK: Verify dataset and network.
- _VerifyFlag.SALIENCY: Verify saliency related settings.
- _VerifyFlag.HOC: Verify HOC related settings.
- _VerifyFlag.ENVIRONMENT: Verify the runtime environment.

Raises:
ValueError: Be raised for any data or settings' value problem.
TypeError: Be raised for any data or settings' type problem.
RuntimeError: Be raised for any runtime problem.
"""
if flags & _VerifyFlag.ENVIRONMENT:
device_target = context.get_context('device_target')
if device_target not in ("Ascend", "GPU"):
raise RuntimeError(f"Unsupported device_target: '{device_target}', "
f"only 'Ascend' or 'GPU' is supported. "
f"Please call context.set_context(device_target='Ascend') or "
f"context.set_context(device_target='GPU').")
if flags & (_VerifyFlag.ENVIRONMENT | _VerifyFlag.SALIENCY):
if self._is_saliency_registered:
mode = context.get_context('mode')
if mode != context.PYNATIVE_MODE:
raise RuntimeError("Context mode: GRAPH_MODE is not supported, "
"please call context.set_context(mode=context.PYNATIVE_MODE).")

if flags & _VerifyFlag.REGISTRATION:
if self._is_uncertainty_registered and not self._is_saliency_registered:
raise ValueError("Function register_uncertainty() is called but register_saliency() is not.")
if not self._is_saliency_registered and not self._is_hoc_registered:
raise ValueError(
"No explanation module was registered, user should at least call register_saliency() "
"or register_hierarchical_occlusion() once with proper arguments.")

if flags & (_VerifyFlag.DATA_N_NETWORK | _VerifyFlag.SALIENCY | _VerifyFlag.HOC):
self._verify_data()

if flags & _VerifyFlag.DATA_N_NETWORK:
self._verify_network()

if flags & _VerifyFlag.SALIENCY:
self._verify_saliency()

def _verify_labels(self):
"""Verify labels."""
label_set = set()
if not self._labels:
raise ValueError(f"The label list provided is empty.")
for i, label in enumerate(self._labels):
if label.strip() == "":
raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is "
f"no empty label.")
if label in label_set:
raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.")
label_set.add(label)

def _verify_ds_inputs_shape(self, sample, inputs, labels):
"""Verify a dataset sample's input shape."""

if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
raise ValueError(
"Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
inputs.shape))
if len(inputs.shape) == 3:
log.warning(
"Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
" dimension as batch data.".format(inputs.shape))
if len(sample) > 1:
if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
raise ValueError(
"Labels shape {} is unrecognizable: outputs should not have more than two dimensions"
" with length greater than 1.".format(labels.shape))

if self._is_hoc_registered:
if inputs.shape[-3] != 3:
raise ValueError(
"Hierarchical occlusion is registered, images must be in 3 channels format, but "
"{} channel(s) is(are) encountered.".format(inputs.shape[-3]))
short_side = min(inputs.shape[-2:])
if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN:
raise ValueError(
"Hierarchical occlusion is registered, images' short side must be equals to or greater then "
"{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side))

def _verify_ds_sample(self, sample):
"""Verify a dataset sample."""
if len(sample) not in [1, 2, 3]:
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
" as columns.")

if len(sample) == 3:
inputs, labels, bboxes = sample
if bboxes.shape[-1] != 4:
raise ValueError("The third element of dataset should be bounding boxes with shape of "
"[batch_size, num_ground_truth, 4].")
else:
if self._benchmarkers is not None:
if any([isinstance(bench, Localization) for bench in self._benchmarkers]):
raise ValueError("The dataset must provide bboxes if Localization is to be computed.")

if len(sample) == 2:
inputs, labels = sample
if len(sample) == 1:
inputs = sample[0]

self._verify_ds_inputs_shape(sample, inputs, labels)

def _verify_data(self):
"""Verify dataset and labels."""
self._verify_labels()

try:
sample = next(self._dataset.create_tuple_iterator())
except StopIteration:
raise ValueError("The dataset provided is empty.")

self._verify_ds_sample(sample)

def _verify_network(self):
"""Verify the network."""
next_element = next(self._dataset.create_tuple_iterator())
inputs, _, _ = self._unpack_next_element(next_element)
prop_test = self._full_network(inputs)
check_value_type("output of network in explainer", prop_test, ms.Tensor)
if prop_test.shape[1] != len(self._labels):
raise ValueError("The dimension of network output does not match the no. of classes. Please "
"check labels or the network in the explainer again.")

def _verify_saliency(self):
"""Verify the saliency settings."""
if self._explainers:
explainer_classes = []
for explainer in self._explainers:
if explainer.__class__ in explainer_classes:
raise ValueError(f"Repeated {explainer.__class__.__name__} explainer! "
"Please make sure all explainers' class is distinct.")
if explainer.network is not self._network:
raise ValueError(f"The network of {explainer.__class__.__name__} explainer is different "
"instance from network of runner. Please make sure they are the same "
"instance.")
explainer_classes.append(explainer.__class__)
if self._benchmarkers:
benchmarker_classes = []
for benchmarker in self._benchmarkers:
if benchmarker.__class__ in benchmarker_classes:
raise ValueError(f"Repeated {benchmarker.__class__.__name__} benchmarker! "
"Please make sure all benchmarkers' class is distinct.")
if isinstance(benchmarker, LabelSensitiveMetric) and benchmarker.num_labels != len(self._labels):
raise ValueError(f"The num_labels of {benchmarker.__class__.__name__} benchmarker is different "
"from no. of labels of runner. Please make them are the same.")
benchmarker_classes.append(benchmarker.__class__)

def _transform_bboxes(self, inputs, labels, bboxes, ifbbox):
"""
Transform the bounding boxes.
@@ -863,26 +864,28 @@ class ImageClassificationRunner:
bboxes (Union[list[dict], None, Tensor]): the bounding boxes
"""
input_len = len(inputs)
if bboxes is not None and ifbbox:
bboxes = ms.Tensor(bboxes, ms.int32)
masks_lst = []
labels = labels.asnumpy().reshape([input_len, -1])
bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
for idx, label in enumerate(labels):
height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
masks = {}
for j, label_item in enumerate(label):
target = int(label_item)
if -1 < target < len(self._labels):
if target not in masks:
mask = np.zeros((1, 1, height, width))
else:
mask = masks[target]
x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
masks[target] = mask
masks_lst.append(masks)
bboxes = masks_lst
if bboxes is None or not ifbbox:
return bboxes
bboxes = ms.Tensor(bboxes, ms.int32)
masks_lst = []
labels = labels.asnumpy().reshape([input_len, -1])
bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
for idx, label in enumerate(labels):
height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
masks = {}
for j, label_item in enumerate(label):
target = int(label_item)
if not -1 < target < len(self._labels):
continue
if target not in masks:
mask = np.zeros((1, 1, height, width))
else:
mask = masks[target]
x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
masks[target] = mask
masks_lst.append(masks)
bboxes = masks_lst
return bboxes

def _transform_data(self, inputs, labels, bboxes, ifbbox):
@@ -976,7 +979,7 @@ class ImageClassificationRunner:
self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp)]
abs_dir_path = self._create_subdir(*path_tokens)
save_path = os.path.join(abs_dir_path, self._MANIFEST_FILENAME)
fd = os.open(save_path, os.O_WRONLY | os.O_CREAT)
fd = os.open(save_path, os.O_WRONLY | os.O_CREAT, mode=self._FILE_MODE)
file = os.fdopen(fd, "w")
try:
json.dump(self._manifest, file, indent=4)
@@ -984,7 +987,7 @@ class ImageClassificationRunner:
log.error(f"Failed to save manifest as {save_path}!")
raise
finally:
file.close()
os.close(fd)
os.chmod(save_path, self._FILE_MODE)

def _save_original_image(self, sample_id, image):


+ 28
- 7
mindspore/explainer/explanation/_counterfactual/hierarchical_occlusion.py View File

@@ -608,20 +608,21 @@ class Searcher:
job.parent_step.add_child(step)
job_queue.extend(sub_job_queue)

def _process_job(self, job, sample_input, job_queue):
def _prepare_job(self, job, sample_input):
"""
Process a job.
Prepare a job for process.

Args:
job (_SearchJob): Search job to be processed.
sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format.
job_queue (list[_SearchJob]): Job queue.

Returns:
tuple[list[EditStep], _StopReason], result edit stop and the stop reason.
"""
edit_steps = []
numpy.ndarray, the image tensor workpiece.

Raise:
OriginalOutputError: Be raised if network output of the original image is not strictly larger than the
threshold.
"""
# make the network output with the original image is strictly larger than the threshold
if job.layer == 0:
original_output = self._network(Tensor(sample_input))[0, job.class_idx].asnumpy().item()
@@ -644,9 +645,29 @@ class Searcher:
self._by_masking)

job.on_start(sample_input, workpiece, self._compiled_mask, self._network)
return workpiece

def _process_job(self, job, sample_input, job_queue):
"""
Process a job.

Args:
job (_SearchJob): Search job to be processed.
sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format.
job_queue (list[_SearchJob]): Job queue.

Returns:
tuple[list[EditStep], _StopReason], result edit stop and the stop reason.

Raise:
OriginalOutputError: Be raised if network output of the original image is not strictly larger than the
threshold.
"""
workpiece = self._prepare_job(job, sample_input)

start_output = self._network(Tensor(workpiece))[0, job.class_idx].asnumpy().item()
last_output = start_output

edit_steps = []
# greedy search loop
while True:



Loading…
Cancel
Save