You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_split.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """Test split"""
  16. import model
  17. from model import model as estimate
  18. from model import graph_split as split
  19. def get_nodes(sp, ops):
  20. """Get nodes"""
  21. if isinstance(ops[0], str):
  22. new_ops = []
  23. for t in ops:
  24. for op in sp.graph.ops:
  25. if op.output.name == t:
  26. new_ops.append(op)
  27. break
  28. else:
  29. print("ERROR: not found op: ", t)
  30. ops = new_ops
  31. return [sp.nodes[sp.graph.ops.index(op)] for op in ops]
  32. def first_connected(sp, space):
  33. for cand in space:
  34. nodes = [sp.nodes[i] for i in cand[0]]
  35. graphs = sp.resolve_connnected_graphs(nodes)
  36. if len(graphs) != 1:
  37. print("connect check failed: ", nodes)
  38. return False
  39. return True
  40. def split_format(sp, cand):
  41. names = []
  42. for ids in cand:
  43. ops = []
  44. for i in ids:
  45. ops.append(sp.graph.ops[i].output.name)
  46. names.append(','.join(ops))
  47. return '|'.join(names)
  48. def graph_1():
  49. ''' ring, no succ_dep, no prev '''
  50. gb = model.GraphBuilder()
  51. with gb.graph_scope("main"):
  52. a = gb.tensor([10240, 16], "float32", name="a")
  53. b = gb.emit("Abs", a, 'b')
  54. c = gb.emit("Abs", b, 'c')
  55. d = gb.emit("Abs", c, 'd')
  56. gb.emit('Add', [b, d], 'e')
  57. return gb.get()[0]
  58. def graph_2():
  59. ''' ring, succ_dep, no prev '''
  60. gb = model.GraphBuilder()
  61. with gb.graph_scope("main"):
  62. a0 = gb.tensor([10240, 16], "float32", name="a0")
  63. a = gb.emit("Abs", a0, 'a')
  64. b = gb.emit("Abs", a, 'b')
  65. c = gb.emit("Abs", a, 'c')
  66. d = gb.emit("Abs", b, 'd')
  67. e = gb.emit('Add', [c, d], 'e')
  68. gb.emit("Abs", e, 'f')
  69. return gb.get()[0]
  70. def graph_3():
  71. ''' no ring, 1 sibling node '''
  72. gb = model.GraphBuilder()
  73. with gb.graph_scope("main"):
  74. a0 = gb.tensor([10240, 16], "float32", name="a0")
  75. a1 = gb.tensor([10240, 16], "float32", name="a1")
  76. b = gb.emit("Abs", a0, 'b')
  77. c = gb.emit("Abs", a1, 'c')
  78. d = gb.emit("Abs", b, 'd')
  79. e = gb.emit('Add', [c, d], 'e')
  80. gb.emit("Abs", e, 'f')
  81. return gb.get()[0]
  82. def graph_4():
  83. ''' no ring, 2 sibling nodes in 1 step '''
  84. gb = model.GraphBuilder()
  85. with gb.graph_scope("main"):
  86. a0 = gb.tensor([10240, 16], "float32", name="a0")
  87. a1 = gb.tensor([10240, 16], "float32", name="a1")
  88. b = gb.emit("Abs", a0, 'b')
  89. c = gb.emit("Abs", b, 'c')
  90. d = gb.emit("Abs", a1, 'd')
  91. e = gb.emit("Abs", d, 'e')
  92. f = gb.emit('Add', [c, e], 'f')
  93. gb.emit('Abs', f, 'g')
  94. h = gb.emit("Abs", d, 'h')
  95. i = gb.emit('Add', [c, h], 'i')
  96. gb.emit("Abs", i, 'j')
  97. return gb.get()[0]
  98. def graph_5():
  99. ''' no ring, 2 sibling step '''
  100. gb = model.GraphBuilder()
  101. with gb.graph_scope("main") as g:
  102. a0 = gb.tensor([10240, 16], "float32", name="a0")
  103. a1 = gb.tensor([10240, 16], "float32", name="a1")
  104. a2 = gb.tensor([10240, 16], "float32", name="a2")
  105. a = gb.emit("Abs", a0, 'a')
  106. b = gb.emit("Abs", a1, 'b')
  107. c = gb.emit("Abs", b, 'c')
  108. d = gb.emit('Add', [a, c], 'd')
  109. gb.emit("Abs", d, 'e')
  110. f = gb.emit("Abs", a2, 'f')
  111. g = gb.emit('Add', [c, f], 'g')
  112. gb.emit("Abs", g, 'h')
  113. return gb.get()[0]
  114. def graph_6():
  115. ''' no ring, tree down '''
  116. gb = model.GraphBuilder()
  117. with gb.graph_scope("main"):
  118. a0 = gb.tensor([10240, 16], "float32", name="a0")
  119. a = gb.emit("Abs", a0, 'a')
  120. b = gb.emit("Abs", a, 'b')
  121. gb.emit("Abs", b, 'd')
  122. gb.emit("Abs", b, 'e')
  123. c = gb.emit("Abs", a, 'c')
  124. gb.emit("Abs", c, 'f')
  125. gb.emit("Abs", c, 'g')
  126. return gb.get()[0]
  127. def graph_pat_1():
  128. ''' split by reduce '''
  129. gb = model.GraphBuilder()
  130. with gb.graph_scope("main"):
  131. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  132. a = gb.emit("Abs", a0, 'a')
  133. b = gb.emit("Abs", a, 'b')
  134. c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
  135. d = gb.emit("Sqrt", c, 'd')
  136. gb.emit("Sqrt", d, 'f')
  137. return gb.get()[0]
  138. def graph_pat_2():
  139. ''' multi output '''
  140. gb = model.GraphBuilder()
  141. with gb.graph_scope("main"):
  142. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  143. a = gb.emit("Abs", a0, 'a')
  144. b = gb.emit("Abs", a, 'b')
  145. gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
  146. gb.emit("ReduceSum", b, 'e', attrs={'reduce_axis': (1,)})
  147. return gb.get()[0]
  148. def graph_pat_3():
  149. ''' two reduce '''
  150. gb = model.GraphBuilder()
  151. with gb.graph_scope("main"):
  152. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  153. a = gb.emit("Abs", a0, 'a')
  154. b = gb.emit("Abs", a, 'b')
  155. c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
  156. d = gb.emit("Abs", c, 'd')
  157. gb.emit("ReduceSum", d, 'e', attrs={'reduce_axis': (1,)})
  158. return gb.get()[0]
  159. def graph_pat_4():
  160. ''' elewise + broadcast '''
  161. gb = model.GraphBuilder()
  162. with gb.graph_scope("main"):
  163. a0 = gb.tensor([1, 1024], "float32", name="a0")
  164. a2 = gb.tensor([1014, 1024], "float32", name="a2")
  165. a = gb.emit("Abs", a0, 'a')
  166. b = gb.emit("Abs", a, 'b')
  167. c = gb.emit("Abs", b, 'c')
  168. d = gb.emit("Abs", c, 'd')
  169. e = gb.emit("Abs", d, 'e')
  170. f = gb.emit("Abs", e, 'f')
  171. g0 = gb.emit("Abs", a2, 'g0')
  172. # g0 = gb.emit("Abs", g0, 'g0')
  173. # g0 = gb.emit("Abs", g0, 'g0')
  174. # g0 = gb.emit("Abs", g0, 'g0')
  175. # g0 = gb.emit("Abs", g0, 'g0')
  176. # g0 = gb.emit("Abs", g0, 'g0')
  177. # g0 = gb.emit("Abs", g0, 'g0')
  178. g0 = gb.emit("Abs", g0, 'g0')
  179. g1 = gb.emit('Add', [f, g0], 'g1')
  180. g2 = gb.emit("Abs", g1, 'g2')
  181. g3 = gb.emit("Abs", g2, 'g3')
  182. g4 = gb.emit("Abs", g3, 'g4')
  183. gb.emit("Abs", g4, 'g5')
  184. return gb.get()[0]
  185. def graph_pat_5():
  186. ''' reduce + reshape '''
  187. gb = model.GraphBuilder()
  188. with gb.graph_scope("main"):
  189. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  190. a = gb.emit("Abs", a0, 'a')
  191. b = gb.emit("Abs", a, 'b')
  192. c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
  193. d = gb.emit("Abs", c, 'd')
  194. e = gb.tensor([512, 2048], "float32", name="e")
  195. gb.op("Reshape", e, [d])
  196. return gb.get()[0]
  197. def graph_pat_6():
  198. ''' dimond '''
  199. gb = model.GraphBuilder()
  200. with gb.graph_scope("main"):
  201. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  202. a = gb.emit("Abs", a0, 'a')
  203. b = gb.emit("Abs", a, 'b')
  204. c = gb.emit("Abs", a, 'c')
  205. gb.emit("Add", [b, c], 'd')
  206. gb.emit("Abs", c, 'f') # broke dimond
  207. return gb.get()[0]
  208. def graph_pat_7():
  209. ''' buddy of control op '''
  210. gb = model.GraphBuilder()
  211. with gb.graph_scope("main"):
  212. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  213. a1 = gb.tensor([1024, 1024], "float32", name="a1")
  214. a = gb.emit("Abs", a0, 'a')
  215. b = gb.emit("Abs", a1, 'b')
  216. c = gb.emit("MakeTuple", [a, b], 'c')
  217. d = gb.tensor([1024, 1024], "float32", name="d")
  218. gb.op("AddN", d, [c])
  219. gb.emit("Abs", d, 'f')
  220. graph = gb.get()[0]
  221. estimate.AddControlBuddy().visit_graph(graph)
  222. return graph
  223. def graph_pat_8():
  224. ''' reduce + reshape '''
  225. gb = model.GraphBuilder()
  226. with gb.graph_scope("main"):
  227. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  228. a = gb.emit("Abs", a0, 'a')
  229. b = gb.emit("Abs", a, 'b')
  230. #c = gb.emit("Abs", b, 'b')
  231. c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
  232. gb.emit("Add", [b, c], 'd')
  233. return gb.get()[0]
  234. def graph_pat_9():
  235. ''' scalar '''
  236. gb = model.GraphBuilder()
  237. with gb.graph_scope("main"):
  238. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  239. a1 = gb.tensor([1], "float32", name="a1")
  240. a = gb.emit("Maximum", a1, 'a')
  241. b = gb.emit("Mul", [a, a1], 'b')
  242. gb.emit('Mul', [b, a0], 'c')
  243. return gb.get()[0]
  244. def graph_mo_1():
  245. gb = model.GraphBuilder()
  246. with gb.graph_scope("main"):
  247. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  248. a = gb.emit("Abs", a0, 'a')
  249. gb.emit("Abs", a, 'b')
  250. gb.emit("Abs", a, 'c')
  251. return gb.get()[0]
  252. def graph_mo_2():
  253. gb = model.GraphBuilder()
  254. with gb.graph_scope("main") as g:
  255. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  256. a = gb.emit("Abs", a0, 'a')
  257. b = gb.emit("Abs", a, 'b')
  258. c = gb.emit("Abs", b, 'c')
  259. g.set_output(b, c)
  260. return gb.get()[0]
  261. def graph_mo_3():
  262. ''' two reduce '''
  263. gb = model.GraphBuilder()
  264. with gb.graph_scope("main") as g:
  265. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  266. a = gb.emit("Abs", a0, 'a')
  267. b = gb.emit("Abs", a, 'b')
  268. c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
  269. g.set_output(b, c)
  270. return gb.get()[0]
  271. def graph_mo_4():
  272. ''' two reduce '''
  273. gb = model.GraphBuilder()
  274. with gb.graph_scope("main") as g:
  275. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  276. a = gb.emit("Abs", a0, 'a')
  277. b = gb.emit("Abs", a, 'b')
  278. c = gb.emit("ReduceSum", a, 'c', attrs={'reduce_axis': (1,)})
  279. g.set_output(b, c)
  280. return gb.get()[0]
  281. def test_binary_split():
  282. """Test binary split"""
  283. def _test(graph, expected_space_size):
  284. print("********* test on graph : {} *************".format(graph.name))
  285. sp = split.GraphSpliter(graph)
  286. nodes = get_nodes(sp, graph.ops)
  287. space = sp.binary_split(nodes)
  288. for i, s in enumerate(space):
  289. print('{}: {}'.format(i, split_format(sp, s)))
  290. assert len(space) == expected_space_size
  291. assert first_connected(sp, space)
  292. _test(graph_1(), 3)
  293. _test(graph_2(), 7)
  294. _test(graph_3(), 4)
  295. _test(graph_4(), 17)
  296. _test(graph_5(), 11)
  297. _test(graph_6(), 24)
  298. def test_resolve_connnected_graphs():
  299. """Test resolve connected graphs"""
  300. graph = graph_5()
  301. sp = split.GraphSpliter(graph)
  302. n1 = get_nodes(sp, ['a', 'd', 'b', 'c'])
  303. graphs = sp.resolve_connnected_graphs(n1)
  304. print(graphs)
  305. assert len(graphs) == 1
  306. n2 = get_nodes(sp, ['a', 'd', 'e', 'f', 'g'])
  307. graphs = sp.resolve_connnected_graphs(n2)
  308. print(graphs)
  309. assert len(graphs) == 2
  310. n3 = get_nodes(sp, ['a', 'b', 'f'])
  311. graphs = sp.resolve_connnected_graphs(n3)
  312. print(graphs)
  313. assert len(graphs) == 3
  314. def test_split():
  315. """Test split"""
  316. def _print_cost(name, c):
  317. print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" %
  318. (name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type()))
  319. def _test(graph):
  320. print("********* test on graph : {} *************".format(graph.name))
  321. sp = split.GraphSpliter(graph)
  322. subgraphs = sp.split(False)
  323. print('----- main graph -------')
  324. print(graph)
  325. for i, g in enumerate(subgraphs):
  326. print(' -------- subgraph {} -------'.format(i))
  327. print(g)
  328. print("--------- cost ------------")
  329. cost, _ = model.estimate(graph)
  330. _print_cost("main graph", cost)
  331. fc, sub_costs = model.estimate(subgraphs)
  332. _print_cost("Subgraphs:", fc)
  333. for i, cost in enumerate(sub_costs):
  334. _print_cost(" |_%d:\t" % (i), cost)
  335. _test(graph_5())
  336. # _test(graph_4())
  337. def test_estimate():
  338. """Test estimate"""
  339. graph = graph_5()
  340. e = estimate.Estimator(graph)
  341. e.estimate()
  342. print(e.iter_space)
  343. def test_pattern_split():
  344. """Test pattern split"""
  345. def _test(graph, expect_n=0):
  346. print("************* main graph **************")
  347. print(graph)
  348. subgraphs = split.GraphSplitByPatternV2(graph).split()
  349. for i, g in enumerate(subgraphs):
  350. print(' -------- subgraph {} -------'.format(i))
  351. print(g)
  352. if expect_n > 0:
  353. assert len(subgraphs) == expect_n
  354. # _test(graph_1(), 1)
  355. # _test(graph_pat_1(), 2)
  356. # _test(graph_pat_2())
  357. # _test(graph_pat_3())
  358. # _test(graph_pat_4())
  359. # _test(graph_pat_5())
  360. # _test(graph_pat_6())
  361. # _test(graph_pat_7())
  362. # _test(graph_pat_8())
  363. # _test(graph_pat_9())
  364. # _test(graph_mo_1())
  365. # _test(graph_mo_2())
  366. # _test(graph_mo_3())
  367. _test(graph_mo_4())
  368. def main():
  369. # test_binary_split()
  370. # test_resolve_connnected_graphs()
  371. # test_split()
  372. # test_estimate()
  373. test_pattern_split()
  374. if __name__ == '__main__':
  375. main()