| @@ -32,6 +32,9 @@ from .topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5Categori | |||
| from .loss import Loss | |||
| from .mean_surface_distance import MeanSurfaceDistance | |||
| from .root_mean_square_surface_distance import RootMeanSquareDistance | |||
| from .bleu_score import BleuScore | |||
| from .cosine_similarity import CosineSimilarity | |||
| from .occlusion_sensitivity import OcclusionSensitivity | |||
| __all__ = [ | |||
| "names", | |||
| @@ -43,6 +46,9 @@ __all__ = [ | |||
| "HausdorffDistance", | |||
| "Recall", | |||
| "Fbeta", | |||
| "BleuScore", | |||
| "CosineSimilarity", | |||
| "OcclusionSensitivity", | |||
| "F1", | |||
| "Dice", | |||
| "ROC", | |||
| @@ -64,6 +70,9 @@ __factory__ = { | |||
| 'dice': Dice, | |||
| 'roc': ROC, | |||
| 'auc': auc, | |||
| 'bleu_score': BleuScore, | |||
| 'cosine_similarity': CosineSimilarity, | |||
| 'occlusion_sensitivity': OcclusionSensitivity, | |||
| 'topk': TopKCategoricalAccuracy, | |||
| 'hausdorff_distance': HausdorffDistance, | |||
| 'top_1_accuracy': Top1CategoricalAccuracy, | |||
| @@ -1,103 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Evaluation.""" | |||
| import numpy as np | |||
| from .metric import Metric | |||
| _eval_types = {'classification', 'multilabel'} | |||
| class EvaluationBase(Metric): | |||
| """ | |||
| Base class of evaluation. | |||
| Note: | |||
| Please refer to the definition of class `Accuracy`. | |||
| Args: | |||
| eval_type (str): Type of evaluation must be in {'classification', 'multilabel'}. | |||
| Raises: | |||
| TypeError: If the input type is not classification or multilabel. | |||
| """ | |||
| def __init__(self, eval_type): | |||
| super(EvaluationBase, self).__init__() | |||
| if eval_type not in _eval_types: | |||
| raise TypeError('Type must be in {}, but got {}'.format(_eval_types, eval_type)) | |||
| self._type = eval_type | |||
| def _check_shape(self, y_pred, y): | |||
| """ | |||
| Checks the shapes of y_pred and y. | |||
| Args: | |||
| y_pred (Tensor): Predict array. | |||
| y (Tensor): Target array. | |||
| """ | |||
| if self._type == 'classification': | |||
| if y_pred.ndim != y.ndim + 1: | |||
| raise ValueError('Classification case, dims of y_pred equal dims of y add 1, ' | |||
| 'but got y_pred: {} dims and y: {} dims'.format(y_pred.ndim, y.ndim)) | |||
| if y.shape != (y_pred.shape[0],) + y_pred.shape[2:]: | |||
| raise ValueError('Classification case, y_pred shape and y shape can not match. ' | |||
| 'got y_pred shape is {} and y shape is {}'.format(y_pred.shape, y.shape)) | |||
| else: | |||
| if y_pred.ndim != y.ndim: | |||
| raise ValueError('{} case, dims of y_pred need equal with dims of y, but got y_pred: {} ' | |||
| 'dims and y: {} dims.'.format(self._type, y_pred.ndim, y.ndim)) | |||
| if y_pred.shape != y.shape: | |||
| raise ValueError('{} case, y_pred shape need equal with y shape, but got y_pred: {} and y: {}'. | |||
| format(self._type, y_pred.shape, y.shape)) | |||
| def _check_value(self, y_pred, y): | |||
| """ | |||
| Checks the values of y_pred and y. | |||
| Args: | |||
| y_pred (Tensor): Predict array. | |||
| y (Tensor): Target array. | |||
| """ | |||
| if self._type != 'classification' and not (np.equal(y_pred ** 2, y_pred).all() and np.equal(y ** 2, y).all()): | |||
| raise ValueError('For multilabel case, input value must be 1 or 0.') | |||
| def clear(self): | |||
| """ | |||
| A interface describes the behavior of clearing the internal evaluation result. | |||
| Note: | |||
| All subclasses must override this interface. | |||
| """ | |||
| raise NotImplementedError | |||
| def update(self, *inputs): | |||
| """ | |||
| A interface describes the behavior of updating the internal evaluation result. | |||
| Note: | |||
| All subclasses must override this interface. | |||
| Args: | |||
| inputs: The first item is predicted array and the second item is target array. | |||
| """ | |||
| raise NotImplementedError | |||
| def eval(self): | |||
| """ | |||
| A interface describes the behavior of computing the evaluation result. | |||
| Note: | |||
| All subclasses must override this interface. | |||
| """ | |||
| raise NotImplementedError | |||
| @@ -14,7 +14,7 @@ | |||
| # ============================================================================ | |||
| """Accuracy.""" | |||
| import numpy as np | |||
| from ._evaluation import EvaluationBase | |||
| from .metric import EvaluationBase | |||
| class Accuracy(EvaluationBase): | |||
| @@ -0,0 +1,149 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BleuScore.""" | |||
| from collections import Counter | |||
| import numpy as np | |||
| from mindspore._checkparam import Validator as validator | |||
| from .metric import Metric | |||
| class BleuScore(Metric): | |||
| """ | |||
| Calculate BLEU score of machine translated text with one or more references. | |||
| Args: | |||
| n_gram (int): The n_gram value ranged from 1 to 4. Default: 4 | |||
| smooth (bool): Whether or not to apply smoothing. Default: False | |||
| Example: | |||
| >>> candidate_corpus = [['i', 'have', 'a', 'pen', 'on', 'my', 'desk']] | |||
| >>> reference_corpus = [[['i', 'have', 'a', 'pen', 'in', 'my', 'desk'], | |||
| >>> ['there', 'is', 'a', 'pen', 'on', 'the', 'desk']]] | |||
| >>> metric = BleuScore() | |||
| >>> metric.clear() | |||
| >>> metric.update(candidate_corpus, reference_corpus) | |||
| >>> bleu_score = metric.eval() | |||
| 0.5946035575013605 | |||
| """ | |||
| def __init__(self, n_gram=4, smooth=False): | |||
| super().__init__() | |||
| self.n_gram = validator.check_value_type("n_gram", n_gram, [int]) | |||
| if self.n_gram > 4 or self.n_gram < 1: | |||
| raise ValueError('The n_gram value ranged from 1 to 4, but got {}'.format(n_gram)) | |||
| self.smooth = validator.check_value_type("smooth", smooth, [bool]) | |||
| self.clear() | |||
| def clear(self): | |||
| """Clear the internal evaluation result.""" | |||
| self._numerator = np.zeros(self.n_gram) | |||
| self._denominator = np.zeros(self.n_gram) | |||
| self._precision_scores = np.zeros(self.n_gram) | |||
| self._c = 0.0 | |||
| self._r = 0.0 | |||
| self._trans_len = 0 | |||
| self._ref_len = 0 | |||
| self._is_update = False | |||
| def _count_ngram(self, ngram_input_list, n_gram): | |||
| """ | |||
| Counting how many times each word appears in a given text with ngram. | |||
| Args: | |||
| ngram_input_list (list): A list of translated text or reference texts. | |||
| n_gram (int): gram value ranged 1 to 4. | |||
| Return: | |||
| ngram_counter: a collections.Counter object of ngram. | |||
| """ | |||
| ngram_counter = Counter() | |||
| for i in range(1, n_gram + 1): | |||
| for j in range(len(ngram_input_list) - i + 1): | |||
| ngram_key = tuple(ngram_input_list[j:(i + j)]) | |||
| ngram_counter[ngram_key] += 1 | |||
| return ngram_counter | |||
| def update(self, *inputs): | |||
| """ | |||
| Updates the internal evaluation result with `candidate_corpus` and `reference_corpus`. | |||
| Args: | |||
| inputs: Input `candidate_corpus` and ``reference_corpus`. `candidate_corpus` and `reference_corpus` are a | |||
| list. The `candidate_corpus` is an iterable of machine translated corpus. The `reference_corpus` is | |||
| an iterable of iterables of reference corpus. | |||
| Raises: | |||
| ValueError: If the number of input is not 2. | |||
| """ | |||
| if len(inputs) != 2: | |||
| raise ValueError('The bleu_score need 2 inputs (candidate_corpus, reference_corpus), ' | |||
| 'but got {}'.format(len(inputs))) | |||
| candidate_corpus = inputs[0] | |||
| reference_corpus = inputs[1] | |||
| if len(candidate_corpus) != len(reference_corpus): | |||
| raise ValueError('translate_corpus and reference_corpus should be equal in length, ' | |||
| 'but got {} {}'.format(len(candidate_corpus), len(reference_corpus))) | |||
| for (candidate, references) in zip(candidate_corpus, reference_corpus): | |||
| self._c += len(candidate) | |||
| ref_len_list = [len(ref) for ref in references] | |||
| ref_len_diff = [abs(len(candidate) - x) for x in ref_len_list] | |||
| self._r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] | |||
| translation_counter = self._count_ngram(candidate, self.n_gram) | |||
| reference_counter = Counter() | |||
| for ref in references: | |||
| reference_counter |= self._count_ngram(ref, self.n_gram) | |||
| ngram_counter_clip = translation_counter & reference_counter | |||
| for counter_clip in ngram_counter_clip: | |||
| self._numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] | |||
| for counter in translation_counter: | |||
| self._denominator[len(counter) - 1] += translation_counter[counter] | |||
| self._trans_len = np.array(self._c) | |||
| self._ref_len = np.array(self._r) | |||
| self._is_update = True | |||
| def eval(self): | |||
| """ | |||
| Computes the bleu score. | |||
| Returns: | |||
| A numpy with bleu score. | |||
| """ | |||
| if self._is_update is False: | |||
| raise RuntimeError('Call the update method before calling eval.') | |||
| if min(self._numerator) == 0.0: | |||
| return np.array(0.0) | |||
| if self.smooth: | |||
| precision_scores = np.add(self._numerator, np.ones(self.n_gram)) / np.add(self._denominator, | |||
| np.ones(self.n_gram)) | |||
| else: | |||
| precision_scores = self._numerator / self._denominator | |||
| log_precision_scores = np.array([1.0 / self.n_gram] * self.n_gram) * np.log(precision_scores) | |||
| geometric_mean = np.exp(np.sum(log_precision_scores)) | |||
| brevity_penalty = np.array(1.0) if self._c > self._r else np.exp(1 - (self._ref_len / self._trans_len)) | |||
| bleu = brevity_penalty * geometric_mean | |||
| return bleu | |||
| @@ -0,0 +1,96 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CosineSimilarity.""" | |||
| import numpy as np | |||
| from mindspore._checkparam import Validator as validator | |||
| from .metric import Metric | |||
| class CosineSimilarity(Metric): | |||
| """ | |||
| Computes representation similarity | |||
| Args: | |||
| similarity (str): 'dot' or 'cosine'. Default: 'cosine' | |||
| reduction (str): 'none', 'sum', 'mean' (all along dim -1). Default: 'none' | |||
| zero_diagonal (bool): if True, the diagonals are set to zero. Default: True | |||
| Return: | |||
| A square matrix (input1, input1) with the similarity scores between all elements. | |||
| If sum or mean are used, then returns (b, 1) with the reduced value for each row | |||
| Example: | |||
| >>> test_data = np.random.randn(4, 8) | |||
| >>> metric = CosineSimilarity() | |||
| >>> metric.clear() | |||
| >>> metric.update(test_data) | |||
| >>> square_matrix = metric.eval() | |||
| [[0. -0.14682831 0.19102288 -0.36204537] | |||
| ... | |||
| ] | |||
| """ | |||
| def __init__(self, similarity='cosine', reduction='none', zero_diagonal=True): | |||
| super().__init__() | |||
| similarity_list = ['dot', 'cosine'] | |||
| reduction_list = ['none', 'sum', 'mean'] | |||
| similarity = validator.check_value_type("similarity", similarity, [str]) | |||
| self.similarity = validator.check_string(similarity, similarity_list, "similarity") | |||
| reduction = validator.check_value_type("reduction", reduction, [str]) | |||
| self.reduction = validator.check_string(reduction, reduction_list, "reduction") | |||
| self.zero_diagonal = validator.check_value_type("zero_diagonal", zero_diagonal, [bool]) | |||
| self.clear() | |||
| def clear(self): | |||
| """Clears the internal evaluation result.""" | |||
| self.sqr_mtx_res = 0 | |||
| self._is_update = False | |||
| def update(self, *inputs): | |||
| """ | |||
| Updates the internal evaluation result with 'input1'. | |||
| Args: | |||
| inputs: input_data `input1`. The input_data is a `Tensor`or an array. | |||
| """ | |||
| input_data = self._convert_data(inputs[0]) | |||
| if self.similarity == 'cosine': | |||
| data = np.linalg.norm(input_data, ord=2, axis=1) | |||
| input_data = input_data / np.expand_dims(data, 1) | |||
| self.sqr_mtx_res = np.dot(input_data, input_data.transpose(1, 0)) | |||
| self._is_update = True | |||
| def eval(self): | |||
| """ | |||
| Computes the Cosine_Similarity square matrix. | |||
| Returns: | |||
| A square matrix. | |||
| """ | |||
| if not self._is_update: | |||
| raise RuntimeError('Call the update method before calling eval.') | |||
| if self.zero_diagonal: | |||
| np.fill_diagonal(self.sqr_mtx_res, 0) | |||
| if self.reduction == 'mean': | |||
| self.sqr_mtx_res = np.mean(self.sqr_mtx_res, axis=-1) | |||
| if self.reduction == 'sum': | |||
| self.sqr_mtx_res = np.sum(self.sqr_mtx_res, axis=-1) | |||
| return self.sqr_mtx_res | |||
| @@ -17,6 +17,8 @@ from abc import ABCMeta, abstractmethod | |||
| import numpy as np | |||
| from mindspore.common.tensor import Tensor | |||
| _eval_types = {'classification', 'multilabel'} | |||
| class Metric(metaclass=ABCMeta): | |||
| """ | |||
| @@ -140,3 +142,87 @@ class Metric(metaclass=ABCMeta): | |||
| inputs: A variable-length input argument list. | |||
| """ | |||
| raise NotImplementedError('Must define update function to use this base class') | |||
| class EvaluationBase(Metric): | |||
| """ | |||
| Base class of evaluation. | |||
| Note: | |||
| Please refer to the definition of class `Accuracy`. | |||
| Args: | |||
| eval_type (str): Type of evaluation must be in {'classification', 'multilabel'}. | |||
| Raises: | |||
| TypeError: If the input type is not classification or multilabel. | |||
| """ | |||
| def __init__(self, eval_type): | |||
| super(EvaluationBase, self).__init__() | |||
| if eval_type not in _eval_types: | |||
| raise TypeError('Type must be in {}, but got {}'.format(_eval_types, eval_type)) | |||
| self._type = eval_type | |||
| def _check_shape(self, y_pred, y): | |||
| """ | |||
| Checks the shapes of y_pred and y. | |||
| Args: | |||
| y_pred (Tensor): Predict array. | |||
| y (Tensor): Target array. | |||
| """ | |||
| if self._type == 'classification': | |||
| if y_pred.ndim != y.ndim + 1: | |||
| raise ValueError('Classification case, dims of y_pred equal dims of y add 1, ' | |||
| 'but got y_pred: {} dims and y: {} dims'.format(y_pred.ndim, y.ndim)) | |||
| if y.shape != (y_pred.shape[0],) + y_pred.shape[2:]: | |||
| raise ValueError('Classification case, y_pred shape and y shape can not match. ' | |||
| 'got y_pred shape is {} and y shape is {}'.format(y_pred.shape, y.shape)) | |||
| else: | |||
| if y_pred.ndim != y.ndim: | |||
| raise ValueError('{} case, dims of y_pred need equal with dims of y, but got y_pred: {} ' | |||
| 'dims and y: {} dims.'.format(self._type, y_pred.ndim, y.ndim)) | |||
| if y_pred.shape != y.shape: | |||
| raise ValueError('{} case, y_pred shape need equal with y shape, but got y_pred: {} and y: {}'. | |||
| format(self._type, y_pred.shape, y.shape)) | |||
| def _check_value(self, y_pred, y): | |||
| """ | |||
| Checks the values of y_pred and y. | |||
| Args: | |||
| y_pred (Tensor): Predict array. | |||
| y (Tensor): Target array. | |||
| """ | |||
| if self._type != 'classification' and not (np.equal(y_pred ** 2, y_pred).all() and np.equal(y ** 2, y).all()): | |||
| raise ValueError('For multilabel case, input value must be 1 or 0.') | |||
| def clear(self): | |||
| """ | |||
| A interface describes the behavior of clearing the internal evaluation result. | |||
| Note: | |||
| All subclasses must override this interface. | |||
| """ | |||
| raise NotImplementedError | |||
| def update(self, *inputs): | |||
| """ | |||
| A interface describes the behavior of updating the internal evaluation result. | |||
| Note: | |||
| All subclasses must override this interface. | |||
| Args: | |||
| inputs: The first item is predicted array and the second item is target array. | |||
| """ | |||
| raise NotImplementedError | |||
| def eval(self): | |||
| """ | |||
| A interface describes the behavior of computing the evaluation result. | |||
| Note: | |||
| All subclasses must override this interface. | |||
| """ | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,196 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """OcclusionSensitivity.""" | |||
| from collections.abc import Sequence | |||
| import numpy as np | |||
| from mindspore import nn | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._checkparam import Validator as validator | |||
| from .metric import Metric | |||
| try: | |||
| from tqdm import trange | |||
| except (ImportError, AttributeError): | |||
| trange = range | |||
| class OcclusionSensitivity(Metric): | |||
| """ | |||
| This function is used to calculate the occlusion sensitivity of the model for a given image. | |||
| Occlusion sensitivity refers to how the probability of a given prediction changes with the change of the occluded | |||
| part of the image. | |||
| For a given result, the output probability is the probability of a region. | |||
| The higher the value in the output image, the greater the decline of certainty, indicating that | |||
| the occluded area is more important in the decision-making process. | |||
| Args: | |||
| pad_val (float): What values need to be entered in the image when a part of the image is occluded. Default: 0.0. | |||
| margin (Union[int, Sequence]): Create a cuboid / cube around the voxel you want to occlude. Default: 2. | |||
| n_batch (int): number of images in a batch before inference. Default: 128. | |||
| b_box (Sequence): Bounding box on which to perform the analysis. The output image will also match in size. | |||
| There should be a minimum and maximum for all dimensions except batch: | |||
| ``[min1, max1, min2, max2,...]``. If no bounding box is supplied, this will be the same size | |||
| as the input image. If a bounding box is used, the output image will be cropped to this size. | |||
| Default: None. | |||
| Example: | |||
| >>> class DenseNet(nn.Cell): | |||
| >>> def init(self): | |||
| >>> super(DenseNet, self).init() | |||
| >>> w = np.array([[0.1, 0.8, 0.1, 0.1],[1, 1, 1, 1]]).astype(np.float32) | |||
| >>> b = np.array([0.3, 0.6]).astype(np.float32) | |||
| >>> self.dense = nn.Dense(4, 2, weight_init=Tensor(w), bias_init=Tensor(b)) | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> return self.dense(x) | |||
| >>> | |||
| >>> model = DenseNet() | |||
| >>> test_data = np.array([[0.1, 0.2, 0.3, 0.4]]).astype(np.float32) | |||
| >>> label = np.array(1).astype(np.int32) | |||
| >>> metric = OcclusionSensitivity() | |||
| >>> metric.clear() | |||
| >>> metric.update(model, test_data, label) | |||
| >>> score = metric.eval() | |||
| [0.29999995 0.6 1 0.9] | |||
| """ | |||
| def __init__(self, pad_val=0.0, margin=2, n_batch=128, b_box=None): | |||
| super().__init__() | |||
| self.pad_val = validator.check_value_type("pad_val", pad_val, [float]) | |||
| self.margin = validator.check_value_type("margin", margin, [int, Sequence]) | |||
| self.n_batch = validator.check_value_type("n_batch", n_batch, [int]) | |||
| self.b_box = b_box if b_box is None else validator.check_value_type("b_box", b_box, [list]) | |||
| self.clear() | |||
| def clear(self): | |||
| """Clears the internal evaluation result.""" | |||
| self._baseline = 0 | |||
| self._sensitivity_im = 0 | |||
| self._is_update = False | |||
| def _check_input_bounding_box(self, b_box, im_shape): | |||
| """Check that the bounding box (if supplied) is as expected.""" | |||
| # If no bounding box has been supplied, set min and max to None | |||
| if b_box is None: | |||
| b_box_min = b_box_max = None | |||
| else: | |||
| if len(b_box) != 2 * len(im_shape): | |||
| raise ValueError("Bounding box should contain upper and lower for all dimensions (except batch number)") | |||
| b_box_min = np.array(b_box[::2]) | |||
| b_box_max = np.array(b_box[1::2]) | |||
| b_box_min[b_box_min < 0] = 0 | |||
| b_box_max[b_box_max < 0] = im_shape[b_box_max < 0] - 1 | |||
| if np.any(b_box_max >= im_shape): | |||
| raise ValueError("Max bounding box should be < image size for all values") | |||
| if np.any(b_box_min > b_box_max): | |||
| raise ValueError("Min bounding box should be <= max for all values") | |||
| return b_box_min, b_box_max | |||
| def _append_to_sensitivity_im(self, model, batch_images, batch_ids, sensitivity_im): | |||
| """For a given number of images, the probability of predicting a given label is obtained. Attach to previous | |||
| assessment.""" | |||
| batch_images = np.vstack(batch_images) | |||
| batch_ids = np.expand_dims(batch_ids, 1) | |||
| model_numpy = model(Tensor(batch_images)).asnumpy() | |||
| first_indices = np.arange(batch_ids.shape[0])[:, None] | |||
| scores = model_numpy[first_indices, batch_ids] | |||
| if sensitivity_im.size == 0: | |||
| return np.vstack(scores) | |||
| return np.vstack((sensitivity_im, scores)) | |||
| def update(self, *inputs): | |||
| """ | |||
| Updates input, including `model`, `y_pred` and `label`. | |||
| Inputs: | |||
| - **model** (nn.Cell) - classification model to use for inference. | |||
| - **y_pred** (Union[Tensor, list, np.ndarray]) - image to test. Should be tensor consisting of 1 batch, | |||
| can be 2- or 3D. | |||
| - **label** (Union[int, Tensor]) - classification label to check for changes (normally the true label, | |||
| but doesn't have to be | |||
| Raises: | |||
| ValueError: If the number of input is not 3. | |||
| """ | |||
| if len(inputs) != 3: | |||
| raise ValueError('occlusion_sensitivity need 3 inputs (model, y_pred, y), but got {}'.format(len(inputs))) | |||
| model = inputs[0] | |||
| y_pred = self._convert_data(inputs[1]) | |||
| label = self._convert_data(inputs[2]) | |||
| model = validator.check_value_type("model", model, [nn.Cell]) | |||
| if y_pred.shape[0] > 1: | |||
| raise RuntimeError("Expected batch size of 1.") | |||
| if isinstance(label, int): | |||
| label = np.array([[label]], dtype=int) | |||
| # If the label is a tensor, make sure there's only 1 element | |||
| elif np.prod(label.shape) != y_pred.shape[0]: | |||
| raise RuntimeError("Expected as many labels as batches.") | |||
| y_pred_shape = np.array(y_pred.shape[1:]) | |||
| b_box_min, b_box_max = self._check_input_bounding_box(self.b_box, y_pred_shape) | |||
| temp = model(Tensor(y_pred)).asnumpy() | |||
| self._baseline = temp[0, label].item() | |||
| batch_images = [] | |||
| batch_ids = [] | |||
| sensitivity_im = np.empty(0, dtype=float) | |||
| output_im_shape = y_pred_shape if self.b_box is None else b_box_max - b_box_min + 1 | |||
| num_required_predictions = np.prod(output_im_shape) | |||
| for i in trange(num_required_predictions): | |||
| idx = np.unravel_index(i, output_im_shape) | |||
| if b_box_min is not None: | |||
| idx += b_box_min | |||
| min_idx = [max(0, i - self.margin) for i in idx] | |||
| max_idx = [min(j, i + self.margin) for i, j in zip(idx, y_pred_shape)] | |||
| occlu_im = y_pred.copy() | |||
| occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = self.pad_val | |||
| batch_images.append(occlu_im) | |||
| batch_ids.append(label) | |||
| if len(batch_images) == self.n_batch or i == num_required_predictions - 1: | |||
| sensitivity_im = self._append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im) | |||
| batch_images = [] | |||
| batch_ids = [] | |||
| self._sensitivity_im = sensitivity_im.reshape(output_im_shape) | |||
| self._is_update = True | |||
| def eval(self): | |||
| """ | |||
| Computes the occlusion_sensitivity. | |||
| Returns: | |||
| A numpy ndarray. | |||
| """ | |||
| if not self._is_update: | |||
| raise RuntimeError('Call the update method before calling eval.') | |||
| sensitivity = self._baseline - np.squeeze(self._sensitivity_im) | |||
| return sensitivity | |||
| @@ -18,7 +18,7 @@ import sys | |||
| import numpy as np | |||
| from mindspore._checkparam import Validator as validator | |||
| from ._evaluation import EvaluationBase | |||
| from .metric import EvaluationBase | |||
| class Precision(EvaluationBase): | |||
| @@ -18,7 +18,7 @@ import sys | |||
| import numpy as np | |||
| from mindspore._checkparam import Validator as validator | |||
| from ._evaluation import EvaluationBase | |||
| from .metric import EvaluationBase | |||
| class Recall(EvaluationBase): | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test_bleu_score""" | |||
| import math | |||
| import pytest | |||
| from mindspore.nn.metrics import BleuScore | |||
| def test_bleu_score(): | |||
| """test_bleu_score""" | |||
| candidate_corpus = [['i', 'have', 'a', 'pen', 'on', 'my', 'desk']] | |||
| reference_corpus = [[['i', 'have', 'a', 'pen', 'in', 'my', 'desk'], | |||
| ['there', 'is', 'a', 'pen', 'on', 'the', 'desk']]] | |||
| metric = BleuScore(n_gram=4, smooth=False) | |||
| metric.clear() | |||
| metric.update(candidate_corpus, reference_corpus) | |||
| bleu_score = metric.eval() | |||
| assert math.isclose(bleu_score, 0.5946035575013605, abs_tol=0.0001) | |||
| def test_bleu_score_update1(): | |||
| """test_bleu_score_update1""" | |||
| candidate_corpus = ['the cat is on the mat'.split()] | |||
| metric = BleuScore() | |||
| metric.clear() | |||
| with pytest.raises(ValueError): | |||
| metric.update(candidate_corpus) | |||
| def test_bleu_score_update2(): | |||
| """test_bleu_score_update2""" | |||
| candidate_corpus = [['the cat is on the mat'.split()], ['a cat is on the mat'.split()]] | |||
| reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] | |||
| metric = BleuScore() | |||
| metric.clear() | |||
| with pytest.raises(ValueError): | |||
| metric.update(candidate_corpus, reference_corpus) | |||
| def test_bleu_score_init1(): | |||
| """test_bleu_score_init1""" | |||
| with pytest.raises(TypeError): | |||
| BleuScore(n_gram="3") | |||
| def test_bleu_score_init2(): | |||
| """test_bleu_score_init2""" | |||
| with pytest.raises(TypeError): | |||
| BleuScore(smooth=5) | |||
| def test_bleu_score_runtime(): | |||
| """test_bleu_score_runtime""" | |||
| metric = BleuScore() | |||
| metric.clear() | |||
| with pytest.raises(RuntimeError): | |||
| metric.eval() | |||
| @@ -0,0 +1,95 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cosine_similarity""" | |||
| import pytest | |||
| import numpy as np | |||
| from sklearn.metrics import pairwise | |||
| from mindspore.nn.metrics import CosineSimilarity | |||
| def test_cosine_similarity(): | |||
| """test_cosine_similarity""" | |||
| test_data = np.array([[5, 8, 3, 2], [5, 8, 3, 2], [4, 2, 3, 4]]) | |||
| metric = CosineSimilarity() | |||
| metric.clear() | |||
| metric.update(test_data) | |||
| square_matrix = metric.eval() | |||
| assert np.allclose(square_matrix, np.array([[0, 1, 0.78229315], [1, 0, 0.78229315], [0.78229315, 0.78229315, 0]])) | |||
| def test_cosine_similarity_compare(): | |||
| """test_cosine_similarity_compare""" | |||
| test_data = np.array([[5, 8, 3, 2], [5, 8, 3, 2], [4, 2, 3, 4]]) | |||
| metric = CosineSimilarity(similarity='cosine', reduction='none', zero_diagonal=False) | |||
| metric.clear() | |||
| metric.update(test_data) | |||
| ms_square_matrix = metric.eval() | |||
| def sklearn_cosine_similarity(test_data, similarity, reduction): | |||
| """sklearn_cosine_similarity""" | |||
| metric_func = {'cosine': pairwise.cosine_similarity, | |||
| 'dot': pairwise.linear_kernel}[similarity] | |||
| square_matrix = metric_func(test_data, test_data) | |||
| if reduction == 'mean': | |||
| return square_matrix.mean(axis=-1) | |||
| if reduction == 'sum': | |||
| return square_matrix.sum(axis=-1) | |||
| return square_matrix | |||
| sk_square_matrix = sklearn_cosine_similarity(test_data, similarity='cosine', reduction='none') | |||
| assert np.allclose(sk_square_matrix, ms_square_matrix) | |||
| def test_cosine_similarity_init1(): | |||
| """test_cosine_similarity_init1""" | |||
| with pytest.raises(ValueError): | |||
| CosineSimilarity(similarity="4") | |||
| def test_cosine_similarity_init2(): | |||
| """test_cosine_similarity_init2""" | |||
| with pytest.raises(TypeError): | |||
| CosineSimilarity(similarity=4) | |||
| def test_cosine_similarity_init3(): | |||
| """test_cosine_similarity_init3""" | |||
| with pytest.raises(TypeError): | |||
| CosineSimilarity(reduction=2) | |||
| def test_cosine_similarity_init4(): | |||
| """test_cosine_similarity_init4""" | |||
| with pytest.raises(ValueError): | |||
| CosineSimilarity(reduction="1") | |||
| def test_cosine_similarity_init5(): | |||
| """test_cosine_similarity_init5""" | |||
| with pytest.raises(TypeError): | |||
| CosineSimilarity(zero_diagonal=3) | |||
| def test_cosine_similarity_runtime(): | |||
| """test_cosine_similarity_runtime""" | |||
| metric = CosineSimilarity() | |||
| metric.clear() | |||
| with pytest.raises(RuntimeError): | |||
| metric.eval() | |||
| @@ -0,0 +1,77 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test_occlusion_sensitivity""" | |||
| import pytest | |||
| import numpy as np | |||
| from mindspore import nn | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.nn.metrics import OcclusionSensitivity | |||
| class DenseNet(nn.Cell): | |||
| def __init__(self): | |||
| super(DenseNet, self).__init__() | |||
| w = np.array([[0.1, 0.8, 0.1, 0.1], [1, 1, 1, 1]]).astype(np.float32) | |||
| b = np.array([0.3, 0.6]).astype(np.float32) | |||
| self.dense = nn.Dense(4, 2, weight_init=Tensor(w), bias_init=Tensor(b)) | |||
| def construct(self, x): | |||
| return self.dense(x) | |||
| model = DenseNet() | |||
| def test_occlusion_sensitivity(): | |||
| """test_occlusion_sensitivity""" | |||
| test_data = np.array([[0.1, 0.2, 0.3, 0.4]]).astype(np.float32) | |||
| label = np.array(1).astype(np.int32) | |||
| metric = OcclusionSensitivity() | |||
| metric.clear() | |||
| metric.update(model, test_data, label) | |||
| score = metric.eval() | |||
| assert np.allclose(score, np.array([0.2, 0.2, 0.2, 0.2])) | |||
| def test_occlusion_sensitivity_update1(): | |||
| """test_occlusion_sensitivity_update1""" | |||
| test_data = np.array([[5, 8], [3, 2], [4, 2]]) | |||
| metric = OcclusionSensitivity() | |||
| metric.clear() | |||
| with pytest.raises(ValueError): | |||
| metric.update(test_data) | |||
| def test_occlusion_sensitivity_init1(): | |||
| """test_occlusion_sensitivity_init1""" | |||
| with pytest.raises(TypeError): | |||
| OcclusionSensitivity(pad_val=False, margin=2, n_batch=128, b_box=None) | |||
| def test_occlusion_sensitivity_init2(): | |||
| """test_occlusion_sensitivity_init2""" | |||
| with pytest.raises(TypeError): | |||
| OcclusionSensitivity(pad_val=0.0, margin=True, n_batch=128, b_box=None) | |||
| def test_occlusion_sensitivity_runtime(): | |||
| """test_occlusion_sensitivity_runtime""" | |||
| metric = OcclusionSensitivity() | |||
| metric.clear() | |||
| with pytest.raises(RuntimeError): | |||
| metric.eval() | |||