You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_metric_factory.py 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test_metric_factory"""
  16. import math
  17. import numpy as np
  18. from mindspore import Tensor
  19. from mindspore.nn.metrics import get_metric_fn
  20. from mindspore.nn.metrics.metric import rearrange_inputs
  21. def test_classification_accuracy():
  22. x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
  23. y = Tensor(np.array([1, 0, 1]))
  24. metric = get_metric_fn('accuracy', eval_type='classification')
  25. metric.clear()
  26. metric.update(x, y)
  27. accuracy = metric.eval()
  28. assert math.isclose(accuracy, 2 / 3)
  29. def test_classification_accuracy_by_alias():
  30. x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
  31. y = Tensor(np.array([1, 0, 1]))
  32. metric = get_metric_fn('acc', eval_type='classification')
  33. metric.clear()
  34. metric.update(x, y)
  35. accuracy = metric.eval()
  36. assert math.isclose(accuracy, 2 / 3)
  37. def test_classification_precision():
  38. x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
  39. y = Tensor(np.array([1, 0, 1]))
  40. metric = get_metric_fn('precision', eval_type='classification')
  41. metric.clear()
  42. metric.update(x, y)
  43. precision = metric.eval()
  44. assert np.equal(precision, np.array([0.5, 1])).all()
  45. class RearrangeInputsDemo:
  46. def __init__(self):
  47. self._indexes = None
  48. @property
  49. def indexes(self):
  50. return getattr(self, '_indexes', None)
  51. def set_indexes(self, indexes):
  52. self._indexes = indexes
  53. return self
  54. @rearrange_inputs
  55. def update(self, *inputs):
  56. return inputs
  57. def test_rearrange_inputs_without_arrange():
  58. mini_decorator = RearrangeInputsDemo()
  59. outs = mini_decorator.update(5, 9)
  60. assert outs == (5, 9)
  61. def test_rearrange_inputs_with_arrange():
  62. mini_decorator = RearrangeInputsDemo().set_indexes([1, 0])
  63. outs = mini_decorator.update(5, 9)
  64. assert outs == (9, 5)
  65. def test_rearrange_inputs_with_multi_inputs():
  66. mini_decorator = RearrangeInputsDemo().set_indexes([1, 3])
  67. outs = mini_decorator.update(0, 9, 0, 5)
  68. assert outs == (9, 5)