You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_comare.py 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore.ops import operations as P
  21. class LessNet(nn.Cell):
  22. def __init__(self):
  23. super(LessNet, self).__init__()
  24. self.ops = P.Less()
  25. def construct(self, x, y):
  26. return self.ops(x, y)
  27. class GreaterNet(nn.Cell):
  28. def __init__(self):
  29. super(GreaterNet, self).__init__()
  30. self.ops = P.Greater()
  31. def construct(self, x, y):
  32. return self.ops(x, y)
  33. class LessEqualNet(nn.Cell):
  34. def __init__(self):
  35. super(LessEqualNet, self).__init__()
  36. self.ops = P.LessEqual()
  37. def construct(self, x, y):
  38. return self.ops(x, y)
  39. class GreaterEqualNet(nn.Cell):
  40. def __init__(self):
  41. super(GreaterEqualNet, self).__init__()
  42. self.ops = P.GreaterEqual()
  43. def construct(self, x, y):
  44. return self.ops(x, y)
  45. def gen_data():
  46. # Generate data which contains broadcast scene and two inputs are expr.
  47. np.random.seed(0)
  48. x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
  49. y0_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32)
  50. x1_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(np.float16)
  51. y1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16)
  52. x2_np = np.random.randint(1, 5, 1).astype(np.int32)
  53. y2_np = np.random.randint(1, 5, 1).astype(np.int32)
  54. x3_np = np.array(768).astype(np.float32)
  55. y3_np = np.array(3072.5).astype(np.float32)
  56. x0 = Tensor(x0_np)
  57. y0 = Tensor(y0_np)
  58. x1 = Tensor(x1_np)
  59. y1 = Tensor(y1_np)
  60. x2 = Tensor(x2_np)
  61. y2 = Tensor(y2_np)
  62. x3 = Tensor(x3_np)
  63. y3 = Tensor(y3_np)
  64. return x0, y0, x1, y1, x2, y2, x3, y3
  65. def get_less_net_output(x0, y0, x1, y1, x2, y2, x3, y3, enable_graph_kernel=False):
  66. context.set_context(enable_graph_kernel=enable_graph_kernel)
  67. net_less = LessNet()
  68. less_output_0 = net_less(x0, y0).asnumpy()
  69. less_output_1 = net_less(x1, y1).asnumpy()
  70. less_output_2 = net_less(x2, y2).asnumpy()
  71. less_output_3 = net_less(x3, y3).asnumpy()
  72. return less_output_0, less_output_1, less_output_2, less_output_3
  73. def get_greater_net_output(x0, y0, x1, y1, x2, y2, x3, y3, enable_graph_kernel=False):
  74. context.set_context(enable_graph_kernel=enable_graph_kernel)
  75. net_greater = GreaterNet()
  76. greater_output_0 = net_greater(x0, y0).asnumpy()
  77. greater_output_1 = net_greater(x1, y1).asnumpy()
  78. greater_output_2 = net_greater(x2, y2).asnumpy()
  79. greater_output_3 = net_greater(x3, y3).asnumpy()
  80. return greater_output_0, greater_output_1, greater_output_2, greater_output_3
  81. def get_less_equal_net_output(x0, y0, x1, y1, x2, y2, x3, y3, enable_graph_kernel=False):
  82. context.set_context(enable_graph_kernel=enable_graph_kernel)
  83. net_less_equal = LessEqualNet()
  84. less_equal_output_0 = net_less_equal(x0, y0).asnumpy()
  85. less_equal_output_1 = net_less_equal(x1, y1).asnumpy()
  86. less_equal_output_2 = net_less_equal(x2, y2).asnumpy()
  87. less_equal_output_3 = net_less_equal(x3, y3).asnumpy()
  88. return less_equal_output_0, less_equal_output_1, less_equal_output_2, less_equal_output_3
  89. def get_greater_equal_net_output(x0, y0, x1, y1, x2, y2, x3, y3, enable_graph_kernel=False):
  90. context.set_context(enable_graph_kernel=enable_graph_kernel)
  91. net_greater_equal = GreaterEqualNet()
  92. greter_equal_output_0 = net_greater_equal(x0, y0).asnumpy()
  93. greter_equal_output_1 = net_greater_equal(x1, y1).asnumpy()
  94. greter_equal_output_2 = net_greater_equal(x2, y2).asnumpy()
  95. greter_equal_output_3 = net_greater_equal(x3, y3).asnumpy()
  96. return greter_equal_output_0, greter_equal_output_1, greter_equal_output_2, greter_equal_output_3
  97. def test_less_net():
  98. x0, y0, x1, y1, x2, y2, x3, y3 = gen_data()
  99. out_gk_on_0, out_gk_on_1, out_gk_on_2, out_gk_on_3 = get_less_net_output(x0, y0, x1, y1, x2, y2, x3, y3, True)
  100. out_gk_off_0, out_gk_off_1, out_gk_off_2, out_gk_off_3 = get_less_net_output(
  101. x0, y0, x1, y1, x2, y2, x3, y3, False)
  102. assert np.all(out_gk_on_0 == out_gk_off_0)
  103. assert out_gk_on_0.shape == out_gk_off_0.shape
  104. assert np.all(out_gk_on_1 == out_gk_off_1)
  105. assert out_gk_on_1.shape == out_gk_off_1.shape
  106. assert np.all(out_gk_on_2 == out_gk_off_2)
  107. assert out_gk_on_2.shape == out_gk_off_2.shape
  108. assert np.all(out_gk_on_3 == out_gk_off_3)
  109. assert out_gk_on_3.shape == out_gk_off_3.shape
  110. def test_greater_net():
  111. x0, y0, x1, y1, x2, y2, x3, y3 = gen_data()
  112. out_gk_on_0, out_gk_on_1, out_gk_on_2, out_gk_on_3 = get_greater_net_output(x0, y0, x1, y1, x2, y2, x3, y3, True)
  113. out_gk_off_0, out_gk_off_1, out_gk_off_2, out_gk_off_3 = get_greater_net_output(
  114. x0, y0, x1, y1, x2, y2, x3, y3, False)
  115. assert np.all(out_gk_on_0 == out_gk_off_0)
  116. assert out_gk_on_0.shape == out_gk_off_0.shape
  117. assert np.all(out_gk_on_1 == out_gk_off_1)
  118. assert out_gk_on_1.shape == out_gk_off_1.shape
  119. assert np.all(out_gk_on_2 == out_gk_off_2)
  120. assert out_gk_on_2.shape == out_gk_off_2.shape
  121. assert np.all(out_gk_on_3 == out_gk_off_3)
  122. assert out_gk_on_3.shape == out_gk_off_3.shape
  123. def test_less_equal_net():
  124. x0, y0, x1, y1, x2, y2, x3, y3 = gen_data()
  125. out_gk_on_0, out_gk_on_1, out_gk_on_2, out_gk_on_3 = get_less_equal_net_output(
  126. x0, y0, x1, y1, x2, y2, x3, y3, True)
  127. out_gk_off_0, out_gk_off_1, out_gk_off_2, out_gk_off_3 = get_less_equal_net_output(
  128. x0, y0, x1, y1, x2, y2, x3, y3, False)
  129. assert np.all(out_gk_on_0 == out_gk_off_0)
  130. assert out_gk_on_0.shape == out_gk_off_0.shape
  131. assert np.all(out_gk_on_1 == out_gk_off_1)
  132. assert out_gk_on_1.shape == out_gk_off_1.shape
  133. assert np.all(out_gk_on_2 == out_gk_off_2)
  134. assert out_gk_on_2.shape == out_gk_off_2.shape
  135. assert np.all(out_gk_on_3 == out_gk_off_3)
  136. assert out_gk_on_3.shape == out_gk_off_3.shape
  137. def test_greater_equal_net():
  138. x0, y0, x1, y1, x2, y2, x3, y3 = gen_data()
  139. out_gk_on_0, out_gk_on_1, out_gk_on_2, out_gk_on_3 = get_greater_equal_net_output(
  140. x0, y0, x1, y1, x2, y2, x3, y3, True)
  141. out_gk_off_0, out_gk_off_1, out_gk_off_2, out_gk_off_3 = get_greater_equal_net_output(
  142. x0, y0, x1, y1, x2, y2, x3, y3, False)
  143. assert np.all(out_gk_on_0 == out_gk_off_0)
  144. assert out_gk_on_0.shape == out_gk_off_0.shape
  145. assert np.all(out_gk_on_1 == out_gk_off_1)
  146. assert out_gk_on_1.shape == out_gk_off_1.shape
  147. assert np.all(out_gk_on_2 == out_gk_off_2)
  148. assert out_gk_on_2.shape == out_gk_off_2.shape
  149. assert np.all(out_gk_on_3 == out_gk_off_3)
  150. assert out_gk_on_3.shape == out_gk_off_3.shape
  151. @pytest.mark.level0
  152. @pytest.mark.platform_x86_gpu_training
  153. @pytest.mark.env_onecard
  154. def test_less_gpu():
  155. context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
  156. test_less_net()
  157. @pytest.mark.level0
  158. @pytest.mark.platform_x86_gpu_training
  159. @pytest.mark.env_onecard
  160. def test_greater_gpu():
  161. context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
  162. test_greater_net()
  163. @pytest.mark.level0
  164. @pytest.mark.platform_x86_gpu_training
  165. @pytest.mark.env_onecard
  166. def test_less_equal_gpu():
  167. context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
  168. test_less_equal_net()
  169. @pytest.mark.level0
  170. @pytest.mark.platform_x86_gpu_training
  171. @pytest.mark.env_onecard
  172. def test_greater_equal_gpu():
  173. context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
  174. test_greater_equal_net()