You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pattern_engine.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import ast
  16. from collections import OrderedDict
  17. from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
  18. from mindspore.ops import Add, AddN
  19. from mindspore.rewrite import ScopedValue, Node, SymbolTree
  20. from mindspore.rewrite import PatternEngine, PatternNode, Replacement, VarNode
  21. def test_tree_pattern_match():
  22. """
  23. Feature: Python api PatternEngine.
  24. Description: Construct a tree PatternEngine and apply it on a SymbolTree, check SymbolTree after PatternEngine
  25. applied.
  26. Expectation: Success.
  27. """
  28. assert True
  29. def test_leak_pattern_match():
  30. """
  31. Feature: Python api PatternEngine.
  32. Description: Construct a leaked tree PatternEngine and apply it on a SymbolTree, check SymbolTree after
  33. PatternEngine applied.
  34. Expectation: Failure.
  35. """
  36. assert True
  37. class ChainNetwork(Cell):
  38. def __init__(self):
  39. super().__init__()
  40. self.conv = Conv2d(16, 16, 3)
  41. self.bn = BatchNorm2d(16)
  42. self.relu1 = ReLU()
  43. self.relu2 = ReLU()
  44. self.relu3 = ReLU()
  45. def construct(self, x):
  46. x = self.conv(x)
  47. x = self.bn(x)
  48. x = self.relu1(x)
  49. x = self.relu2(x)
  50. x = self.relu3(x)
  51. return x
  52. def test_one_to_one_pattern():
  53. """
  54. Feature: Python api PatternEngine.
  55. Description: Construct a one-to-one PatternEngine and apply it on a SymbolTree, check SymbolTree after PatternEngine
  56. applied.
  57. Expectation: Success.
  58. """
  59. class BnReplacement(Replacement):
  60. def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
  61. assert is_chain_pattern
  62. assert pattern.type() == BatchNorm2d
  63. bn_node: Node = matched.get(pattern.name())
  64. assert bn_node is not None
  65. conv = Conv2d(16, 16, 3)
  66. conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs())
  67. return [conv_node]
  68. class BnReplace(PatternEngine):
  69. def __init__(self):
  70. super().__init__([BatchNorm2d], BnReplacement())
  71. net = ChainNetwork()
  72. stree = SymbolTree.create(net)
  73. conv = stree.get_node("conv")
  74. bn = stree.get_node("bn")
  75. relu1 = stree.get_node("relu1")
  76. construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
  77. assert conv is not None
  78. assert bn is not None
  79. assert relu1 is not None
  80. assert len(construct_ast.body) == 6
  81. assert len(stree.nodes()) == 7
  82. bn_replace = BnReplace()
  83. bn_replace.apply(stree)
  84. assert len(construct_ast.body) == 6
  85. assert len(stree.nodes()) == 7
  86. conv = stree.get_node("conv")
  87. bn = stree.get_node("bn")
  88. relu1 = stree.get_node("relu1")
  89. new_conv = stree.get_node("x1")
  90. assert conv is not None
  91. assert bn is None
  92. assert relu1 is not None
  93. assert new_conv is not None
  94. # check conv topological order
  95. assert len(conv.get_users()) == 1
  96. assert conv.get_users()[0] == new_conv
  97. # check new_conv topological order
  98. assert len(new_conv.get_inputs()) == 1
  99. assert new_conv.get_inputs()[0] == conv
  100. assert len(new_conv.get_users()) == 1
  101. assert new_conv.get_users()[0] == relu1
  102. # check source code order
  103. assert getattr(conv.get_handler(), "_next") == new_conv.get_handler()
  104. assert getattr(new_conv.get_handler(), "_next") == relu1.get_handler()
  105. assert getattr(relu1.get_handler(), "_prev") == new_conv.get_handler()
  106. assert getattr(new_conv.get_handler(), "_prev") == conv.get_handler()
  107. # # check arg edge
  108. assert len(conv.get_targets()) == 1
  109. assert len(new_conv.get_args()) == 1
  110. assert conv.get_targets()[0] == new_conv.get_args()[0]
  111. assert len(new_conv.get_targets()) == 1
  112. assert len(relu1.get_args()) == 1
  113. assert new_conv.get_targets()[0] == relu1.get_args()[0]
  114. def test_one_to_multi_chain_pattern():
  115. """
  116. Feature: Python api PatternEngine.
  117. Description: Construct a one-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
  118. PatternEngine applied.
  119. Expectation: Success.
  120. """
  121. class BnReplacement(Replacement):
  122. def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
  123. assert is_chain_pattern
  124. assert pattern.type() == BatchNorm2d
  125. bn_node: Node = matched.get(pattern.name())
  126. assert bn_node is not None
  127. # Replacement should ensure target is unique in result
  128. # Replacement should ensure args and kwargs are well set by topological relation
  129. conv1 = Conv2d(16, 16, 3)
  130. conv_node1 = Node.create_call_cell(conv1, ['x1'], bn_node.get_args(), bn_node.get_kwargs())
  131. conv2 = Conv2d(16, 16, 5)
  132. conv_node2 = Node.create_call_cell(conv2, ['x2'], [ScopedValue.create_naming_value('x1')])
  133. return [conv_node1, conv_node2]
  134. class BnReplace(PatternEngine):
  135. def __init__(self):
  136. super().__init__([BatchNorm2d], BnReplacement())
  137. net = ChainNetwork()
  138. stree = SymbolTree.create(net)
  139. conv = stree.get_node("conv")
  140. bn = stree.get_node("bn")
  141. relu1 = stree.get_node("relu1")
  142. construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
  143. assert conv is not None
  144. assert bn is not None
  145. assert relu1 is not None
  146. assert len(construct_ast.body) == 6
  147. assert len(stree.nodes()) == 7
  148. bn_replace = BnReplace()
  149. bn_replace.apply(stree)
  150. assert len(construct_ast.body) == 7
  151. assert len(stree.nodes()) == 8
  152. conv = stree.get_node("conv")
  153. bn = stree.get_node("bn")
  154. relu1 = stree.get_node("relu1")
  155. new_conv1 = stree.get_node("x1")
  156. new_conv2 = stree.get_node("x2")
  157. assert conv is not None
  158. assert bn is None
  159. assert relu1 is not None
  160. assert new_conv1 is not None
  161. assert new_conv2 is not None
  162. # check conv topological order
  163. assert len(conv.get_users()) == 1
  164. assert conv.get_users()[0] == new_conv1
  165. # check new_conv1 topological order
  166. assert len(new_conv1.get_inputs()) == 1
  167. assert new_conv1.get_inputs()[0] == conv
  168. assert len(new_conv1.get_users()) == 1
  169. assert new_conv1.get_users()[0] == new_conv2
  170. # check new_conv2 topological order
  171. assert len(new_conv2.get_inputs()) == 1
  172. assert new_conv2.get_inputs()[0] == new_conv1
  173. assert len(new_conv2.get_users()) == 1
  174. assert new_conv2.get_users()[0] == relu1
  175. # check source code order
  176. assert getattr(conv.get_handler(), "_next") == new_conv1.get_handler()
  177. assert getattr(new_conv1.get_handler(), "_next") == new_conv2.get_handler()
  178. assert getattr(new_conv2.get_handler(), "_next") == relu1.get_handler()
  179. assert getattr(relu1.get_handler(), "_prev") == new_conv2.get_handler()
  180. assert getattr(new_conv2.get_handler(), "_prev") == new_conv1.get_handler()
  181. assert getattr(new_conv1.get_handler(), "_prev") == conv.get_handler()
  182. # check arg edge
  183. assert len(conv.get_targets()) == 1
  184. assert len(new_conv1.get_args()) == 1
  185. assert conv.get_targets()[0] == new_conv1.get_args()[0]
  186. assert len(new_conv1.get_targets()) == 1
  187. assert len(new_conv2.get_args()) == 1
  188. assert new_conv1.get_targets()[0] == new_conv2.get_args()[0]
  189. assert len(new_conv2.get_targets()) == 1
  190. assert len(relu1.get_args()) == 1
  191. assert new_conv2.get_targets()[0] == relu1.get_args()[0]
  192. class TreeNetwork(Cell):
  193. def __init__(self):
  194. super().__init__()
  195. self.conv1 = Conv2d(16, 16, 3)
  196. self.conv2 = Conv2d(16, 16, 5)
  197. self.add = Add()
  198. self.relu = ReLU()
  199. self.relu1 = ReLU()
  200. self.relu2 = ReLU()
  201. def construct(self, x):
  202. x1 = self.conv1(x)
  203. x2 = self.conv2(x)
  204. x = self.add(x1, x2)
  205. x = self.relu(x)
  206. x1 = self.relu1(x)
  207. x2 = self.relu2(x)
  208. x = self.add(x1, x2)
  209. return x
  210. def test_tree_pattern():
  211. """
  212. Feature: Python api PatternEngine.
  213. Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
  214. PatternEngine applied.
  215. Expectation: Success.
  216. """
  217. class AddReluReplacement(Replacement):
  218. def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
  219. assert is_chain_pattern
  220. assert pattern.type() == ReLU
  221. relu_node: Node = matched.get(pattern.name())
  222. assert relu_node is not None
  223. assert len(pattern.get_inputs()) == 1
  224. add_pattern = pattern.get_inputs()[0]
  225. assert add_pattern.type() == Add
  226. add_node: Node = matched.get(add_pattern.name())
  227. assert add_node is not None
  228. assert not add_pattern.get_inputs()
  229. # can not use add_node here
  230. new_add1 = Add()
  231. new_add1_node = Node.create_call_cell(new_add1, ['new_add_1'], add_node.get_args(), add_node.get_kwargs())
  232. new_relu1 = ReLU()
  233. new_relu1_node = Node.create_call_cell(new_relu1, ['new_relu_1'],
  234. [ScopedValue.create_naming_value('new_add_1')])
  235. new_relu2 = ReLU()
  236. new_relu2_node = Node.create_call_cell(new_relu2, ['new_relu_2'],
  237. [ScopedValue.create_naming_value('new_add_1')])
  238. new_add2 = Add()
  239. new_add2_node = Node.create_call_cell(new_add2, ['new_add_2'],
  240. [ScopedValue.create_naming_value('new_relu_1'),
  241. ScopedValue.create_naming_value('new_relu_2')])
  242. return [new_add1_node, new_relu1_node, new_relu2_node, new_add2_node]
  243. class AddReluPattern(PatternEngine):
  244. def __init__(self):
  245. super().__init__([Add, ReLU], AddReluReplacement())
  246. net = TreeNetwork()
  247. stree = SymbolTree.create(net)
  248. conv1 = stree.get_node("conv1")
  249. conv2 = stree.get_node("conv2")
  250. add = stree.get_node("add")
  251. relu = stree.get_node("relu")
  252. relu1 = stree.get_node("relu1")
  253. relu2 = stree.get_node("relu2")
  254. assert conv1 is not None
  255. assert conv2 is not None
  256. assert add is not None
  257. assert relu is not None
  258. assert relu1 is not None
  259. assert relu2 is not None
  260. construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
  261. assert len(construct_ast.body) == 8
  262. assert len(stree.nodes()) == 9
  263. add_relu_pattern = AddReluPattern()
  264. add_relu_pattern.apply(stree)
  265. assert len(construct_ast.body) == 10
  266. assert len(stree.nodes()) == 11
  267. conv1 = stree.get_node("conv1")
  268. conv2 = stree.get_node("conv2")
  269. add = stree.get_node("add")
  270. relu = stree.get_node("relu")
  271. relu1 = stree.get_node("relu1")
  272. relu2 = stree.get_node("relu2")
  273. new_add = stree.get_node("new_add")
  274. new_relu = stree.get_node("new_relu")
  275. new_relu_1 = stree.get_node("new_relu_1")
  276. new_add_1 = stree.get_node("new_add_1")
  277. assert conv1 is not None
  278. assert conv2 is not None
  279. assert add is None
  280. assert relu is None
  281. assert relu1 is not None
  282. assert relu2 is not None
  283. assert new_add is not None
  284. assert new_relu is not None
  285. assert new_relu_1 is not None
  286. assert new_add_1 is not None
  287. # check conv1 topological order
  288. assert len(conv1.get_users()) == 1
  289. assert conv1.get_users()[0] == new_add
  290. # check conv2 topological order
  291. assert len(conv2.get_users()) == 1
  292. assert conv2.get_users()[0] == new_add
  293. # check new_add topological order
  294. assert len(new_add.get_inputs()) == 2
  295. assert new_add.get_inputs()[0] == conv1
  296. assert new_add.get_inputs()[1] == conv2
  297. assert len(new_add.get_users()) == 2
  298. assert new_add.get_users()[0] == new_relu
  299. assert new_add.get_users()[1] == new_relu_1
  300. # check new_relu topological order
  301. assert len(new_relu.get_inputs()) == 1
  302. assert new_relu.get_inputs()[0] == new_add
  303. assert len(new_relu.get_users()) == 1
  304. assert new_relu.get_users()[0] == new_add_1
  305. # check new_relu_1 topological order
  306. assert len(new_relu_1.get_inputs()) == 1
  307. assert new_relu_1.get_inputs()[0] == new_add
  308. assert len(new_relu_1.get_users()) == 1
  309. assert new_relu_1.get_users()[0] == new_add_1
  310. # check new_add_1 topological order
  311. assert len(new_add_1.get_inputs()) == 2
  312. assert new_add_1.get_inputs()[0] == new_relu_1
  313. assert new_add_1.get_inputs()[1] == new_relu
  314. assert len(new_add_1.get_users()) == 2
  315. assert new_add_1.get_users()[0] == relu1
  316. assert new_add_1.get_users()[1] == relu2
  317. # check source code order
  318. assert getattr(conv1.get_handler(), "_next") == conv2.get_handler()
  319. assert getattr(conv2.get_handler(), "_next") == new_add.get_handler()
  320. assert getattr(new_add.get_handler(), "_next") == new_relu.get_handler()
  321. assert getattr(new_relu.get_handler(), "_next") == new_relu_1.get_handler()
  322. assert getattr(new_relu_1.get_handler(), "_next") == new_add_1.get_handler()
  323. assert getattr(new_add_1.get_handler(), "_next") == relu1.get_handler()
  324. assert getattr(relu1.get_handler(), "_prev") == new_add_1.get_handler()
  325. assert getattr(new_add_1.get_handler(), "_prev") == new_relu_1.get_handler()
  326. assert getattr(new_relu_1.get_handler(), "_prev") == new_relu.get_handler()
  327. assert getattr(new_relu.get_handler(), "_prev") == new_add.get_handler()
  328. assert getattr(new_add.get_handler(), "_prev") == conv2.get_handler()
  329. assert getattr(conv2.get_handler(), "_prev") == conv1.get_handler()
  330. # check arg edge
  331. assert len(conv1.get_targets()) == 1
  332. assert len(conv2.get_targets()) == 1
  333. assert len(new_add.get_args()) == 2
  334. assert conv1.get_targets()[0] == new_add.get_args()[0]
  335. assert conv2.get_targets()[0] == new_add.get_args()[1]
  336. assert len(new_add.get_targets()) == 1
  337. assert len(new_relu.get_args()) == 1
  338. assert len(new_relu_1.get_args()) == 1
  339. assert new_add.get_targets()[0] == new_relu.get_args()[0]
  340. assert new_add.get_targets()[0] == new_relu_1.get_args()[0]
  341. assert len(new_relu.get_targets()) == 1
  342. assert len(new_relu_1.get_targets()) == 1
  343. assert len(new_add_1.get_args()) == 2
  344. assert new_relu.get_targets()[0] == new_add_1.get_args()[1]
  345. assert new_relu_1.get_targets()[0] == new_add_1.get_args()[0]
  346. assert len(new_add_1.get_targets()) == 1
  347. assert len(relu1.get_args()) == 1
  348. assert len(relu2.get_args()) == 1
  349. assert new_add_1.get_targets()[0] == relu1.get_args()[0]
  350. assert new_add_1.get_targets()[0] == relu2.get_args()[0]
  351. class TreeNetwork2(Cell):
  352. def __init__(self):
  353. super().__init__()
  354. self.conv1 = Conv2d(16, 16, 1)
  355. self.conv2 = Conv2d(16, 16, 3)
  356. self.add1 = AddN()
  357. self.add2 = AddN()
  358. self.relu = ReLU()
  359. def construct(self, x, y, z):
  360. x = self.conv1(x)
  361. y = self.conv2(y)
  362. z = self.add1(x, y, z)
  363. z = self.add2(x, y, z)
  364. z = self.relu(z)
  365. return z
  366. class MultiInputPattern(PatternEngine):
  367. class MultiInputReplacement(Replacement):
  368. def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
  369. assert not is_chain_pattern
  370. assert pattern.type() == AddN
  371. addn2_node: Node = matched.get(pattern.name())
  372. assert addn2_node is not None
  373. assert len(pattern.get_inputs()) == 3
  374. conv1_pn = pattern.get_inputs()[0]
  375. conv2_pn = pattern.get_inputs()[1]
  376. addn1_pn = pattern.get_inputs()[2]
  377. assert conv1_pn.type() == Conv2d
  378. assert conv2_pn.type() == Conv2d
  379. assert addn1_pn.type() == AddN
  380. conv1_node: Node = matched.get(conv1_pn.name())
  381. conv2_node: Node = matched.get(conv2_pn.name())
  382. addn1_node: Node = matched.get(addn1_pn.name())
  383. assert conv1_node is not None
  384. assert conv2_node is not None
  385. assert addn1_node is not None
  386. assert len(conv1_node.get_inputs()) == 1
  387. assert len(conv2_node.get_inputs()) == 1
  388. assert len(addn1_node.get_inputs()) == 3
  389. arg1 = conv1_node.get_args()[0]
  390. arg2 = conv2_node.get_args()[0]
  391. arg3 = addn1_node.get_args()[2]
  392. # can not use add_node here
  393. new_add1 = Add()
  394. new_add1_node = Node.create_call_cell(new_add1, ['new_add1'], [arg1, arg2])
  395. new_add2 = Add()
  396. new_add2_node = Node.create_call_cell(new_add2, ['new_add2'], [ScopedValue.create_naming_value('new_add1'),
  397. arg3])
  398. return [new_add1_node, new_add2_node]
  399. def __init__(self):
  400. conv1_pn = PatternNode("conv1", Conv2d)
  401. conv2_pn = PatternNode("conv2", Conv2d)
  402. addn1_pn = PatternNode("addn1", AddN)
  403. addn2_pn = PatternNode("addn2", AddN)
  404. conv1_pn.set_inputs([VarNode()])
  405. conv2_pn.set_inputs([VarNode()])
  406. addn1_pn.set_inputs([conv1_pn, conv2_pn, VarNode()])
  407. addn2_pn.set_inputs([conv1_pn, conv2_pn, addn1_pn])
  408. super().__init__(addn2_pn, MultiInputPattern.MultiInputReplacement())
  409. def test_multi_input_to_multi_pattern_tree_pattern():
  410. """
  411. Feature: Python api PatternEngine.
  412. Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
  413. PatternEngine applied.
  414. Expectation: Success.
  415. """
  416. net = TreeNetwork2()
  417. stree = SymbolTree.create(net)
  418. conv1 = stree.get_node("conv1")
  419. conv2 = stree.get_node("conv2")
  420. add1 = stree.get_node("add1")
  421. add2 = stree.get_node("add2")
  422. relu = stree.get_node("relu")
  423. assert conv1 is not None
  424. assert conv2 is not None
  425. assert add1 is not None
  426. assert add2 is not None
  427. assert relu is not None
  428. construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
  429. assert len(construct_ast.body) == 6
  430. assert len(stree.nodes()) == 9
  431. multi_input_pattern = MultiInputPattern()
  432. multi_input_pattern.apply(stree)
  433. assert len(construct_ast.body) == 4
  434. assert len(stree.nodes()) == 7
  435. conv1 = stree.get_node("conv1")
  436. conv2 = stree.get_node("conv2")
  437. add1 = stree.get_node("add1")
  438. add2 = stree.get_node("add2")
  439. relu = stree.get_node("relu")
  440. new_add1 = stree.get_node("new_add1")
  441. new_add2 = stree.get_node("new_add2")
  442. inputx = stree.get_node("input_x")
  443. inputy = stree.get_node("input_y")
  444. inputz = stree.get_node("input_z")
  445. assert conv1 is None
  446. assert conv2 is None
  447. assert add1 is None
  448. assert add2 is None
  449. assert relu is not None
  450. assert new_add1 is not None
  451. assert new_add2 is not None
  452. assert inputx is not None
  453. assert inputy is not None
  454. assert inputz is not None
  455. # check inputx topological order
  456. assert len(inputx.get_users()) == 1
  457. assert inputx.get_users()[0] == new_add1
  458. # check inputy topological order
  459. assert len(inputy.get_users()) == 1
  460. assert inputy.get_users()[0] == new_add1
  461. # check inputz topological order
  462. assert len(inputz.get_users()) == 1
  463. assert inputz.get_users()[0] == new_add2
  464. # check new_add1 topological order
  465. assert len(new_add1.get_inputs()) == 2
  466. assert new_add1.get_inputs()[0] == inputx
  467. assert new_add1.get_inputs()[1] == inputy
  468. assert len(new_add1.get_users()) == 1
  469. assert new_add1.get_users()[0] == new_add2
  470. # check new_add2 topological order
  471. assert len(new_add2.get_inputs()) == 2
  472. assert new_add2.get_inputs()[0] == new_add1
  473. assert new_add2.get_inputs()[1] == inputz
  474. assert len(new_add2.get_users()) == 1
  475. assert new_add2.get_users()[0] == relu
  476. # check relu topological order
  477. assert len(relu.get_inputs()) == 1
  478. assert relu.get_inputs()[0] == new_add2
  479. # check source code order
  480. assert getattr(inputz.get_handler(), "_next") == new_add1.get_handler()
  481. assert getattr(new_add1.get_handler(), "_next") == new_add2.get_handler()
  482. assert getattr(new_add2.get_handler(), "_next") == relu.get_handler()
  483. assert getattr(relu.get_handler(), "_prev") == new_add2.get_handler()
  484. assert getattr(new_add2.get_handler(), "_prev") == new_add1.get_handler()
  485. assert getattr(new_add1.get_handler(), "_prev") == inputz.get_handler()
  486. # check arg edge
  487. assert len(inputx.get_targets()) == 1
  488. assert len(inputy.get_targets()) == 1
  489. assert len(new_add1.get_args()) == 2
  490. assert inputx.get_targets()[0] == new_add1.get_args()[0]
  491. assert inputy.get_targets()[0] == new_add1.get_args()[1]
  492. assert len(inputz.get_targets()) == 1
  493. assert len(new_add1.get_targets()) == 1
  494. assert len(new_add2.get_args()) == 2
  495. assert new_add1.get_targets()[0] == new_add2.get_args()[0]
  496. assert inputz.get_targets()[0] == new_add2.get_args()[1]
  497. assert len(new_add2.get_targets()) == 1
  498. assert len(relu.get_args()) == 1
  499. assert new_add2.get_targets()[0] == relu.get_args()[0]
  500. class TreeNetwork3(Cell):
  501. def __init__(self):
  502. super().__init__()
  503. self.conv1 = Conv2d(16, 16, 1)
  504. self.conv2 = Conv2d(16, 16, 3)
  505. self.add1 = AddN()
  506. self.add2 = AddN()
  507. self.relu = ReLU()
  508. def construct(self, x):
  509. y = self.conv1(x)
  510. z = self.conv2(x)
  511. x = self.add1(y, z, x)
  512. x = self.add2(y, z, x)
  513. x = self.relu(x)
  514. return x
  515. def test_one_input_to_multi_pattern_tree_pattern():
  516. """
  517. Feature: Python api PatternEngine.
  518. Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
  519. PatternEngine applied.
  520. Expectation: Success.
  521. """
  522. net = TreeNetwork3()
  523. stree = SymbolTree.create(net)
  524. conv1 = stree.get_node("conv1")
  525. conv2 = stree.get_node("conv2")
  526. add1 = stree.get_node("add1")
  527. add2 = stree.get_node("add2")
  528. relu = stree.get_node("relu")
  529. assert conv1 is not None
  530. assert conv2 is not None
  531. assert add1 is not None
  532. assert add2 is not None
  533. assert relu is not None
  534. construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
  535. assert len(construct_ast.body) == 6
  536. assert len(stree.nodes()) == 7
  537. multi_input_pattern = MultiInputPattern()
  538. multi_input_pattern.apply(stree)
  539. assert len(construct_ast.body) == 4
  540. assert len(stree.nodes()) == 5
  541. conv1 = stree.get_node("conv1")
  542. conv2 = stree.get_node("conv2")
  543. add1 = stree.get_node("add1")
  544. add2 = stree.get_node("add2")
  545. relu = stree.get_node("relu")
  546. new_add1 = stree.get_node("new_add1")
  547. new_add2 = stree.get_node("new_add2")
  548. inputx = stree.get_node("input_x")
  549. assert conv1 is None
  550. assert conv2 is None
  551. assert add1 is None
  552. assert add2 is None
  553. assert relu is not None
  554. assert new_add1 is not None
  555. assert new_add2 is not None
  556. assert inputx is not None
  557. # check inputx topological order
  558. assert len(inputx.get_users()) == 2
  559. assert inputx.get_users()[0] == new_add1
  560. assert inputx.get_users()[1] == new_add2
  561. # check new_add1 topological order
  562. assert len(new_add1.get_inputs()) == 2
  563. assert new_add1.get_inputs()[0] == inputx
  564. assert new_add1.get_inputs()[1] == inputx
  565. assert len(new_add1.get_users()) == 1
  566. assert new_add1.get_users()[0] == new_add2
  567. # check new_add2 topological order
  568. assert len(new_add2.get_inputs()) == 2
  569. assert new_add2.get_inputs()[0] == new_add1
  570. assert new_add2.get_inputs()[1] == inputx
  571. assert len(new_add2.get_users()) == 1
  572. assert new_add2.get_users()[0] == relu
  573. # check relu topological order
  574. assert len(relu.get_inputs()) == 1
  575. assert relu.get_inputs()[0] == new_add2
  576. # check source code order
  577. assert getattr(inputx.get_handler(), "_next") == new_add1.get_handler()
  578. assert getattr(new_add1.get_handler(), "_next") == new_add2.get_handler()
  579. assert getattr(new_add2.get_handler(), "_next") == relu.get_handler()
  580. assert getattr(relu.get_handler(), "_prev") == new_add2.get_handler()
  581. assert getattr(new_add2.get_handler(), "_prev") == new_add1.get_handler()
  582. assert getattr(new_add1.get_handler(), "_prev") == inputx.get_handler()
  583. # check arg edge
  584. assert len(inputx.get_targets()) == 1
  585. assert len(new_add1.get_args()) == 2
  586. assert inputx.get_targets()[0] == new_add1.get_args()[0]
  587. assert inputx.get_targets()[0] == new_add1.get_args()[1]
  588. assert len(inputx.get_targets()) == 1
  589. assert len(new_add1.get_targets()) == 1
  590. assert len(new_add2.get_args()) == 2
  591. assert new_add1.get_targets()[0] == new_add2.get_args()[0]
  592. assert inputx.get_targets()[0] == new_add2.get_args()[1]
  593. assert len(new_add2.get_targets()) == 1
  594. assert len(relu.get_args()) == 1
  595. assert new_add2.get_targets()[0] == relu.get_args()[0]