You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_recompute.py 7.4 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. from mindspore import Tensor
  19. from mindspore.common import dtype as mstype
  20. from mindspore.nn import Cell
  21. import mindspore.ops.operations as P
  22. #{cast} would be recompute and fused
  23. class Net1(Cell):
  24. def __init__(self):
  25. super(Net1, self).__init__()
  26. self.cast = P.Cast()
  27. self.sum = P.ReduceSum(keep_dims=False)
  28. def construct(self, x):
  29. cast_res = self.cast(x, mstype.float32)
  30. sum1_res = self.sum(cast_res, (0,))
  31. sum2_res = self.sum(cast_res, (1,))
  32. return sum1_res, sum2_res
  33. #{sqrt} would be recompute on Ascend
  34. class Net2(Cell):
  35. def __init__(self):
  36. super(Net2, self).__init__()
  37. self.sqrt = P.Sqrt()
  38. self.sum = P.ReduceSum(keep_dims=True)
  39. self.add = P.Add()
  40. self.neg = P.Neg()
  41. def construct(self, x0, x1):
  42. sqrt_res = self.sqrt(x0)
  43. neg_res = self.neg(sqrt_res)
  44. add_res = self.add(x1, sqrt_res)
  45. sum_res = self.sum(add_res, (0,))
  46. return neg_res, sum_res
  47. #{sqrt} would be recompute
  48. class Net3(Cell):
  49. def __init__(self):
  50. super(Net3, self).__init__()
  51. self.sqrt = P.Sqrt()
  52. self.add = P.Add()
  53. self.neg = P.Neg()
  54. def construct(self, x0, x1):
  55. sqrt_res = self.sqrt(x0)
  56. neg_res = self.neg(sqrt_res)
  57. add_res = self.add(x1, sqrt_res)
  58. return neg_res, add_res
  59. #{sqrt neg} would be recompute
  60. class Net4(Cell):
  61. def __init__(self):
  62. super(Net4, self).__init__()
  63. self.sqrt = P.Sqrt()
  64. self.neg = P.Neg()
  65. self.sum = P.ReduceSum(keep_dims=False)
  66. def construct(self, x):
  67. sqrt_res = self.sqrt(x)
  68. neg_res = self.neg(sqrt_res)
  69. sum1_res = self.sum(neg_res, (0,))
  70. sum2_res = self.sum(neg_res, (1,))
  71. return sum1_res, sum2_res
  72. #{sqrt} would be recompute
  73. class Net5(Cell):
  74. def __init__(self):
  75. super(Net5, self).__init__()
  76. self.sqrt = P.Sqrt()
  77. self.add = P.Add()
  78. def construct(self, x0, x1, x2):
  79. sqrt_res = self.sqrt(x0)
  80. add1_res = self.add(sqrt_res, x1)
  81. add2_res = self.add(sqrt_res, x2)
  82. return add1_res, add2_res
  83. def test_basic1(net):
  84. def get_output(i0, net, enable_graph_kernel=False):
  85. context.set_context(enable_graph_kernel=enable_graph_kernel)
  86. net_obj = net()
  87. output = net_obj(i0)
  88. return output
  89. i0 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16))
  90. expect = get_output(i0, net, False)
  91. output = get_output(i0, net, True)
  92. expect0_np = expect[0].asnumpy().copy()
  93. output0_np = output[0].asnumpy().copy()
  94. expect1_np = expect[1].asnumpy().copy()
  95. output1_np = output[1].asnumpy().copy()
  96. assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3)
  97. assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3)
  98. def test_basic2(net):
  99. def get_output(i0, i1, net, enable_graph_kernel=False):
  100. context.set_context(enable_graph_kernel=enable_graph_kernel)
  101. net_obj = net()
  102. output = net_obj(i0, i1)
  103. return output
  104. i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float32))
  105. i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float32))
  106. expect = get_output(i0, i1, net, False)
  107. output = get_output(i0, i1, net, True)
  108. expect0_np = expect[0].asnumpy().copy()
  109. output0_np = output[0].asnumpy().copy()
  110. expect1_np = expect[1].asnumpy().copy()
  111. output1_np = output[1].asnumpy().copy()
  112. assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3)
  113. assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3)
  114. def test_basic3(net):
  115. def get_output(i0, i1, i2, net, enable_graph_kernel=False):
  116. context.set_context(enable_graph_kernel=enable_graph_kernel)
  117. net_obj = net()
  118. output = net_obj(i0, i1, i2)
  119. return output
  120. i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float16))
  121. i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16))
  122. i2 = Tensor(np.random.uniform(1, 2, [2048, 1024]).astype(np.float16))
  123. expect = get_output(i0, i1, i2, net, False)
  124. output = get_output(i0, i1, i2, net, True)
  125. expect0_np = expect[0].asnumpy().copy()
  126. output0_np = output[0].asnumpy().copy()
  127. expect1_np = expect[1].asnumpy().copy()
  128. output1_np = output[1].asnumpy().copy()
  129. assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3)
  130. assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3)
  131. @pytest.mark.level0
  132. @pytest.mark.platform_x86_gpu_training
  133. @pytest.mark.env_onecard
  134. def test_gpu_1():
  135. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  136. test_basic1(Net1)
  137. @pytest.mark.level0
  138. @pytest.mark.platform_x86_gpu_training
  139. @pytest.mark.env_onecard
  140. def test_gpu_2():
  141. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  142. test_basic2(Net2)
  143. @pytest.mark.level0
  144. @pytest.mark.platform_x86_gpu_training
  145. @pytest.mark.env_onecard
  146. def test_gpu_3():
  147. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  148. test_basic2(Net3)
  149. @pytest.mark.level0
  150. @pytest.mark.platform_x86_gpu_training
  151. @pytest.mark.env_onecard
  152. def test_gpu_4():
  153. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  154. test_basic1(Net4)
  155. @pytest.mark.level0
  156. @pytest.mark.platform_x86_gpu_training
  157. @pytest.mark.env_onecard
  158. def test_gpu_5():
  159. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  160. test_basic3(Net5)
  161. @pytest.mark.level0
  162. @pytest.mark.platform_arm_ascend_training
  163. @pytest.mark.platform_x86_ascend_training
  164. @pytest.mark.env_onecard
  165. def test_ascend_1():
  166. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  167. test_basic1(Net1)
  168. @pytest.mark.level0
  169. @pytest.mark.platform_arm_ascend_training
  170. @pytest.mark.platform_x86_ascend_training
  171. @pytest.mark.env_onecard
  172. def test_ascend_2():
  173. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  174. test_basic2(Net2)
  175. @pytest.mark.level0
  176. @pytest.mark.platform_arm_ascend_training
  177. @pytest.mark.platform_x86_ascend_training
  178. @pytest.mark.env_onecard
  179. def test_ascend_3():
  180. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  181. test_basic2(Net3)
  182. @pytest.mark.level0
  183. @pytest.mark.platform_arm_ascend_training
  184. @pytest.mark.platform_x86_ascend_training
  185. @pytest.mark.env_onecard
  186. def test_ascend_4():
  187. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  188. test_basic1(Net4)
  189. @pytest.mark.level0
  190. @pytest.mark.platform_arm_ascend_training
  191. @pytest.mark.platform_x86_ascend_training
  192. @pytest.mark.env_onecard
  193. def test_ascend_5():
  194. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  195. test_basic3(Net5)