You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_multigraph_sink.py 2.9 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_multigraph_sink """
  16. import mindspore.context as context
  17. from mindspore.common import dtype as mstype
  18. from mindspore.common import ms_function
  19. from mindspore.common.tensor import Tensor
  20. def setup_module(module):
  21. context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target="Ascend")
  22. context.set_context(device_id=0)
  23. c1 = Tensor([2], mstype.int32)
  24. c2 = Tensor([14], mstype.int32)
  25. c3 = Tensor([1], mstype.int32)
  26. c4 = Tensor([0], mstype.int32)
  27. c5 = Tensor([14], mstype.int32)
  28. @ms_function
  29. def simple_if(x, y, z):
  30. if x < y:
  31. x = x + 1
  32. else:
  33. x = x + 2
  34. x = x + 3
  35. return x
  36. @ms_function
  37. def if_by_if(x, y, z):
  38. if x < y:
  39. x = x + 1
  40. if y > x:
  41. x = x + 2
  42. x = x + 3
  43. return x
  44. @ms_function
  45. def if_in_if(x, y, z):
  46. out = c4
  47. if x < y:
  48. z = c4 + c4
  49. if z < y:
  50. z = z + 2
  51. out = out + z
  52. x = x + 3
  53. out = out + x
  54. return out
  55. @ms_function
  56. def simple_while(x, y, z):
  57. y = y + 4
  58. while x < y:
  59. x = x + 1
  60. x = x + 3
  61. return x
  62. @ms_function
  63. def while_by_while(x, y, z):
  64. while x < y:
  65. x = x + 1
  66. while z < c5:
  67. z = z + 1
  68. x = x + 1
  69. x = x + 1
  70. return x
  71. @ms_function
  72. def while_in_while(x, y, z):
  73. out = c4
  74. while x < y:
  75. z = c4 + c4
  76. while z < y:
  77. z = z + 1
  78. out = out + z
  79. x = x + 1
  80. out = out + x
  81. return out
  82. def test_simple_if():
  83. output = simple_if(c1, c2, c3)
  84. expect = Tensor([6], mstype.int32)
  85. assert output == expect
  86. def test_if_by_if():
  87. output = if_by_if(c1, c2, c3)
  88. expect = Tensor([8], mstype.int32)
  89. assert output == expect
  90. def test_if_in_if():
  91. output = if_in_if(c1, c2, c3)
  92. expect = Tensor([7], mstype.int32)
  93. assert output == expect
  94. def test_simple_while():
  95. output = simple_while(c1, c2, c3)
  96. expect = Tensor([21], mstype.int32)
  97. assert output == expect
  98. def test_while_by_while():
  99. output = while_by_while(c1, c2, c3)
  100. expect = Tensor([28], mstype.int32)
  101. assert output == expect
  102. def test_while_in_while():
  103. output = while_in_while(c1, c2, c3)
  104. expect = Tensor([1274], mstype.int32)
  105. assert output == expect