You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parameter_ms_function.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore as ms
  17. from mindspore import context, Tensor, ms_function
  18. from mindspore.common.parameter import Parameter
  19. from mindspore.common import ParameterTuple
  20. context.set_context(mode=context.GRAPH_MODE)
  21. @pytest.mark.level1
  22. @pytest.mark.platform_arm_ascend_training
  23. @pytest.mark.platform_x86_ascend_training
  24. @pytest.mark.env_onecard
  25. def test_parameter_ms_function_1():
  26. """
  27. Feature: Check the names of parameters.
  28. Description: Check the name of parameter in ms_function.
  29. Expectation: No exception.
  30. """
  31. param_a = Parameter(Tensor([1], ms.float32), name="name_a")
  32. param_b = Parameter(Tensor([2], ms.float32), name="name_a")
  33. @ms_function
  34. def test_parameter_ms_function():
  35. return param_a + param_b
  36. with pytest.raises(RuntimeError, match="its name 'name_a' already exists."):
  37. res = test_parameter_ms_function()
  38. assert res == 3
  39. @pytest.mark.level1
  40. @pytest.mark.platform_arm_ascend_training
  41. @pytest.mark.platform_x86_ascend_training
  42. @pytest.mark.env_onecard
  43. def test_parameter_ms_function_2():
  44. """
  45. Feature: Check the names of parameters.
  46. Description: Check the name of parameter in ms_function.
  47. Expectation: No exception.
  48. """
  49. param_a = Parameter(Tensor([1], ms.float32), name="name_a")
  50. param_b = param_a
  51. @ms_function
  52. def test_parameter_ms_function():
  53. return param_a + param_b
  54. res = test_parameter_ms_function()
  55. assert res == 2
  56. @pytest.mark.level1
  57. @pytest.mark.platform_arm_ascend_training
  58. @pytest.mark.platform_x86_ascend_training
  59. @pytest.mark.env_onecard
  60. def test_parameter_ms_function_3():
  61. """
  62. Feature: Check the names of parameters.
  63. Description: Check the name of parameter in ms_function.
  64. Expectation: No exception.
  65. """
  66. param_a = Parameter(Tensor([1], ms.float32))
  67. param_b = Parameter(Tensor([2], ms.float32))
  68. @ms_function
  69. def test_parameter_ms_function():
  70. return param_a + param_b
  71. with pytest.raises(RuntimeError, match="its name 'Parameter' already exists."):
  72. res = test_parameter_ms_function()
  73. assert res == 3
  74. @pytest.mark.level1
  75. @pytest.mark.platform_arm_ascend_training
  76. @pytest.mark.platform_x86_ascend_training
  77. @pytest.mark.env_onecard
  78. def test_parameter_ms_function_4():
  79. """
  80. Feature: Check the names of parameters.
  81. Description: Check the name of parameter in ms_function.
  82. Expectation: No exception.
  83. """
  84. with pytest.raises(ValueError, match="its name 'name_a' already exists."):
  85. param_a = ParameterTuple((Parameter(Tensor([1], ms.float32), name="name_a"),
  86. Parameter(Tensor([2], ms.float32), name="name_a")))
  87. @ms_function
  88. def test_parameter_ms_function():
  89. return param_a[0] + param_a[1]
  90. res = test_parameter_ms_function()
  91. assert res == 3
  92. @pytest.mark.level1
  93. @pytest.mark.platform_arm_ascend_training
  94. @pytest.mark.platform_x86_ascend_training
  95. @pytest.mark.env_onecard
  96. def test_parameter_ms_function_5():
  97. """
  98. Feature: Check the names of parameters.
  99. Description: Check the name of parameter in ms_function.
  100. Expectation: No exception.
  101. """
  102. with pytest.raises(ValueError, match="its name 'Parameter' already exists."):
  103. param_a = ParameterTuple((Parameter(Tensor([1], ms.float32)), Parameter(Tensor([2], ms.float32))))
  104. @ms_function
  105. def test_parameter_ms_function():
  106. return param_a[0] + param_a[1]
  107. res = test_parameter_ms_function()
  108. assert res == 3