You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_topk.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test topk"""
  16. import math
  17. import numpy as np
  18. import pytest
  19. from mindspore import Tensor
  20. from mindspore.nn.metrics import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy
  21. def test_type_topk():
  22. with pytest.raises(TypeError):
  23. TopKCategoricalAccuracy(2.1)
  24. def test_value_topk():
  25. with pytest.raises(ValueError):
  26. TopKCategoricalAccuracy(-1)
  27. def test_input_topk():
  28. x = Tensor(np.array([[0.2, 0.5, 0.3, 0.6, 0.2],
  29. [0.3, 0.1, 0.5, 0.1, 0.],
  30. [0.9, 0.6, 0.2, 0.01, 0.3]]))
  31. topk = TopKCategoricalAccuracy(3)
  32. topk.clear()
  33. with pytest.raises(ValueError):
  34. topk.update(x)
  35. def test_topk():
  36. """test_topk"""
  37. x = Tensor(np.array([[0.2, 0.5, 0.3, 0.6, 0.2],
  38. [0.1, 0.35, 0.5, 0.2, 0.],
  39. [0.9, 0.6, 0.2, 0.01, 0.3]]))
  40. y = Tensor(np.array([2, 0, 1]))
  41. y2 = Tensor(np.array([[0, 0, 1, 0, 0],
  42. [1, 0, 0, 0, 0],
  43. [0, 1, 0, 0, 0]]))
  44. topk = TopKCategoricalAccuracy(3)
  45. topk.clear()
  46. topk.update(x, y)
  47. result = topk.eval()
  48. result2 = topk(x, y2)
  49. assert math.isclose(result, 2 / 3)
  50. assert math.isclose(result2, 2 / 3)
  51. def test_zero_topk():
  52. topk = TopKCategoricalAccuracy(3)
  53. topk.clear()
  54. with pytest.raises(RuntimeError):
  55. topk.eval()
  56. def test_top1():
  57. """test_top1"""
  58. x = Tensor(np.array([[0.2, 0.5, 0.2, 0.1, 0.],
  59. [0.1, 0.35, 0.25, 0.2, 0.1],
  60. [0.9, 0.1, 0, 0., 0]]))
  61. y = Tensor(np.array([2, 0, 0]))
  62. y2 = Tensor(np.array([[0, 0, 1, 0, 0],
  63. [1, 0, 0, 0, 0],
  64. [1, 0, 0, 0, 0]]))
  65. topk = Top1CategoricalAccuracy()
  66. topk.clear()
  67. topk.update(x, y)
  68. result = topk.eval()
  69. result2 = topk(x, y2)
  70. assert math.isclose(result, 1 / 3)
  71. assert math.isclose(result2, 1 / 3)
  72. def test_top5():
  73. """test_top5"""
  74. x = Tensor(np.array([[0.15, 0.4, 0.1, 0.05, 0., 0.2, 0.1],
  75. [0.1, 0.35, 0.25, 0.2, 0.1, 0., 0.],
  76. [0., 0.5, 0.2, 0.1, 0.1, 0.1, 0.]]))
  77. y = Tensor(np.array([2, 0, 0]))
  78. y2 = Tensor(np.array([[0, 0, 1, 0, 0],
  79. [1, 0, 0, 0, 0],
  80. [1, 0, 0, 0, 0]]))
  81. topk = Top5CategoricalAccuracy()
  82. topk.clear()
  83. topk.update(x, y)
  84. result = topk.eval()
  85. result2 = topk(x, y2)
  86. assert math.isclose(result, 2 / 3)
  87. assert math.isclose(result2, 2 / 3)