From d6f4421a1d2605e6a5b44abc2f430bd6390607cb Mon Sep 17 00:00:00 2001 From: Jiaqi Date: Sat, 9 Jan 2021 09:46:15 +0800 Subject: [PATCH] perplexity --- mindspore/nn/metrics/__init__.py | 3 + mindspore/nn/metrics/perplexity.py | 114 +++++++++++++++++++++ tests/ut/python/metrics/test_perplexity.py | 65 ++++++++++++ 3 files changed, 182 insertions(+) create mode 100644 mindspore/nn/metrics/perplexity.py create mode 100644 tests/ut/python/metrics/test_perplexity.py diff --git a/mindspore/nn/metrics/__init__.py b/mindspore/nn/metrics/__init__.py index 211c201ce8..0abf7de28f 100755 --- a/mindspore/nn/metrics/__init__.py +++ b/mindspore/nn/metrics/__init__.py @@ -35,6 +35,7 @@ from .root_mean_square_surface_distance import RootMeanSquareDistance from .bleu_score import BleuScore from .cosine_similarity import CosineSimilarity from .occlusion_sensitivity import OcclusionSensitivity +from .perplexity import Perplexity __all__ = [ "names", @@ -59,6 +60,7 @@ __all__ = [ "Loss", "MeanSurfaceDistance", "RootMeanSquareDistance", + "Perplexity", ] __factory__ = { @@ -82,6 +84,7 @@ __factory__ = { 'loss': Loss, 'mean_surface_distance': MeanSurfaceDistance, 'root_mean_square_distance': RootMeanSquareDistance, + 'perplexity': Perplexity, } diff --git a/mindspore/nn/metrics/perplexity.py b/mindspore/nn/metrics/perplexity.py new file mode 100644 index 0000000000..c1c0c3726c --- /dev/null +++ b/mindspore/nn/metrics/perplexity.py @@ -0,0 +1,114 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Perplexity""" +import math +import numpy as np +from mindspore._checkparam import Validator as validator +from .metric import Metric + + +class Perplexity(Metric): + r""" + Computes perplexity. Perplexity is a measurement about how well a probability distribution or a model predicts a + sample. A low perplexity indicates the model can predict the sample well. The function is shown as follows: + + .. math:: + b^{\\big(-\\frac{1}{N} \\sum_{i=1}^N \\log_b q(x_i) \\big)} + = \\exp \\big(-\\frac{1}{N} \\sum_{i=1}^N \\log q(x_i)\\big) + + Args: + ignore_label (int): Index of an invalid label to be ignored when counting. If set to `None`, it will include all + entries. Default: -1. + + Examples: + >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) + >>> y = Tensor(np.array([1, 0, 1])) + >>> metric = Perplexity(ignore_label=None) + >>> metric.clear() + >>> metric.update(x, y) + >>> perplexity = metric.eval() + 2.231443166940565 + """ + + def __init__(self, ignore_label=None): + super(Perplexity, self).__init__() + + if ignore_label is None: + self.ignore_label = ignore_label + else: + self.ignore_label = validator.check_value_type("ignore_label", ignore_label, [int]) + self.clear() + + def clear(self): + """Clears the internal evaluation result.""" + self._sum_metric = 0.0 + self._num_inst = 0 + + def update(self, *inputs): + """ + Updates the internal evaluation result: math:preds and :math:labels. + + Args: + inputs: Input `preds` and `labels`. `preds` and `labels` are Tensor, list or numpy.ndarray. + `preds` is the predicted values, `labels` is the label of the data. + The shape of `preds` and `labels` are both :math:`(N, C)`. + + Raises: + ValueError: If the number of the inputs is not 2. + """ + if len(inputs) != 2: + raise ValueError('Perplexity needs 2 inputs (preds, labels), but got {}.'.format(len(inputs))) + + preds = [self._convert_data(inputs[0])] + labels = [self._convert_data(inputs[1])] + + if len(preds) != len(labels): + raise RuntimeError('preds and labels should have the same length, but the length of preds is{}, ' + 'the length of labels is {}.'.format(len(preds), len(labels))) + + loss = 0. + num = 0 + for label, pred in zip(labels, preds): + if label.size != pred.size / pred.shape[-1]: + raise RuntimeError("shape mismatch: label shape should be equal to pred shape, but got label shape " + "is {}, pred shape is {}.".format(label.shape, pred.shape)) + label = label.reshape((label.size,)) + label_expand = label.astype(int) + label_expand = np.expand_dims(label_expand, axis=1) + first_indices = np.arange(label_expand.shape[0])[:, None] + pred = np.squeeze(pred[first_indices, label_expand]) + if self.ignore_label is not None: + ignore = (label == self.ignore_label).astype(pred.dtype) + num -= np.sum(ignore) + pred = pred * (1 - ignore) + ignore + loss -= np.sum(np.log(np.maximum(1e-10, pred))) + num += pred.size + self._sum_metric += loss + self._num_inst += num + + def eval(self): + r""" + Returns the current evaluation result. + + Returns: + float, the computed result. + + Raises: + RuntimeError: If the sample size is 0. + """ + if self._num_inst == 0: + raise RuntimeError('Perplexity can not be calculated, because the number of samples is 0.') + + return math.exp(self._sum_metric / self._num_inst) diff --git a/tests/ut/python/metrics/test_perplexity.py b/tests/ut/python/metrics/test_perplexity.py new file mode 100644 index 0000000000..605531b1bc --- /dev/null +++ b/tests/ut/python/metrics/test_perplexity.py @@ -0,0 +1,65 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# """test_perplexity""" + +import math +import numpy as np +import pytest +from mindspore import Tensor +from mindspore.nn.metrics import get_metric_fn, Perplexity + + +def test_perplexity(): + """test_perplexity""" + x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) + y = Tensor(np.array([1, 0, 1])) + metric = get_metric_fn('perplexity') + metric.clear() + metric.update(x, y) + perplexity = metric.eval() + + assert math.isclose(perplexity, 2.231443166940565, abs_tol=0.001) + + +def test_perplexity_update1(): + x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) + metric = Perplexity() + metric.clear() + + with pytest.raises(ValueError): + metric.update(x) + + +def test_perplexity_update2(): + x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) + y = Tensor(np.array([1, 0])) + metric = Perplexity() + metric.clear() + + with pytest.raises(RuntimeError): + metric.update(x, y) + + +def test_perplexity_init(): + with pytest.raises(TypeError): + Perplexity(ignore_label='abc') + + +def test_perplexity_runtime(): + metric = Perplexity() + metric.clear() + + with pytest.raises(RuntimeError): + metric.eval()