|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- import ast
- import inspect
-
- from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
- from mindspore.ops import Add
- from mindspore.rewrite import ScopedValue, ValueType, NodeType
- from mindspore.rewrite import Node as NodeApi
- from mindspore.rewrite.symbol_tree import SymbolTree
- from mindspore.rewrite.node import Node
-
-
- class Network(Cell):
- def __init__(self):
- super().__init__()
- self.conv = Conv2d(16, 16, 3)
- self.bn = BatchNorm2d(16)
- self.relu1 = ReLU()
- self.relu2 = ReLU()
- self.relu3 = ReLU()
-
- def construct(self, x):
- x = self.conv(x)
- x = self.bn(x)
- x = self.relu1(x)
- x = self.relu2(x)
- x = self.relu3(x)
- return x
-
-
- def create_symbol_tree():
- net = Network()
- source = inspect.getsource(type(net))
- ast_root = ast.parse(source)
- ast_module = ast_root
- assert isinstance(ast_root, ast.Module)
- ast_class = ast_module.body[0]
- assert isinstance(ast_class, ast.ClassDef)
- ast_init_func = ast_class.body[0]
- assert isinstance(ast_init_func, ast.FunctionDef)
- ast_construct_func = ast_class.body[1]
- assert isinstance(ast_construct_func, ast.FunctionDef)
- ast_conv = ast_construct_func.body[0]
- ast_bn = ast_construct_func.body[1]
- ast_relu1 = ast_construct_func.body[2]
- ast_relu2 = ast_construct_func.body[3]
- ast_relu3 = ast_construct_func.body[4]
- ast_return = ast_construct_func.body[5]
- stree = SymbolTree(net, ast_module)
- stree.set_class_ast(ast_class)
- stree.set_init_func_ast(ast_init_func)
- stree.set_ast_root(ast_construct_func)
- stree.append_input_node("x")
- conv_node = Node.create_call_buildin_op(net.conv, ast_conv, [ScopedValue.create_naming_value("x")],
- ScopedValue.create_naming_value("conv", "self"),
- [ScopedValue.create_naming_value("x")],
- {},
- "conv")
- stree.append_origin_field(conv_node)
- bn_node = Node.create_call_buildin_op(net.bn, ast_bn, [ScopedValue.create_naming_value("x")],
- ScopedValue.create_naming_value("bn", "self"),
- [ScopedValue.create_naming_value("x")], {},
- "bn")
- bn_node = stree.append_origin_field(bn_node)
- relu1_node = Node.create_call_buildin_op(net.relu1, ast_relu1, [ScopedValue.create_naming_value("x")],
- ScopedValue.create_naming_value("relu1", "self"),
- [ScopedValue.create_naming_value("x")],
- {}, "relu1")
- relu1_node = stree.append_origin_field(relu1_node)
- relu2_node = Node.create_call_buildin_op(net.relu2, ast_relu2, [ScopedValue.create_naming_value("x")],
- ScopedValue.create_naming_value("relu2", "self"),
- [ScopedValue.create_naming_value("x")],
- {}, "relu2")
- relu2_node = stree.append_origin_field(relu2_node)
- relu3_node = Node.create_call_buildin_op(net.relu3, ast_relu3, [ScopedValue.create_naming_value("x")],
- ScopedValue.create_naming_value("relu3", "self"),
- [ScopedValue.create_naming_value("x")],
- {}, "relu3")
- stree.append_origin_field(relu3_node)
- node_return = Node.create_output_node(ast_return, ["x"])
- stree.append_origin_field(node_return)
- return stree, bn_node, relu1_node, relu2_node
-
-
- def test_insert_node():
- """
- Feature: Python api insert_node of SymbolTree of Rewrite.
- Description: Call insert_node to insert a node into SymbolTree.
- Expectation: Success.
- """
- stree, _, relu1, relu2 = create_symbol_tree()
- construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
- providers = getattr(getattr(stree, "_topo_mgr"), "_target_provider")
- consumers = getattr(getattr(stree, "_topo_mgr"), "_target_consumer")
- providers_len = len(providers)
- consumers_len = len(consumers)
- assert len(stree.nodes()) == 7
- assert len(construct_ast.body) == 6
- assert len(relu1.get_targets()) == 1
- assert len(relu2.get_normalized_args().values()) == 1
- assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
- input1 = 1
- node = Node.create_call_buildin_op(Add(), None, ['x'], 'new_conv',
- [ScopedValue.create_naming_value('x'),
- ScopedValue.create_variable_value(input1)], {},
- 'new_conv')
- position = stree.before(relu2)
- node = stree.insert_node(position, node)
- # check nodes size
- assert len(stree.nodes()) == 8
- # check args
- assert len(relu2.get_normalized_args().values()) == 1
- assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
- assert len(node.get_normalized_args().values()) == 2
- assert list(node.get_normalized_args().values())[0] == ScopedValue.create_naming_value('x')
- assert list(node.get_normalized_args().values())[1].type == ValueType.IntValue
- # check provider
- assert len(providers) == providers_len + 1
- assert len(node.get_targets()) == 1
- assert providers.get(node.get_targets()[0])[0] == node
- assert providers.get(node.get_targets()[0])[1] == 0
- # check consumer
- assert len(consumers) == consumers_len + 1
- assert consumers.get(list(node.get_normalized_args().values())[1]) is not None
- # check inputs
- assert len(relu2.get_inputs()) == 1
- assert relu2.get_inputs()[0] == relu1
- assert len(node.get_inputs()) == 1
- assert node.get_inputs()[0].get_node_type() == NodeType.Input
- # check ast
- node_ast = node.get_ast()
- assert isinstance(node_ast, ast.Assign)
- args = node_ast.value.args
- assert isinstance(args, list)
- assert len(args) == 2
- assert isinstance(args[0], ast.Name)
- assert isinstance(args[1], ast.Constant)
- assert len(construct_ast.body) == 7
-
-
- def test_set_node_arg():
- """
- Feature: Python api set_node_arg of SymbolTree of Rewrite.
- Description: Call set_node_arg to change topological-order of a node.
- Expectation: Success.
- """
- stree, bn, relu1, relu2 = create_symbol_tree()
- assert len(stree.nodes()) == 7
- assert len(bn.get_targets()) == 1
- bn_output = bn.get_targets()[0]
- # check bn topological order
- assert len(stree.get_node_users(bn)) == 1
- assert stree.get_node_users(bn)[0][0] == relu1
- # check relu1 topological order
- assert len(stree.get_node_inputs(relu1)) == 1
- assert stree.get_node_inputs(relu1)[0] == bn
- assert len(stree.get_node_users(relu1)) == 1
- assert stree.get_node_users(relu1)[0][0] == relu2
- # check relu2 topological order
- assert len(stree.get_node_inputs(relu2)) == 1
- assert stree.get_node_inputs(relu2)[0] == relu1
- # check relu1 and relu2 edge
- assert len(relu1.get_targets()) == 1
- assert len(relu2.get_normalized_args().values()) == 1
- assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
-
- stree.set_node_arg(relu2, 0, bn_output)
- # check bn topological order
- assert len(stree.get_node_users(bn)) == 2
- assert stree.get_node_users(bn)[0][0] == relu1
- assert stree.get_node_users(bn)[1][0] == relu2
- # check relu1 topological order
- assert len(stree.get_node_inputs(relu1)) == 1
- assert stree.get_node_inputs(relu1)[0] == bn
- assert len(stree.get_node_users(relu1)) == 0
- # check relu2 topological order
- assert len(stree.get_node_inputs(relu2)) == 1
- assert stree.get_node_inputs(relu2)[0] == bn
- # check bn and relu2 edge
- assert len(relu1.get_targets()) == 1
- assert len(relu2.get_normalized_args().values()) == 1
- assert bn_output == list(relu2.get_normalized_args().values())[0]
- # check ast
- node_ast = relu2.get_ast()
- assert isinstance(node_ast, ast.Assign)
- args = node_ast.value.args
- assert isinstance(args, list)
- assert len(args) == 1
- assert isinstance(args[0], ast.Name)
- assert args[0].id == bn_output.value
-
-
- def test_set_node_arg_by_node():
- """
- Feature: Python api set_node_arg_by_node of SymbolTree of Rewrite.
- Description: Call set_node_arg_by_node to change topological-order of a node.
- Expectation: Success.
- """
- stree, bn, relu1, relu2 = create_symbol_tree()
- assert len(stree.nodes()) == 7
- assert len(bn.get_targets()) == 1
- bn_output = bn.get_targets()[0]
- # check bn topological order
- assert len(stree.get_node_users(bn)) == 1
- assert stree.get_node_users(bn)[0][0] == relu1
- # check relu1 topological order
- assert len(stree.get_node_inputs(relu1)) == 1
- assert stree.get_node_inputs(relu1)[0] == bn
- assert len(stree.get_node_users(relu1)) == 1
- assert stree.get_node_users(relu1)[0][0] == relu2
- # check relu2 topological order
- assert len(stree.get_node_inputs(relu2)) == 1
- assert stree.get_node_inputs(relu2)[0] == relu1
- # check relu1 and relu2 edge
- assert len(relu1.get_targets()) == 1
- assert len(relu2.get_normalized_args().values()) == 1
- assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
-
- stree.set_node_arg_by_node(relu2, 0, bn)
- # check bn topological order
- assert len(stree.get_node_users(bn)) == 2
- assert stree.get_node_users(bn)[0][0] == relu1
- assert stree.get_node_users(bn)[1][0] == relu2
- # check relu1 topological order
- assert len(stree.get_node_inputs(relu1)) == 1
- assert stree.get_node_inputs(relu1)[0] == bn
- assert len(stree.get_node_users(relu1)) == 0
- # check relu2 topological order
- assert len(stree.get_node_inputs(relu2)) == 1
- assert stree.get_node_inputs(relu2)[0] == bn
- # check bn and relu2 edge
- assert len(relu1.get_targets()) == 1
- assert len(relu2.get_normalized_args().values()) == 1
- assert bn_output == list(relu2.get_normalized_args().values())[0]
- # check ast
- node_ast = relu2.get_ast()
- assert isinstance(node_ast, ast.Assign)
- args = node_ast.value.args
- assert isinstance(args, list)
- assert len(args) == 1
- assert isinstance(args[0], ast.Name)
- assert args[0].id == bn_output.value
-
-
- def test_erase_succeed():
- """
- Feature: Python api erase_node of SymbolTree of Rewrite.
- Description: Call erase_node to erase a node from SymbolTree.
- Expectation: Success.
- """
- stree, bn, relu1, relu2 = create_symbol_tree()
- construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
- providers = getattr(getattr(stree, "_topo_mgr"), "_target_provider")
- providers_len = len(providers)
- assert len(stree.nodes()) == 7
- assert len(construct_ast.body) == 6
-
- stree.set_node_arg_by_node(relu2, 0, bn)
- stree.erase_node(relu1)
-
- assert len(stree.nodes()) == 6
- assert len(providers) == providers_len - 1
- assert len(construct_ast.body) == 5
-
-
- def test_erase_failed():
- """
- Feature: Python api erase_node of SymbolTree of Rewrite.
- Description: Call erase_node to erase a node from SymbolTree which is not isolated.
- Expectation: Failure.
- """
- stree, _, relu1, _ = create_symbol_tree()
- catched_error = False
- try:
- stree.erase_node(relu1)
- except RuntimeError:
- catched_error = True
- assert catched_error
-
-
- def test_replace_one_to_one():
- """
- Feature: Python api replace of SymbolTree of Rewrite.
- Description: Call replace to replace an origin node to a new node.
- Expectation: Success.
- """
- stree, bn, relu1, relu2 = create_symbol_tree()
- construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
- assert len(construct_ast.body) == 6
- assert len(stree.nodes()) == 7
-
- new_conv = Conv2d(16, 16, 5)
- new_conv_node = NodeApi.create_call_cell(new_conv, [ScopedValue.create_naming_value("new_conv")],
- bn.get_targets()).get_handler()
- new_conv_node = stree.replace(relu1, [new_conv_node])
- assert len(stree.nodes()) == 7
- # check ast
- assert len(construct_ast.body) == 6
- node_ast: ast.Assign = construct_ast.body[2]
- func_ast: ast.Attribute = node_ast.value.func
- assert func_ast.attr == new_conv_node.get_name()
- # check bn topological order
- assert len(stree.get_node_users(bn)) == 1
- assert stree.get_node_users(bn)[0][0] == new_conv_node
- # check new_conv_node topological order
- assert len(stree.get_node_inputs(new_conv_node)) == 1
- assert stree.get_node_inputs(new_conv_node)[0] == bn
- assert len(stree.get_node_users(new_conv_node)) == 1
- assert stree.get_node_users(new_conv_node)[0][0] == relu2
- # check relu2 topological order
- assert len(stree.get_node_inputs(relu2)) == 1
- assert stree.get_node_inputs(relu2)[0] == new_conv_node
- # check arg edge
- assert len(bn.get_targets()) == 1
- assert len(new_conv_node.get_normalized_args().values()) == 1
- assert bn.get_targets()[0] == list(new_conv_node.get_normalized_args().values())[0]
- assert len(new_conv_node.get_targets()) == 1
- assert len(relu2.get_normalized_args().values()) == 1
- assert new_conv_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
-
-
- def test_replace_one_to_multi():
- """
- Feature: Python api replace of SymbolTree of Rewrite.
- Description: Call replace to replace an origin node to a new node-tree.
- Expectation: Success.
- """
- stree, bn, relu1, relu2 = create_symbol_tree()
- construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
- assert len(construct_ast.body) == 6
- assert len(stree.nodes()) == 7
-
- new_conv_node = NodeApi.create_call_cell(Conv2d(16, 16, 5), [ScopedValue.create_naming_value("new_conv")],
- bn.get_targets()).get_handler()
- new_relu_node = NodeApi.create_call_cell(ReLU(), [ScopedValue.create_naming_value("new_relu")],
- new_conv_node.get_targets()).get_handler()
- new_relu_node = stree.replace(relu1, [new_relu_node, new_conv_node])
- new_conv_node = new_relu_node.get_inputs()[0]
-
- assert len(stree.nodes()) == 8
- # check ast
- assert len(construct_ast.body) == 7
- new_conv_ast: ast.Assign = construct_ast.body[2]
- new_conv_func_ast: ast.Attribute = new_conv_ast.value.func
- assert new_conv_func_ast.attr == new_conv_node.get_name()
- new_relu_ast: ast.Assign = construct_ast.body[3]
- new_relu_func_ast: ast.Attribute = new_relu_ast.value.func
- assert new_relu_func_ast.attr == new_relu_node.get_name()
- # check bn topological order
- assert len(stree.get_node_users(bn)) == 1
- assert stree.get_node_users(bn)[0][0] == new_conv_node
- # check new_conv_node topological order
- assert len(stree.get_node_inputs(new_conv_node)) == 1
- assert stree.get_node_inputs(new_conv_node)[0] == bn
- assert len(stree.get_node_users(new_conv_node)) == 1
- assert stree.get_node_users(new_conv_node)[0][0] == new_relu_node
- # check new_relu_node topological order
- assert len(stree.get_node_inputs(new_relu_node)) == 1
- assert stree.get_node_inputs(new_relu_node)[0] == new_conv_node
- assert len(stree.get_node_users(new_relu_node)) == 1
- assert stree.get_node_users(new_relu_node)[0][0] == relu2
- # check relu2 topological order
- assert len(stree.get_node_inputs(relu2)) == 1
- assert stree.get_node_inputs(relu2)[0] == new_relu_node
- # check arg edge
- assert len(bn.get_targets()) == 1
- assert len(new_conv_node.get_normalized_args().values()) == 1
- assert bn.get_targets()[0] == list(new_conv_node.get_normalized_args().values())[0]
-
- assert len(new_conv_node.get_targets()) == 1
- assert len(new_relu_node.get_normalized_args().values()) == 1
- assert new_conv_node.get_targets()[0] == list(new_relu_node.get_normalized_args().values())[0]
-
- assert len(new_relu_node.get_targets()) == 1
- assert len(relu2.get_normalized_args().values()) == 1
- assert new_relu_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
|