You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cont_break.py 4.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_cont_break """
  16. import numpy as np
  17. import pytest
  18. from mindspore import Tensor, Model, context
  19. from mindspore.nn import Cell
  20. def run_test(netclass, count, dev):
  21. context.set_context(mode=context.GRAPH_MODE, device_target=dev)
  22. net = netclass()
  23. model = Model(net)
  24. for _ in range(count):
  25. input_np = np.random.randn(2, 3).astype(np.float32)
  26. input_ms = Tensor(input_np)
  27. output_np = net.construct(input_np) # run python
  28. output_ms = model.predict(input_ms) # run graph
  29. np.testing.assert_array_almost_equal(output_np, output_ms.asnumpy(), decimal=3)
  30. class for_loop_with_break(Cell):
  31. def __init__(self):
  32. super().__init__()
  33. def construct(self, x):
  34. for i in range(8):
  35. if i > 5:
  36. x *= 3
  37. break
  38. x = x * 2
  39. return x
  40. class for_loop_with_continue(Cell):
  41. def __init__(self):
  42. super().__init__()
  43. def construct(self, x):
  44. for i in range(8):
  45. if i > 5:
  46. x *= 3
  47. continue
  48. x = x * 2
  49. return x
  50. class for_loop_with_cont_break(Cell):
  51. def __init__(self):
  52. super().__init__()
  53. def construct(self, x):
  54. for i in range(8):
  55. if i < 3:
  56. i *= 2
  57. continue
  58. if i > 5:
  59. x *= 3
  60. break
  61. x = x * 2
  62. return x
  63. class for_nested_loop_with_break(Cell):
  64. def __init__(self):
  65. super().__init__()
  66. def construct(self, x):
  67. for _ in range(3):
  68. for j in range(5):
  69. if j > 3:
  70. x *= 2
  71. break
  72. x = x * 1.5
  73. return x
  74. class while_with_break(Cell):
  75. def __init__(self):
  76. super().__init__()
  77. def construct(self, x):
  78. i = 0
  79. while i < 5:
  80. if i > 3:
  81. x *= 2
  82. break
  83. x = x * 1.5
  84. i += 1
  85. return x
  86. class while_with_continue(Cell):
  87. def __init__(self):
  88. super().__init__()
  89. def construct(self, x):
  90. i = 0
  91. while i < 5:
  92. if i > 3:
  93. x *= 2
  94. i += 1
  95. continue
  96. x = x * 1.5
  97. i += 1
  98. return x
  99. class while_for_nested(Cell):
  100. def __init__(self):
  101. super().__init__()
  102. def construct(self, x):
  103. i = 0
  104. while i < 5:
  105. if i > 3:
  106. for j in range(3):
  107. if j > 1:
  108. break
  109. x *= 2
  110. i += 1
  111. continue
  112. x = x * 1.5
  113. i += 1
  114. return x
  115. class pass_branch(Cell):
  116. def __init__(self):
  117. super().__init__()
  118. def construct(self, x):
  119. i = 0
  120. while i < 5:
  121. if i > 3:
  122. pass
  123. else:
  124. x = x * 1.5
  125. i += 1
  126. return x
  127. @pytest.mark.level0
  128. @pytest.mark.platform_x86_cpu
  129. @pytest.mark.env_onecard
  130. def test_cont_break():
  131. count = 20
  132. dev = 'CPU'
  133. run_test(for_loop_with_break, count, dev)
  134. run_test(for_loop_with_continue, count, dev)
  135. run_test(for_loop_with_cont_break, count, dev)
  136. run_test(for_nested_loop_with_break, count, dev)
  137. run_test(while_with_break, count, dev)
  138. run_test(while_with_continue, count, dev)
  139. run_test(while_for_nested, count, dev)
  140. run_test(pass_branch, count, dev)