|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Runner."""
- import os
- import re
- import traceback
- from time import time
- from typing import Tuple, List, Optional
-
- import numpy as np
- from PIL import Image
- from scipy.stats import beta
-
- import mindspore as ms
- import mindspore.dataset as ds
- from mindspore import log
- from mindspore.nn import Softmax, Cell
- from mindspore.nn.probability.toolbox import UncertaintyEvaluation
- from mindspore.ops.operations import ExpandDims
- from mindspore.train._utils import check_value_type
- from mindspore.train.summary._summary_adapter import _convert_image_format
- from mindspore.train.summary.summary_record import SummaryRecord
- from mindspore.train.summary_pb2 import Explain
- from .benchmark import Localization
- from .explanation import RISE
- from .benchmark._attribution.metric import AttributionMetric, LabelSensitiveMetric, LabelAgnosticMetric
- from .explanation._attribution.attribution import Attribution
-
- # datafile directory names
- _DATAFILE_DIRNAME_PREFIX = "_explain_"
- _ORIGINAL_IMAGE_DIRNAME = "origin_images"
- _HEATMAP_DIRNAME = "heatmap"
- # max. no. of sample per directory
- _SAMPLE_PER_DIR = 1000
-
- _EXPAND_DIMS = ExpandDims()
- _SEED = 58 # set a seed to fix the iterating order of the dataset
-
-
- def _normalize(img_np):
- """Normalize the numpy image to the range of [0, 1]. """
- max_ = img_np.max()
- min_ = img_np.min()
- normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
- return normed
-
-
- def _np_to_image(img_np, mode):
- """Convert numpy array to PIL image."""
- return Image.fromarray(np.uint8(img_np * 255), mode=mode)
-
-
- def _calc_prob_interval(volume, probs, prob_vars):
- """Compute the confidence interval of probability."""
- if not isinstance(probs, np.ndarray):
- probs = np.asarray(probs)
- if not isinstance(prob_vars, np.ndarray):
- prob_vars = np.asarray(prob_vars)
- one_minus_probs = 1 - probs
- alpha_coef = (np.square(probs) * one_minus_probs / prob_vars) - probs
- beta_coef = alpha_coef * one_minus_probs / probs
- intervals = beta.interval(volume, alpha_coef, beta_coef)
-
- # avoid invalid result due to extreme small value of prob_vars
- lows = []
- highs = []
- for i, low in enumerate(intervals[0]):
- high = intervals[1][i]
- if prob_vars[i] <= 0 or \
- not np.isfinite(low) or low > probs[i] or \
- not np.isfinite(high) or high < probs[i]:
- low = probs[i]
- high = probs[i]
- lows.append(low)
- highs.append(high)
-
- return lows, highs
-
-
- def _get_id_dirname(sample_id: int):
- """Get the name of parent directory of the image id."""
- return str(int(sample_id / _SAMPLE_PER_DIR) * _SAMPLE_PER_DIR)
-
-
- def _extract_timestamp(filename: str):
- """Extract timestamp from summary filename."""
- matched = re.search(r"summary\.(\d+)", filename)
- if matched:
- return int(matched.group(1))
- return None
-
-
- class ExplainRunner:
- """
- A high-level API for users to generate and store results of the explanation methods and the evaluation methods.
-
- After generating results with the explanation methods and the evaluation methods, the results will be written into
- a specified file with `mindspore.summary.SummaryRecord`. The stored content can be viewed using MindInsight.
-
- Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version
- will be deprecated and will not be supported in MindInsight of current version.
-
- Args:
- summary_dir (str, optional): The directory path to save the summary files which store the generated results.
- Default: "./"
-
- Examples:
- >>> from mindspore.explainer import ExplainRunner
- >>> # init a runner with a specified directory
- >>> summary_dir = "summary_dir"
- >>> runner = ExplainRunner(summary_dir)
- """
-
- def __init__(self, summary_dir: Optional[str] = "./"):
- check_value_type("summary_dir", summary_dir, str)
- self._summary_dir = summary_dir
- self._count = 0
- self._classes = None
- self._model = None
- self._uncertainty = None
- self._summary_timestamp = None
-
- def run(self,
- dataset: Tuple,
- explainers: List,
- benchmarkers: Optional[List] = None,
- uncertainty: Optional[UncertaintyEvaluation] = None,
- activation_fn: Optional[Cell] = Softmax()):
- """
- Genereates results and writes results into the summary files in `summary_dir` specified during the object
- initialization.
-
- Args:
- dataset (tuple): A tuple that contains `mindspore.dataset` object for iteration and its labels.
-
- - dataset[0]: A `mindspore.dataset` object to provide data to explain.
- - dataset[1]: A list of string that specifies the label names of the dataset.
-
- explainers (list[Explanation]): A list of explanation objects to generate attribution results. Explanation
- object is an instance initialized with the explanation methods in module
- `mindspore.explainer.explanation`.
- benchmarkers (list[Benchmark], optional): A list of benchmark objects to generate evaluation results.
- Default: None
- uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference
- uncertainty of samples.
- activation_fn (Cell, optional): The activation layer that transforms the output of the network to
- label probability distribution :math:`P(y|x)`. Default: Softmax().
-
- Examples:
- >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
- >>> from mindspore.nn import Sigmoid
- >>> # obtain dataset object
- >>> dataset = get_dataset()
- >>> classes = ["cat", "dog", ...]
- >>> # load checkpoint to a network, e.g. resnet50
- >>> param_dict = load_checkpoint("checkpoint.ckpt")
- >>> net = resnet50(len(classes))
- >>> load_parama_into_net(net, param_dict)
- >>> gbp = GuidedBackprop(net)
- >>> gradient = Gradient(net)
- >>> runner = ExplainRunner("./")
- >>> explainers = [gbp, gradient]
- >>> runner.run((dataset, classes), explainers, activation_fn=Sigmoid())
- """
-
- check_value_type("dataset", dataset, tuple)
- if len(dataset) != 2:
- raise ValueError("Argument `dataset` should be a tuple with length = 2.")
-
- dataset, classes = dataset
- self._verify_data_form(dataset, benchmarkers)
- self._classes = classes
-
- check_value_type("explainers", explainers, list)
- if not explainers:
- raise ValueError("Argument `explainers` must be a non-empty list")
-
- for exp in explainers:
- if not isinstance(exp, Attribution):
- raise TypeError("Argument `explainers` should be a list of objects of classes in "
- "`mindspore.explainer.explanation`.")
- if benchmarkers is not None:
- check_value_type("benchmarkers", benchmarkers, list)
- for bench in benchmarkers:
- if not isinstance(bench, AttributionMetric):
- raise TypeError("Argument `benchmarkers` should be a list of objects of classes in explanation"
- "`mindspore.explainer.benchmark`.")
- check_value_type("activation_fn", activation_fn, Cell)
-
- self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn])
- next_element = dataset.create_tuple_iterator().get_next()
- inputs, _, _ = self._unpack_next_element(next_element)
- prop_test = self._model(inputs)
- check_value_type("output of model im explainer", prop_test, ms.Tensor)
- if prop_test.shape[1] != len(self._classes):
- raise ValueError("The dimension of model output does not match the length of dataset classes. Please "
- "check dataset classes or the black-box model in the explainer again.")
-
- if uncertainty is not None:
- check_value_type("uncertainty", uncertainty, UncertaintyEvaluation)
- prop_var_test = uncertainty.eval_epistemic_uncertainty(inputs)
- check_value_type("output of uncertainty", prop_var_test, np.ndarray)
- if prop_var_test.shape[1] != len(self._classes):
- raise ValueError("The dimension of uncertainty output does not match the length of dataset classes"
- "classes. Please check dataset classes or the black-box model in the explainer again.")
- self._uncertainty = uncertainty
- else:
- self._uncertainty = None
-
- with SummaryRecord(self._summary_dir) as summary:
- spacer = '{:120}\r'
- print("Start running and writing......")
- begin = time()
- print("Start writing metadata......")
-
- self._summary_timestamp = _extract_timestamp(summary.event_file_name)
- if self._summary_timestamp is None:
- raise RuntimeError("Cannot extract timestamp from summary filename!"
- " It should contains a timestamp of 10 digits.")
-
- explain = Explain()
- explain.metadata.label.extend(classes)
- exp_names = [exp.__class__.__name__ for exp in explainers]
- explain.metadata.explain_method.extend(exp_names)
- if benchmarkers is not None:
- bench_names = [bench.__class__.__name__ for bench in benchmarkers]
- explain.metadata.benchmark_method.extend(bench_names)
-
- summary.add_value("explainer", "metadata", explain)
- summary.record(1)
-
- print("Finish writing metadata.")
-
- now = time()
- print("Start running and writing inference data.....")
- imageid_labels = self._run_inference(dataset, summary)
- print(spacer.format("Finish running and writing inference data. "
- "Time elapsed: {:.3f} s".format(time() - now)))
-
- if benchmarkers is None or not benchmarkers:
- for exp in explainers:
- start = time()
- print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
- self._count = 0
- ds.config.set_seed(_SEED)
- for idx, next_element in enumerate(dataset):
- now = time()
- self._run_exp_step(next_element, exp, imageid_labels, summary)
- print(spacer.format("Finish writing {}-th explanation data for {}. Time elapsed: "
- "{:.3f} s".format(idx, exp.__class__.__name__, time() - now)), end='')
- print(spacer.format(
- "Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format(
- exp.__class__.__name__, time() - start)))
- else:
- for exp in explainers:
- explain = Explain()
- for bench in benchmarkers:
- bench.reset()
- print(f"Start running and writing explanation and "
- f"benchmark data for {exp.__class__.__name__}......")
- self._count = 0
- start = time()
- ds.config.set_seed(_SEED)
- for idx, next_element in enumerate(dataset):
- now = time()
- saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary)
- print(spacer.format(
- "Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format(
- idx, exp.__class__.__name__, time() - now)), end='')
- for bench in benchmarkers:
- now = time()
- self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
- print(spacer.format(
- "Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format(
- idx, bench.__class__.__name__, exp.__class__.__name__, time() - now)), end='')
-
- for bench in benchmarkers:
- benchmark = explain.benchmark.add()
- benchmark.explain_method = exp.__class__.__name__
- benchmark.benchmark_method = bench.__class__.__name__
-
- benchmark.total_score = bench.performance
- if isinstance(bench, LabelSensitiveMetric):
- benchmark.label_score.extend(bench.class_performances)
-
- print(spacer.format("Finish running and writing explanation and benchmark data for {}. "
- "Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start)))
- summary.add_value('explainer', 'benchmark', explain)
- summary.record(1)
- print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin))
-
- @staticmethod
- def _verify_data_form(dataset, benchmarkers):
- """
- Verify the validity of dataset.
-
- Args:
- dataset (`ds`): the user parsed dataset.
- benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers.
- """
- next_element = dataset.create_tuple_iterator().get_next()
-
- if len(next_element) not in [1, 2, 3]:
- raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
- " as columns.")
-
- if len(next_element) == 3:
- inputs, labels, bboxes = next_element
- if bboxes.shape[-1] != 4:
- raise ValueError("The third element of dataset should be bounding boxes with shape of "
- "[batch_size, num_ground_truth, 4].")
- else:
- if benchmarkers is not None:
- if True in [isinstance(bench, Localization) for bench in benchmarkers]:
- raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
-
- if len(next_element) == 2:
- inputs, labels = next_element
- if len(next_element) == 1:
- inputs = next_element[0]
-
- if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
- raise ValueError(
- "Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
- inputs.shape))
- if len(inputs.shape) == 3:
- log.warning(
- "Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
- " dimension as batch data.".format(inputs.shape))
-
- if len(next_element) > 1:
- if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
- raise ValueError(
- "Labels shape {} is unrecognizable: labels should not have more than two dimensions"
- " with length greater than 1.".format(labels.shape))
-
- def _transform_data(self, inputs, labels, bboxes, ifbbox):
- """
- Transform the data from one iteration of dataset to a unifying form for the follow-up operations.
-
- Args:
- inputs (Tensor): the image data
- labels (Tensor): the labels
- bboxes (Tensor): the boudnding boxes data
- ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t label
- id will be returned. If False, the returned bboxes is the the parsed bboxes.
-
- Returns:
- inputs (Tensor): the image data, unified to a 4D Tensor.
- labels (List[List[int]]): the ground truth labels.
- bboxes (Union[List[Dict], None, Tensor]): the bounding boxes
- """
- inputs = ms.Tensor(inputs, ms.float32)
- if len(inputs.shape) == 3:
- inputs = _EXPAND_DIMS(inputs, 0)
- if isinstance(labels, ms.Tensor):
- labels = ms.Tensor(labels, ms.int32)
- labels = _EXPAND_DIMS(labels, 0)
- if isinstance(bboxes, ms.Tensor):
- bboxes = ms.Tensor(bboxes, ms.int32)
- bboxes = _EXPAND_DIMS(bboxes, 0)
-
- input_len = len(inputs)
- if bboxes is not None and ifbbox:
- bboxes = ms.Tensor(bboxes, ms.int32)
- masks_lst = []
- labels = labels.asnumpy().reshape([input_len, -1])
- bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
- for idx, label in enumerate(labels):
- height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
- masks = {}
- for j, label_item in enumerate(label):
- target = int(label_item)
- if -1 < target < len(self._classes):
- if target not in masks:
- mask = np.zeros((1, 1, height, width))
- else:
- mask = masks[target]
- x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
- mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
- masks[target] = mask
-
- masks_lst.append(masks)
- bboxes = masks_lst
-
- labels = ms.Tensor(labels, ms.int32)
- if len(labels.shape) == 1:
- labels_lst = [[int(i)] for i in labels.asnumpy()]
- else:
- labels = labels.asnumpy().reshape([input_len, -1])
- labels_lst = []
- for item in labels:
- labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._classes))))
- labels = labels_lst
- return inputs, labels, bboxes
-
- def _unpack_next_element(self, next_element, ifbbox=False):
- """
- Unpack a single iteration of dataset.
-
- Args:
- next_element (Tuple): a single element iterated from dataset object.
- ifbbox (bool): whether to preprocess bboxes in self._transform_data.
-
- Returns:
- Tuple, a unified Tuple contains image_data, labels, and bounding boxes.
- """
- if len(next_element) == 3:
- inputs, labels, bboxes = next_element
- elif len(next_element) == 2:
- inputs, labels = next_element
- bboxes = None
- else:
- inputs = next_element[0]
- labels = [[] for _ in inputs]
- bboxes = None
- inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
- return inputs, labels, bboxes
-
- @staticmethod
- def _make_label_batch(labels):
- """
- Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max
- length of all the rows in labels.
-
- Args:
- labels (List[List]): the union labels of a data batch.
-
- Returns:
- 2D Tensor.
- """
-
- max_len = max([len(label) for label in labels])
- batch_labels = np.zeros((len(labels), max_len))
-
- for idx, _ in enumerate(batch_labels):
- length = len(labels[idx])
- batch_labels[idx, :length] = np.array(labels[idx])
-
- return ms.Tensor(batch_labels, ms.int32)
-
- def _run_inference(self, dataset, summary, threshold=0.5):
- """
- Run inference for the dataset and write the inference related data into summary.
-
- Args:
- dataset (`ds`): the parsed dataset
- summary (`SummaryRecord`): the summary object to store the data
- threshold (float): the threshold for prediction.
-
- Returns:
- imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels.
- """
- spacer = '{:120}\r'
- imageid_labels = {}
- ds.config.set_seed(_SEED)
- self._count = 0
- for j, next_element in enumerate(dataset):
- now = time()
- inputs, labels, _ = self._unpack_next_element(next_element)
- prob = self._model(inputs).asnumpy()
- if self._uncertainty is not None:
- prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
- prob_sd = np.sqrt(prob_var)
- else:
- prob_var = prob_sd = None
-
- for idx, inp in enumerate(inputs):
- gt_labels = labels[idx]
- gt_probs = [float(prob[idx][i]) for i in gt_labels]
-
- data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
- original_image = _np_to_image(_normalize(data_np), mode='RGB')
- original_image_path = self._save_original_image(self._count, original_image)
-
- predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]]
- predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
-
- has_uncertainty = False
- gt_prob_sds = gt_prob_itl95_lows = gt_prob_itl95_his = None
- predicted_prob_sds = predicted_prob_itl95_lows = predicted_prob_itl95_his = None
- if prob_var is not None:
- gt_prob_sds = [float(prob_sd[idx][i]) for i in gt_labels]
- predicted_prob_sds = [float(prob_sd[idx][i]) for i in predicted_labels]
- try:
- gt_prob_itl95_lows, gt_prob_itl95_his = \
- _calc_prob_interval(0.95, gt_probs, [float(prob_var[idx][i]) for i in gt_labels])
- predicted_prob_itl95_lows, predicted_prob_itl95_his = \
- _calc_prob_interval(0.95, predicted_probs, [float(prob_var[idx][i])
- for i in predicted_labels])
- has_uncertainty = True
- except ValueError:
- log.error(traceback.format_exc())
- log.error("Error on calculating uncertainty")
-
- union_labs = list(set(gt_labels + predicted_labels))
- imageid_labels[str(self._count)] = union_labs
-
- explain = Explain()
- explain.sample_id = self._count
- explain.image_path = original_image_path
- summary.add_value("explainer", "sample", explain)
-
- explain = Explain()
- explain.sample_id = self._count
- explain.ground_truth_label.extend(gt_labels)
- explain.inference.ground_truth_prob.extend(gt_probs)
- explain.inference.predicted_label.extend(predicted_labels)
- explain.inference.predicted_prob.extend(predicted_probs)
-
- if has_uncertainty:
- explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
- explain.inference.ground_truth_prob_itl95_low.extend(gt_prob_itl95_lows)
- explain.inference.ground_truth_prob_itl95_hi.extend(gt_prob_itl95_his)
-
- explain.inference.predicted_prob_sd.extend(predicted_prob_sds)
- explain.inference.predicted_prob_itl95_low.extend(predicted_prob_itl95_lows)
- explain.inference.predicted_prob_itl95_hi.extend(predicted_prob_itl95_his)
-
- summary.add_value("explainer", "inference", explain)
-
- summary.record(1)
-
- self._count += 1
- print(spacer.format("Finish running and writing {}-th batch inference data."
- " Time elapsed: {:.3f} s".format(j, time() - now)),
- end='')
- return imageid_labels
-
- def _run_exp_step(self, next_element, explainer, imageid_labels, summary):
- """
- Run the explanation for each step and write explanation results into summary.
-
- Args:
- next_element (Tuple): data of one step
- explainer (_Attribution): an Attribution object to generate saliency maps.
- imageid_labels (dict): a dict that maps the image_id and its union labels.
- summary (SummaryRecord): the summary object to store the data
-
- Returns:
- List of dict that maps label to its corresponding saliency map.
- """
- inputs, labels, _ = self._unpack_next_element(next_element)
- count = self._count
- unions = []
- for _ in range(len(labels)):
- unions_labels = imageid_labels[str(count)]
- unions.append(unions_labels)
- count += 1
-
- batch_unions = self._make_label_batch(unions)
- saliency_dict_lst = []
-
- if isinstance(explainer, RISE):
- batch_saliency_full = explainer(inputs, batch_unions)
- else:
- batch_saliency_full = []
- for i in range(len(batch_unions[0])):
- batch_saliency = explainer(inputs, batch_unions[:, i])
- batch_saliency_full.append(batch_saliency)
- concat = ms.ops.operations.Concat(1)
- batch_saliency_full = concat(tuple(batch_saliency_full))
-
- for idx, union in enumerate(unions):
- saliency_dict = {}
- explain = Explain()
- explain.sample_id = self._count
- for k, lab in enumerate(union):
- saliency = batch_saliency_full[idx:idx + 1, k:k + 1]
- saliency_dict[lab] = saliency
-
- saliency_np = _normalize(saliency.asnumpy().squeeze())
- saliency_image = _np_to_image(saliency_np, mode='L')
- heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab, self._count, saliency_image)
-
- explanation = explain.explanation.add()
- explanation.explain_method = explainer.__class__.__name__
- explanation.heatmap_path = heatmap_path
- explanation.label = lab
-
- summary.add_value("explainer", "explanation", explain)
- summary.record(1)
-
- self._count += 1
- saliency_dict_lst.append(saliency_dict)
- return saliency_dict_lst
-
- def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
- """
- Run the explanation and evaluation for each step and write explanation results into summary.
-
- Args:
- next_element (Tuple): Data of one step
- explainer (`_Attribution`): An Attribution object to generate saliency maps.
- """
- inputs, labels, _ = self._unpack_next_element(next_element)
- for idx, inp in enumerate(inputs):
- inp = _EXPAND_DIMS(inp, 0)
- saliency_dict = saliency_dict_lst[idx]
- for label, saliency in saliency_dict.items():
- if isinstance(benchmarker, Localization):
- _, _, bboxes = self._unpack_next_element(next_element, True)
- if label in labels[idx]:
- res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
- saliency=saliency)
- if np.any(res == np.nan):
- res = np.zeros_like(res)
- benchmarker.aggregate(res, label)
- elif isinstance(benchmarker, LabelSensitiveMetric):
- res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
- if np.any(res == np.nan):
- res = np.zeros_like(res)
- benchmarker.aggregate(res, label)
- elif isinstance(benchmarker, LabelAgnosticMetric):
- res = benchmarker.evaluate(explainer, inp)
- if np.any(res == np.nan):
- res = np.zeros_like(res)
- benchmarker.aggregate(res)
- else:
- raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
- 'receive {}'.format(type(benchmarker)))
-
- def _save_original_image(self, sample_id: int, image):
- """Save an image to summary directory."""
- id_dirname = _get_id_dirname(sample_id)
- relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
- _ORIGINAL_IMAGE_DIRNAME,
- id_dirname)
- os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
- relative_path = os.path.join(relative_dir, f"{sample_id}.jpg")
- save_path = os.path.join(self._summary_dir, relative_path)
- with open(save_path, "wb") as file:
- image.save(file)
- return relative_path
-
- def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image):
- """Save heatmap image to summary directory."""
- id_dirname = _get_id_dirname(sample_id)
- relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
- _HEATMAP_DIRNAME,
- explain_method,
- id_dirname)
- os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
- relative_path = os.path.join(relative_dir, f"{sample_id}_{class_id}.jpg")
- save_path = os.path.join(self._summary_dir, relative_path)
- with open(save_path, "wb") as file:
- image.save(file, optimize=True)
- return relative_path
|