You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_minimum_op.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import pytest
  16. from mindspore.ops import operations as P
  17. from mindspore.ops import composite as C
  18. from mindspore.nn import Cell
  19. from mindspore.common.tensor import Tensor
  20. import mindspore.common.dtype as mstype
  21. import mindspore.context as context
  22. import numpy as np
  23. class MinimumNet(Cell):
  24. def __init__(self):
  25. super(MinimumNet, self).__init__()
  26. self.min = P.Minimum()
  27. def construct(self, x1, x2):
  28. x = self.min(x1, x2)
  29. return x
  30. class Grad(Cell):
  31. def __init__(self, network):
  32. super(Grad, self).__init__()
  33. self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True)
  34. self.network = network
  35. def construct(self, x1, x2, sens):
  36. gout = self.grad(self.network)(x1, x2, sens)
  37. return gout
  38. @pytest.mark.level0
  39. @pytest.mark.platform_x86_gpu_training
  40. @pytest.mark.env_onecard
  41. def test_nobroadcast():
  42. context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target='GPU')
  43. x1_np = np.random.rand(3, 4).astype(np.float32)
  44. x2_np = np.random.rand(3, 4).astype(np.float32)
  45. dy_np = np.random.rand(3, 4).astype(np.float32)
  46. net = Grad(MinimumNet())
  47. output_ms = net(Tensor(x1_np), Tensor(x2_np), Tensor(dy_np))
  48. output0_np = np.where(x1_np < x2_np, dy_np, 0)
  49. output1_np = np.where(x1_np < x2_np, 0, dy_np)
  50. assert np.allclose(output_ms[0].asnumpy(), output0_np)
  51. assert np.allclose(output_ms[1].asnumpy(), output1_np)
  52. @pytest.mark.level0
  53. @pytest.mark.platform_x86_gpu_training
  54. @pytest.mark.env_onecard
  55. def test_broadcast():
  56. context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target='GPU')
  57. x1_np = np.array([[[[0.659578 ],
  58. [0.49113268],
  59. [0.75909054],
  60. [0.71681815],
  61. [0.30421826]]],
  62. [[[0.30322495],
  63. [0.02858258],
  64. [0.06398096],
  65. [0.09519596],
  66. [0.12498625]]],
  67. [[[0.7347768 ],
  68. [0.166469 ],
  69. [0.328553 ],
  70. [0.54908437],
  71. [0.23673844]]]]).astype(np.float32)
  72. x2_np = np.array([[[[0.9154968, 0.29014662, 0.6492294, 0.39918253, 0.1648203, 0.00861965]],
  73. [[0.996885, 0.24152198, 0.3601213, 0.51664376, 0.7933056, 0.84706444]],
  74. [[0.75606346, 0.974512, 0.3939527, 0.69697475, 0.83400667, 0.6348955 ]],
  75. [[0.68492866, 0.24609096, 0.4924665, 0.22500521, 0.38474053, 0.5586104 ]]]]).astype(np.float32)
  76. dy_np = np.array([[[[0.42891738, 0.03434946, 0.06192983, 0.21216309, 0.37450036, 0.6619524 ],
  77. [0.8583447, 0.5765161, 0.1468952, 0.9975385, 0.6908136, 0.4903796 ],
  78. [0.68952006, 0.39336833, 0.9049695, 0.66886294, 0.2338471, 0.913618 ],
  79. [0.0428149, 0.6243054, 0.8519898, 0.12088962, 0.9735885, 0.45661286],
  80. [0.41563734, 0.41607043, 0.4754915, 0.32207987, 0.33823156, 0.47422352]],
  81. [[0.64478457, 0.22430937, 0.7682554, 0.46082005, 0.8938723, 0.20490853],
  82. [0.44393885, 0.08278944, 0.4734108, 0.5543551, 0.39428464, 0.44424313],
  83. [0.12612297, 0.76566416, 0.71133816, 0.81280327, 0.20583127, 0.54058075],
  84. [0.41341263, 0.48118508, 0.00401995, 0.37259838, 0.05435474, 0.5240658 ],
  85. [0.4081956, 0.48718935, 0.9132831, 0.67969185, 0.0119757, 0.8328054 ]],
  86. [[0.91695577, 0.95370644, 0.263782, 0.7477626, 0.6448147, 0.8080634 ],
  87. [0.15576603, 0.9104615, 0.3778708, 0.6912833, 0.2092224, 0.67462957],
  88. [0.7087075, 0.7888326, 0.4672294, 0.98221505, 0.25210258, 0.98920417],
  89. [0.7466197, 0.22702982, 0.01991269, 0.6846591, 0.7515228, 0.5890395 ],
  90. [0.04531088, 0.21740614, 0.8406235, 0.36480767, 0.37733936, 0.02914464]],
  91. [[0.33069974, 0.5497569, 0.9896345, 0.4167176, 0.78057563, 0.04659131],
  92. [0.7747768, 0.21427679, 0.29893255, 0.7706969, 0.9755185, 0.42388415],
  93. [0.3910244, 0.39381978, 0.37065396, 0.15558061, 0.05012341, 0.15870963],
  94. [0.17791101, 0.47219893, 0.13899496, 0.32323205, 0.3628809, 0.02580585],
  95. [0.30274773, 0.62890774, 0.11024303, 0.6980051, 0.35346958, 0.062852 ]]],
  96. [[[0.6925081, 0.74668753, 0.80145043, 0.06598313, 0.665123, 0.15073007],
  97. [0.11784806, 0.6385372, 0.5228278, 0.5349848, 0.84671104, 0.8096436 ],
  98. [0.09516156, 0.63298017, 0.52382874, 0.36734378, 0.66497755, 0.6019127 ],
  99. [0.46438488, 0.0194377, 0.9388292, 0.7286089, 0.29178405, 0.11872514],
  100. [0.22101837, 0.6164887, 0.6139798, 0.11711904, 0.6227745, 0.09701069]],
  101. [[0.80480653, 0.90034056, 0.8633447, 0.97415197, 0.08309154, 0.8446033 ],
  102. [0.9473769, 0.791024, 0.26339203, 0.01155075, 0.2673186, 0.7116369 ],
  103. [0.9687511, 0.24281934, 0.37777108, 0.09802654, 0.2421312, 0.87095344],
  104. [0.6311381, 0.23368953, 0.0998995, 0.4364419, 0.9187446, 0.5043872 ],
  105. [0.35226053, 0.09357589, 0.41317305, 0.85930043, 0.16249318, 0.5478765 ]],
  106. [[0.14338651, 0.24859418, 0.4246941, 0.73034066, 0.47172204, 0.8717199 ],
  107. [0.05415315, 0.78556925, 0.99214983, 0.7415298, 0.673708, 0.87817156],
  108. [0.616975, 0.42843062, 0.05179814, 0.1566958, 0.04536059, 0.70166487],
  109. [0.15493333, 0.776598, 0.4361967, 0.40253627, 0.89210516, 0.8144414 ],
  110. [0.04816005, 0.29696834, 0.4586605, 0.3419852, 0.5595613, 0.74093205]],
  111. [[0.1388035, 0.9168704, 0.64287645, 0.83864623, 0.48026922, 0.78323376],
  112. [0.12724937, 0.83034366, 0.42557436, 0.50578654, 0.25630295, 0.15349793],
  113. [0.27256685, 0.04547984, 0.5385756, 0.39270344, 0.7661698, 0.23722854],
  114. [0.24620503, 0.25431684, 0.71564585, 0.01161419, 0.846467, 0.7043044 ],
  115. [0.63272387, 0.11857849, 0.3772076, 0.16758402, 0.46743023, 0.05919575]]],
  116. [[[0.18827082, 0.8912264, 0.6841404, 0.74436826, 0.9582085, 0.1083683 ],
  117. [0.60695344, 0.09742349, 0.25074378, 0.87940735, 0.21116392, 0.39418384],
  118. [0.744686, 0.35679692, 0.01308284, 0.45166633, 0.68166, 0.8634658 ],
  119. [0.7331758, 0.21113694, 0.3935488, 0.87934476, 0.70728546, 0.09309767],
  120. [0.12128611, 0.93696386, 0.81177396, 0.85402405, 0.5827289, 0.9776509 ]],
  121. [[0.54069614, 0.66651285, 0.10646132, 0.17342485, 0.88795924, 0.03551182],
  122. [0.25531697, 0.87946486, 0.74267226, 0.89230734, 0.95171434, 0.94697934],
  123. [0.3708397, 0.507355, 0.97099817, 0.4918163, 0.17212386, 0.5008048 ],
  124. [0.62530744, 0.25210327, 0.73966664, 0.71555346, 0.82484317, 0.6094874 ],
  125. [0.4589691, 0.1386695, 0.27448782, 0.20373994, 0.27805242, 0.23292768]],
  126. [[0.7414099, 0.2270226, 0.90431255, 0.47035843, 0.9581062, 0.5359226 ],
  127. [0.79603523, 0.45549425, 0.80858237, 0.7705133, 0.017761, 0.98001194],
  128. [0.06013146, 0.99240226, 0.33515573, 0.04110833, 0.41470334, 0.7130743 ],
  129. [0.5687417, 0.5788611, 0.00722461, 0.6603336, 0.3420471, 0.75181854],
  130. [0.4699261, 0.51390815, 0.343182, 0.81498754, 0.8942413, 0.46532857]],
  131. [[0.4589523, 0.5534698, 0.2825786, 0.8205943, 0.78258514, 0.43154418],
  132. [0.27020997, 0.01667354, 0.60871965, 0.90670526, 0.3208025, 0.96995634],
  133. [0.85337156, 0.9711295, 0.1381724, 0.53670496, 0.7347996, 0.73380876],
  134. [0.6137464, 0.54751194, 0.9037335, 0.23134394, 0.61411524, 0.26583543],
  135. [0.70770144, 0.01813207, 0.24718016, 0.70329237, 0.7062925, 0.14399007]]]]).astype(np.float32)
  136. expect_dx1 = np.array([[[[ 5.7664223],
  137. [ 6.981018 ],
  138. [ 2.6029902],
  139. [ 2.7598202],
  140. [ 6.763105 ]]],
  141. [[[10.06558 ],
  142. [12.077246 ],
  143. [ 9.338394 ],
  144. [11.52271 ],
  145. [ 8.889048 ]]],
  146. [[[ 3.5789769],
  147. [13.424448 ],
  148. [ 8.732746 ],
  149. [ 6.9677467],
  150. [ 9.635765 ]]]]).astype(np.float32)
  151. expect_dx2 = np.array([[[[0. , 4.250458 , 2.5030296 , 3.623167 , 6.4171505 , 7.2115746 ]],
  152. [[0. , 4.367449 , 2.803152 , 2.5352 , 0. , 0. ]],
  153. [[0.7087075 , 0. , 2.040332 , 2.1372325 , 0. , 2.9222295 ]],
  154. [[1.0278877 , 5.247942 , 2.6855955 , 5.494814 , 3.5657988 , 0.66265094]]]]).astype(np.float32)
  155. net = Grad(MinimumNet())
  156. output_ms = net(Tensor(x1_np), Tensor(x2_np), Tensor(dy_np))
  157. assert np.allclose(output_ms[0].asnumpy(), expect_dx1)
  158. assert np.allclose(output_ms[1].asnumpy(), expect_dx2)
  159. @pytest.mark.level0
  160. @pytest.mark.platform_x86_gpu_training
  161. @pytest.mark.env_onecard
  162. def test_broadcast_diff_dims():
  163. context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target='GPU')
  164. x1_np = np.array([[[0.275478, 0.48933202, 0.71846116],
  165. [0.9803821, 0.57205725, 0.28511533]],
  166. [[0.61111903, 0.9671023, 0.70624334],
  167. [0.53730786, 0.90413177, 0.94349676]]]).astype(np.float32)
  168. x2_np = np.array([[0.01045662, 0.82126397, 0.6365063 ],
  169. [0.9900942, 0.6584232, 0.98537433]]).astype(np.float32)
  170. dy_np = np.array([[[0.3897645, 0.61152864, 0.33675498],
  171. [0.5303635, 0.84893036, 0.4959739 ]],
  172. [[0.5391046, 0.8443047, 0.4174708 ],
  173. [0.57513475, 0.9225578, 0.46760973]]]).astype(np.float32)
  174. expect_dx1 = np.array([[[0. , 0.61152864, 0. ],
  175. [0.5303635 , 0.84893036, 0.4959739 ]],
  176. [[0. , 0. , 0. ],
  177. [0.57513475, 0. , 0.46760973]]]).astype(np.float32)
  178. expect_dx2 = np.array([[0.92886907, 0.8443047 , 0.7542258 ],
  179. [0. , 0.9225578 , 0. ]]).astype(np.float32)
  180. net = Grad(MinimumNet())
  181. output_ms = net(Tensor(x1_np), Tensor(x2_np), Tensor(dy_np))
  182. assert np.allclose(output_ms[0].asnumpy(), expect_dx1)
  183. assert np.allclose(output_ms[1].asnumpy(), expect_dx2)