You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_symbol_tree.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import ast
  16. import inspect
  17. from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
  18. from mindspore.ops import Add
  19. from mindspore.rewrite import ScopedValue, ValueType, NodeType
  20. from mindspore.rewrite import Node as NodeApi
  21. from mindspore.rewrite.symbol_tree import SymbolTree
  22. from mindspore.rewrite.node import Node
  23. class Network(Cell):
  24. def __init__(self):
  25. super().__init__()
  26. self.conv = Conv2d(16, 16, 3)
  27. self.bn = BatchNorm2d(16)
  28. self.relu1 = ReLU()
  29. self.relu2 = ReLU()
  30. self.relu3 = ReLU()
  31. def construct(self, x):
  32. x = self.conv(x)
  33. x = self.bn(x)
  34. x = self.relu1(x)
  35. x = self.relu2(x)
  36. x = self.relu3(x)
  37. return x
  38. def create_symbol_tree():
  39. net = Network()
  40. source = inspect.getsource(type(net))
  41. ast_root = ast.parse(source)
  42. ast_module = ast_root
  43. assert isinstance(ast_root, ast.Module)
  44. ast_class = ast_module.body[0]
  45. assert isinstance(ast_class, ast.ClassDef)
  46. ast_init_func = ast_class.body[0]
  47. assert isinstance(ast_init_func, ast.FunctionDef)
  48. ast_construct_func = ast_class.body[1]
  49. assert isinstance(ast_construct_func, ast.FunctionDef)
  50. ast_conv = ast_construct_func.body[0]
  51. ast_bn = ast_construct_func.body[1]
  52. ast_relu1 = ast_construct_func.body[2]
  53. ast_relu2 = ast_construct_func.body[3]
  54. ast_relu3 = ast_construct_func.body[4]
  55. ast_return = ast_construct_func.body[5]
  56. stree = SymbolTree(net, ast_module)
  57. stree.set_class_ast(ast_class)
  58. stree.set_init_func_ast(ast_init_func)
  59. stree.set_ast_root(ast_construct_func)
  60. stree.append_input_node("x")
  61. conv_node = Node.create_call_buildin_op(net.conv, ast_conv, [ScopedValue.create_naming_value("x")],
  62. ScopedValue.create_naming_value("conv", "self"),
  63. [ScopedValue.create_naming_value("x")],
  64. {},
  65. "conv")
  66. stree.append_origin_field(conv_node)
  67. bn_node = Node.create_call_buildin_op(net.bn, ast_bn, [ScopedValue.create_naming_value("x")],
  68. ScopedValue.create_naming_value("bn", "self"),
  69. [ScopedValue.create_naming_value("x")], {},
  70. "bn")
  71. bn_node = stree.append_origin_field(bn_node)
  72. relu1_node = Node.create_call_buildin_op(net.relu1, ast_relu1, [ScopedValue.create_naming_value("x")],
  73. ScopedValue.create_naming_value("relu1", "self"),
  74. [ScopedValue.create_naming_value("x")],
  75. {}, "relu1")
  76. relu1_node = stree.append_origin_field(relu1_node)
  77. relu2_node = Node.create_call_buildin_op(net.relu2, ast_relu2, [ScopedValue.create_naming_value("x")],
  78. ScopedValue.create_naming_value("relu2", "self"),
  79. [ScopedValue.create_naming_value("x")],
  80. {}, "relu2")
  81. relu2_node = stree.append_origin_field(relu2_node)
  82. relu3_node = Node.create_call_buildin_op(net.relu3, ast_relu3, [ScopedValue.create_naming_value("x")],
  83. ScopedValue.create_naming_value("relu3", "self"),
  84. [ScopedValue.create_naming_value("x")],
  85. {}, "relu3")
  86. stree.append_origin_field(relu3_node)
  87. node_return = Node.create_output_node(ast_return, ["x"])
  88. stree.append_origin_field(node_return)
  89. return stree, bn_node, relu1_node, relu2_node
  90. def test_insert_node():
  91. """
  92. Feature: Python api insert_node of SymbolTree of Rewrite.
  93. Description: Call insert_node to insert a node into SymbolTree.
  94. Expectation: Success.
  95. """
  96. stree, _, relu1, relu2 = create_symbol_tree()
  97. construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
  98. providers = getattr(getattr(stree, "_topo_mgr"), "_target_provider")
  99. consumers = getattr(getattr(stree, "_topo_mgr"), "_target_consumer")
  100. providers_len = len(providers)
  101. consumers_len = len(consumers)
  102. assert len(stree.nodes()) == 7
  103. assert len(construct_ast.body) == 6
  104. assert len(relu1.get_targets()) == 1
  105. assert len(relu2.get_normalized_args().values()) == 1
  106. assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
  107. input1 = 1
  108. node = Node.create_call_buildin_op(Add(), None, ['x'], 'new_conv',
  109. [ScopedValue.create_naming_value('x'),
  110. ScopedValue.create_variable_value(input1)], {},
  111. 'new_conv')
  112. position = stree.before(relu2)
  113. node = stree.insert_node(position, node)
  114. # check nodes size
  115. assert len(stree.nodes()) == 8
  116. # check args
  117. assert len(relu2.get_normalized_args().values()) == 1
  118. assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
  119. assert len(node.get_normalized_args().values()) == 2
  120. assert list(node.get_normalized_args().values())[0] == ScopedValue.create_naming_value('x')
  121. assert list(node.get_normalized_args().values())[1].type == ValueType.IntValue
  122. # check provider
  123. assert len(providers) == providers_len + 1
  124. assert len(node.get_targets()) == 1
  125. assert providers.get(node.get_targets()[0])[0] == node
  126. assert providers.get(node.get_targets()[0])[1] == 0
  127. # check consumer
  128. assert len(consumers) == consumers_len + 1
  129. assert consumers.get(list(node.get_normalized_args().values())[1]) is not None
  130. # check inputs
  131. assert len(relu2.get_inputs()) == 1
  132. assert relu2.get_inputs()[0] == relu1
  133. assert len(node.get_inputs()) == 1
  134. assert node.get_inputs()[0].get_node_type() == NodeType.Input
  135. # check ast
  136. node_ast = node.get_ast()
  137. assert isinstance(node_ast, ast.Assign)
  138. args = node_ast.value.args
  139. assert isinstance(args, list)
  140. assert len(args) == 2
  141. assert isinstance(args[0], ast.Name)
  142. assert isinstance(args[1], ast.Constant)
  143. assert len(construct_ast.body) == 7
  144. def test_set_node_arg():
  145. """
  146. Feature: Python api set_node_arg of SymbolTree of Rewrite.
  147. Description: Call set_node_arg to change topological-order of a node.
  148. Expectation: Success.
  149. """
  150. stree, bn, relu1, relu2 = create_symbol_tree()
  151. assert len(stree.nodes()) == 7
  152. assert len(bn.get_targets()) == 1
  153. bn_output = bn.get_targets()[0]
  154. # check bn topological order
  155. assert len(stree.get_node_users(bn)) == 1
  156. assert stree.get_node_users(bn)[0][0] == relu1
  157. # check relu1 topological order
  158. assert len(stree.get_node_inputs(relu1)) == 1
  159. assert stree.get_node_inputs(relu1)[0] == bn
  160. assert len(stree.get_node_users(relu1)) == 1
  161. assert stree.get_node_users(relu1)[0][0] == relu2
  162. # check relu2 topological order
  163. assert len(stree.get_node_inputs(relu2)) == 1
  164. assert stree.get_node_inputs(relu2)[0] == relu1
  165. # check relu1 and relu2 edge
  166. assert len(relu1.get_targets()) == 1
  167. assert len(relu2.get_normalized_args().values()) == 1
  168. assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
  169. stree.set_node_arg(relu2, 0, bn_output)
  170. # check bn topological order
  171. assert len(stree.get_node_users(bn)) == 2
  172. assert stree.get_node_users(bn)[0][0] == relu1
  173. assert stree.get_node_users(bn)[1][0] == relu2
  174. # check relu1 topological order
  175. assert len(stree.get_node_inputs(relu1)) == 1
  176. assert stree.get_node_inputs(relu1)[0] == bn
  177. assert len(stree.get_node_users(relu1)) == 0
  178. # check relu2 topological order
  179. assert len(stree.get_node_inputs(relu2)) == 1
  180. assert stree.get_node_inputs(relu2)[0] == bn
  181. # check bn and relu2 edge
  182. assert len(relu1.get_targets()) == 1
  183. assert len(relu2.get_normalized_args().values()) == 1
  184. assert bn_output == list(relu2.get_normalized_args().values())[0]
  185. # check ast
  186. node_ast = relu2.get_ast()
  187. assert isinstance(node_ast, ast.Assign)
  188. args = node_ast.value.args
  189. assert isinstance(args, list)
  190. assert len(args) == 1
  191. assert isinstance(args[0], ast.Name)
  192. assert args[0].id == bn_output.value
  193. def test_set_node_arg_by_node():
  194. """
  195. Feature: Python api set_node_arg_by_node of SymbolTree of Rewrite.
  196. Description: Call set_node_arg_by_node to change topological-order of a node.
  197. Expectation: Success.
  198. """
  199. stree, bn, relu1, relu2 = create_symbol_tree()
  200. assert len(stree.nodes()) == 7
  201. assert len(bn.get_targets()) == 1
  202. bn_output = bn.get_targets()[0]
  203. # check bn topological order
  204. assert len(stree.get_node_users(bn)) == 1
  205. assert stree.get_node_users(bn)[0][0] == relu1
  206. # check relu1 topological order
  207. assert len(stree.get_node_inputs(relu1)) == 1
  208. assert stree.get_node_inputs(relu1)[0] == bn
  209. assert len(stree.get_node_users(relu1)) == 1
  210. assert stree.get_node_users(relu1)[0][0] == relu2
  211. # check relu2 topological order
  212. assert len(stree.get_node_inputs(relu2)) == 1
  213. assert stree.get_node_inputs(relu2)[0] == relu1
  214. # check relu1 and relu2 edge
  215. assert len(relu1.get_targets()) == 1
  216. assert len(relu2.get_normalized_args().values()) == 1
  217. assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
  218. stree.set_node_arg_by_node(relu2, 0, bn)
  219. # check bn topological order
  220. assert len(stree.get_node_users(bn)) == 2
  221. assert stree.get_node_users(bn)[0][0] == relu1
  222. assert stree.get_node_users(bn)[1][0] == relu2
  223. # check relu1 topological order
  224. assert len(stree.get_node_inputs(relu1)) == 1
  225. assert stree.get_node_inputs(relu1)[0] == bn
  226. assert len(stree.get_node_users(relu1)) == 0
  227. # check relu2 topological order
  228. assert len(stree.get_node_inputs(relu2)) == 1
  229. assert stree.get_node_inputs(relu2)[0] == bn
  230. # check bn and relu2 edge
  231. assert len(relu1.get_targets()) == 1
  232. assert len(relu2.get_normalized_args().values()) == 1
  233. assert bn_output == list(relu2.get_normalized_args().values())[0]
  234. # check ast
  235. node_ast = relu2.get_ast()
  236. assert isinstance(node_ast, ast.Assign)
  237. args = node_ast.value.args
  238. assert isinstance(args, list)
  239. assert len(args) == 1
  240. assert isinstance(args[0], ast.Name)
  241. assert args[0].id == bn_output.value
  242. def test_erase_succeed():
  243. """
  244. Feature: Python api erase_node of SymbolTree of Rewrite.
  245. Description: Call erase_node to erase a node from SymbolTree.
  246. Expectation: Success.
  247. """
  248. stree, bn, relu1, relu2 = create_symbol_tree()
  249. construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
  250. providers = getattr(getattr(stree, "_topo_mgr"), "_target_provider")
  251. providers_len = len(providers)
  252. assert len(stree.nodes()) == 7
  253. assert len(construct_ast.body) == 6
  254. stree.set_node_arg_by_node(relu2, 0, bn)
  255. stree.erase_node(relu1)
  256. assert len(stree.nodes()) == 6
  257. assert len(providers) == providers_len - 1
  258. assert len(construct_ast.body) == 5
  259. def test_erase_failed():
  260. """
  261. Feature: Python api erase_node of SymbolTree of Rewrite.
  262. Description: Call erase_node to erase a node from SymbolTree which is not isolated.
  263. Expectation: Failure.
  264. """
  265. stree, _, relu1, _ = create_symbol_tree()
  266. catched_error = False
  267. try:
  268. stree.erase_node(relu1)
  269. except RuntimeError:
  270. catched_error = True
  271. assert catched_error
  272. def test_replace_one_to_one():
  273. """
  274. Feature: Python api replace of SymbolTree of Rewrite.
  275. Description: Call replace to replace an origin node to a new node.
  276. Expectation: Success.
  277. """
  278. stree, bn, relu1, relu2 = create_symbol_tree()
  279. construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
  280. assert len(construct_ast.body) == 6
  281. assert len(stree.nodes()) == 7
  282. new_conv = Conv2d(16, 16, 5)
  283. new_conv_node = NodeApi.create_call_cell(new_conv, [ScopedValue.create_naming_value("new_conv")],
  284. bn.get_targets()).get_handler()
  285. new_conv_node = stree.replace(relu1, [new_conv_node])
  286. assert len(stree.nodes()) == 7
  287. # check ast
  288. assert len(construct_ast.body) == 6
  289. node_ast: ast.Assign = construct_ast.body[2]
  290. func_ast: ast.Attribute = node_ast.value.func
  291. assert func_ast.attr == new_conv_node.get_name()
  292. # check bn topological order
  293. assert len(stree.get_node_users(bn)) == 1
  294. assert stree.get_node_users(bn)[0][0] == new_conv_node
  295. # check new_conv_node topological order
  296. assert len(stree.get_node_inputs(new_conv_node)) == 1
  297. assert stree.get_node_inputs(new_conv_node)[0] == bn
  298. assert len(stree.get_node_users(new_conv_node)) == 1
  299. assert stree.get_node_users(new_conv_node)[0][0] == relu2
  300. # check relu2 topological order
  301. assert len(stree.get_node_inputs(relu2)) == 1
  302. assert stree.get_node_inputs(relu2)[0] == new_conv_node
  303. # check arg edge
  304. assert len(bn.get_targets()) == 1
  305. assert len(new_conv_node.get_normalized_args().values()) == 1
  306. assert bn.get_targets()[0] == list(new_conv_node.get_normalized_args().values())[0]
  307. assert len(new_conv_node.get_targets()) == 1
  308. assert len(relu2.get_normalized_args().values()) == 1
  309. assert new_conv_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
  310. def test_replace_one_to_multi():
  311. """
  312. Feature: Python api replace of SymbolTree of Rewrite.
  313. Description: Call replace to replace an origin node to a new node-tree.
  314. Expectation: Success.
  315. """
  316. stree, bn, relu1, relu2 = create_symbol_tree()
  317. construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
  318. assert len(construct_ast.body) == 6
  319. assert len(stree.nodes()) == 7
  320. new_conv_node = NodeApi.create_call_cell(Conv2d(16, 16, 5), [ScopedValue.create_naming_value("new_conv")],
  321. bn.get_targets()).get_handler()
  322. new_relu_node = NodeApi.create_call_cell(ReLU(), [ScopedValue.create_naming_value("new_relu")],
  323. new_conv_node.get_targets()).get_handler()
  324. new_relu_node = stree.replace(relu1, [new_relu_node, new_conv_node])
  325. new_conv_node = new_relu_node.get_inputs()[0]
  326. assert len(stree.nodes()) == 8
  327. # check ast
  328. assert len(construct_ast.body) == 7
  329. new_conv_ast: ast.Assign = construct_ast.body[2]
  330. new_conv_func_ast: ast.Attribute = new_conv_ast.value.func
  331. assert new_conv_func_ast.attr == new_conv_node.get_name()
  332. new_relu_ast: ast.Assign = construct_ast.body[3]
  333. new_relu_func_ast: ast.Attribute = new_relu_ast.value.func
  334. assert new_relu_func_ast.attr == new_relu_node.get_name()
  335. # check bn topological order
  336. assert len(stree.get_node_users(bn)) == 1
  337. assert stree.get_node_users(bn)[0][0] == new_conv_node
  338. # check new_conv_node topological order
  339. assert len(stree.get_node_inputs(new_conv_node)) == 1
  340. assert stree.get_node_inputs(new_conv_node)[0] == bn
  341. assert len(stree.get_node_users(new_conv_node)) == 1
  342. assert stree.get_node_users(new_conv_node)[0][0] == new_relu_node
  343. # check new_relu_node topological order
  344. assert len(stree.get_node_inputs(new_relu_node)) == 1
  345. assert stree.get_node_inputs(new_relu_node)[0] == new_conv_node
  346. assert len(stree.get_node_users(new_relu_node)) == 1
  347. assert stree.get_node_users(new_relu_node)[0][0] == relu2
  348. # check relu2 topological order
  349. assert len(stree.get_node_inputs(relu2)) == 1
  350. assert stree.get_node_inputs(relu2)[0] == new_relu_node
  351. # check arg edge
  352. assert len(bn.get_targets()) == 1
  353. assert len(new_conv_node.get_normalized_args().values()) == 1
  354. assert bn.get_targets()[0] == list(new_conv_node.get_normalized_args().values())[0]
  355. assert len(new_conv_node.get_targets()) == 1
  356. assert len(new_relu_node.get_normalized_args().values()) == 1
  357. assert new_conv_node.get_targets()[0] == list(new_relu_node.get_normalized_args().values())[0]
  358. assert len(new_relu_node.get_targets()) == 1
  359. assert len(relu2.get_normalized_args().values()) == 1
  360. assert new_relu_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]