You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_while_mindir.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. import pytest
  17. import mindspore.nn as nn
  18. from mindspore import context, ms_function
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.train.serialization import export, load
  21. class SingleWhileNet(nn.Cell):
  22. def construct(self, x, y):
  23. x += 1
  24. while x < y:
  25. x += 1
  26. y += 2 * x
  27. return y
  28. @pytest.mark.level0
  29. @pytest.mark.platform_x86_gpu_training
  30. @pytest.mark.platform_x86_ascend_training
  31. @pytest.mark.platform_arm_ascend_training
  32. @pytest.mark.env_onecard
  33. def test_single_while():
  34. context.set_context(mode=context.GRAPH_MODE)
  35. network = SingleWhileNet()
  36. x = Tensor(np.array([1]).astype(np.float32))
  37. y = Tensor(np.array([2]).astype(np.float32))
  38. origin_out = network(x, y)
  39. file_name = "while_net"
  40. export(network, x, y, file_name=file_name, file_format='MINDIR')
  41. mindir_name = file_name + ".mindir"
  42. assert os.path.exists(mindir_name)
  43. graph = load(mindir_name)
  44. loaded_net = nn.GraphCell(graph)
  45. outputs_after_load = loaded_net(x, y)
  46. assert origin_out == outputs_after_load
  47. @pytest.mark.level0
  48. @pytest.mark.platform_x86_gpu_training
  49. @pytest.mark.platform_x86_ascend_training
  50. @pytest.mark.platform_arm_ascend_training
  51. @pytest.mark.env_onecard
  52. def test_ms_function_while():
  53. context.set_context(mode=context.GRAPH_MODE)
  54. network = SingleWhileNet()
  55. x = Tensor(np.array([1]).astype(np.float32))
  56. y = Tensor(np.array([2]).astype(np.float32))
  57. origin_out = network(x, y)
  58. file_name = "while_net"
  59. export(network, x, y, file_name=file_name, file_format='MINDIR')
  60. mindir_name = file_name + ".mindir"
  61. assert os.path.exists(mindir_name)
  62. graph = load(mindir_name)
  63. loaded_net = nn.GraphCell(graph)
  64. context.set_context(mode=context.PYNATIVE_MODE)
  65. @ms_function
  66. def run_graph(x, y):
  67. outputs = loaded_net(x, y)
  68. return outputs
  69. outputs_after_load = run_graph(x, y)
  70. assert origin_out == outputs_after_load
  71. class SingleWhileInlineNet(nn.Cell):
  72. def construct(self, x, y):
  73. x += 1
  74. while x < y:
  75. x += 1
  76. y += x
  77. return y
  78. @pytest.mark.level0
  79. @pytest.mark.platform_x86_gpu_training
  80. @pytest.mark.platform_x86_ascend_training
  81. @pytest.mark.platform_arm_ascend_training
  82. @pytest.mark.env_onecard
  83. def test_single_while_inline_export():
  84. context.set_context(mode=context.GRAPH_MODE)
  85. network = SingleWhileInlineNet()
  86. x = Tensor(np.array([1]).astype(np.float32))
  87. y = Tensor(np.array([2]).astype(np.float32))
  88. file_name = "while_inline_net"
  89. export(network, x, y, file_name=file_name, file_format='MINDIR')
  90. mindir_name = file_name + ".mindir"
  91. assert os.path.exists(mindir_name)
  92. @pytest.mark.level0
  93. @pytest.mark.platform_x86_gpu_training
  94. @pytest.mark.platform_x86_ascend_training
  95. @pytest.mark.platform_arm_ascend_training
  96. @pytest.mark.env_onecard
  97. def test_single_while_inline_load():
  98. context.set_context(mode=context.GRAPH_MODE)
  99. network = SingleWhileInlineNet()
  100. x = Tensor(np.array([1]).astype(np.float32))
  101. y = Tensor(np.array([2]).astype(np.float32))
  102. file_name = "while_inline_net"
  103. export(network, x, y, file_name=file_name, file_format='MINDIR')
  104. mindir_name = file_name + ".mindir"
  105. assert os.path.exists(mindir_name)
  106. load(mindir_name)
  107. @pytest.mark.level0
  108. @pytest.mark.platform_x86_gpu_training
  109. @pytest.mark.platform_x86_ascend_training
  110. @pytest.mark.platform_arm_ascend_training
  111. @pytest.mark.env_onecard
  112. def test_single_while_inline():
  113. context.set_context(mode=context.GRAPH_MODE)
  114. network = SingleWhileInlineNet()
  115. x = Tensor(np.array([1]).astype(np.float32))
  116. y = Tensor(np.array([2]).astype(np.float32))
  117. origin_out = network(x, y)
  118. file_name = "while_inline_net"
  119. export(network, x, y, file_name=file_name, file_format='MINDIR')
  120. mindir_name = file_name + ".mindir"
  121. assert os.path.exists(mindir_name)
  122. graph = load(mindir_name)
  123. loaded_net = nn.GraphCell(graph)
  124. outputs_after_load = loaded_net(x, y)
  125. assert origin_out == outputs_after_load