You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_split.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """Test split"""
  16. import model
  17. from model import model as estimate
  18. from model import graph_split as split
  19. from mindspore.ops import _constants as Constants
  20. def get_nodes(sp, ops):
  21. """Get nodes"""
  22. if isinstance(ops[0], str):
  23. new_ops = []
  24. for t in ops:
  25. for op in sp.graph.ops:
  26. if op.output.name == t:
  27. new_ops.append(op)
  28. break
  29. else:
  30. print("ERROR: not found op: ", t)
  31. ops = new_ops
  32. return [sp.nodes[sp.graph.ops.index(op)] for op in ops]
  33. def first_connected(sp, space):
  34. for cand in space:
  35. nodes = [sp.nodes[i] for i in cand[0]]
  36. graphs = sp.resolve_connnected_graphs(nodes)
  37. if len(graphs) != 1:
  38. print("connect check faied: ", nodes)
  39. return False
  40. return True
  41. def split_format(sp, cand):
  42. names = []
  43. for ids in cand:
  44. ops = []
  45. for i in ids:
  46. ops.append(sp.graph.ops[i].output.name)
  47. names.append(','.join(ops))
  48. return '|'.join(names)
  49. def graph_1():
  50. ''' ring, no succ_dep, no prev '''
  51. gb = model.GraphBuilder()
  52. with gb.graph_scope("main"):
  53. a = gb.tensor([10240, 16], "float32", name="a")
  54. b = gb.emit("Abs", a, 'b')
  55. c = gb.emit("Abs", b, 'c')
  56. d = gb.emit("Abs", c, 'd')
  57. gb.emit('Add', [b, d], 'e')
  58. return gb.get()[0]
  59. def graph_2():
  60. ''' ring, succ_dep, no prev '''
  61. gb = model.GraphBuilder()
  62. with gb.graph_scope("main"):
  63. a0 = gb.tensor([10240, 16], "float32", name="a0")
  64. a = gb.emit("Abs", a0, 'a')
  65. b = gb.emit("Abs", a, 'b')
  66. c = gb.emit("Abs", a, 'c')
  67. d = gb.emit("Abs", b, 'd')
  68. e = gb.emit('Add', [c, d], 'e')
  69. gb.emit("Abs", e, 'f')
  70. return gb.get()[0]
  71. def graph_3():
  72. ''' no ring, 1 sibling node '''
  73. gb = model.GraphBuilder()
  74. with gb.graph_scope("main"):
  75. a0 = gb.tensor([10240, 16], "float32", name="a0")
  76. a1 = gb.tensor([10240, 16], "float32", name="a1")
  77. b = gb.emit("Abs", a0, 'b')
  78. c = gb.emit("Abs", a1, 'c')
  79. d = gb.emit("Abs", b, 'd')
  80. e = gb.emit('Add', [c, d], 'e')
  81. gb.emit("Abs", e, 'f')
  82. return gb.get()[0]
  83. def graph_4():
  84. ''' no ring, 2 sibling nodes in 1 step '''
  85. gb = model.GraphBuilder()
  86. with gb.graph_scope("main"):
  87. a0 = gb.tensor([10240, 16], "float32", name="a0")
  88. a1 = gb.tensor([10240, 16], "float32", name="a1")
  89. b = gb.emit("Abs", a0, 'b')
  90. c = gb.emit("Abs", b, 'c')
  91. d = gb.emit("Abs", a1, 'd')
  92. e = gb.emit("Abs", d, 'e')
  93. f = gb.emit('Add', [c, e], 'f')
  94. gb.emit('Abs', f, 'g')
  95. h = gb.emit("Abs", d, 'h')
  96. i = gb.emit('Add', [c, h], 'i')
  97. gb.emit("Abs", i, 'j')
  98. return gb.get()[0]
  99. def graph_5():
  100. ''' no ring, 2 sibling step '''
  101. gb = model.GraphBuilder()
  102. with gb.graph_scope("main") as g:
  103. a0 = gb.tensor([10240, 16], "float32", name="a0")
  104. a1 = gb.tensor([10240, 16], "float32", name="a1")
  105. a2 = gb.tensor([10240, 16], "float32", name="a2")
  106. a = gb.emit("Abs", a0, 'a')
  107. b = gb.emit("Abs", a1, 'b')
  108. c = gb.emit("Abs", b, 'c')
  109. d = gb.emit('Add', [a, c], 'd')
  110. gb.emit("Abs", d, 'e')
  111. f = gb.emit("Abs", a2, 'f')
  112. g = gb.emit('Add', [c, f], 'g')
  113. gb.emit("Abs", g, 'h')
  114. return gb.get()[0]
  115. def graph_6():
  116. ''' no ring, tree down '''
  117. gb = model.GraphBuilder()
  118. with gb.graph_scope("main"):
  119. a0 = gb.tensor([10240, 16], "float32", name="a0")
  120. a = gb.emit("Abs", a0, 'a')
  121. b = gb.emit("Abs", a, 'b')
  122. gb.emit("Abs", b, 'd')
  123. gb.emit("Abs", b, 'e')
  124. c = gb.emit("Abs", a, 'c')
  125. gb.emit("Abs", c, 'f')
  126. gb.emit("Abs", c, 'g')
  127. return gb.get()[0]
  128. def graph_pat_1():
  129. ''' split by reduce '''
  130. gb = model.GraphBuilder()
  131. with gb.graph_scope("main"):
  132. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  133. a = gb.emit("Abs", a0, 'a')
  134. b = gb.emit("Abs", a, 'b')
  135. c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
  136. d = gb.emit("Sqrt", c, 'd')
  137. gb.emit("Sqrt", d, 'f')
  138. return gb.get()[0]
  139. def graph_pat_2():
  140. ''' multi output '''
  141. gb = model.GraphBuilder()
  142. with gb.graph_scope("main"):
  143. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  144. a = gb.emit("Abs", a0, 'a')
  145. b = gb.emit("Abs", a, 'b')
  146. gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
  147. gb.emit("ReduceSum", b, 'e', attrs={'reduce_axis': (1,)})
  148. return gb.get()[0]
  149. def graph_pat_3():
  150. ''' two reduce '''
  151. gb = model.GraphBuilder()
  152. with gb.graph_scope("main"):
  153. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  154. a = gb.emit("Abs", a0, 'a')
  155. b = gb.emit("Abs", a, 'b')
  156. c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
  157. d = gb.emit("Abs", c, 'd')
  158. gb.emit("ReduceSum", d, 'e', attrs={'reduce_axis': (1,)})
  159. return gb.get()[0]
  160. def graph_pat_4():
  161. ''' elewise + broadcast '''
  162. gb = model.GraphBuilder()
  163. with gb.graph_scope("main"):
  164. a0 = gb.tensor([1, 1024], "float32", name="a0")
  165. a2 = gb.tensor([1014, 1024], "float32", name="a2")
  166. a = gb.emit("Abs", a0, 'a')
  167. b = gb.emit("Abs", a, 'b')
  168. c = gb.emit("Abs", b, 'c')
  169. d = gb.emit("Abs", c, 'd')
  170. e = gb.emit("Abs", d, 'e')
  171. f = gb.emit("Abs", e, 'f')
  172. g0 = gb.emit("Abs", a2, 'g0')
  173. # g0 = gb.emit("Abs", g0, 'g0')
  174. # g0 = gb.emit("Abs", g0, 'g0')
  175. # g0 = gb.emit("Abs", g0, 'g0')
  176. # g0 = gb.emit("Abs", g0, 'g0')
  177. # g0 = gb.emit("Abs", g0, 'g0')
  178. # g0 = gb.emit("Abs", g0, 'g0')
  179. g0 = gb.emit("Abs", g0, 'g0')
  180. g1 = gb.emit('Add', [f, g0], 'g1')
  181. g2 = gb.emit("Abs", g1, 'g2')
  182. g3 = gb.emit("Abs", g2, 'g3')
  183. g4 = gb.emit("Abs", g3, 'g4')
  184. gb.emit("Abs", g4, 'g5')
  185. return gb.get()[0]
  186. def graph_pat_5():
  187. ''' reduce + reshape '''
  188. gb = model.GraphBuilder()
  189. with gb.graph_scope("main"):
  190. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  191. a = gb.emit("Abs", a0, 'a')
  192. b = gb.emit("Abs", a, 'b')
  193. c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
  194. d = gb.emit("Abs", c, 'd')
  195. e = gb.tensor([512, 2048], "float32", name="e")
  196. gb.op("Reshape", e, [d])
  197. return gb.get()[0]
  198. def graph_pat_6():
  199. ''' dimond '''
  200. gb = model.GraphBuilder()
  201. with gb.graph_scope("main"):
  202. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  203. a = gb.emit("Abs", a0, 'a')
  204. b = gb.emit("Abs", a, 'b')
  205. c = gb.emit("Abs", a, 'c')
  206. gb.emit("Add", [b, c], 'd')
  207. gb.emit("Abs", c, 'f') # broke dimond
  208. return gb.get()[0]
  209. def graph_pat_7():
  210. ''' buddy of control op '''
  211. gb = model.GraphBuilder()
  212. with gb.graph_scope("main"):
  213. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  214. a1 = gb.tensor([1024, 1024], "float32", name="a1")
  215. a = gb.emit("Abs", a0, 'a')
  216. b = gb.emit("Abs", a1, 'b')
  217. c = gb.emit(Constants.kMakeTuple, [a, b], 'c')
  218. d = gb.tensor([1024, 1024], "float32", name="d")
  219. gb.op("AddN", d, [c])
  220. gb.emit("Abs", d, 'f')
  221. graph = gb.get()[0]
  222. estimate.AddControlBuddy().visit_graph(graph)
  223. return graph
  224. def graph_pat_8():
  225. ''' reduce + reshape '''
  226. gb = model.GraphBuilder()
  227. with gb.graph_scope("main"):
  228. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  229. a = gb.emit("Abs", a0, 'a')
  230. b = gb.emit("Abs", a, 'b')
  231. #c = gb.emit("Abs", b, 'b')
  232. c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
  233. gb.emit("Add", [b, c], 'd')
  234. return gb.get()[0]
  235. def graph_pat_9():
  236. ''' scalar '''
  237. gb = model.GraphBuilder()
  238. with gb.graph_scope("main"):
  239. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  240. a1 = gb.tensor([1], "float32", name="a1")
  241. a = gb.emit("Maximum", a1, 'a')
  242. b = gb.emit("Mul", [a, a1], 'b')
  243. gb.emit('Mul', [b, a0], 'c')
  244. return gb.get()[0]
  245. def graph_mo_1():
  246. gb = model.GraphBuilder()
  247. with gb.graph_scope("main"):
  248. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  249. a = gb.emit("Abs", a0, 'a')
  250. gb.emit("Abs", a, 'b')
  251. gb.emit("Abs", a, 'c')
  252. return gb.get()[0]
  253. def graph_mo_2():
  254. gb = model.GraphBuilder()
  255. with gb.graph_scope("main") as g:
  256. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  257. a = gb.emit("Abs", a0, 'a')
  258. b = gb.emit("Abs", a, 'b')
  259. c = gb.emit("Abs", b, 'c')
  260. g.set_output(b, c)
  261. return gb.get()[0]
  262. def graph_mo_3():
  263. ''' two reduce '''
  264. gb = model.GraphBuilder()
  265. with gb.graph_scope("main") as g:
  266. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  267. a = gb.emit("Abs", a0, 'a')
  268. b = gb.emit("Abs", a, 'b')
  269. c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
  270. g.set_output(b, c)
  271. return gb.get()[0]
  272. def graph_mo_4():
  273. ''' two reduce '''
  274. gb = model.GraphBuilder()
  275. with gb.graph_scope("main") as g:
  276. a0 = gb.tensor([1024, 1024], "float32", name="a0")
  277. a = gb.emit("Abs", a0, 'a')
  278. b = gb.emit("Abs", a, 'b')
  279. c = gb.emit("ReduceSum", a, 'c', attrs={'reduce_axis': (1,)})
  280. g.set_output(b, c)
  281. return gb.get()[0]
  282. def test_binary_split():
  283. """Test binary split"""
  284. def _test(graph, expected_space_size):
  285. print("********* test on graph : {} *************".format(graph.name))
  286. sp = split.GraphSpliter(graph)
  287. nodes = get_nodes(sp, graph.ops)
  288. space = sp.binary_split(nodes)
  289. for i, s in enumerate(space):
  290. print('{}: {}'.format(i, split_format(sp, s)))
  291. assert len(space) == expected_space_size
  292. assert first_connected(sp, space)
  293. _test(graph_1(), 3)
  294. _test(graph_2(), 7)
  295. _test(graph_3(), 4)
  296. _test(graph_4(), 17)
  297. _test(graph_5(), 11)
  298. _test(graph_6(), 24)
  299. def test_resolve_connnected_graphs():
  300. """Test resolve connected graphs"""
  301. graph = graph_5()
  302. sp = split.GraphSpliter(graph)
  303. n1 = get_nodes(sp, ['a', 'd', 'b', 'c'])
  304. graphs = sp.resolve_connnected_graphs(n1)
  305. print(graphs)
  306. assert len(graphs) == 1
  307. n2 = get_nodes(sp, ['a', 'd', 'e', 'f', 'g'])
  308. graphs = sp.resolve_connnected_graphs(n2)
  309. print(graphs)
  310. assert len(graphs) == 2
  311. n3 = get_nodes(sp, ['a', 'b', 'f'])
  312. graphs = sp.resolve_connnected_graphs(n3)
  313. print(graphs)
  314. assert len(graphs) == 3
  315. def test_split():
  316. """Test split"""
  317. def _print_cost(name, c):
  318. print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" %
  319. (name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type()))
  320. def _test(graph):
  321. print("********* test on graph : {} *************".format(graph.name))
  322. sp = split.GraphSpliter(graph)
  323. subgraphs = sp.split(False)
  324. print('----- main graph -------')
  325. print(graph)
  326. for i, g in enumerate(subgraphs):
  327. print(' -------- subgraph {} -------'.format(i))
  328. print(g)
  329. print("--------- cost ------------")
  330. cost, _ = model.estimate(graph)
  331. _print_cost("main graph", cost)
  332. fc, sub_costs = model.estimate(subgraphs)
  333. _print_cost("Subgraphs:", fc)
  334. for i, cost in enumerate(sub_costs):
  335. _print_cost(" |_%d:\t" % (i), cost)
  336. _test(graph_5())
  337. # _test(graph_4())
  338. def test_estimate():
  339. """Test estimate"""
  340. graph = graph_5()
  341. e = estimate.Estimator(graph)
  342. e.estimate()
  343. print(e.iter_space)
  344. def test_pattern_split():
  345. """Test pattern split"""
  346. def _test(graph, expect_n=0):
  347. print("************* main graph **************")
  348. print(graph)
  349. subgraphs = split.GraphSplitByPatternV2(graph).split()
  350. for i, g in enumerate(subgraphs):
  351. print(' -------- subgraph {} -------'.format(i))
  352. print(g)
  353. if expect_n > 0:
  354. assert len(subgraphs) == expect_n
  355. # _test(graph_1(), 1)
  356. # _test(graph_pat_1(), 2)
  357. # _test(graph_pat_2())
  358. # _test(graph_pat_3())
  359. # _test(graph_pat_4())
  360. # _test(graph_pat_5())
  361. # _test(graph_pat_6())
  362. # _test(graph_pat_7())
  363. # _test(graph_pat_8())
  364. # _test(graph_pat_9())
  365. # _test(graph_mo_1())
  366. # _test(graph_mo_2())
  367. # _test(graph_mo_3())
  368. _test(graph_mo_4())
  369. def main():
  370. # test_binary_split()
  371. # test_resolve_connnected_graphs()
  372. # test_split()
  373. # test_estimate()
  374. test_pattern_split()
  375. if __name__ == '__main__':
  376. main()