You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_while_param.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_cont_break """
  16. import numpy as np
  17. import mindspore as ms
  18. from mindspore import Tensor, context, nn, ms_function
  19. from mindspore.nn import Cell
  20. from mindspore.ops import operations as P
  21. class WhileSubGraphParam(Cell):
  22. def __init__(self):
  23. super().__init__()
  24. self.update = ms.Parameter(Tensor(1, ms.float32), "update")
  25. def construct(self, x, y, z):
  26. out1 = z
  27. while x < y:
  28. self.update = self.update + 1
  29. out1 = out1 + 1
  30. x = x + 1
  31. return out1, self.update
  32. def test_while_loop_phi():
  33. context.set_context(mode=context.GRAPH_MODE)
  34. x = Tensor(0, ms.float32)
  35. y = Tensor(10, ms.float32)
  36. z = Tensor(100, ms.float32)
  37. net = WhileSubGraphParam()
  38. net(x, y, z)
  39. class WhileSubGraphParam2(Cell):
  40. def __init__(self):
  41. super().__init__()
  42. self.update = ms.Parameter(Tensor(1, ms.float32), "update")
  43. def construct(self, x, y, z):
  44. out1 = z
  45. i = self.update
  46. while x < y:
  47. i = i + 1
  48. out1 = out1 + 1
  49. x = x + 1
  50. return out1, self.update
  51. def test_while_loop_phi_2():
  52. context.set_context(mode=context.GRAPH_MODE)
  53. x = Tensor(0, ms.float32)
  54. y = Tensor(10, ms.float32)
  55. z = Tensor(100, ms.float32)
  56. net = WhileSubGraphParam2()
  57. net(x, y, z)
  58. class WhileSubGraphParam3(Cell):
  59. def __init__(self, initial_input_x):
  60. super().__init__()
  61. self.initial_input_x = initial_input_x
  62. self.X = ms.Parameter(initial_input_x, name="parameter_x")
  63. self.Y = ms.Parameter(self.initial_input_x, name="parameter_y")
  64. def construct(self):
  65. a = 0
  66. while a < 3:
  67. self.X = self.X + self.Y
  68. a += 1
  69. return self.X
  70. def test_while_loop_phi_3():
  71. context.set_context(mode=context.GRAPH_MODE)
  72. x = Tensor(0, ms.float32)
  73. net = WhileSubGraphParam3(x)
  74. net()
  75. class ControlMixedWhileIf(nn.Cell):
  76. def __init__(self):
  77. super().__init__()
  78. self.assign = P.Assign()
  79. self.var = ms.Parameter(ms.Tensor([1], ms.float32), name="var")
  80. @ms_function
  81. def construct(self, x, y, z, c2, c4):
  82. out = self.assign(self.var, c4)
  83. while x < c2:
  84. y = self.assign(self.var, c4)
  85. while y < c2 and x < c2:
  86. if 2 * y < c2:
  87. y = y + 2
  88. else:
  89. y = y + 1
  90. out = out + y
  91. z = self.assign(self.var, c4)
  92. while z < c2:
  93. z = z + 1
  94. out = out + z
  95. x = x + 1
  96. out = out + x
  97. while x < 2 * c2:
  98. y = self.assign(self.var, c4)
  99. x = x + 1
  100. while y < c2:
  101. z = self.assign(self.var, c4)
  102. while z < c2:
  103. z = z + 1
  104. if x < c2:
  105. y = y - 1
  106. else:
  107. y = y + 1
  108. out = out + z
  109. out = out + y
  110. out = out + x
  111. return out
  112. def test_mixed_while_if():
  113. context.set_context(mode=context.PYNATIVE_MODE)
  114. x = np.array(2).astype(np.int32)
  115. y = np.array(14).astype(np.int32)
  116. z = np.array(1).astype(np.int32)
  117. c2 = Tensor([14], ms.int32)
  118. c4 = Tensor([0], ms.int32)
  119. net = ControlMixedWhileIf()
  120. output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4)
  121. expect = np.array(3318).astype(np.int32)
  122. assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
  123. context.set_context(mode=context.GRAPH_MODE)