You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_multigraph_sink.py 3.3 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_multigraph_sink """
  16. import mindspore.context as context
  17. from mindspore.common import dtype as mstype
  18. from mindspore.common import ms_function
  19. from mindspore.common.tensor import Tensor
  20. def setup_module(module):
  21. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  22. c1 = Tensor([2], mstype.int32)
  23. c2 = Tensor([14], mstype.int32)
  24. c3 = Tensor([1], mstype.int32)
  25. c4 = Tensor([0], mstype.int32)
  26. c5 = Tensor([14], mstype.int32)
  27. @ms_function
  28. def simple_if(x, y, z):
  29. if x < y:
  30. x = x + 1
  31. else:
  32. x = x + 2
  33. x = x + 3
  34. return x
  35. @ms_function
  36. def if_by_if(x, y, z):
  37. if x < y:
  38. x = x + 1
  39. if y > x:
  40. x = x + 2
  41. x = x + 3
  42. return x
  43. @ms_function
  44. def if_in_if(x, y, z):
  45. out = c4
  46. if x < y:
  47. z = c4 + c4
  48. if z < y:
  49. z = z + 2
  50. out = out + z
  51. x = x + 3
  52. out = out + x
  53. return out
  54. @ms_function
  55. def simple_while(x, y, z):
  56. y = y + 4
  57. while x < y:
  58. x = x + 1
  59. x = x + 3
  60. return x
  61. @ms_function
  62. def while_by_while(x, y, z):
  63. while x < y:
  64. x = x + 1
  65. while z < c5:
  66. z = z + 1
  67. x = x + 1
  68. x = x + 1
  69. return x
  70. @ms_function
  71. def while_in_while(x, y, z):
  72. out = c4
  73. while x < y:
  74. z = c4 + c4
  75. while z < y:
  76. z = z + 1
  77. out = out + z
  78. x = x + 1
  79. out = out + x
  80. return out
  81. def test_simple_if():
  82. output = simple_if(c1, c2, c3)
  83. expect = Tensor([6], mstype.int32)
  84. assert output == expect
  85. def test_if_by_if():
  86. output = if_by_if(c1, c2, c3)
  87. expect = Tensor([8], mstype.int32)
  88. assert output == expect
  89. def test_if_in_if():
  90. output = if_in_if(c1, c2, c3)
  91. expect = Tensor([7], mstype.int32)
  92. assert output == expect
  93. def test_simple_while():
  94. output = simple_while(c1, c2, c3)
  95. expect = Tensor([21], mstype.int32)
  96. assert output == expect
  97. def test_while_by_while():
  98. output = while_by_while(c1, c2, c3)
  99. expect = Tensor([28], mstype.int32)
  100. assert output == expect
  101. def test_while_in_while():
  102. output = while_in_while(c1, c2, c3)
  103. expect = Tensor([1274], mstype.int32)
  104. assert output == expect
  105. @ms_function
  106. def while_by_while_in_while(x, y, z):
  107. out = c4
  108. while x < c2:
  109. y = c4 + c4
  110. while y < c2:
  111. y = y + 1
  112. out = out + y
  113. z = c4 + c4
  114. while z < c2:
  115. z = z + 1
  116. out = out + z
  117. x = x + 1
  118. out = out + x
  119. return out
  120. def test_while_by_while_in_while():
  121. output = while_by_while_in_while(c1, c2, c3)
  122. expect = Tensor([350], mstype.int32)
  123. assert output == expect