|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- import ast
- from collections import OrderedDict
-
- from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
- from mindspore.ops import Add, AddN
- from mindspore.rewrite import ScopedValue, Node, SymbolTree
- from mindspore.rewrite import PatternEngine, PatternNode, Replacement, VarNode
-
-
- def test_tree_pattern_match():
- """
- Feature: Python api PatternEngine.
- Description: Construct a tree PatternEngine and apply it on a SymbolTree, check SymbolTree after PatternEngine
- applied.
- Expectation: Success.
- """
- assert True
-
-
- def test_leak_pattern_match():
- """
- Feature: Python api PatternEngine.
- Description: Construct a leaked tree PatternEngine and apply it on a SymbolTree, check SymbolTree after
- PatternEngine applied.
- Expectation: Failure.
- """
- assert True
-
-
- class ChainNetwork(Cell):
- def __init__(self):
- super().__init__()
- self.conv = Conv2d(16, 16, 3)
- self.bn = BatchNorm2d(16)
- self.relu1 = ReLU()
- self.relu2 = ReLU()
- self.relu3 = ReLU()
-
- def construct(self, x):
- x = self.conv(x)
- x = self.bn(x)
- x = self.relu1(x)
- x = self.relu2(x)
- x = self.relu3(x)
- return x
-
-
- def test_one_to_one_pattern():
- """
- Feature: Python api PatternEngine.
- Description: Construct a one-to-one PatternEngine and apply it on a SymbolTree, check SymbolTree after PatternEngine
- applied.
- Expectation: Success.
- """
-
- class BnReplacement(Replacement):
- def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
- assert is_chain_pattern
- assert pattern.type() == BatchNorm2d
- bn_node: Node = matched.get(pattern.name())
- assert bn_node is not None
-
- conv = Conv2d(16, 16, 3)
- conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs())
- return [conv_node]
-
- class BnReplace(PatternEngine):
- def __init__(self):
- super().__init__([BatchNorm2d], BnReplacement())
-
- net = ChainNetwork()
- stree = SymbolTree.create(net)
- conv = stree.get_node("conv")
- bn = stree.get_node("bn")
- relu1 = stree.get_node("relu1")
- construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
- assert conv is not None
- assert bn is not None
- assert relu1 is not None
- assert len(construct_ast.body) == 6
- assert len(stree.nodes()) == 7
-
- bn_replace = BnReplace()
- bn_replace.apply(stree)
-
- assert len(construct_ast.body) == 6
- assert len(stree.nodes()) == 7
- conv = stree.get_node("conv")
- bn = stree.get_node("bn")
- relu1 = stree.get_node("relu1")
- new_conv = stree.get_node("x1")
- assert conv is not None
- assert bn is None
- assert relu1 is not None
- assert new_conv is not None
-
- # check conv topological order
- assert len(conv.get_users()) == 1
- assert conv.get_users()[0] == new_conv
- # check new_conv topological order
- assert len(new_conv.get_inputs()) == 1
- assert new_conv.get_inputs()[0] == conv
- assert len(new_conv.get_users()) == 1
- assert new_conv.get_users()[0] == relu1
- # check source code order
- assert getattr(conv.get_handler(), "_next") == new_conv.get_handler()
- assert getattr(new_conv.get_handler(), "_next") == relu1.get_handler()
- assert getattr(relu1.get_handler(), "_prev") == new_conv.get_handler()
- assert getattr(new_conv.get_handler(), "_prev") == conv.get_handler()
- # # check arg edge
- assert len(conv.get_targets()) == 1
- assert len(new_conv.get_args()) == 1
- assert conv.get_targets()[0] == new_conv.get_args()[0]
- assert len(new_conv.get_targets()) == 1
- assert len(relu1.get_args()) == 1
- assert new_conv.get_targets()[0] == relu1.get_args()[0]
-
-
- def test_one_to_multi_chain_pattern():
- """
- Feature: Python api PatternEngine.
- Description: Construct a one-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
- PatternEngine applied.
- Expectation: Success.
- """
-
- class BnReplacement(Replacement):
- def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
- assert is_chain_pattern
- assert pattern.type() == BatchNorm2d
- bn_node: Node = matched.get(pattern.name())
- assert bn_node is not None
-
- # Replacement should ensure target is unique in result
- # Replacement should ensure args and kwargs are well set by topological relation
- conv1 = Conv2d(16, 16, 3)
- conv_node1 = Node.create_call_cell(conv1, ['x1'], bn_node.get_args(), bn_node.get_kwargs())
- conv2 = Conv2d(16, 16, 5)
- conv_node2 = Node.create_call_cell(conv2, ['x2'], [ScopedValue.create_naming_value('x1')])
- return [conv_node1, conv_node2]
-
- class BnReplace(PatternEngine):
- def __init__(self):
- super().__init__([BatchNorm2d], BnReplacement())
-
- net = ChainNetwork()
- stree = SymbolTree.create(net)
- conv = stree.get_node("conv")
- bn = stree.get_node("bn")
- relu1 = stree.get_node("relu1")
- construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
- assert conv is not None
- assert bn is not None
- assert relu1 is not None
- assert len(construct_ast.body) == 6
- assert len(stree.nodes()) == 7
-
- bn_replace = BnReplace()
- bn_replace.apply(stree)
-
- assert len(construct_ast.body) == 7
- assert len(stree.nodes()) == 8
- conv = stree.get_node("conv")
- bn = stree.get_node("bn")
- relu1 = stree.get_node("relu1")
- new_conv1 = stree.get_node("x1")
- new_conv2 = stree.get_node("x2")
- assert conv is not None
- assert bn is None
- assert relu1 is not None
- assert new_conv1 is not None
- assert new_conv2 is not None
-
- # check conv topological order
- assert len(conv.get_users()) == 1
- assert conv.get_users()[0] == new_conv1
- # check new_conv1 topological order
- assert len(new_conv1.get_inputs()) == 1
- assert new_conv1.get_inputs()[0] == conv
- assert len(new_conv1.get_users()) == 1
- assert new_conv1.get_users()[0] == new_conv2
- # check new_conv2 topological order
- assert len(new_conv2.get_inputs()) == 1
- assert new_conv2.get_inputs()[0] == new_conv1
- assert len(new_conv2.get_users()) == 1
- assert new_conv2.get_users()[0] == relu1
- # check source code order
- assert getattr(conv.get_handler(), "_next") == new_conv1.get_handler()
- assert getattr(new_conv1.get_handler(), "_next") == new_conv2.get_handler()
- assert getattr(new_conv2.get_handler(), "_next") == relu1.get_handler()
- assert getattr(relu1.get_handler(), "_prev") == new_conv2.get_handler()
- assert getattr(new_conv2.get_handler(), "_prev") == new_conv1.get_handler()
- assert getattr(new_conv1.get_handler(), "_prev") == conv.get_handler()
- # check arg edge
- assert len(conv.get_targets()) == 1
- assert len(new_conv1.get_args()) == 1
- assert conv.get_targets()[0] == new_conv1.get_args()[0]
-
- assert len(new_conv1.get_targets()) == 1
- assert len(new_conv2.get_args()) == 1
- assert new_conv1.get_targets()[0] == new_conv2.get_args()[0]
-
- assert len(new_conv2.get_targets()) == 1
- assert len(relu1.get_args()) == 1
- assert new_conv2.get_targets()[0] == relu1.get_args()[0]
-
-
- class TreeNetwork(Cell):
- def __init__(self):
- super().__init__()
- self.conv1 = Conv2d(16, 16, 3)
- self.conv2 = Conv2d(16, 16, 5)
- self.add = Add()
- self.relu = ReLU()
- self.relu1 = ReLU()
- self.relu2 = ReLU()
-
- def construct(self, x):
- x1 = self.conv1(x)
- x2 = self.conv2(x)
- x = self.add(x1, x2)
- x = self.relu(x)
- x1 = self.relu1(x)
- x2 = self.relu2(x)
- x = self.add(x1, x2)
- return x
-
-
- def test_tree_pattern():
- """
- Feature: Python api PatternEngine.
- Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
- PatternEngine applied.
- Expectation: Success.
- """
-
- class AddReluReplacement(Replacement):
- def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
- assert is_chain_pattern
- assert pattern.type() == ReLU
- relu_node: Node = matched.get(pattern.name())
- assert relu_node is not None
- assert len(pattern.get_inputs()) == 1
- add_pattern = pattern.get_inputs()[0]
- assert add_pattern.type() == Add
- add_node: Node = matched.get(add_pattern.name())
- assert add_node is not None
- assert not add_pattern.get_inputs()
-
- # can not use add_node here
- new_add1 = Add()
- new_add1_node = Node.create_call_cell(new_add1, ['new_add_1'], add_node.get_args(), add_node.get_kwargs())
- new_relu1 = ReLU()
- new_relu1_node = Node.create_call_cell(new_relu1, ['new_relu_1'],
- [ScopedValue.create_naming_value('new_add_1')])
- new_relu2 = ReLU()
- new_relu2_node = Node.create_call_cell(new_relu2, ['new_relu_2'],
- [ScopedValue.create_naming_value('new_add_1')])
- new_add2 = Add()
- new_add2_node = Node.create_call_cell(new_add2, ['new_add_2'],
- [ScopedValue.create_naming_value('new_relu_1'),
- ScopedValue.create_naming_value('new_relu_2')])
- return [new_add1_node, new_relu1_node, new_relu2_node, new_add2_node]
-
- class AddReluPattern(PatternEngine):
- def __init__(self):
- super().__init__([Add, ReLU], AddReluReplacement())
-
- net = TreeNetwork()
- stree = SymbolTree.create(net)
- conv1 = stree.get_node("conv1")
- conv2 = stree.get_node("conv2")
- add = stree.get_node("add")
- relu = stree.get_node("relu")
- relu1 = stree.get_node("relu1")
- relu2 = stree.get_node("relu2")
- assert conv1 is not None
- assert conv2 is not None
- assert add is not None
- assert relu is not None
- assert relu1 is not None
- assert relu2 is not None
- construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
- assert len(construct_ast.body) == 8
- assert len(stree.nodes()) == 9
-
- add_relu_pattern = AddReluPattern()
- add_relu_pattern.apply(stree)
-
- assert len(construct_ast.body) == 10
- assert len(stree.nodes()) == 11
- conv1 = stree.get_node("conv1")
- conv2 = stree.get_node("conv2")
- add = stree.get_node("add")
- relu = stree.get_node("relu")
- relu1 = stree.get_node("relu1")
- relu2 = stree.get_node("relu2")
- new_add = stree.get_node("new_add")
- new_relu = stree.get_node("new_relu")
- new_relu_1 = stree.get_node("new_relu_1")
- new_add_1 = stree.get_node("new_add_1")
-
- assert conv1 is not None
- assert conv2 is not None
- assert add is None
- assert relu is None
- assert relu1 is not None
- assert relu2 is not None
- assert new_add is not None
- assert new_relu is not None
- assert new_relu_1 is not None
- assert new_add_1 is not None
-
- # check conv1 topological order
- assert len(conv1.get_users()) == 1
- assert conv1.get_users()[0] == new_add
- # check conv2 topological order
- assert len(conv2.get_users()) == 1
- assert conv2.get_users()[0] == new_add
- # check new_add topological order
- assert len(new_add.get_inputs()) == 2
- assert new_add.get_inputs()[0] == conv1
- assert new_add.get_inputs()[1] == conv2
- assert len(new_add.get_users()) == 2
- assert new_add.get_users()[0] == new_relu
- assert new_add.get_users()[1] == new_relu_1
- # check new_relu topological order
- assert len(new_relu.get_inputs()) == 1
- assert new_relu.get_inputs()[0] == new_add
- assert len(new_relu.get_users()) == 1
- assert new_relu.get_users()[0] == new_add_1
- # check new_relu_1 topological order
- assert len(new_relu_1.get_inputs()) == 1
- assert new_relu_1.get_inputs()[0] == new_add
- assert len(new_relu_1.get_users()) == 1
- assert new_relu_1.get_users()[0] == new_add_1
- # check new_add_1 topological order
- assert len(new_add_1.get_inputs()) == 2
- assert new_add_1.get_inputs()[0] == new_relu_1
- assert new_add_1.get_inputs()[1] == new_relu
- assert len(new_add_1.get_users()) == 2
- assert new_add_1.get_users()[0] == relu1
- assert new_add_1.get_users()[1] == relu2
- # check source code order
- assert getattr(conv1.get_handler(), "_next") == conv2.get_handler()
- assert getattr(conv2.get_handler(), "_next") == new_add.get_handler()
- assert getattr(new_add.get_handler(), "_next") == new_relu.get_handler()
- assert getattr(new_relu.get_handler(), "_next") == new_relu_1.get_handler()
- assert getattr(new_relu_1.get_handler(), "_next") == new_add_1.get_handler()
- assert getattr(new_add_1.get_handler(), "_next") == relu1.get_handler()
- assert getattr(relu1.get_handler(), "_prev") == new_add_1.get_handler()
- assert getattr(new_add_1.get_handler(), "_prev") == new_relu_1.get_handler()
- assert getattr(new_relu_1.get_handler(), "_prev") == new_relu.get_handler()
- assert getattr(new_relu.get_handler(), "_prev") == new_add.get_handler()
- assert getattr(new_add.get_handler(), "_prev") == conv2.get_handler()
- assert getattr(conv2.get_handler(), "_prev") == conv1.get_handler()
- # check arg edge
- assert len(conv1.get_targets()) == 1
- assert len(conv2.get_targets()) == 1
- assert len(new_add.get_args()) == 2
- assert conv1.get_targets()[0] == new_add.get_args()[0]
- assert conv2.get_targets()[0] == new_add.get_args()[1]
-
- assert len(new_add.get_targets()) == 1
- assert len(new_relu.get_args()) == 1
- assert len(new_relu_1.get_args()) == 1
- assert new_add.get_targets()[0] == new_relu.get_args()[0]
- assert new_add.get_targets()[0] == new_relu_1.get_args()[0]
-
- assert len(new_relu.get_targets()) == 1
- assert len(new_relu_1.get_targets()) == 1
- assert len(new_add_1.get_args()) == 2
- assert new_relu.get_targets()[0] == new_add_1.get_args()[1]
- assert new_relu_1.get_targets()[0] == new_add_1.get_args()[0]
-
- assert len(new_add_1.get_targets()) == 1
- assert len(relu1.get_args()) == 1
- assert len(relu2.get_args()) == 1
- assert new_add_1.get_targets()[0] == relu1.get_args()[0]
- assert new_add_1.get_targets()[0] == relu2.get_args()[0]
-
-
- class TreeNetwork2(Cell):
- def __init__(self):
- super().__init__()
- self.conv1 = Conv2d(16, 16, 1)
- self.conv2 = Conv2d(16, 16, 3)
- self.add1 = AddN()
- self.add2 = AddN()
- self.relu = ReLU()
-
- def construct(self, x, y, z):
- x = self.conv1(x)
- y = self.conv2(y)
- z = self.add1(x, y, z)
- z = self.add2(x, y, z)
- z = self.relu(z)
- return z
-
-
- class MultiInputPattern(PatternEngine):
- class MultiInputReplacement(Replacement):
- def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
- assert not is_chain_pattern
- assert pattern.type() == AddN
- addn2_node: Node = matched.get(pattern.name())
- assert addn2_node is not None
- assert len(pattern.get_inputs()) == 3
- conv1_pn = pattern.get_inputs()[0]
- conv2_pn = pattern.get_inputs()[1]
- addn1_pn = pattern.get_inputs()[2]
- assert conv1_pn.type() == Conv2d
- assert conv2_pn.type() == Conv2d
- assert addn1_pn.type() == AddN
- conv1_node: Node = matched.get(conv1_pn.name())
- conv2_node: Node = matched.get(conv2_pn.name())
- addn1_node: Node = matched.get(addn1_pn.name())
- assert conv1_node is not None
- assert conv2_node is not None
- assert addn1_node is not None
- assert len(conv1_node.get_inputs()) == 1
- assert len(conv2_node.get_inputs()) == 1
- assert len(addn1_node.get_inputs()) == 3
- arg1 = conv1_node.get_args()[0]
- arg2 = conv2_node.get_args()[0]
- arg3 = addn1_node.get_args()[2]
-
- # can not use add_node here
- new_add1 = Add()
- new_add1_node = Node.create_call_cell(new_add1, ['new_add1'], [arg1, arg2])
- new_add2 = Add()
- new_add2_node = Node.create_call_cell(new_add2, ['new_add2'], [ScopedValue.create_naming_value('new_add1'),
- arg3])
- return [new_add1_node, new_add2_node]
-
- def __init__(self):
- conv1_pn = PatternNode("conv1", Conv2d)
- conv2_pn = PatternNode("conv2", Conv2d)
- addn1_pn = PatternNode("addn1", AddN)
- addn2_pn = PatternNode("addn2", AddN)
- conv1_pn.set_inputs([VarNode()])
- conv2_pn.set_inputs([VarNode()])
- addn1_pn.set_inputs([conv1_pn, conv2_pn, VarNode()])
- addn2_pn.set_inputs([conv1_pn, conv2_pn, addn1_pn])
- super().__init__(addn2_pn, MultiInputPattern.MultiInputReplacement())
-
-
- def test_multi_input_to_multi_pattern_tree_pattern():
- """
- Feature: Python api PatternEngine.
- Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
- PatternEngine applied.
- Expectation: Success.
- """
-
- net = TreeNetwork2()
- stree = SymbolTree.create(net)
- conv1 = stree.get_node("conv1")
- conv2 = stree.get_node("conv2")
- add1 = stree.get_node("add1")
- add2 = stree.get_node("add2")
- relu = stree.get_node("relu")
- assert conv1 is not None
- assert conv2 is not None
- assert add1 is not None
- assert add2 is not None
- assert relu is not None
- construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
- assert len(construct_ast.body) == 6
- assert len(stree.nodes()) == 9
-
- multi_input_pattern = MultiInputPattern()
- multi_input_pattern.apply(stree)
-
- assert len(construct_ast.body) == 4
- assert len(stree.nodes()) == 7
- conv1 = stree.get_node("conv1")
- conv2 = stree.get_node("conv2")
- add1 = stree.get_node("add1")
- add2 = stree.get_node("add2")
- relu = stree.get_node("relu")
- new_add1 = stree.get_node("new_add1")
- new_add2 = stree.get_node("new_add2")
- inputx = stree.get_node("input_x")
- inputy = stree.get_node("input_y")
- inputz = stree.get_node("input_z")
- assert conv1 is None
- assert conv2 is None
- assert add1 is None
- assert add2 is None
- assert relu is not None
- assert new_add1 is not None
- assert new_add2 is not None
- assert inputx is not None
- assert inputy is not None
- assert inputz is not None
-
- # check inputx topological order
- assert len(inputx.get_users()) == 1
- assert inputx.get_users()[0] == new_add1
- # check inputy topological order
- assert len(inputy.get_users()) == 1
- assert inputy.get_users()[0] == new_add1
- # check inputz topological order
- assert len(inputz.get_users()) == 1
- assert inputz.get_users()[0] == new_add2
- # check new_add1 topological order
- assert len(new_add1.get_inputs()) == 2
- assert new_add1.get_inputs()[0] == inputx
- assert new_add1.get_inputs()[1] == inputy
- assert len(new_add1.get_users()) == 1
- assert new_add1.get_users()[0] == new_add2
- # check new_add2 topological order
- assert len(new_add2.get_inputs()) == 2
- assert new_add2.get_inputs()[0] == new_add1
- assert new_add2.get_inputs()[1] == inputz
- assert len(new_add2.get_users()) == 1
- assert new_add2.get_users()[0] == relu
- # check relu topological order
- assert len(relu.get_inputs()) == 1
- assert relu.get_inputs()[0] == new_add2
- # check source code order
- assert getattr(inputz.get_handler(), "_next") == new_add1.get_handler()
- assert getattr(new_add1.get_handler(), "_next") == new_add2.get_handler()
- assert getattr(new_add2.get_handler(), "_next") == relu.get_handler()
- assert getattr(relu.get_handler(), "_prev") == new_add2.get_handler()
- assert getattr(new_add2.get_handler(), "_prev") == new_add1.get_handler()
- assert getattr(new_add1.get_handler(), "_prev") == inputz.get_handler()
- # check arg edge
- assert len(inputx.get_targets()) == 1
- assert len(inputy.get_targets()) == 1
- assert len(new_add1.get_args()) == 2
- assert inputx.get_targets()[0] == new_add1.get_args()[0]
- assert inputy.get_targets()[0] == new_add1.get_args()[1]
-
- assert len(inputz.get_targets()) == 1
- assert len(new_add1.get_targets()) == 1
- assert len(new_add2.get_args()) == 2
- assert new_add1.get_targets()[0] == new_add2.get_args()[0]
- assert inputz.get_targets()[0] == new_add2.get_args()[1]
-
- assert len(new_add2.get_targets()) == 1
- assert len(relu.get_args()) == 1
- assert new_add2.get_targets()[0] == relu.get_args()[0]
-
-
- class TreeNetwork3(Cell):
- def __init__(self):
- super().__init__()
- self.conv1 = Conv2d(16, 16, 1)
- self.conv2 = Conv2d(16, 16, 3)
- self.add1 = AddN()
- self.add2 = AddN()
- self.relu = ReLU()
-
- def construct(self, x):
- y = self.conv1(x)
- z = self.conv2(x)
- x = self.add1(y, z, x)
- x = self.add2(y, z, x)
- x = self.relu(x)
- return x
-
-
- def test_one_input_to_multi_pattern_tree_pattern():
- """
- Feature: Python api PatternEngine.
- Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
- PatternEngine applied.
- Expectation: Success.
- """
-
- net = TreeNetwork3()
- stree = SymbolTree.create(net)
- conv1 = stree.get_node("conv1")
- conv2 = stree.get_node("conv2")
- add1 = stree.get_node("add1")
- add2 = stree.get_node("add2")
- relu = stree.get_node("relu")
- assert conv1 is not None
- assert conv2 is not None
- assert add1 is not None
- assert add2 is not None
- assert relu is not None
- construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
- assert len(construct_ast.body) == 6
- assert len(stree.nodes()) == 7
-
- multi_input_pattern = MultiInputPattern()
- multi_input_pattern.apply(stree)
-
- assert len(construct_ast.body) == 4
- assert len(stree.nodes()) == 5
- conv1 = stree.get_node("conv1")
- conv2 = stree.get_node("conv2")
- add1 = stree.get_node("add1")
- add2 = stree.get_node("add2")
- relu = stree.get_node("relu")
- new_add1 = stree.get_node("new_add1")
- new_add2 = stree.get_node("new_add2")
- inputx = stree.get_node("input_x")
- assert conv1 is None
- assert conv2 is None
- assert add1 is None
- assert add2 is None
- assert relu is not None
- assert new_add1 is not None
- assert new_add2 is not None
- assert inputx is not None
-
- # check inputx topological order
- assert len(inputx.get_users()) == 2
- assert inputx.get_users()[0] == new_add1
- assert inputx.get_users()[1] == new_add2
- # check new_add1 topological order
- assert len(new_add1.get_inputs()) == 2
- assert new_add1.get_inputs()[0] == inputx
- assert new_add1.get_inputs()[1] == inputx
- assert len(new_add1.get_users()) == 1
- assert new_add1.get_users()[0] == new_add2
- # check new_add2 topological order
- assert len(new_add2.get_inputs()) == 2
- assert new_add2.get_inputs()[0] == new_add1
- assert new_add2.get_inputs()[1] == inputx
- assert len(new_add2.get_users()) == 1
- assert new_add2.get_users()[0] == relu
- # check relu topological order
- assert len(relu.get_inputs()) == 1
- assert relu.get_inputs()[0] == new_add2
- # check source code order
- assert getattr(inputx.get_handler(), "_next") == new_add1.get_handler()
- assert getattr(new_add1.get_handler(), "_next") == new_add2.get_handler()
- assert getattr(new_add2.get_handler(), "_next") == relu.get_handler()
- assert getattr(relu.get_handler(), "_prev") == new_add2.get_handler()
- assert getattr(new_add2.get_handler(), "_prev") == new_add1.get_handler()
- assert getattr(new_add1.get_handler(), "_prev") == inputx.get_handler()
- # check arg edge
- assert len(inputx.get_targets()) == 1
- assert len(new_add1.get_args()) == 2
- assert inputx.get_targets()[0] == new_add1.get_args()[0]
- assert inputx.get_targets()[0] == new_add1.get_args()[1]
-
- assert len(inputx.get_targets()) == 1
- assert len(new_add1.get_targets()) == 1
- assert len(new_add2.get_args()) == 2
- assert new_add1.get_targets()[0] == new_add2.get_args()[0]
- assert inputx.get_targets()[0] == new_add2.get_args()[1]
-
- assert len(new_add2.get_targets()) == 1
- assert len(relu.get_args()) == 1
- assert new_add2.get_targets()[0] == relu.get_args()[0]
|