You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cont_break.py 4.2 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_cont_break """
  16. import pytest
  17. import numpy as np
  18. from mindspore.nn import Cell
  19. from mindspore import Tensor, Model, context
  20. def run_test(netclass, count, dev):
  21. context.set_context(mode=context.GRAPH_MODE, device_target=dev)
  22. net = netclass()
  23. model = Model(net)
  24. for _ in range(count):
  25. input_np = np.random.randn(2, 3).astype(np.float32)
  26. input_ms = Tensor(input_np)
  27. output_np = net.construct(input_np) # run python
  28. output_ms = model.predict(input_ms) # run graph
  29. np.testing.assert_array_almost_equal(output_np, output_ms.asnumpy(), decimal=3)
  30. class for_loop_with_break(Cell):
  31. def __init__(self):
  32. super().__init__()
  33. def construct(self, x):
  34. for i in range(8):
  35. if i > 5:
  36. x *= 3
  37. break
  38. x = x * 2
  39. pass
  40. return x
  41. class for_loop_with_continue(Cell):
  42. def __init__(self):
  43. super().__init__()
  44. def construct(self, x):
  45. for i in range(8):
  46. if i > 5:
  47. x *= 3
  48. continue
  49. x = x * 2
  50. return x
  51. class for_loop_with_cont_break(Cell):
  52. def __init__(self):
  53. super().__init__()
  54. def construct(self, x):
  55. for i in range(8):
  56. if i < 3:
  57. i *= 2
  58. continue
  59. if i > 5:
  60. x *= 3
  61. break
  62. x *= 2
  63. x = x * 2
  64. pass
  65. return x
  66. class for_nested_loop_with_break(Cell):
  67. def __init__(self):
  68. super().__init__()
  69. def construct(self, x):
  70. for i in range(3):
  71. for j in range(5):
  72. if j > 3:
  73. x *= 2
  74. break
  75. x = x * 1.5
  76. return x
  77. class while_with_break(Cell):
  78. def __init__(self):
  79. super().__init__()
  80. def construct(self, x):
  81. i = 0
  82. while i < 5:
  83. if i > 3:
  84. x *= 2
  85. break
  86. x = x * 1.5
  87. i += 1
  88. return x
  89. class while_with_continue(Cell):
  90. def __init__(self):
  91. super().__init__()
  92. def construct(self, x):
  93. i = 0
  94. while i < 5:
  95. if i > 3:
  96. x *= 2
  97. i += 1
  98. continue
  99. x = x * 1.5
  100. i += 1
  101. return x
  102. class while_for_nested(Cell):
  103. def __init__(self):
  104. super().__init__()
  105. def construct(self, x):
  106. i = 0
  107. while i < 5:
  108. if i > 3:
  109. for j in range(3):
  110. if j > 1:
  111. break
  112. x *= 2
  113. i += 1
  114. continue
  115. x = x * 1.5
  116. i += 1
  117. return x
  118. class pass_branch(Cell):
  119. def __init__(self):
  120. super().__init__()
  121. def construct(self, x):
  122. i = 0
  123. while i < 5:
  124. if i > 3:
  125. pass
  126. else:
  127. x = x * 1.5
  128. i += 1
  129. return x
  130. @pytest.mark.level0
  131. @pytest.mark.platform_x86_cpu
  132. @pytest.mark.env_onecard
  133. def test_cont_break():
  134. count = 20
  135. dev = 'CPU'
  136. run_test(for_loop_with_break, count, dev)
  137. run_test(for_loop_with_continue, count, dev)
  138. run_test(for_loop_with_cont_break, count, dev)
  139. run_test(for_nested_loop_with_break, count, dev)
  140. run_test(while_with_break, count, dev)
  141. run_test(while_with_continue, count, dev)
  142. run_test(while_for_nested, count, dev)
  143. run_test(pass_branch, count, dev)