You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_node.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import ast
  16. from mindspore.nn import Cell
  17. from mindspore.rewrite import ScopedValue
  18. from mindspore.rewrite.node import Node
  19. class FakeCell(Cell):
  20. def construct(self, input1, input2, cool_boy=None):
  21. return input1 + input2 + cool_boy
  22. class FakeCell2(Cell):
  23. def construct(self, a, b, d, e, *args, f=6, **kwargs):
  24. return a + b + d + e + sum(args) + f + sum(kwargs.values())
  25. class FakeCell3(Cell):
  26. def construct(self, a, b, *args, f=6, h=7, **kwargs):
  27. return a + b + f + h + sum(args) + sum(kwargs.values())
  28. def test_create_by_cell():
  29. """
  30. Feature: Python api create_call_buildin_op of Node of Rewrite.
  31. Description: Call create_call_buildin_op to create a CallCell node.
  32. Expectation: Success.
  33. """
  34. node = Node.create_call_buildin_op(FakeCell(), None, ['x'], 'new_conv',
  35. [ScopedValue.create_naming_value('x'), ScopedValue.create_variable_value(1)],
  36. {"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
  37. assert node._args_num == 2
  38. assert node._kwargs_num == 1
  39. assert node._normalized_args_keys == ["input1", "input2", "cool_boy"]
  40. assert node._normalized_args == {
  41. "input1": ScopedValue.create_naming_value('x'),
  42. "input2": ScopedValue.create_variable_value(1),
  43. "cool_boy": ScopedValue.create_naming_value('Naroto')
  44. }
  45. ast_node: ast.Assign = node.get_ast()
  46. assign_value: ast.Call = ast_node.value
  47. args_ast = assign_value.args
  48. keywords_ast = assign_value.keywords
  49. assert len(args_ast) == 2
  50. assert len(keywords_ast) == 1
  51. assert keywords_ast[0].arg == "cool_boy"
  52. assert isinstance(args_ast[0], ast.Name)
  53. assert args_ast[0].id == "x"
  54. assert isinstance(args_ast[1], ast.Constant)
  55. assert args_ast[1].value == 1
  56. keyword_value_3 = keywords_ast[0].value
  57. assert isinstance(keyword_value_3, ast.Name)
  58. assert keyword_value_3.id == "Naroto"
  59. node.set_arg(ScopedValue.create_variable_value(2), 1)
  60. assert isinstance(node.get_normalized_args().get("input2"), ScopedValue)
  61. assert node.get_normalized_args().get("input2").value == 2
  62. ast_node: ast.Assign = node.get_ast()
  63. assign_value: ast.Call = ast_node.value
  64. args_ast = assign_value.args
  65. assert args_ast[1].value == 2
  66. args = node.get_args()
  67. assert args == [ScopedValue.create_naming_value('x'), ScopedValue.create_variable_value(2)]
  68. kwargs = node.get_kwargs()
  69. assert kwargs == {"cool_boy": ScopedValue.create_naming_value('Naroto')}
  70. def test_create_by_cell2():
  71. """
  72. Feature: Python api create_call_buildin_op of Node of Rewrite.
  73. Description: Call create_call_buildin_op to create a CallCell node.
  74. Expectation: Success.
  75. """
  76. node = Node.create_call_buildin_op(FakeCell2(), None, ['x'], 'new_conv',
  77. [ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"),
  78. ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"),
  79. ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x")],
  80. {"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
  81. assert node.get_normalized_args() == {
  82. "a": ScopedValue.create_naming_value('x'),
  83. "b": ScopedValue.create_naming_value('x'),
  84. "d": ScopedValue.create_naming_value('x'),
  85. "e": ScopedValue.create_naming_value('x'),
  86. "args_4": ScopedValue.create_naming_value('x'),
  87. "args_5": ScopedValue.create_naming_value('x'),
  88. "cool_boy": ScopedValue.create_naming_value('Naroto'),
  89. }
  90. def test_create_by_cell3():
  91. """
  92. Feature: Python api create_call_buildin_op of Node of Rewrite.
  93. Description: Call create_call_buildin_op to create a CallCell node.
  94. Expectation: Success.
  95. """
  96. node = Node.create_call_buildin_op(FakeCell3(), None, ['x'], 'new_conv',
  97. [ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"),
  98. ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x")],
  99. {"h": ScopedValue.create_naming_value(1),
  100. "f": ScopedValue.create_naming_value(2),
  101. "cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
  102. assert node.get_normalized_args() == {
  103. "a": ScopedValue.create_naming_value('x'),
  104. "b": ScopedValue.create_naming_value('x'),
  105. "args_2": ScopedValue.create_naming_value('x'),
  106. "args_3": ScopedValue.create_naming_value('x'),
  107. "f": ScopedValue.create_naming_value(2),
  108. "h": ScopedValue.create_naming_value(1),
  109. "cool_boy": ScopedValue.create_naming_value('Naroto'),
  110. }