|
|
|
@@ -59,8 +59,8 @@ def _np_to_image(img_np, mode): |
|
|
|
return Image.fromarray(np.uint8(img_np * 255), mode=mode) |
|
|
|
|
|
|
|
|
|
|
|
class _VerifyFlag: |
|
|
|
"""Verification flags of dataset and settings of ImageClassificationRunner.""" |
|
|
|
class _Verifier: |
|
|
|
"""Verification of dataset and settings of ImageClassificationRunner.""" |
|
|
|
ALL = 0xFFFFFFFF |
|
|
|
REGISTRATION = 1 |
|
|
|
DATA_N_NETWORK = 1 << 1 |
|
|
|
@@ -68,8 +68,168 @@ class _VerifyFlag: |
|
|
|
HOC = 1 << 3 |
|
|
|
ENVIRONMENT = 1 << 4 |
|
|
|
|
|
|
|
def _verify(self, flags): |
|
|
|
""" |
|
|
|
Verify datasets and settings. |
|
|
|
|
|
|
|
Args: |
|
|
|
flags (int): Verification flags, use bitwise or '|' to combine multiple flags. |
|
|
|
Possible bitwise flags are shown as follow. |
|
|
|
|
|
|
|
- ALL: Verify everything. |
|
|
|
- REGISTRATION: Verify explainer module registration. |
|
|
|
- DATA_N_NETWORK: Verify dataset and network. |
|
|
|
- SALIENCY: Verify saliency related settings. |
|
|
|
- HOC: Verify HOC related settings. |
|
|
|
- ENVIRONMENT: Verify the runtime environment. |
|
|
|
|
|
|
|
Raises: |
|
|
|
ValueError: Be raised for any data or settings' value problem. |
|
|
|
TypeError: Be raised for any data or settings' type problem. |
|
|
|
RuntimeError: Be raised for any runtime problem. |
|
|
|
""" |
|
|
|
if flags & self.ENVIRONMENT: |
|
|
|
device_target = context.get_context('device_target') |
|
|
|
if device_target not in ("Ascend", "GPU"): |
|
|
|
raise RuntimeError(f"Unsupported device_target: '{device_target}', " |
|
|
|
f"only 'Ascend' or 'GPU' is supported. " |
|
|
|
f"Please call context.set_context(device_target='Ascend') or " |
|
|
|
f"context.set_context(device_target='GPU').") |
|
|
|
if flags & (self.ENVIRONMENT | self.SALIENCY): |
|
|
|
if self._is_saliency_registered: |
|
|
|
mode = context.get_context('mode') |
|
|
|
if mode != context.PYNATIVE_MODE: |
|
|
|
raise RuntimeError("Context mode: GRAPH_MODE is not supported, " |
|
|
|
"please call context.set_context(mode=context.PYNATIVE_MODE).") |
|
|
|
|
|
|
|
if flags & self.REGISTRATION: |
|
|
|
if self._is_uncertainty_registered and not self._is_saliency_registered: |
|
|
|
raise ValueError("Function register_uncertainty() is called but register_saliency() is not.") |
|
|
|
if not self._is_saliency_registered and not self._is_hoc_registered: |
|
|
|
raise ValueError( |
|
|
|
"No explanation module was registered, user should at least call register_saliency() " |
|
|
|
"or register_hierarchical_occlusion() once with proper arguments.") |
|
|
|
|
|
|
|
if flags & (self.DATA_N_NETWORK | self.SALIENCY | self.HOC): |
|
|
|
self._verify_data() |
|
|
|
|
|
|
|
if flags & self.DATA_N_NETWORK: |
|
|
|
self._verify_network() |
|
|
|
|
|
|
|
if flags & self.SALIENCY: |
|
|
|
self._verify_saliency() |
|
|
|
|
|
|
|
class ImageClassificationRunner: |
|
|
|
def _verify_labels(self): |
|
|
|
"""Verify labels.""" |
|
|
|
label_set = set() |
|
|
|
if not self._labels: |
|
|
|
raise ValueError(f"The label list provided is empty.") |
|
|
|
for i, label in enumerate(self._labels): |
|
|
|
if label.strip() == "": |
|
|
|
raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is " |
|
|
|
f"no empty label.") |
|
|
|
if label in label_set: |
|
|
|
raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.") |
|
|
|
label_set.add(label) |
|
|
|
|
|
|
|
def _verify_ds_inputs_shape(self, sample, inputs, labels): |
|
|
|
"""Verify a dataset sample's input shape.""" |
|
|
|
|
|
|
|
if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]: |
|
|
|
raise ValueError( |
|
|
|
"Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format( |
|
|
|
inputs.shape)) |
|
|
|
if len(inputs.shape) == 3: |
|
|
|
log.warning( |
|
|
|
"Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th" |
|
|
|
" dimension as batch data.".format(inputs.shape)) |
|
|
|
if len(sample) > 1: |
|
|
|
if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1: |
|
|
|
raise ValueError( |
|
|
|
"Labels shape {} is unrecognizable: outputs should not have more than two dimensions" |
|
|
|
" with length greater than 1.".format(labels.shape)) |
|
|
|
|
|
|
|
if self._is_hoc_registered: |
|
|
|
if inputs.shape[-3] != 3: |
|
|
|
raise ValueError( |
|
|
|
"Hierarchical occlusion is registered, images must be in 3 channels format, but " |
|
|
|
"{} channel(s) is(are) encountered.".format(inputs.shape[-3])) |
|
|
|
short_side = min(inputs.shape[-2:]) |
|
|
|
if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN: |
|
|
|
raise ValueError( |
|
|
|
"Hierarchical occlusion is registered, images' short side must be equals to or greater then " |
|
|
|
"{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side)) |
|
|
|
|
|
|
|
def _verify_ds_sample(self, sample): |
|
|
|
"""Verify a dataset sample.""" |
|
|
|
if len(sample) not in [1, 2, 3]: |
|
|
|
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]" |
|
|
|
" as columns.") |
|
|
|
|
|
|
|
if len(sample) == 3: |
|
|
|
inputs, labels, bboxes = sample |
|
|
|
if bboxes.shape[-1] != 4: |
|
|
|
raise ValueError("The third element of dataset should be bounding boxes with shape of " |
|
|
|
"[batch_size, num_ground_truth, 4].") |
|
|
|
else: |
|
|
|
if self._benchmarkers is not None: |
|
|
|
if any([isinstance(bench, Localization) for bench in self._benchmarkers]): |
|
|
|
raise ValueError("The dataset must provide bboxes if Localization is to be computed.") |
|
|
|
|
|
|
|
if len(sample) == 2: |
|
|
|
inputs, labels = sample |
|
|
|
if len(sample) == 1: |
|
|
|
inputs = sample[0] |
|
|
|
|
|
|
|
self._verify_ds_inputs_shape(sample, inputs, labels) |
|
|
|
|
|
|
|
def _verify_data(self): |
|
|
|
"""Verify dataset and labels.""" |
|
|
|
self._verify_labels() |
|
|
|
|
|
|
|
try: |
|
|
|
sample = next(self._dataset.create_tuple_iterator()) |
|
|
|
except StopIteration: |
|
|
|
raise ValueError("The dataset provided is empty.") |
|
|
|
|
|
|
|
self._verify_ds_sample(sample) |
|
|
|
|
|
|
|
def _verify_network(self): |
|
|
|
"""Verify the network.""" |
|
|
|
next_element = next(self._dataset.create_tuple_iterator()) |
|
|
|
inputs, _, _ = self._unpack_next_element(next_element) |
|
|
|
prop_test = self._full_network(inputs) |
|
|
|
check_value_type("output of network in explainer", prop_test, ms.Tensor) |
|
|
|
if prop_test.shape[1] != len(self._labels): |
|
|
|
raise ValueError("The dimension of network output does not match the no. of classes. Please " |
|
|
|
"check labels or the network in the explainer again.") |
|
|
|
|
|
|
|
def _verify_saliency(self): |
|
|
|
"""Verify the saliency settings.""" |
|
|
|
if self._explainers: |
|
|
|
explainer_classes = [] |
|
|
|
for explainer in self._explainers: |
|
|
|
if explainer.__class__ in explainer_classes: |
|
|
|
raise ValueError(f"Repeated {explainer.__class__.__name__} explainer! " |
|
|
|
"Please make sure all explainers' class is distinct.") |
|
|
|
if explainer.network is not self._network: |
|
|
|
raise ValueError(f"The network of {explainer.__class__.__name__} explainer is different " |
|
|
|
"instance from network of runner. Please make sure they are the same " |
|
|
|
"instance.") |
|
|
|
explainer_classes.append(explainer.__class__) |
|
|
|
if self._benchmarkers: |
|
|
|
benchmarker_classes = [] |
|
|
|
for benchmarker in self._benchmarkers: |
|
|
|
if benchmarker.__class__ in benchmarker_classes: |
|
|
|
raise ValueError(f"Repeated {benchmarker.__class__.__name__} benchmarker! " |
|
|
|
"Please make sure all benchmarkers' class is distinct.") |
|
|
|
if isinstance(benchmarker, LabelSensitiveMetric) and benchmarker.num_labels != len(self._labels): |
|
|
|
raise ValueError(f"The num_labels of {benchmarker.__class__.__name__} benchmarker is different " |
|
|
|
"from no. of labels of runner. Please make them are the same.") |
|
|
|
benchmarker_classes.append(benchmarker.__class__) |
|
|
|
|
|
|
|
|
|
|
|
class ImageClassificationRunner(_Verifier): |
|
|
|
""" |
|
|
|
A high-level API for users to generate and store results of the explanation methods and the evaluation methods. |
|
|
|
|
|
|
|
@@ -132,9 +292,9 @@ class ImageClassificationRunner: |
|
|
|
# printing spacer |
|
|
|
_SPACER = "{:120}\r" |
|
|
|
# datafile directory's permission |
|
|
|
_DIR_MODE = 0o750 |
|
|
|
_DIR_MODE = 0o700 |
|
|
|
# datafile's permission |
|
|
|
_FILE_MODE = 0o600 |
|
|
|
_FILE_MODE = 0o400 |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
summary_dir, |
|
|
|
@@ -169,7 +329,7 @@ class ImageClassificationRunner: |
|
|
|
self._full_network = SequentialCell([self._network, activation_fn]) |
|
|
|
self._full_network.set_train(False) |
|
|
|
|
|
|
|
self._verify(_VerifyFlag.DATA_N_NETWORK | _VerifyFlag.ENVIRONMENT) |
|
|
|
self._verify(_Verifier.DATA_N_NETWORK | _Verifier.ENVIRONMENT) |
|
|
|
|
|
|
|
def register_saliency(self, |
|
|
|
explainers, |
|
|
|
@@ -211,7 +371,7 @@ class ImageClassificationRunner: |
|
|
|
self._benchmarkers = benchmarkers |
|
|
|
|
|
|
|
try: |
|
|
|
self._verify(_VerifyFlag.SALIENCY | _VerifyFlag.ENVIRONMENT) |
|
|
|
self._verify(_Verifier.SALIENCY | _Verifier.ENVIRONMENT) |
|
|
|
except (ValueError, TypeError): |
|
|
|
self._explainers = None |
|
|
|
self._benchmarkers = None |
|
|
|
@@ -235,7 +395,7 @@ class ImageClassificationRunner: |
|
|
|
self._hoc_searcher = hoc.Searcher(self._full_network) |
|
|
|
|
|
|
|
try: |
|
|
|
self._verify(_VerifyFlag.HOC | _VerifyFlag.ENVIRONMENT) |
|
|
|
self._verify(_Verifier.HOC | _Verifier.ENVIRONMENT) |
|
|
|
except ValueError: |
|
|
|
self._hoc_searcher = None |
|
|
|
raise |
|
|
|
@@ -274,7 +434,7 @@ class ImageClassificationRunner: |
|
|
|
TypeError: Be raised for any data or settings' type problem. |
|
|
|
RuntimeError: Be raised for any runtime problem. |
|
|
|
""" |
|
|
|
self._verify(_VerifyFlag.ALL) |
|
|
|
self._verify(_Verifier.ALL) |
|
|
|
self._manifest = {"saliency_map": False, |
|
|
|
"benchmark": False, |
|
|
|
"uncertainty": False, |
|
|
|
@@ -358,25 +518,24 @@ class ImageClassificationRunner: |
|
|
|
sample_id_labels = {} |
|
|
|
self._sample_index = 0 |
|
|
|
ds.config.set_seed(self._DATASET_SEED) |
|
|
|
for j, next_element in enumerate(self._dataset): |
|
|
|
for j, batch in enumerate(self._dataset): |
|
|
|
now = time() |
|
|
|
self._run_sample(summary, next_element, sample_id_labels, threshold) |
|
|
|
self._sample_index += 1 |
|
|
|
self._infer_batch(summary, batch, sample_id_labels, threshold) |
|
|
|
self._spaced_print("Finish running and writing {}-th batch inference data." |
|
|
|
" Time elapsed: {:.3f} s".format(j, time() - now)) |
|
|
|
return sample_id_labels |
|
|
|
|
|
|
|
def _run_sample(self, summary, next_element, sample_id_labels, threshold): |
|
|
|
def _infer_batch(self, summary, batch, sample_id_labels, threshold): |
|
|
|
""" |
|
|
|
Run inference for a sample. |
|
|
|
Infer a batch. |
|
|
|
|
|
|
|
Args: |
|
|
|
summary (SummaryRecord): The summary object to store the data. |
|
|
|
next_element (tuple): The next dataset sample. |
|
|
|
batch (tuple): The next dataset sample. |
|
|
|
sample_id_labels (dict): The sample id to labels dictionary. |
|
|
|
threshold (float): The threshold for prediction. |
|
|
|
""" |
|
|
|
inputs, labels, _ = self._unpack_next_element(next_element) |
|
|
|
inputs, labels, _ = self._unpack_next_element(batch) |
|
|
|
prob = self._full_network(inputs).asnumpy() |
|
|
|
|
|
|
|
if self._uncertainty is not None: |
|
|
|
@@ -436,6 +595,8 @@ class ImageClassificationRunner: |
|
|
|
if self._is_hoc_registered: |
|
|
|
self._run_hoc(summary, self._sample_index, inputs[idx], prob[idx]) |
|
|
|
|
|
|
|
self._sample_index += 1 |
|
|
|
|
|
|
|
def _run_explainer(self, summary, sample_id_labels, explainer): |
|
|
|
""" |
|
|
|
Run the explainer. |
|
|
|
@@ -689,166 +850,6 @@ class ImageClassificationRunner: |
|
|
|
sds[i] = 0 |
|
|
|
return itl_lows, itl_his, sds |
|
|
|
|
|
|
|
def _verify(self, flags): |
|
|
|
""" |
|
|
|
Verify datasets and settings. |
|
|
|
|
|
|
|
Args: |
|
|
|
flags (int): Verification flags, use bitwise or '|' to combine multiple flags. |
|
|
|
Possible bitwise flags are shown as follow. |
|
|
|
|
|
|
|
- _VerifyFlag.ALL: Verify everything. |
|
|
|
- _VerifyFlag.REGISTRATION: Verify explainer module registration. |
|
|
|
- _VerifyFlag.DATA_N_NETWORK: Verify dataset and network. |
|
|
|
- _VerifyFlag.SALIENCY: Verify saliency related settings. |
|
|
|
- _VerifyFlag.HOC: Verify HOC related settings. |
|
|
|
- _VerifyFlag.ENVIRONMENT: Verify the runtime environment. |
|
|
|
|
|
|
|
Raises: |
|
|
|
ValueError: Be raised for any data or settings' value problem. |
|
|
|
TypeError: Be raised for any data or settings' type problem. |
|
|
|
RuntimeError: Be raised for any runtime problem. |
|
|
|
""" |
|
|
|
if flags & _VerifyFlag.ENVIRONMENT: |
|
|
|
device_target = context.get_context('device_target') |
|
|
|
if device_target not in ("Ascend", "GPU"): |
|
|
|
raise RuntimeError(f"Unsupported device_target: '{device_target}', " |
|
|
|
f"only 'Ascend' or 'GPU' is supported. " |
|
|
|
f"Please call context.set_context(device_target='Ascend') or " |
|
|
|
f"context.set_context(device_target='GPU').") |
|
|
|
if flags & (_VerifyFlag.ENVIRONMENT | _VerifyFlag.SALIENCY): |
|
|
|
if self._is_saliency_registered: |
|
|
|
mode = context.get_context('mode') |
|
|
|
if mode != context.PYNATIVE_MODE: |
|
|
|
raise RuntimeError("Context mode: GRAPH_MODE is not supported, " |
|
|
|
"please call context.set_context(mode=context.PYNATIVE_MODE).") |
|
|
|
|
|
|
|
if flags & _VerifyFlag.REGISTRATION: |
|
|
|
if self._is_uncertainty_registered and not self._is_saliency_registered: |
|
|
|
raise ValueError("Function register_uncertainty() is called but register_saliency() is not.") |
|
|
|
if not self._is_saliency_registered and not self._is_hoc_registered: |
|
|
|
raise ValueError( |
|
|
|
"No explanation module was registered, user should at least call register_saliency() " |
|
|
|
"or register_hierarchical_occlusion() once with proper arguments.") |
|
|
|
|
|
|
|
if flags & (_VerifyFlag.DATA_N_NETWORK | _VerifyFlag.SALIENCY | _VerifyFlag.HOC): |
|
|
|
self._verify_data() |
|
|
|
|
|
|
|
if flags & _VerifyFlag.DATA_N_NETWORK: |
|
|
|
self._verify_network() |
|
|
|
|
|
|
|
if flags & _VerifyFlag.SALIENCY: |
|
|
|
self._verify_saliency() |
|
|
|
|
|
|
|
def _verify_labels(self): |
|
|
|
"""Verify labels.""" |
|
|
|
label_set = set() |
|
|
|
if not self._labels: |
|
|
|
raise ValueError(f"The label list provided is empty.") |
|
|
|
for i, label in enumerate(self._labels): |
|
|
|
if label.strip() == "": |
|
|
|
raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is " |
|
|
|
f"no empty label.") |
|
|
|
if label in label_set: |
|
|
|
raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.") |
|
|
|
label_set.add(label) |
|
|
|
|
|
|
|
def _verify_ds_inputs_shape(self, sample, inputs, labels): |
|
|
|
"""Verify a dataset sample's input shape.""" |
|
|
|
|
|
|
|
if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]: |
|
|
|
raise ValueError( |
|
|
|
"Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format( |
|
|
|
inputs.shape)) |
|
|
|
if len(inputs.shape) == 3: |
|
|
|
log.warning( |
|
|
|
"Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th" |
|
|
|
" dimension as batch data.".format(inputs.shape)) |
|
|
|
if len(sample) > 1: |
|
|
|
if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1: |
|
|
|
raise ValueError( |
|
|
|
"Labels shape {} is unrecognizable: outputs should not have more than two dimensions" |
|
|
|
" with length greater than 1.".format(labels.shape)) |
|
|
|
|
|
|
|
if self._is_hoc_registered: |
|
|
|
if inputs.shape[-3] != 3: |
|
|
|
raise ValueError( |
|
|
|
"Hierarchical occlusion is registered, images must be in 3 channels format, but " |
|
|
|
"{} channel(s) is(are) encountered.".format(inputs.shape[-3])) |
|
|
|
short_side = min(inputs.shape[-2:]) |
|
|
|
if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN: |
|
|
|
raise ValueError( |
|
|
|
"Hierarchical occlusion is registered, images' short side must be equals to or greater then " |
|
|
|
"{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side)) |
|
|
|
|
|
|
|
def _verify_ds_sample(self, sample): |
|
|
|
"""Verify a dataset sample.""" |
|
|
|
if len(sample) not in [1, 2, 3]: |
|
|
|
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]" |
|
|
|
" as columns.") |
|
|
|
|
|
|
|
if len(sample) == 3: |
|
|
|
inputs, labels, bboxes = sample |
|
|
|
if bboxes.shape[-1] != 4: |
|
|
|
raise ValueError("The third element of dataset should be bounding boxes with shape of " |
|
|
|
"[batch_size, num_ground_truth, 4].") |
|
|
|
else: |
|
|
|
if self._benchmarkers is not None: |
|
|
|
if any([isinstance(bench, Localization) for bench in self._benchmarkers]): |
|
|
|
raise ValueError("The dataset must provide bboxes if Localization is to be computed.") |
|
|
|
|
|
|
|
if len(sample) == 2: |
|
|
|
inputs, labels = sample |
|
|
|
if len(sample) == 1: |
|
|
|
inputs = sample[0] |
|
|
|
|
|
|
|
self._verify_ds_inputs_shape(sample, inputs, labels) |
|
|
|
|
|
|
|
def _verify_data(self): |
|
|
|
"""Verify dataset and labels.""" |
|
|
|
self._verify_labels() |
|
|
|
|
|
|
|
try: |
|
|
|
sample = next(self._dataset.create_tuple_iterator()) |
|
|
|
except StopIteration: |
|
|
|
raise ValueError("The dataset provided is empty.") |
|
|
|
|
|
|
|
self._verify_ds_sample(sample) |
|
|
|
|
|
|
|
def _verify_network(self): |
|
|
|
"""Verify the network.""" |
|
|
|
next_element = next(self._dataset.create_tuple_iterator()) |
|
|
|
inputs, _, _ = self._unpack_next_element(next_element) |
|
|
|
prop_test = self._full_network(inputs) |
|
|
|
check_value_type("output of network in explainer", prop_test, ms.Tensor) |
|
|
|
if prop_test.shape[1] != len(self._labels): |
|
|
|
raise ValueError("The dimension of network output does not match the no. of classes. Please " |
|
|
|
"check labels or the network in the explainer again.") |
|
|
|
|
|
|
|
def _verify_saliency(self): |
|
|
|
"""Verify the saliency settings.""" |
|
|
|
if self._explainers: |
|
|
|
explainer_classes = [] |
|
|
|
for explainer in self._explainers: |
|
|
|
if explainer.__class__ in explainer_classes: |
|
|
|
raise ValueError(f"Repeated {explainer.__class__.__name__} explainer! " |
|
|
|
"Please make sure all explainers' class is distinct.") |
|
|
|
if explainer.network is not self._network: |
|
|
|
raise ValueError(f"The network of {explainer.__class__.__name__} explainer is different " |
|
|
|
"instance from network of runner. Please make sure they are the same " |
|
|
|
"instance.") |
|
|
|
explainer_classes.append(explainer.__class__) |
|
|
|
if self._benchmarkers: |
|
|
|
benchmarker_classes = [] |
|
|
|
for benchmarker in self._benchmarkers: |
|
|
|
if benchmarker.__class__ in benchmarker_classes: |
|
|
|
raise ValueError(f"Repeated {benchmarker.__class__.__name__} benchmarker! " |
|
|
|
"Please make sure all benchmarkers' class is distinct.") |
|
|
|
if isinstance(benchmarker, LabelSensitiveMetric) and benchmarker.num_labels != len(self._labels): |
|
|
|
raise ValueError(f"The num_labels of {benchmarker.__class__.__name__} benchmarker is different " |
|
|
|
"from no. of labels of runner. Please make them are the same.") |
|
|
|
benchmarker_classes.append(benchmarker.__class__) |
|
|
|
|
|
|
|
def _transform_bboxes(self, inputs, labels, bboxes, ifbbox): |
|
|
|
""" |
|
|
|
Transform the bounding boxes. |
|
|
|
@@ -863,26 +864,28 @@ class ImageClassificationRunner: |
|
|
|
bboxes (Union[list[dict], None, Tensor]): the bounding boxes |
|
|
|
""" |
|
|
|
input_len = len(inputs) |
|
|
|
if bboxes is not None and ifbbox: |
|
|
|
bboxes = ms.Tensor(bboxes, ms.int32) |
|
|
|
masks_lst = [] |
|
|
|
labels = labels.asnumpy().reshape([input_len, -1]) |
|
|
|
bboxes = bboxes.asnumpy().reshape([input_len, -1, 4]) |
|
|
|
for idx, label in enumerate(labels): |
|
|
|
height, width = inputs[idx].shape[-2], inputs[idx].shape[-1] |
|
|
|
masks = {} |
|
|
|
for j, label_item in enumerate(label): |
|
|
|
target = int(label_item) |
|
|
|
if -1 < target < len(self._labels): |
|
|
|
if target not in masks: |
|
|
|
mask = np.zeros((1, 1, height, width)) |
|
|
|
else: |
|
|
|
mask = masks[target] |
|
|
|
x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int) |
|
|
|
mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1 |
|
|
|
masks[target] = mask |
|
|
|
masks_lst.append(masks) |
|
|
|
bboxes = masks_lst |
|
|
|
if bboxes is None or not ifbbox: |
|
|
|
return bboxes |
|
|
|
bboxes = ms.Tensor(bboxes, ms.int32) |
|
|
|
masks_lst = [] |
|
|
|
labels = labels.asnumpy().reshape([input_len, -1]) |
|
|
|
bboxes = bboxes.asnumpy().reshape([input_len, -1, 4]) |
|
|
|
for idx, label in enumerate(labels): |
|
|
|
height, width = inputs[idx].shape[-2], inputs[idx].shape[-1] |
|
|
|
masks = {} |
|
|
|
for j, label_item in enumerate(label): |
|
|
|
target = int(label_item) |
|
|
|
if not -1 < target < len(self._labels): |
|
|
|
continue |
|
|
|
if target not in masks: |
|
|
|
mask = np.zeros((1, 1, height, width)) |
|
|
|
else: |
|
|
|
mask = masks[target] |
|
|
|
x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int) |
|
|
|
mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1 |
|
|
|
masks[target] = mask |
|
|
|
masks_lst.append(masks) |
|
|
|
bboxes = masks_lst |
|
|
|
return bboxes |
|
|
|
|
|
|
|
def _transform_data(self, inputs, labels, bboxes, ifbbox): |
|
|
|
@@ -976,7 +979,7 @@ class ImageClassificationRunner: |
|
|
|
self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp)] |
|
|
|
abs_dir_path = self._create_subdir(*path_tokens) |
|
|
|
save_path = os.path.join(abs_dir_path, self._MANIFEST_FILENAME) |
|
|
|
fd = os.open(save_path, os.O_WRONLY | os.O_CREAT) |
|
|
|
fd = os.open(save_path, os.O_WRONLY | os.O_CREAT, mode=self._FILE_MODE) |
|
|
|
file = os.fdopen(fd, "w") |
|
|
|
try: |
|
|
|
json.dump(self._manifest, file, indent=4) |
|
|
|
@@ -984,7 +987,7 @@ class ImageClassificationRunner: |
|
|
|
log.error(f"Failed to save manifest as {save_path}!") |
|
|
|
raise |
|
|
|
finally: |
|
|
|
file.close() |
|
|
|
os.close(fd) |
|
|
|
os.chmod(save_path, self._FILE_MODE) |
|
|
|
|
|
|
|
def _save_original_image(self, sample_id, image): |
|
|
|
|