You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

build_module.py 37 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """build module"""
  17. import os
  18. import json
  19. from functools import reduce
  20. import logging
  21. import akg
  22. from akg import tvm
  23. from akg.tvm import _api_internal
  24. from akg.topi.cuda.injective_single_kernel import schedule_injective
  25. import topi
  26. from akg.global_configs import get_dump_ir_flag
  27. def should_enable_atomic_add(kernel_info):
  28. for op in kernel_info["op_desc"]:
  29. if not op["attr"]:
  30. continue
  31. for attr in op["attr"]:
  32. if attr["name"] == "enable_atomic_add" and attr["value"]:
  33. return True
  34. return False
  35. class Graph():
  36. def __init__(self, output):
  37. self.tensors = set(output)
  38. self.ops = []
  39. self.output_name = output
  40. self.input_name = []
  41. self.input = []
  42. self.core_num = 0
  43. self.output = []
  44. self.op_name = 'Fused'
  45. class Liveness():
  46. def __init__(self):
  47. self.start = -1
  48. self.end = -1
  49. self.is_reduce = False
  50. def __str__(self):
  51. return "live_" + str(self.start) + "_" + str(self.end) + "_" + str(self.is_reduce)
  52. def __repr__(self):
  53. return "live_" + str(self.start) + "_" + str(self.end) + "_" + str(self.is_reduce)
  54. def liveness_analysis(desc_d, req_map):
  55. req_liveness = dict((k, Liveness()) for k in req_map.keys())
  56. idx = len(desc_d['op_desc'])
  57. for i in range(len(desc_d['op_desc']) - 1, -1, -1):
  58. idx -= 1
  59. op_info = desc_d['op_desc'][i]
  60. for out_desc in op_info['output_desc']:
  61. out_name = out_desc['tensor_name']
  62. if out_name in req_liveness:
  63. if is_reduce(op_info['name']):
  64. req_liveness[out_name].is_reduce = True
  65. if req_liveness[out_name].end == -1:
  66. req_liveness[out_name].end = idx
  67. req_liveness[out_name].start = idx
  68. else:
  69. req_liveness[out_name].start = idx
  70. for input_desc in op_info['input_desc']:
  71. for sub_input_desc in input_desc:
  72. inp_name = sub_input_desc['tensor_name']
  73. if inp_name in req_liveness and req_liveness[inp_name].end == -1:
  74. req_liveness[inp_name].end = idx
  75. if inp_name in req_liveness and req_liveness[inp_name].end > -1:
  76. req_liveness[inp_name].start = idx
  77. # sort req_liveness by Liveness.end.
  78. sort_req_liveness = dict(sorted(req_liveness.items(), key=lambda x: x[1].end, reverse=True))
  79. return sort_req_liveness
  80. def is_reduce(tensor_name):
  81. return tensor_name.startswith('Reduce')
  82. def shared_memory_optimization(desc_d, req_map, outputs):
  83. sort_req_liveness = liveness_analysis(desc_d, req_map)
  84. sort_req_buf = list(sort_req_liveness.keys())
  85. alloc_map = dict()
  86. reuse_map = dict()
  87. reverse_reuse_map = dict()
  88. for i in range(len(sort_req_liveness)):
  89. reuse = False
  90. find_conflit = False
  91. ### TODO: the check is used due to the initialization clause position of reduce computation.
  92. if sort_req_liveness[sort_req_buf[i]].is_reduce:
  93. alloc_map[sort_req_buf[i]] = ['ALLOC', req_map[sort_req_buf[i]]]
  94. continue
  95. for j in range(len(sort_req_liveness) - 1, i, -1):
  96. # whether reuseable.
  97. # rule1: one buffer start larger equal to the reused buffer end.
  98. if sort_req_liveness[sort_req_buf[i]].start >= sort_req_liveness[sort_req_buf[j]].end:
  99. # rule2: sizes are compatiable.
  100. if req_map[sort_req_buf[i]] <= req_map[sort_req_buf[j]] and sort_req_buf[j] not in outputs:
  101. # rule3: make sure the candidate reused buffer is not using by other conflict variable.
  102. for item in reverse_reuse_map.get(sort_req_buf[j], []):
  103. if (sort_req_liveness[item].end >= sort_req_liveness[sort_req_buf[i]].end) or (sort_req_liveness[item].end >= sort_req_liveness[sort_req_buf[i]].start):
  104. find_conflit = True
  105. break
  106. if not find_conflit:
  107. if sort_req_buf[j] not in reverse_reuse_map:
  108. reverse_reuse_map[sort_req_buf[j]] = [sort_req_buf[i]]
  109. else:
  110. reverse_reuse_map[sort_req_buf[j]].append(sort_req_buf[i])
  111. # rule4: prefer to reuse buffer with same size.
  112. if req_map[sort_req_buf[i]] == req_map[sort_req_buf[j]]:
  113. reuse_map[sort_req_buf[i]] = [sort_req_buf[j], req_map[sort_req_buf[i]]]
  114. reuse = True
  115. break
  116. else:
  117. reuse_map[sort_req_buf[i]] = [sort_req_buf[j], req_map[sort_req_buf[i]]]
  118. reuse = True
  119. if not reuse:
  120. alloc_map[sort_req_buf[i]] = ['ALLOC', req_map[sort_req_buf[i]]]
  121. return alloc_map, reuse_map
  122. def is_tensor(op_info):
  123. return 'value' not in op_info
  124. def parse_merged_json(desc_d, stitch_tensor_name, input_tensor_name, output_tensor_name):
  125. '''
  126. Parse merged json to get subgraph splitted by stitch nodes and input-output relationship of merged graph.
  127. Args:
  128. desc_d (dict): The dict of compute description.
  129. stitch_tensor_name (list[string]): The list of stitch node tensors.
  130. stitch nodes are regarded as edges of sub_graphs. The smallest number of sub_graph is the length of
  131. stitch_tensor_name + 1.
  132. input_tensor_name (list[string]): The list of input tensors.
  133. output_tensor_name (list[string]): The list of output tensors.
  134. output tensors would be regarded as inter_output_tensor and final_output_tensor. The main difference
  135. of the two kinds of tensors is whether out-degree is zero, in which final_output_tensor is the tensor
  136. with zero out-degree in merged graph and otherwise, it is inter_output_tensor.
  137. Returns:
  138. extra_subgraph_output (dict): The dict of extra output tensors for each sub_graph.
  139. final_output_list (list[string]): The list of final output tensors.
  140. output tensors in this list are are final_output_tensor and the subgraph they belong to doesn't
  141. include stitch nodes.
  142. final_output_within_graph (list[string]): The list of final output tensors.
  143. output tensors in this list are final_output_tensor and the subgraph they belong to also includes
  144. stitch node.
  145. '''
  146. # Initialize sub_graph number as the smallest possible number of sub graph.
  147. # sub graphs number might increase based on graph structure.
  148. sub_graph_length = len(stitch_tensor_name)
  149. sub_graph_node = [set() for _ in range(sub_graph_length)]
  150. # use dict to save extra outputs for each sub_graph.
  151. extra_subgraph_output = dict(zip(stitch_tensor_name, [[] for _ in range(sub_graph_length)]))
  152. in_out_dict = {}
  153. inter_output_list = set()
  154. final_output_list = set()
  155. final_output_within_graph = []
  156. idx = 0
  157. final_output_graph = False
  158. for i in range(len(desc_d['op_desc']) - 1, -1, -1):
  159. op_info = desc_d['op_desc'][i]
  160. for out_desc in op_info['output_desc']:
  161. # switch to next subgraph if find stitch node.
  162. if out_desc['tensor_name'] in stitch_tensor_name:
  163. idx += 1
  164. cur_stitch_node = out_desc['tensor_name']
  165. # when current subgraph concludes final output and encounters with stitch node, increase number of subgraph.
  166. if final_output_graph:
  167. final_output_list.add(cur_final_node)
  168. final_output_within_graph.remove(cur_final_node)
  169. sub_graph_length += 1
  170. sub_graph_node += [set()]
  171. final_output_graph = False
  172. # out_desc not in in_out_dict means out-degree is zero.
  173. if out_desc['tensor_name'] not in in_out_dict:
  174. final_output_graph = True
  175. cur_final_node = out_desc['tensor_name']
  176. final_output_within_graph.append(cur_final_node)
  177. sub_graph_node[idx].add(out_desc['tensor_name'])
  178. for input_desc in op_info['input_desc']:
  179. for sub_input_desc in input_desc:
  180. sub_graph_node[idx].add(sub_input_desc['tensor_name'])
  181. tmp_name = sub_input_desc['tensor_name']
  182. if tmp_name in output_tensor_name:
  183. inter_output_list.add(sub_input_desc['tensor_name'])
  184. for subgraph in sub_graph_node[0: idx]:
  185. extra_output = is_tensor(sub_input_desc) and tmp_name not in stitch_tensor_name and tmp_name not in input_tensor_name
  186. used_by_other_sg = tmp_name in subgraph
  187. used_as_output = tmp_name in output_tensor_name
  188. extra_output = extra_output and (used_by_other_sg or used_as_output)
  189. if extra_output and cur_stitch_node and not final_output_graph:
  190. extra_subgraph_output[cur_stitch_node].insert(0, tmp_name)
  191. break
  192. if sub_input_desc['tensor_name'] not in in_out_dict:
  193. in_out_dict[sub_input_desc['tensor_name']] = [out_desc['tensor_name']]
  194. else:
  195. in_out_dict[sub_input_desc['tensor_name']].append(out_desc['tensor_name'])
  196. return extra_subgraph_output, list(final_output_list), final_output_within_graph
  197. def collect_subgraph_info(desc_d, sub_stitch_graphs, req_map, input_tensor_name, output_tensor_name, stitch_node_list):
  198. inplace_assign_map = {}
  199. fake_output_list = []
  200. # traversal desc_d by reverse topologically order.
  201. for i in range(len(desc_d['op_desc']) - 1, -1, -1):
  202. op_info = desc_d['op_desc'][i]
  203. if (op_info['name'] == "InplaceAssign"):
  204. inplace_assign_map[op_info['output_desc'][0]['tensor_name']] = op_info['input_desc'][0][0]['tensor_name']
  205. if (op_info['attr'][0]['name'] == 'fake_output' and op_info['attr'][0]['value'] == 1):
  206. fake_output_list.append(op_info['output_desc'][0]['tensor_name'])
  207. for sg in sub_stitch_graphs:
  208. added_output = []
  209. for out_desc in op_info['output_desc']:
  210. out_tensor_name = out_desc['tensor_name']
  211. if out_tensor_name in sg.tensors:
  212. sg.ops.append(op_info)
  213. if out_tensor_name in req_map:
  214. if out_desc['shape']:
  215. req_map[out_tensor_name] = reduce(lambda x, y: x * y, out_desc['shape'])
  216. else:
  217. req_map[out_tensor_name] = 1
  218. if out_tensor_name in sg.output_name and out_tensor_name not in added_output:
  219. sg.output.append(out_desc)
  220. added_output.append(out_tensor_name)
  221. for input_desc in op_info['input_desc']:
  222. for sub_input_desc in input_desc:
  223. if is_tensor(sub_input_desc):
  224. input_name = sub_input_desc['tensor_name']
  225. if input_name in output_tensor_name and input_name not in added_output:
  226. sg.output.insert(0, sub_input_desc)
  227. added_output.append(input_name)
  228. if input_name in input_tensor_name and input_name not in sg.input_name:
  229. sg.input_name.append(sub_input_desc['tensor_name'])
  230. sg.input.append([sub_input_desc])
  231. # stop expand subgraph when encounter with stitch node.
  232. if input_name not in stitch_node_list:
  233. sg.tensors.add(sub_input_desc['tensor_name'])
  234. # add extra input into subgraph.
  235. elif input_name not in sg.output_name and input_name not in sg.input_name:
  236. sg.input_name.append(input_name)
  237. sg.input.append([sub_input_desc])
  238. return sub_stitch_graphs, inplace_assign_map, fake_output_list
  239. def sub_graph_info(sub_graph, desc_d):
  240. # gather info for sub graph.
  241. op_json_str = {}
  242. op_json_str['composite'] = True
  243. op_json_str['composite_graph'] = desc_d['composite_graph']
  244. op_json_str['id'] = desc_d['id']
  245. op_json_str['op'] = sub_graph.op_name
  246. op_json_str['input_desc'] = sub_graph.input
  247. op_json_str['op_desc'] = sub_graph.ops
  248. op_json_str['output_desc'] = sub_graph.output
  249. op_json_str['platform'] = "AKG"
  250. op_json_str['process'] = desc_d['process']
  251. if 'sub_block_size' in desc_d['buffer_stitch']:
  252. op_json_str['blocksize'] = desc_d['buffer_stitch']['sub_block_size']
  253. json_str = json.dumps(op_json_str)
  254. return json_str
  255. def stitch_json_split(desc_d):
  256. """
  257. split sub graph from merged json file.
  258. Using 'buffer_stitch' to store stitch info from graph kernel.
  259. Args:
  260. desc_d: dict of compute description
  261. Returns:
  262. List of spilted json info.
  263. List of original input.
  264. Dict of dominance info.
  265. """
  266. stitch_jsons = []
  267. input_tensor_name = [tensor[0]['tensor_name'] for tensor in desc_d['input_desc']]
  268. output_tensor_name = [tensor['tensor_name'] for tensor in desc_d['output_desc']]
  269. stitch_node = desc_d['buffer_stitch']['stitch_op']
  270. stitch_node_name = [node for stitchnode in stitch_node for node in stitchnode]
  271. extra_subgraph_output, final_output_list, final_output_within_graph = parse_merged_json(desc_d, stitch_node_name, input_tensor_name, output_tensor_name)
  272. # traverse extra_subgraph_output to save extra output into subgraph.
  273. stitch_node = []
  274. extra_list = []
  275. for item in extra_subgraph_output:
  276. cur_list = [item]
  277. for node in extra_subgraph_output[item]:
  278. if node not in extra_list:
  279. extra_list.append(node)
  280. cur_list.append(node)
  281. stitch_node.append(cur_list)
  282. stitch_node_name = [node for stitchnode in stitch_node for node in stitchnode]
  283. # initialize req_map
  284. req_op_size = [0] * len(stitch_node_name)
  285. req_map = dict(zip(stitch_node_name, req_op_size))
  286. # add final output within subgraph into the last initialized stitch sub_graph.
  287. stitch_node = stitch_node[:-1] + [stitch_node[-1] + final_output_within_graph]
  288. # add final output into stitch_op.
  289. stitch_node += [[op] for op in final_output_list if op not in stitch_node_name]
  290. stitch_node_list = [node for stitchnode in stitch_node for node in stitchnode]
  291. # each output tensor can only be parsed as output once in all subgraphs.
  292. # All tensors in stitch_node_list will be put into output_name.
  293. # Save other output tensors which are not in stitch_node_name for the output collection of subgraphs.
  294. complement_output = [tensor for tensor in output_tensor_name if tensor not in stitch_node_list]
  295. # initialize sub_stitch_graphs.
  296. sub_stitch_graphs = []
  297. for i, stitch_op in enumerate(stitch_node):
  298. sub_stitch_graphs.append(Graph(stitch_op))
  299. sub_stitch_graphs, inplace_assign_map, fake_output_list = collect_subgraph_info(desc_d, sub_stitch_graphs, req_map, input_tensor_name, complement_output, stitch_node_list)
  300. # reverse op order to generate topological subgraph
  301. for i, sg in enumerate(sub_stitch_graphs):
  302. sg.ops = list(reversed(sg.ops))
  303. sg.op_name = desc_d['op']
  304. stitch_json_str = sub_graph_info(sg, desc_d)
  305. if (os.getenv(get_dump_ir_flag()) == "on"):
  306. if not os.path.exists("stitch_info"):
  307. try:
  308. os.mkdir("stitch_info")
  309. except OSError as err:
  310. # 17, OSError: [Errno 17] File exists
  311. if err.errno == 17:
  312. pass
  313. else:
  314. raise err
  315. with open('stitch_info/' + sg.op_name + '_stitch_' + str(i + 1) + '.json', 'w+') as f:
  316. f.write(stitch_json_str)
  317. with open('stitch_info/' + sg.op_name + '_stitch.json', 'w+') as f:
  318. f.write(json.dumps(desc_d))
  319. stitch_jsons.append(stitch_json_str)
  320. clean_op_list = [fake_op for fake_op in fake_output_list if fake_op in stitch_node_name]
  321. # add fake outputs into output_tensor_name
  322. output_tensor_name += clean_op_list
  323. # start node for dominance tree is final_output_list + final_output_within_graph.
  324. start_node = final_output_list + final_output_within_graph
  325. alloc_map, reuse_map = shared_memory_optimization(desc_d, req_map, output_tensor_name)
  326. # remove fake output from alloc_map and store them into clean_op_map
  327. clean_op_map = dict()
  328. for fake_op in clean_op_list:
  329. clean_info = alloc_map[fake_op] if fake_op in alloc_map else reuse_map[fake_op]
  330. clean_op_map[inplace_assign_map[fake_op]] = clean_info
  331. alloc_map.pop(fake_op) if fake_op in alloc_map else reuse_map.pop(fake_op)
  332. if not alloc_map:
  333. alloc_map['EMPTY'] = []
  334. if not clean_op_map:
  335. clean_op_map['EMPTY'] = []
  336. if not reuse_map:
  337. reuse_map['EMPTY'] = []
  338. return stitch_jsons, input_tensor_name, output_tensor_name, alloc_map, reuse_map, clean_op_map
  339. def parallel_json_split(desc_d):
  340. """
  341. spilt merge_json to single graph json.
  342. Args:
  343. desc_d : dict of compute desciption
  344. Returns:
  345. List of subgraph json.
  346. List of input names.
  347. Dict of output names.
  348. """
  349. op_jsons = []
  350. # get some basic info to init subgraph
  351. composite_graph_id = desc_d['composite_graph']
  352. composite_id = desc_d['id']
  353. final_output_name = desc_d['parallel_fusion']['sub_graph']
  354. sub_graphs = []
  355. for i in range(len(final_output_name)):
  356. sub_graphs.append(Graph(final_output_name[i]))
  357. # traversal desc_d by reverse topological order to construct subgraph
  358. for i in range(len(desc_d['op_desc']) - 1, -1, -1):
  359. op_info = desc_d['op_desc'][i]
  360. for g in sub_graphs:
  361. for j in range(len(op_info['output_desc'])):
  362. if op_info['output_desc'][j]['tensor_name'] in g.tensors:
  363. g.ops.append(op_info)
  364. for input_info in op_info['input_desc']:
  365. for sub_input_info in input_info:
  366. g.tensors.add(sub_input_info['tensor_name'])
  367. # get subgraph original input
  368. if desc_d['input_desc']:
  369. for op_input in desc_d['input_desc']:
  370. for g in sub_graphs:
  371. if op_input[0]['tensor_name'] in g.tensors:
  372. g.input.append(op_input)
  373. # get subgraph original output
  374. for op_output in desc_d['output_desc']:
  375. for g in sub_graphs:
  376. if op_output['tensor_name'] in g.tensors:
  377. g.output.append(op_output)
  378. # get subgraph core num info
  379. core_num_info = desc_d['parallel_fusion']['core_num']
  380. for idx in range(len(sub_graphs)):
  381. g = sub_graphs[idx]
  382. g.core_num = core_num_info[idx]
  383. # reverse ops order to generate a topology order subgraph
  384. for g in sub_graphs:
  385. g.ops = list(reversed(g.ops))
  386. g.op_name = desc_d['op']
  387. # get the original input of all subgraphs in order
  388. # suppose all original json input_args info satisfies this order
  389. input_tensor_names = [tensor[0]['tensor_name'] for tensor in desc_d['input_desc']] if desc_d['input_desc'] else []
  390. output_tensor_names = [tensor['tensor_name'] for tensor in desc_d['output_desc']] if desc_d['output_desc'] else []
  391. # construct subgraph json info
  392. op_result = []
  393. for g in sub_graphs:
  394. op_json_str = {}
  395. op_json_str['composite'] = True
  396. op_json_str['composite_graph'] = composite_graph_id
  397. op_json_str['id'] = composite_id
  398. op_json_str['op'] = g.op_name
  399. op_json_str['input_desc'] = g.input
  400. op_json_str['op_desc'] = g.ops
  401. op_json_str['output_desc'] = g.output
  402. op_json_str['core_num'] = g.core_num
  403. op_json_str['platform'] = "AKG"
  404. op_json_str['process'] = desc_d['process']
  405. op_result.append(op_json_str)
  406. # all sub json info saved in op_jsons list
  407. for idx in range(len(op_result)):
  408. single_op = op_result[idx]
  409. json_str = json.dumps(single_op, indent=4)
  410. op_jsons.append(json_str)
  411. return op_jsons, input_tensor_names, output_tensor_names
  412. def generate_trait(desc):
  413. """ generate trait of kernel description """
  414. def generate_compute_trait():
  415. tensor_idx = {}
  416. counter = 0
  417. traits = []
  418. if desc['input_desc'] is not None:
  419. for in_desc in desc['input_desc']:
  420. tensor_idx[in_desc[0]['tensor_name']] = counter
  421. counter += 1
  422. traits = [str(len(desc['input_desc']))]
  423. for op in desc['op_desc'] if desc['op_desc'] is not None else []:
  424. input_idx = []
  425. for input_desc in op['input_desc']:
  426. if input_desc[0].get('value', None) is None:
  427. input_idx.append(counter - tensor_idx[input_desc[0]['tensor_name']])
  428. input_idx.sort()
  429. input_idx_str = ''.join([str(i) for i in input_idx])
  430. op_trait = op['name'] + input_idx_str
  431. if op['name'] == "MatMul":
  432. for attr in op['attr']:
  433. if attr['name'] == "transpose_a":
  434. transpose_a = str(int(attr['value']))
  435. if attr['name'] == "transpose_b":
  436. transpose_b = str(int(attr['value']))
  437. op_trait += '_' + transpose_a + '_' + transpose_b
  438. traits.append(op_trait)
  439. tensor_idx[op['output_desc'][0]['tensor_name']] = counter
  440. counter += 1
  441. output_idx = []
  442. for out_desc in desc['output_desc'] if desc['output_desc'] is not None else []:
  443. output_idx.append(tensor_idx[out_desc['tensor_name']])
  444. output_idx.sort()
  445. traits.append(''.join([str(i) for i in output_idx]))
  446. return '.'.join(traits)
  447. def append_trait(traits, data):
  448. if traits and traits[-1].rstrip('-') == data:
  449. traits[-1] += '-'
  450. else:
  451. traits.append(data)
  452. def generate_shape_trait():
  453. traits = []
  454. for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []:
  455. shape_s = '_'.join([str(i) for i in in_desc[0]['shape']])
  456. append_trait(traits, shape_s)
  457. for out_desc in desc['output_desc'] if desc['output_desc'] is not None else []:
  458. shape_s = '_'.join([str(i) for i in out_desc['shape']])
  459. append_trait(traits, shape_s)
  460. return '.'.join(traits)
  461. def generate_dtype_trait():
  462. traits = []
  463. for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []:
  464. dtype = in_desc[0]['data_type']
  465. append_trait(traits, dtype)
  466. for out_desc in desc['output_desc'] if desc['output_desc'] is not None else []:
  467. dtype = out_desc['data_type']
  468. append_trait(traits, dtype)
  469. return '.'.join(traits)
  470. compute = generate_compute_trait()
  471. shape = generate_shape_trait()
  472. dtype = generate_dtype_trait()
  473. return compute, shape, dtype
  474. def read_repo_file(repo_file):
  475. with open(repo_file, 'r') as f:
  476. repo = json.loads(f.read())
  477. return repo
  478. def _get_repository_file_path(file):
  479. pwd = os.path.dirname(os.path.abspath(__file__))
  480. path = pwd + "/" + file
  481. if not os.path.exists(path):
  482. path = pwd + "/../config/" + file
  483. if not os.path.exists(path):
  484. raise FileNotFoundError("Can not find {} in directory {} and {}".format(file, pwd, pwd + "/../config"))
  485. return path
  486. def _set_compute_attrs(desc_d_in, attr):
  487. desc_d = desc_d_in
  488. for i, op in enumerate(desc_d.get('op_desc')):
  489. if op.get('name') == "MatMul" and attr.get('bypass') not in (None, ''):
  490. desc_d['op_desc'][i]['attr'].append({'data_type': 'int32', 'name': 'bypass', 'value': attr['bypass']})
  491. desc_s = json.dumps(desc_d)
  492. return desc_d, desc_s
  493. def _pragma_rmselfdep(kernel_info):
  494. for op in kernel_info["op_desc"]:
  495. if op['name'] == "MatMul":
  496. return False
  497. return True
  498. def _enable_auto_inline(kernel_info):
  499. for op in kernel_info["op_desc"]:
  500. # For the MatMul/BatchMatMul with bias, the inline is necessary
  501. if op['name'] in ["MatMul", "BatchMatMul"]:
  502. return True
  503. # For the Ascend, turn 'enable_auto_inline' off for composite op by default.
  504. return False
  505. def _build_to_module(desc_s_in, desc_d_in, attr=None, use_repo=True):
  506. """
  507. build kernel with compute description in json format
  508. Args:
  509. desc_s_in : str of compute description
  510. desc_d_in : dict of compute description
  511. attr : dict of build attributes
  512. Returns:
  513. Module.
  514. """
  515. if os.getenv('MS_GRAPH_KERNEL_TILING'):
  516. repository = read_repo_file(str(os.getenv('MS_GRAPH_KERNEL_TILING')))
  517. else:
  518. file_path = _get_repository_file_path("repository.json")
  519. repository = read_repo_file(file_path)
  520. def get_repo(keys, default=None):
  521. repo = repository
  522. for key in keys:
  523. repo = repo.get(key)
  524. if not repo:
  525. return default
  526. return repo
  527. if attr is None:
  528. attr = {'dim': ''}
  529. desc_d = desc_d_in
  530. desc_s = desc_s_in
  531. attr["pragma_rmselfdep"] = _pragma_rmselfdep(desc_d)
  532. attr["enable_auto_inline"] = _enable_auto_inline(desc_d)
  533. if use_repo:
  534. compute, shape, dtype = generate_trait(desc_d)
  535. repo_attr = get_repo([compute, shape, dtype, 'metadata', 'attrs'], {})
  536. if not repo_attr:
  537. repo_attr = get_repo([compute, 'metadata', 'attrs'], {})
  538. for a in repo_attr:
  539. if not attr.get(a):
  540. attr[a] = repo_attr[a]
  541. if attr.get('dim') in (None, ''):
  542. tiling = get_repo([compute, shape, dtype, 'dim'])
  543. if tiling:
  544. attr['dim'] = tiling
  545. elif 'online_tuning' in attr:
  546. from akg.auto_tune.composite_tuner import tune_composite
  547. best_config = tune_composite(desc_s_in,
  548. tune_level=attr["online_tuning"],
  549. repo_path=_get_repository_file_path("repository.json"),
  550. skip_exist=True)
  551. attr.update(best_config)
  552. desc_d, desc_s = _set_compute_attrs(desc_d, attr)
  553. if 'parallel_fusion' in desc_d or 'buffer_stitch' in desc_d:
  554. return _build_json_list_to_module(desc_d, attr, True, 'cce')
  555. func = tvm.get_global_func("composite_with_json")
  556. return func(desc_s, attr, True)
  557. def _reducemax_pattern(kernel_info):
  558. for op in kernel_info['op_desc']:
  559. if op['name'] == 'ReduceMax':
  560. input_shape = op['input_desc'][0][0]['shape']
  561. batch_size = input_shape[0]
  562. reduce_size = batch_size * input_shape[1] * input_shape[2]
  563. return (True, reduce_size)
  564. return (False, 0)
  565. def _is_batchmatmul(kernel_info):
  566. for op in kernel_info['op_desc']:
  567. if op['name'] == 'BatchMatMul':
  568. return True
  569. return False
  570. def _set_tiling_attrs(out_shape, attrs):
  571. axis_len = len(out_shape)
  572. if axis_len < 3:
  573. return attrs
  574. if all(map(lambda x:x == 1, [out_shape[x] for x in range(axis_len - 2)])):
  575. return attrs
  576. if attrs.get('bind_block') in (None, ''):
  577. i = 0
  578. while out_shape[i] == 1:
  579. i += 1
  580. block_y = out_shape[i]
  581. block_x = out_shape[i + 1] if i < axis_len - 3 else 1
  582. attrs['bind_block'] = str(block_x) + ' ' + str(block_y)
  583. if attrs.get('dim') in (None, ''):
  584. batch_axis = 0
  585. for i in range(axis_len - 2):
  586. if out_shape[i] != 1:
  587. batch_axis += 1
  588. dim_list = [0, 0, 64, 64, 0, 0, 64, 64, 0, 0, 64, 4]
  589. dim_list = [0, 0, 1, 1] * batch_axis + dim_list
  590. i = 0
  591. while i < (len(dim_list) // 4):
  592. dim_list[i * 4 + 1] = i
  593. i += 1
  594. attrs['dim'] = ' '.join(str(x) for x in dim_list)
  595. return attrs
  596. def _set_reducemax_attrs(desc_d, attrs):
  597. if _reducemax_pattern(desc_d)[0]:
  598. attrs['enable_tile_c0'] = True
  599. elem_per_thread = 4
  600. blockdim_x = 64
  601. blockdim_y = 16
  602. griddim_x = 1
  603. griddim_y = _reducemax_pattern(desc_d)[1] / (blockdim_y * elem_per_thread)
  604. attrs['dim'] = ' 0 0 128 64 0 1 128 128'
  605. attrs['bind_block'] = str(griddim_x) + ' ' + str(griddim_y)
  606. attrs['bind_thread'] = str(blockdim_x) + ' ' + str(blockdim_y)
  607. return attrs
  608. def _json_need_split(desc_d, attrs):
  609. block_jsons = []
  610. input_tensor_name = []
  611. output_tensor_name = []
  612. attrs_list = []
  613. alloc_map_list = []
  614. reuse_map_list = []
  615. clean_op_map_list = []
  616. if 'parallel_fusion' in desc_d:
  617. block_jsons, input_tensor_name, output_tensor_name = parallel_json_split(desc_d)
  618. if desc_d["parallel_fusion"]["fusion_type"] == "block_pipeline_fusion":
  619. attrs["pipeline_groups"] = desc_d["parallel_fusion"]['type_info']
  620. for i, _ in enumerate(block_jsons):
  621. if 'buffer_stitch' in block_jsons[i]:
  622. stitch_jsons, _, _, alloc_map, reuse_map, clean_op_map = stitch_json_split(block_jsons[i])
  623. block_jsons[i] = stitch_jsons
  624. cur_attrs = _set_reducemax_attrs(json.loads(stitch_jsons), attrs.copy())
  625. else:
  626. alloc_map, reuse_map, clean_op_map = dict(), dict(), dict()
  627. cur_attrs = attrs.copy()
  628. cur_attrs["enable_atomic_add"] = should_enable_atomic_add(json.loads(block_jsons[i]))
  629. attrs_list.append(cur_attrs)
  630. alloc_map_list.append(alloc_map)
  631. reuse_map_list.append(reuse_map)
  632. clean_op_map_list.append(clean_op_map)
  633. elif 'buffer_stitch' in desc_d:
  634. stitch_jsons, input_tensor_name, output_tensor_name, alloc_map, reuse_map, clean_op_map = stitch_json_split(desc_d)
  635. block_jsons.append(stitch_jsons)
  636. attrs = _set_reducemax_attrs(desc_d, attrs)
  637. attrs_list.append(attrs)
  638. alloc_map_list.append(alloc_map)
  639. reuse_map_list.append(reuse_map)
  640. clean_op_map_list.append(clean_op_map)
  641. return block_jsons, input_tensor_name, output_tensor_name, attrs_list, alloc_map_list, reuse_map_list, clean_op_map_list
  642. def _build_json_list_to_module(desc_d, attrs, poly, target):
  643. func = tvm.get_global_func("composite_with_json_list")
  644. block_jsons, input_tensor_name, output_tensor_name, attrs_list, alloc_map_list, reuse_map_list, \
  645. clean_op_map_list = _json_need_split(desc_d, attrs)
  646. return func(block_jsons, input_tensor_name, output_tensor_name, alloc_map_list, reuse_map_list, \
  647. clean_op_map_list, attrs_list, poly, target)
  648. def _build_to_module_gpu(desc_s, desc_d, attrs=None, poly=False):
  649. """
  650. build kernel with compute description in json format
  651. Args:
  652. desc_s : str of compute description
  653. desc_d : dict of compute description
  654. attrs : dict of build attributes
  655. Returns:
  656. Module.
  657. """
  658. if os.getenv('MS_GRAPH_KERNEL_TILING'):
  659. repository_gpu = read_repo_file(str(os.getenv('MS_GRAPH_KERNEL_TILING')))
  660. elif 'buffer_stitch' in desc_d:
  661. repository_gpu = {}
  662. else:
  663. file_path = _get_repository_file_path("repository_gpu.json")
  664. repository_gpu = read_repo_file(file_path)
  665. def get_repo(keys, default=None):
  666. repo = repository_gpu
  667. for key in keys:
  668. repo = repo.get(key)
  669. if not repo:
  670. return default
  671. return repo
  672. if attrs is None:
  673. attrs = {'dim': ''}
  674. compute, shape, dtype = generate_trait(desc_d)
  675. batchmatmul = _is_batchmatmul(desc_d)
  676. if batchmatmul:
  677. shape = "any_shape"
  678. repo_attr = get_repo([compute, shape, dtype, 'metadata', 'attrs'], {})
  679. if repo_attr and batchmatmul:
  680. repo_attr = _set_tiling_attrs(desc_d['output_desc'][0]['shape'], repo_attr)
  681. if not repo_attr:
  682. repo_attr = get_repo([compute, 'metadata', 'attrs'], {})
  683. for a in repo_attr:
  684. if not attrs.get(a):
  685. attrs[a] = repo_attr[a]
  686. attr_list = ['dim', 'bind_block', 'bind_thread']
  687. for item in attr_list:
  688. if attrs.get(item) in (None, ''):
  689. value = get_repo([compute, shape, dtype, item])
  690. if value:
  691. attrs[item] = value
  692. if 'parallel_fusion' in desc_d or 'buffer_stitch' in desc_d:
  693. return _build_json_list_to_module(desc_d, attrs, poly, 'cuda')
  694. func = tvm.get_global_func("composite_with_json")
  695. return func(desc_s, attrs, poly)
  696. def _build(desc_s, desc_d, attrs=None, poly=True, use_repo=True):
  697. if attrs is None:
  698. attrs = dict()
  699. backend = desc_d['process']
  700. if "enable_atomic_add" not in attrs.keys():
  701. attrs["enable_atomic_add"] = should_enable_atomic_add(desc_d)
  702. if not poly:
  703. attrs["enable_atomic_add"] = False
  704. if backend == 'cuda':
  705. if poly:
  706. attrs["enable_akg_reduce_lib"] = True
  707. return _build_to_module_gpu(desc_s, desc_d, attrs, poly)
  708. else:
  709. return _build_to_module(desc_s, desc_d, attrs, use_repo)
  710. def build(kernel_desc, attrs=None, poly=True, use_repo=True):
  711. """
  712. build kernel with compute description in json format
  713. Args:
  714. kernel_desc : str or dict of compute description
  715. attrs : dict of build attributes
  716. Returns:
  717. Module.
  718. """
  719. if isinstance(kernel_desc, str):
  720. desc_s = kernel_desc
  721. desc_d = json.loads(kernel_desc)
  722. else:
  723. assert isinstance(kernel_desc, dict)
  724. desc_s = json.dumps(kernel_desc)
  725. desc_d = kernel_desc
  726. return _build(desc_s, desc_d, attrs, poly, use_repo)
  727. def get_tiling_space(kernel_desc, level=1, attr=None):
  728. """
  729. get tiling space of composite kernel
  730. Args:
  731. kernel_desc : str of compute description
  732. level : info level
  733. attr : dict of build attributes
  734. Returns:
  735. Module.
  736. """
  737. if attr is None:
  738. attr = {}
  739. attr['help_tiling'] = level
  740. attr['tuning'] = 'on'
  741. if 'enable_auto_inline' not in attr:
  742. attr['enable_auto_inline'] = False
  743. attr['pragma_reschedule'] = 1
  744. func = tvm.get_global_func('composite_lower')
  745. ret = func(kernel_desc, attr)
  746. spaces = {}
  747. spaces['index'] = ret.index_table.asnumpy().tolist()
  748. spaces['c1_range'] = ret.c1_tile_range_table.asnumpy().tolist()
  749. spaces['c0_range'] = ret.c0_tile_range_table.asnumpy().tolist()
  750. spaces['c1_mod'] = ret.c1_tile_mod_table.asnumpy().tolist()
  751. spaces['c0_mod'] = ret.c0_tile_mod_table.asnumpy().tolist()
  752. if level >= 2:
  753. spaces['tuning_space'] = ret.tiling_candidate.asnumpy().tolist()
  754. return spaces
  755. @tvm.register_func("akg_build_gpu_module")
  756. def build_cuda(outputs, args, sch_name, kernel_name, attrs = False, poly = False, binds = None):
  757. s = select_cuda_scheduler(outputs, sch_name, poly)
  758. if attrs:
  759. attrs_t = dict(attrs.items())
  760. else:
  761. attrs_t = None
  762. dump_ir = os.getenv(get_dump_ir_flag()) == "on"
  763. with tvm.build_config(dump_pass_ir = dump_ir):
  764. mod = akg.build(s, list(args), "cuda", name = kernel_name, binds = binds, attrs = attrs_t, polyhedral=bool(poly))
  765. return mod
  766. @tvm.register_func("select_cuda_scheduler")
  767. def select_cuda_scheduler(outputs, sch_name, poly = False, grid_dims=0, block_dims=0, buffer_stitch=False):
  768. scheduler = {
  769. "injective" : topi.cuda.injective_single_kernel.schedule_injective,
  770. "reduce" : topi.cuda.reduce_opt.schedule_reduce,
  771. }
  772. with tvm.target.cuda():
  773. if bool(poly):
  774. s = akg.tvm.create_schedule([x.op for x in list(outputs)])
  775. else:
  776. if grid_dims and block_dims and sch_name == "injective":
  777. s = scheduler[sch_name](outputs, grid_dims, block_dims, buffer_stitch=buffer_stitch)
  778. else:
  779. s = scheduler[sch_name](outputs, grid_dims, block_dims)
  780. return s