You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_multigraph_sink.py 4.7 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_multigraph_sink """
  16. import pytest
  17. import mindspore.context as context
  18. from mindspore.common import dtype as mstype
  19. from mindspore.common import ms_function
  20. from mindspore.common.tensor import Tensor
  21. def setup_module():
  22. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  23. c1 = Tensor([2], mstype.int32)
  24. c2 = Tensor([14], mstype.int32)
  25. c3 = Tensor([1], mstype.int32)
  26. c4 = Tensor([0], mstype.int32)
  27. c5 = Tensor([14], mstype.int32)
  28. @ms_function
  29. def simple_if(x, y):
  30. if x < y:
  31. x = x + 1
  32. else:
  33. x = x + 2
  34. x = x + 3
  35. return x
  36. @ms_function
  37. def if_by_if(x, y):
  38. if x < y:
  39. x = x + 1
  40. if y > x:
  41. x = x + 2
  42. x = x + 3
  43. return x
  44. @ms_function
  45. def if_in_if(x, y, z):
  46. out = c4
  47. if x < y:
  48. z = c4 + c4
  49. if z < y:
  50. z = z + 2
  51. out = out + z
  52. x = x + 3
  53. out = out + x
  54. return out
  55. @ms_function
  56. def simple_while(x, y):
  57. y = y + 4
  58. while x < y:
  59. x = x + 1
  60. x = x + 3
  61. return x
  62. @ms_function
  63. def while_by_while(x, y, z):
  64. while x < y:
  65. x = x + 1
  66. while z < c5:
  67. z = z + 1
  68. x = x + 1
  69. x = x + 1
  70. return x
  71. @ms_function
  72. def while_in_while(x, y, z):
  73. out = c4
  74. while x < y:
  75. z = c4 + c4
  76. while z < y:
  77. z = z + 1
  78. out = out + z
  79. x = x + 1
  80. out = out + x
  81. return out
  82. @ms_function
  83. def while_by_while_in_while(x, y, z):
  84. out = c4
  85. while x < c2:
  86. y = c4 + c4
  87. while y < c2:
  88. y = y + 1
  89. out = out + y
  90. z = c4 + c4
  91. while z < c2:
  92. z = z + 1
  93. out = out + z
  94. x = x + 1
  95. out = out + x
  96. return out
  97. @ms_function
  98. def while_in_while_in_while(x, y, z):
  99. out = c4
  100. while x < c2:
  101. y = c4 + c4
  102. while y < c2:
  103. y = y + 1
  104. z = c4 + c4
  105. while z < c2:
  106. z = z + 1
  107. out = out + z
  108. out = out + y
  109. x = x + 1
  110. out = out + x
  111. return out
  112. @pytest.mark.level0
  113. @pytest.mark.platform_x86_ascend_training
  114. @pytest.mark.platform_arm_ascend_training
  115. @pytest.mark.env_onecard
  116. def test_simple_if():
  117. output = simple_if(c1, c2)
  118. expect = Tensor([6], mstype.int32)
  119. assert output == expect
  120. def test_if_by_if():
  121. output = if_by_if(c1, c2)
  122. expect = Tensor([8], mstype.int32)
  123. assert output == expect
  124. @pytest.mark.level0
  125. @pytest.mark.platform_x86_ascend_training
  126. @pytest.mark.platform_arm_ascend_training
  127. @pytest.mark.env_onecard
  128. def test_if_in_if():
  129. output = if_in_if(c1, c2, c3)
  130. expect = Tensor([7], mstype.int32)
  131. assert output == expect
  132. @pytest.mark.level0
  133. @pytest.mark.platform_x86_ascend_training
  134. @pytest.mark.platform_arm_ascend_training
  135. @pytest.mark.env_onecard
  136. def test_simple_while():
  137. output = simple_while(c1, c2)
  138. expect = Tensor([21], mstype.int32)
  139. assert output == expect
  140. @pytest.mark.level0
  141. @pytest.mark.platform_x86_ascend_training
  142. @pytest.mark.platform_arm_ascend_training
  143. @pytest.mark.env_onecard
  144. def test_while_by_while():
  145. output = while_by_while(c1, c2, c3)
  146. expect = Tensor([28], mstype.int32)
  147. assert output == expect
  148. @pytest.mark.level0
  149. @pytest.mark.platform_x86_ascend_training
  150. @pytest.mark.platform_arm_ascend_training
  151. @pytest.mark.env_onecard
  152. def test_while_in_while():
  153. output = while_in_while(c1, c2, c3)
  154. expect = Tensor([1274], mstype.int32)
  155. assert output == expect
  156. @pytest.mark.level0
  157. @pytest.mark.platform_x86_ascend_training
  158. @pytest.mark.platform_arm_ascend_training
  159. @pytest.mark.env_onecard
  160. def test_while_by_while_in_while():
  161. output = while_by_while_in_while(c1, c2, c3)
  162. expect = Tensor([350], mstype.int32)
  163. assert output == expect
  164. @pytest.mark.level0
  165. @pytest.mark.platform_x86_ascend_training
  166. @pytest.mark.platform_arm_ascend_training
  167. @pytest.mark.env_onecard
  168. def test_while_in_while_in_while():
  169. output = while_in_while_in_while(c1, c2, c3)
  170. expect = Tensor([2534], mstype.int32)
  171. assert output == expect