You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_fallback.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test graph fallback """
  16. import pytest
  17. import numpy as np
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, ms_function, context
  20. import mindspore.common.dtype as mstype
  21. context.set_context(mode=context.GRAPH_MODE)
  22. class ControlNet(nn.Cell):
  23. def inner_function_1(self, a, b):
  24. return a + b
  25. def inner_function_2(self, a, b):
  26. return a - b
  27. def construct(self, x):
  28. a = Tensor(np.array(4), mstype.int32)
  29. b = Tensor(np.array(5), mstype.int32)
  30. if a + b > x:
  31. return self.inner_function_1(a, b)
  32. return self.inner_function_2(a, b)
  33. @pytest.mark.level0
  34. @pytest.mark.platform_x86_gpu_training
  35. @pytest.mark.platform_arm_ascend_training
  36. @pytest.mark.platform_x86_ascend_training
  37. @pytest.mark.env_onecard
  38. def test_fallback_control_sink_tensor():
  39. """
  40. Feature: Fallback feature: support define Tensor in Class construct.
  41. Description: Fallback feature: support define Tensor in Class construct.
  42. Expectation: Fallback feature: support define Tensor in Class construct.
  43. """
  44. x = Tensor(np.array(1), mstype.int32)
  45. net = ControlNet()
  46. output = net(x)
  47. output_expect = Tensor(9, mstype.int32)
  48. assert output == output_expect
  49. @pytest.mark.level0
  50. @pytest.mark.platform_x86_gpu_training
  51. @pytest.mark.platform_arm_ascend_training
  52. @pytest.mark.platform_x86_ascend_training
  53. @pytest.mark.env_onecard
  54. def test_np_tensor_list():
  55. """
  56. Feature: Fallback feature
  57. Description: support Basic method of Tensor list.
  58. Expectation: No exception.
  59. """
  60. @ms_function
  61. def np_tensor_list():
  62. a = Tensor(np.array(4), mstype.int32)
  63. b = Tensor(np.array(5), mstype.int32)
  64. c = Tensor(np.array(6), mstype.int32)
  65. tensor_list = [a, b]
  66. for tensor in tensor_list:
  67. print(tensor)
  68. tensor_list.append(tensor_list[-1] + c)
  69. return tensor_list
  70. tensor_list = np_tensor_list()
  71. print("tensor_list:", tensor_list)
  72. assert len(tensor_list) == 3
  73. @pytest.mark.level0
  74. @pytest.mark.platform_x86_gpu_training
  75. @pytest.mark.platform_arm_ascend_training
  76. @pytest.mark.platform_x86_ascend_training
  77. @pytest.mark.env_onecard
  78. def test_list_count():
  79. """
  80. Feature: Fallback feature
  81. Description: support attr/method of builtin type.
  82. Expectation: No exception.
  83. """
  84. @ms_function
  85. def list_count():
  86. x = list([1, 2, 3])
  87. res = x.count(1)
  88. return res
  89. assert list_count() == 1
  90. @pytest.mark.level0
  91. @pytest.mark.platform_x86_gpu_training
  92. @pytest.mark.platform_arm_ascend_training
  93. @pytest.mark.platform_x86_ascend_training
  94. @pytest.mark.env_onecard
  95. def test_list_append():
  96. """
  97. Feature: Fallback feature
  98. Description: support attr/method of builtin type.
  99. Expectation: No exception.
  100. """
  101. @ms_function
  102. def list_append():
  103. x = list([1, 2, 3])
  104. x.append(4)
  105. return Tensor(x)
  106. assert np.all(list_append().asnumpy() == np.array([1, 2, 3, 4]))
  107. @pytest.mark.skip(reason='Not support graph fallback feature yet')
  108. def test_list_insert():
  109. """
  110. Feature: Fallback feature
  111. Description: support attr/method of builtin type.
  112. Expectation: No exception.
  113. """
  114. @ms_function
  115. def list_insert():
  116. x = list([1, 3, 4])
  117. x.insert(1, 2)
  118. return Tensor(x)
  119. assert np.all(list_insert().asnumpy() == np.array([1, 2, 3, 4]))