You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_argminwithvalue_op.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore.ops import operations as P
  21. context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
  22. class NetArgminWithValue(nn.Cell):
  23. def __init__(self, axis=0, keep_dims=False):
  24. super(NetArgminWithValue, self).__init__()
  25. self.argmin = P.ArgMinWithValue(axis=axis, keep_dims=keep_dims)
  26. def construct(self, x):
  27. return self.argmin(x)
  28. @pytest.mark.level0
  29. @pytest.mark.platform_x86_cpu
  30. @pytest.mark.env_onecard
  31. def test_argminwithvalue_fp32():
  32. x = np.array([[1., 20., 5.],
  33. [67., 8., 9.],
  34. [130., 24., 15.],
  35. [-0.5, 25, 100]]).astype(np.float32)
  36. argmin_a0 = NetArgminWithValue(axis=0, keep_dims=False)
  37. output0, output1 = argmin_a0(Tensor(x))
  38. expect0 = np.array([3, 1, 0]).astype(np.int32)
  39. expect1 = np.array([-0.5, 8., 5.]).astype(np.float32)
  40. error = np.ones(shape=expect1.shape) * 1.0e-6
  41. assert np.all(output0.asnumpy() == expect0)
  42. assert np.all(np.abs(output1.asnumpy() - expect1) < error)
  43. argmin_a0k = NetArgminWithValue(axis=0, keep_dims=True)
  44. output0, output1 = argmin_a0k(Tensor(x))
  45. expect0 = np.array([[3, 1, 0]]).astype(np.int32)
  46. expect1 = np.array([[-0.5, 8., 5.]]).astype(np.float32)
  47. error = np.ones(shape=expect1.shape) * 1.0e-6
  48. assert np.all(output0.asnumpy() == expect0)
  49. assert np.all(np.abs(output1.asnumpy() - expect1) < error)
  50. argmin_a1 = NetArgminWithValue(axis=1, keep_dims=False)
  51. output0, output1 = argmin_a1(Tensor(x))
  52. expect0 = np.array([0, 1, 2, 0]).astype(np.int32)
  53. expect1 = np.array([1., 8., 15., -0.5]).astype(np.float32)
  54. error = np.ones(shape=expect1.shape) * 1.0e-6
  55. assert np.all(output0.asnumpy() == expect0)
  56. assert np.all(np.abs(output1.asnumpy() - expect1) < error)
  57. argmin_a1k = NetArgminWithValue(axis=-1, keep_dims=True)
  58. output0, output1 = argmin_a1k(Tensor(x))
  59. expect0 = np.array([[0], [1], [2], [0]]).astype(np.int32)
  60. expect1 = np.array([[1.], [8.], [15.], [-0.5]]).astype(np.float32)
  61. error = np.ones(shape=expect1.shape) * 1.0e-6
  62. assert np.all(output0.asnumpy() == expect0)
  63. assert np.all(np.abs(output1.asnumpy() - expect1) < error)
  64. @pytest.mark.level0
  65. @pytest.mark.platform_x86_cpu
  66. @pytest.mark.env_onecard
  67. def test_argminwithvalue_fp16():
  68. x = np.array([[1., 20., 5.],
  69. [67., 8., 9.],
  70. [130., 24., 15.],
  71. [-0.5, 25, 100]]).astype(np.float16)
  72. argmin_a0 = NetArgminWithValue(axis=0, keep_dims=False)
  73. output0, output1 = argmin_a0(Tensor(x))
  74. expect0 = np.array([3, 1, 0]).astype(np.int32)
  75. expect1 = np.array([-0.5, 8., 5.]).astype(np.float16)
  76. error = np.ones(shape=expect1.shape) * 1.0e-6
  77. assert np.all(output0.asnumpy() == expect0)
  78. assert np.all(np.abs(output1.asnumpy() - expect1) < error)
  79. argmin_a0k = NetArgminWithValue(axis=0, keep_dims=True)
  80. output0, output1 = argmin_a0k(Tensor(x))
  81. expect0 = np.array([[3, 1, 0]]).astype(np.int32)
  82. expect1 = np.array([[-0.5, 8., 5.]]).astype(np.float16)
  83. error = np.ones(shape=expect1.shape) * 1.0e-6
  84. assert np.all(output0.asnumpy() == expect0)
  85. assert np.all(np.abs(output1.asnumpy() - expect1) < error)
  86. argmin_a1 = NetArgminWithValue(axis=1, keep_dims=False)
  87. output0, output1 = argmin_a1(Tensor(x))
  88. expect0 = np.array([0, 1, 2, 0]).astype(np.int32)
  89. expect1 = np.array([1., 8., 15., -0.5]).astype(np.float16)
  90. error = np.ones(shape=expect1.shape) * 1.0e-6
  91. assert np.all(output0.asnumpy() == expect0)
  92. assert np.all(np.abs(output1.asnumpy() - expect1) < error)
  93. argmin_a1k = NetArgminWithValue(axis=-1, keep_dims=True)
  94. output0, output1 = argmin_a1k(Tensor(x))
  95. expect0 = np.array([[0], [1], [2], [0]]).astype(np.int32)
  96. expect1 = np.array([[1.], [8.], [15.], [-0.5]]).astype(np.float16)
  97. error = np.ones(shape=expect1.shape) * 1.0e-6
  98. assert np.all(output0.asnumpy() == expect0)
  99. assert np.all(np.abs(output1.asnumpy() - expect1) < error)
  100. @pytest.mark.level0
  101. @pytest.mark.platform_x86_cpu
  102. @pytest.mark.env_onecard
  103. def test_argminwithvalue_tensor():
  104. prop = 100 if np.random.random() > 0.5 else -100
  105. x = np.random.randn(3, 4, 5, 6).astype(np.float16) * prop
  106. argmin_a0 = NetArgminWithValue(axis=-2, keep_dims=False)
  107. output0, output1 = argmin_a0(Tensor(x))
  108. expect0 = np.argmin(x, axis=-2)
  109. expect1 = np.min(x, axis=-2).astype(np.float16)
  110. error = np.ones(shape=expect1.shape) * 1.0e-6
  111. assert np.all(output0.asnumpy() == expect0)
  112. assert np.all(np.abs(output1.asnumpy() - expect1) < error)