You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

build_module.py 35 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """build module"""
  17. import os
  18. import json
  19. from functools import reduce
  20. import logging
  21. import akg
  22. from akg import tvm
  23. from akg.tvm import _api_internal
  24. from akg.topi.cuda.injective_single_kernel import schedule_injective
  25. import topi
  26. def should_enable_atomic_add(kernel_info):
  27. for op in kernel_info["op_desc"]:
  28. if not op["attr"]:
  29. continue
  30. for attr in op["attr"]:
  31. if attr["name"] == "enable_atomic_add" and attr["value"]:
  32. return True
  33. return False
  34. class Graph():
  35. def __init__(self, output):
  36. self.tensors = set(output)
  37. self.ops = []
  38. self.output_name = output
  39. self.input_name = []
  40. self.input = []
  41. self.core_num = 0
  42. self.output = []
  43. self.op_name = 'Fused'
  44. class Liveness():
  45. def __init__(self):
  46. self.start = -1
  47. self.end = -1
  48. self.is_reduce = False
  49. def liveness_analysis(desc_d, req_map):
  50. req_liveness = dict((k, Liveness()) for k in req_map.keys())
  51. idx = len(desc_d['op_desc'])
  52. for i in range(len(desc_d['op_desc']) - 1, -1, -1):
  53. idx -= 1
  54. op_info = desc_d['op_desc'][i]
  55. for out_desc in op_info['output_desc']:
  56. out_name = out_desc['tensor_name']
  57. if out_name in req_liveness:
  58. if is_reduce(op_info['name']):
  59. req_liveness[out_name].is_reduce = True
  60. if req_liveness[out_name].end == -1:
  61. req_liveness[out_name].end = idx
  62. req_liveness[out_name].start = idx
  63. else:
  64. req_liveness[out_name].start = idx
  65. for input_desc in op_info['input_desc']:
  66. for sub_input_desc in input_desc:
  67. inp_name = sub_input_desc['tensor_name']
  68. if inp_name in req_liveness and req_liveness[inp_name].end == -1:
  69. req_liveness[inp_name].end = idx
  70. if inp_name in req_liveness and req_liveness[inp_name].end > -1:
  71. req_liveness[inp_name].start = idx
  72. # sort req_liveness by Liveness.end.
  73. sort_req_liveness = dict(sorted(req_liveness.items(), key=lambda x: x[1].end, reverse=True))
  74. return sort_req_liveness
  75. def is_reduce(tensor_name):
  76. return tensor_name.startswith('Reduce')
  77. def shared_memory_optimization(desc_d, req_map):
  78. sort_req_liveness = liveness_analysis(desc_d, req_map)
  79. sort_req_buf = list(sort_req_liveness.keys())
  80. alloc_map = dict()
  81. reuse_map = dict()
  82. reverse_reuse_map = dict()
  83. for i in range(len(sort_req_liveness)):
  84. reuse = False
  85. find_conflit = False
  86. ### TODO: the check is used due to the initialization clause position of reduce computation.
  87. if sort_req_liveness[sort_req_buf[i]].is_reduce:
  88. alloc_map[sort_req_buf[i]] = ['ALLOC', req_map[sort_req_buf[i]]]
  89. continue
  90. for j in range(len(sort_req_liveness) - 1, i, -1):
  91. # whether reuseable.
  92. # rule1: one buffer start larger equal to the reused buffer end.
  93. if sort_req_liveness[sort_req_buf[i]].start >= sort_req_liveness[sort_req_buf[j]].end:
  94. # rule2: sizes are compatiable.
  95. if req_map[sort_req_buf[i]] <= req_map[sort_req_buf[j]]:
  96. # rule3: make sure the candidate reused buffer is not using by other conflict variable.
  97. for item in reverse_reuse_map.get(sort_req_buf[j], []):
  98. if (sort_req_liveness[item].end >= sort_req_liveness[sort_req_buf[i]].end) or (sort_req_liveness[item].end >= sort_req_liveness[sort_req_buf[i]].start):
  99. find_conflit = True
  100. break
  101. if not find_conflit:
  102. if sort_req_buf[j] not in reverse_reuse_map:
  103. reverse_reuse_map[sort_req_buf[j]] = [sort_req_buf[i]]
  104. else:
  105. reverse_reuse_map[sort_req_buf[j]].append(sort_req_buf[i])
  106. # rule4: prefer to reuse buffer with same size.
  107. if req_map[sort_req_buf[i]] == req_map[sort_req_buf[j]]:
  108. reuse_map[sort_req_buf[i]] = [sort_req_buf[j], req_map[sort_req_buf[i]]]
  109. reuse = True
  110. break
  111. else:
  112. reuse_map[sort_req_buf[i]] = [sort_req_buf[j], req_map[sort_req_buf[i]]]
  113. reuse = True
  114. if not reuse:
  115. alloc_map[sort_req_buf[i]] = ['ALLOC', req_map[sort_req_buf[i]]]
  116. return alloc_map, reuse_map
  117. def is_tensor(op_info):
  118. return 'value' not in op_info
  119. def parse_merged_json(desc_d, stitch_tensor_name, input_tensor_name, output_tensor_name):
  120. '''
  121. Parse merged json to get subgraph splitted by stitch nodes and input-output relationship of merged graph.
  122. Args:
  123. desc_d (dict): The dict of compute description.
  124. stitch_tensor_name (list[string]): The list of stitch node tensors.
  125. stitch nodes are regarded as edges of sub_graphs. The smallest number of sub_graph is the length of
  126. stitch_tensor_name + 1.
  127. input_tensor_name (list[string]): The list of input tensors.
  128. output_tensor_name (list[string]): The list of output tensors.
  129. output tensors would be regarded as inter_output_tensor and final_output_tensor. The main difference
  130. of the two kinds of tensors is whether out-degree is zero, in which final_output_tensor is the tensor
  131. with zero out-degree in merged graph and otherwise, it is inter_output_tensor.
  132. Returns:
  133. extra_subgraph_output (dict): The dict of extra output tensors for each sub_graph.
  134. final_output_list (list[string]): The list of final output tensors.
  135. output tensors in this list are are final_output_tensor and the subgraph they belong to doesn't
  136. include stitch nodes.
  137. final_output_within_graph (list[string]): The list of final output tensors.
  138. output tensors in this list are final_output_tensor and the subgraph they belong to also includes
  139. stitch node.
  140. '''
  141. # Initialize sub_graph number as the smallest possible number of sub graph.
  142. # sub graphs number might increase based on graph structure.
  143. sub_graph_length = len(stitch_tensor_name)
  144. sub_graph_node = [set() for _ in range(sub_graph_length)]
  145. # use dict to save extra outputs for each sub_graph.
  146. extra_subgraph_output = dict(zip(stitch_tensor_name, [[] for _ in range(sub_graph_length)]))
  147. in_out_dict = {}
  148. inter_output_list = set()
  149. final_output_list = set()
  150. final_output_within_graph = []
  151. idx = 0
  152. final_output_graph = False
  153. for i in range(len(desc_d['op_desc']) - 1, -1, -1):
  154. op_info = desc_d['op_desc'][i]
  155. for out_desc in op_info['output_desc']:
  156. # switch to next subgraph if find stitch node.
  157. if out_desc['tensor_name'] in stitch_tensor_name:
  158. idx += 1
  159. cur_stitch_node = out_desc['tensor_name']
  160. # when current subgraph concludes final output and encounters with stitch node, increase number of subgraph.
  161. if final_output_graph:
  162. final_output_list.add(cur_final_node)
  163. final_output_within_graph.remove(cur_final_node)
  164. sub_graph_length += 1
  165. sub_graph_node += [set()]
  166. final_output_graph = False
  167. # out_desc not in in_out_dict means out-degree is zero.
  168. if out_desc['tensor_name'] not in in_out_dict:
  169. final_output_graph = True
  170. cur_final_node = out_desc['tensor_name']
  171. final_output_within_graph.append(cur_final_node)
  172. sub_graph_node[idx].add(out_desc['tensor_name'])
  173. for input_desc in op_info['input_desc']:
  174. for sub_input_desc in input_desc:
  175. sub_graph_node[idx].add(sub_input_desc['tensor_name'])
  176. tmp_name = sub_input_desc['tensor_name']
  177. if tmp_name in output_tensor_name:
  178. inter_output_list.add(sub_input_desc['tensor_name'])
  179. for subgraph in sub_graph_node[0: idx]:
  180. extra_output = is_tensor(sub_input_desc) and tmp_name not in stitch_tensor_name and tmp_name not in input_tensor_name
  181. used_by_other_sg = tmp_name in subgraph
  182. used_as_output = tmp_name in output_tensor_name
  183. extra_output = extra_output and (used_by_other_sg or used_as_output)
  184. if extra_output and cur_stitch_node and not final_output_graph:
  185. extra_subgraph_output[cur_stitch_node].insert(0, tmp_name)
  186. break
  187. if sub_input_desc['tensor_name'] not in in_out_dict:
  188. in_out_dict[sub_input_desc['tensor_name']] = [out_desc['tensor_name']]
  189. else:
  190. in_out_dict[sub_input_desc['tensor_name']].append(out_desc['tensor_name'])
  191. return extra_subgraph_output, list(final_output_list), final_output_within_graph
  192. def collect_subgraph_info(desc_d, sub_stitch_graphs, req_map, input_tensor_name, output_tensor_name, stitch_node_list):
  193. inplace_assign_map = {}
  194. fake_output_list = []
  195. # traversal desc_d by reverse topologically order.
  196. for i in range(len(desc_d['op_desc']) - 1, -1, -1):
  197. op_info = desc_d['op_desc'][i]
  198. if (op_info['name'] == "InplaceAssign"):
  199. inplace_assign_map[op_info['output_desc'][0]['tensor_name']] = op_info['input_desc'][0][0]['tensor_name']
  200. if (op_info['attr'][0]['name'] == 'fake_output' and op_info['attr'][0]['value'] == 1):
  201. fake_output_list.append(op_info['output_desc'][0]['tensor_name'])
  202. for sg in sub_stitch_graphs:
  203. added_output = []
  204. for out_desc in op_info['output_desc']:
  205. out_tensor_name = out_desc['tensor_name']
  206. if out_tensor_name in sg.tensors:
  207. sg.ops.append(op_info)
  208. if out_tensor_name in req_map:
  209. if out_desc['shape']:
  210. req_map[out_tensor_name] = reduce(lambda x, y: x * y, out_desc['shape'])
  211. else:
  212. req_map[out_tensor_name] = 1
  213. if out_tensor_name in sg.output_name and out_tensor_name not in added_output:
  214. sg.output.append(out_desc)
  215. added_output.append(out_tensor_name)
  216. for input_desc in op_info['input_desc']:
  217. for sub_input_desc in input_desc:
  218. if is_tensor(sub_input_desc):
  219. input_name = sub_input_desc['tensor_name']
  220. if input_name in output_tensor_name and input_name not in added_output:
  221. sg.output.insert(0, sub_input_desc)
  222. added_output.append(input_name)
  223. if input_name in input_tensor_name and input_name not in sg.input_name:
  224. sg.input_name.append(sub_input_desc['tensor_name'])
  225. sg.input.append([sub_input_desc])
  226. # stop expand subgraph when encounter with stitch node.
  227. if input_name not in stitch_node_list:
  228. sg.tensors.add(sub_input_desc['tensor_name'])
  229. # add extra input into subgraph.
  230. elif input_name not in sg.output_name and input_name not in sg.input_name:
  231. sg.input_name.append(input_name)
  232. sg.input.append([sub_input_desc])
  233. return sub_stitch_graphs, inplace_assign_map, fake_output_list
  234. def sub_graph_info(sub_graph, desc_d):
  235. # gather info for sub graph.
  236. op_json_str = {}
  237. op_json_str['composite'] = True
  238. op_json_str['composite_graph'] = desc_d['composite_graph']
  239. op_json_str['id'] = desc_d['id']
  240. op_json_str['op'] = sub_graph.op_name
  241. op_json_str['input_desc'] = sub_graph.input
  242. op_json_str['op_desc'] = sub_graph.ops
  243. op_json_str['output_desc'] = sub_graph.output
  244. op_json_str['platform'] = "AKG"
  245. op_json_str['process'] = desc_d['process']
  246. if 'sub_block_size' in desc_d['buffer_stitch']:
  247. op_json_str['blocksize'] = desc_d['buffer_stitch']['sub_block_size']
  248. json_str = json.dumps(op_json_str)
  249. return json_str
  250. def stitch_json_split(desc_d):
  251. """
  252. split sub graph from merged json file.
  253. Using 'buffer_stitch' to store stitch info from graph kernel.
  254. Args:
  255. desc_d: dict of compute description
  256. Returns:
  257. List of spilted json info.
  258. List of original input.
  259. Dict of dominance info.
  260. """
  261. stitch_jsons = []
  262. input_tensor_name = [tensor[0]['tensor_name'] for tensor in desc_d['input_desc']]
  263. output_tensor_name = [tensor['tensor_name'] for tensor in desc_d['output_desc']]
  264. stitch_node = desc_d['buffer_stitch']['stitch_op']
  265. stitch_node_name = [node for stitchnode in stitch_node for node in stitchnode]
  266. extra_subgraph_output, final_output_list, final_output_within_graph = parse_merged_json(desc_d, stitch_node_name, input_tensor_name, output_tensor_name)
  267. # traverse extra_subgraph_output to save extra output into subgraph.
  268. stitch_node = []
  269. extra_list = []
  270. for item in extra_subgraph_output:
  271. cur_list = [item]
  272. for node in extra_subgraph_output[item]:
  273. if node not in extra_list:
  274. extra_list.append(node)
  275. cur_list.append(node)
  276. stitch_node.append(cur_list)
  277. stitch_node = [[item] + extra_subgraph_output[item] for item in extra_subgraph_output]
  278. stitch_node_name = [node for stitchnode in stitch_node for node in stitchnode]
  279. # initialize req_map
  280. req_op_size = [0] * len(stitch_node_name)
  281. req_map = dict(zip(stitch_node_name, req_op_size))
  282. # add final output within subgraph into the last initialized stitch sub_graph.
  283. stitch_node = stitch_node[:-1] + [stitch_node[-1] + final_output_within_graph]
  284. # add final output into stitch_op.
  285. stitch_node += [[op] for op in final_output_list if op not in stitch_node_name]
  286. stitch_node_list = [node for stitchnode in stitch_node for node in stitchnode]
  287. # each output tensor can only be parsed as output once in all subgraphs.
  288. # All tensors in stitch_node_list will be put into output_name.
  289. # Save other output tensors which are not in stitch_node_name for the output collection of subgraphs.
  290. complement_output = [tensor for tensor in output_tensor_name if tensor not in stitch_node_list]
  291. # initialize sub_stitch_graphs.
  292. sub_stitch_graphs = []
  293. for i, stitch_op in enumerate(stitch_node):
  294. sub_stitch_graphs.append(Graph(stitch_op))
  295. sub_stitch_graphs, inplace_assign_map, fake_output_list = collect_subgraph_info(desc_d, sub_stitch_graphs, req_map, input_tensor_name, complement_output, stitch_node_list)
  296. # reverse op order to generate topological subgraph
  297. for i, sg in enumerate(sub_stitch_graphs):
  298. sg.ops = list(reversed(sg.ops))
  299. sg.op_name = desc_d['op']
  300. stitch_json_str = sub_graph_info(sg, desc_d)
  301. if (os.getenv('MS_AKG_DUMP_IR') == "on"):
  302. if not os.path.exists("stitch_info"):
  303. try:
  304. os.mkdir("stitch_info")
  305. except OSError as err:
  306. # 17, OSError: [Errno 17] File exists
  307. if err.errno == 17:
  308. pass
  309. else:
  310. raise err
  311. with open('stitch_info/' + sg.op_name + '_stitch_' + str(i + 1) + '.json', 'w+') as f:
  312. f.write(stitch_json_str)
  313. stitch_jsons.append(stitch_json_str)
  314. clean_op_list = [fake_op for fake_op in fake_output_list if fake_op in stitch_node_name]
  315. # add fake outputs into output_tensor_name
  316. output_tensor_name += clean_op_list
  317. # start node for dominance tree is final_output_list + final_output_within_graph.
  318. start_node = final_output_list + final_output_within_graph
  319. alloc_map, reuse_map = shared_memory_optimization(desc_d, req_map)
  320. # remove fake output from alloc_map and store them into clean_op_map
  321. clean_op_map = dict()
  322. for fake_op in clean_op_list:
  323. clean_info = alloc_map[fake_op] if fake_op in alloc_map else reuse_map[fake_op]
  324. clean_op_map[inplace_assign_map[fake_op]] = clean_info
  325. alloc_map.pop(fake_op) if fake_op in alloc_map else reuse_map.pop(fake_op)
  326. if not alloc_map:
  327. alloc_map['EMPTY'] = []
  328. if not clean_op_map:
  329. clean_op_map['EMPTY'] = []
  330. if not reuse_map:
  331. reuse_map['EMPTY'] = []
  332. return stitch_jsons, input_tensor_name, output_tensor_name, alloc_map, reuse_map, clean_op_map
  333. def parallel_json_split(desc_d):
  334. """
  335. spilt merge_json to single graph json.
  336. Args:
  337. desc_d : dict of compute desciption
  338. Returns:
  339. List of subgraph json.
  340. List of input names.
  341. Dict of output names.
  342. """
  343. op_jsons = []
  344. # get some basic info to init subgraph
  345. composite_graph_id = desc_d['composite_graph']
  346. composite_id = desc_d['id']
  347. final_output_name = desc_d['parallel_fusion']['sub_graph']
  348. sub_graphs = []
  349. for i in range(len(final_output_name)):
  350. sub_graphs.append(Graph(final_output_name[i]))
  351. # traversal desc_d by reverse topological order to construct subgraph
  352. for i in range(len(desc_d['op_desc']) - 1, -1, -1):
  353. op_info = desc_d['op_desc'][i]
  354. for g in sub_graphs:
  355. for j in range(len(op_info['output_desc'])):
  356. if op_info['output_desc'][j]['tensor_name'] in g.tensors:
  357. g.ops.append(op_info)
  358. for input_info in op_info['input_desc']:
  359. for sub_input_info in input_info:
  360. g.tensors.add(sub_input_info['tensor_name'])
  361. # get subgraph original input
  362. if desc_d['input_desc']:
  363. for op_input in desc_d['input_desc']:
  364. for g in sub_graphs:
  365. if op_input[0]['tensor_name'] in g.tensors:
  366. g.input.append(op_input)
  367. # get subgraph original output
  368. for op_output in desc_d['output_desc']:
  369. for g in sub_graphs:
  370. if op_output['tensor_name'] in g.tensors:
  371. g.output.append(op_output)
  372. # get subgraph core num info
  373. core_num_info = desc_d['parallel_fusion']['core_num']
  374. for idx in range(len(sub_graphs)):
  375. g = sub_graphs[idx]
  376. g.core_num = core_num_info[idx]
  377. # reverse ops order to generate a topology order subgraph
  378. for g in sub_graphs:
  379. g.ops = list(reversed(g.ops))
  380. g.op_name = desc_d['op']
  381. # get the original input of all subgraphs in order
  382. # suppose all original json input_args info satisfies this order
  383. input_tensor_names = [tensor[0]['tensor_name'] for tensor in desc_d['input_desc']] if desc_d['input_desc'] else []
  384. output_tensor_names = [tensor['tensor_name'] for tensor in desc_d['output_desc']] if desc_d['output_desc'] else []
  385. # construct subgraph json info
  386. op_result = []
  387. for g in sub_graphs:
  388. op_json_str = {}
  389. op_json_str['composite'] = True
  390. op_json_str['composite_graph'] = composite_graph_id
  391. op_json_str['id'] = composite_id
  392. op_json_str['op'] = g.op_name
  393. op_json_str['input_desc'] = g.input
  394. op_json_str['op_desc'] = g.ops
  395. op_json_str['output_desc'] = g.output
  396. op_json_str['core_num'] = g.core_num
  397. op_json_str['platform'] = "AKG"
  398. op_json_str['process'] = desc_d['process']
  399. op_result.append(op_json_str)
  400. # all sub json info saved in op_jsons list
  401. for idx in range(len(op_result)):
  402. single_op = op_result[idx]
  403. json_str = json.dumps(single_op, indent=4)
  404. op_jsons.append(json_str)
  405. return op_jsons, input_tensor_names, output_tensor_names
  406. def generate_trait(desc):
  407. """ generate trait of kernel description """
  408. def generate_compute_trait():
  409. tensor_idx = {}
  410. counter = 0
  411. traits = []
  412. if desc['input_desc'] is not None:
  413. for in_desc in desc['input_desc']:
  414. tensor_idx[in_desc[0]['tensor_name']] = counter
  415. counter += 1
  416. traits = [str(len(desc['input_desc']))]
  417. for op in desc['op_desc'] if desc['op_desc'] is not None else []:
  418. input_idx = []
  419. for input_desc in op['input_desc']:
  420. if input_desc[0].get('value', None) is None:
  421. input_idx.append(counter - tensor_idx[input_desc[0]['tensor_name']])
  422. input_idx.sort()
  423. input_idx_str = ''.join([str(i) for i in input_idx])
  424. op_trait = op['name'] + input_idx_str
  425. if op['name'] == "MatMul":
  426. for attr in op['attr']:
  427. if attr['name'] == "transpose_a":
  428. transpose_a = str(int(attr['value']))
  429. if attr['name'] == "transpose_b":
  430. transpose_b = str(int(attr['value']))
  431. op_trait += '_' + transpose_a + '_' + transpose_b
  432. traits.append(op_trait)
  433. tensor_idx[op['output_desc'][0]['tensor_name']] = counter
  434. counter += 1
  435. output_idx = []
  436. for out_desc in desc['output_desc'] if desc['output_desc'] is not None else []:
  437. output_idx.append(tensor_idx[out_desc['tensor_name']])
  438. output_idx.sort()
  439. traits.append(''.join([str(i) for i in output_idx]))
  440. return '.'.join(traits)
  441. def append_trait(traits, data):
  442. if traits and traits[-1].rstrip('-') == data:
  443. traits[-1] += '-'
  444. else:
  445. traits.append(data)
  446. def generate_shape_trait():
  447. traits = []
  448. for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []:
  449. shape_s = '_'.join([str(i) for i in in_desc[0]['shape']])
  450. append_trait(traits, shape_s)
  451. for out_desc in desc['output_desc'] if desc['output_desc'] is not None else []:
  452. shape_s = '_'.join([str(i) for i in out_desc['shape']])
  453. append_trait(traits, shape_s)
  454. return '.'.join(traits)
  455. def generate_dtype_trait():
  456. traits = []
  457. for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []:
  458. dtype = in_desc[0]['data_type']
  459. append_trait(traits, dtype)
  460. for out_desc in desc['output_desc'] if desc['output_desc'] is not None else []:
  461. dtype = out_desc['data_type']
  462. append_trait(traits, dtype)
  463. return '.'.join(traits)
  464. compute = generate_compute_trait()
  465. shape = generate_shape_trait()
  466. dtype = generate_dtype_trait()
  467. return compute, shape, dtype
  468. def read_repo_file(repo_file):
  469. with open(repo_file, 'r') as f:
  470. repo = json.loads(f.read())
  471. return repo
  472. def _get_repository_file_path(file):
  473. pwd = os.path.dirname(os.path.abspath(__file__))
  474. path = pwd + "/" + file
  475. if not os.path.exists(path):
  476. path = pwd + "/../config/" + file
  477. if not os.path.exists(path):
  478. raise FileNotFoundError("Can not find {} in directory {} and {}".format(file, pwd, pwd + "/../config"))
  479. return path
  480. def _build_to_func(desc_s, desc_d, attr=None, use_repo=True):
  481. """
  482. build kernel with compute description in json format
  483. Args:
  484. desc_s : str of compute description
  485. desc_d : dict of compute description
  486. attr : dict of build attributes
  487. Returns:
  488. Module.
  489. """
  490. if os.getenv('MS_GRAPH_KERNEL_TILING'):
  491. repository = read_repo_file(str(os.getenv('MS_GRAPH_KERNEL_TILING')))
  492. else:
  493. file_path = _get_repository_file_path("repository.json")
  494. repository = read_repo_file(file_path)
  495. def get_repo(keys, default=None):
  496. repo = repository
  497. for key in keys:
  498. repo = repo.get(key)
  499. if not repo:
  500. return default
  501. return repo
  502. if attr is None:
  503. attr = {'dim': ''}
  504. # turn 'enable_auto_inline' off for composite op by default.
  505. if 'enable_auto_inline' not in attr:
  506. attr['enable_auto_inline'] = False
  507. if use_repo:
  508. compute, shape, dtype = generate_trait(desc_d)
  509. repo_attr = get_repo([compute, shape, dtype, 'metadata', 'attrs'], {})
  510. if not repo_attr:
  511. repo_attr = get_repo([compute, 'metadata', 'attrs'], {})
  512. for a in repo_attr:
  513. if not attr.get(a):
  514. attr[a] = repo_attr[a]
  515. if attr.get('dim') in (None, ''):
  516. tiling = get_repo([compute, shape, dtype, 'dim'])
  517. if tiling:
  518. attr['dim'] = tiling
  519. if 'parallel_fusion' in desc_d or 'buffer_stitch' in desc_d:
  520. return _build_json_list_func(desc_d, attr, True, 'cce')
  521. func = tvm.get_global_func("composite_with_json_to_func")
  522. return func(desc_s, attr)
  523. def _reducemax_pattern(kernel_info):
  524. for op in kernel_info['op_desc']:
  525. if op['name'] == 'ReduceMax':
  526. input_shape = op['input_desc'][0][0]['shape']
  527. batch_size = input_shape[0]
  528. reduce_size = batch_size * input_shape[1] * input_shape[2]
  529. return (True, reduce_size)
  530. return (False, 0)
  531. def _is_batchmatmul(kernel_info):
  532. for op in kernel_info['op_desc']:
  533. if op['name'] == 'BatchMatMul':
  534. return True
  535. return False
  536. def _set_tiling_attrs(out_shape, attrs):
  537. axis_len = len(out_shape)
  538. if axis_len < 3:
  539. return attrs
  540. if all(map(lambda x:x == 1, [out_shape[x] for x in range(axis_len - 2)])):
  541. return attrs
  542. if attrs.get('bind_block') in (None, ''):
  543. i = 0
  544. while out_shape[i] == 1:
  545. i += 1
  546. block_y = out_shape[i]
  547. block_x = out_shape[i + 1] if i < axis_len - 3 else 1
  548. attrs['bind_block'] = str(block_x) + ' ' + str(block_y)
  549. if attrs.get('dim') in (None, ''):
  550. batch_axis = 0
  551. for i in range(axis_len - 2):
  552. if out_shape[i] != 1:
  553. batch_axis += 1
  554. dim_list = [0, 0, 64, 64, 0, 0, 64, 64, 0, 0, 64, 4]
  555. dim_list = [0, 0, 1, 1] * batch_axis + dim_list
  556. i = 0
  557. while i < (len(dim_list) // 4):
  558. dim_list[i * 4 + 1] = i
  559. i += 1
  560. attrs['dim'] = ' '.join(str(x) for x in dim_list)
  561. return attrs
  562. def _set_reducemax_attrs(desc_d, attrs):
  563. if _reducemax_pattern(desc_d)[0]:
  564. attrs['enable_tile_c0'] = True
  565. elem_per_thread = 4
  566. blockdim_x = 64
  567. blockdim_y = 16
  568. griddim_x = 1
  569. griddim_y = _reducemax_pattern(desc_d)[1] / (blockdim_y * elem_per_thread)
  570. attrs['dim'] = ' 0 0 128 64 0 1 128 128'
  571. attrs['bind_block'] = str(griddim_x) + ' ' + str(griddim_y)
  572. attrs['bind_thread'] = str(blockdim_x) + ' ' + str(blockdim_y)
  573. return attrs
  574. def _json_need_split(desc_d, attrs):
  575. block_jsons = []
  576. input_tensor_name = []
  577. output_tensor_name = []
  578. attrs_list = []
  579. alloc_map_list = []
  580. reuse_map_list = []
  581. clean_op_map_list = []
  582. if 'parallel_fusion' in desc_d:
  583. block_jsons, input_tensor_name, output_tensor_name = parallel_json_split(desc_d)
  584. if desc_d["parallel_fusion"]["fusion_type"] == "block_pipeline_fusion":
  585. attrs["pipeline_groups"] = desc_d["parallel_fusion"]['type_info']
  586. for i, _ in enumerate(block_jsons):
  587. if 'buffer_stitch' in block_jsons[i]:
  588. stitch_jsons, _, _, alloc_map, reuse_map, clean_op_map = stitch_json_split(block_jsons[i])
  589. block_jsons[i] = stitch_jsons
  590. cur_attrs = _set_reducemax_attrs(json.loads(stitch_jsons), attrs.copy())
  591. else:
  592. alloc_map, reuse_map, clean_op_map = dict(), dict(), dict()
  593. cur_attrs = attrs.copy()
  594. cur_attrs["enable_atomic_add"] = should_enable_atomic_add(json.loads(block_jsons[i]))
  595. attrs_list.append(cur_attrs)
  596. alloc_map_list.append(alloc_map)
  597. reuse_map_list.append(reuse_map)
  598. clean_op_map_list.append(clean_op_map)
  599. elif 'buffer_stitch' in desc_d:
  600. stitch_jsons, input_tensor_name, output_tensor_name, alloc_map, reuse_map, clean_op_map = stitch_json_split(desc_d)
  601. block_jsons.append(stitch_jsons)
  602. attrs = _set_reducemax_attrs(desc_d, attrs)
  603. attrs_list.append(attrs)
  604. alloc_map_list.append(alloc_map)
  605. reuse_map_list.append(reuse_map)
  606. clean_op_map_list.append(clean_op_map)
  607. return block_jsons, input_tensor_name, output_tensor_name, attrs_list, alloc_map_list, reuse_map_list, clean_op_map_list
  608. def _build_json_list_func(desc_d, attrs, poly, target):
  609. func = tvm.get_global_func("composite_with_json_list")
  610. block_jsons, input_tensor_name, output_tensor_name, attrs_list, alloc_map_list, reuse_map_list, \
  611. clean_op_map_list = _json_need_split(desc_d, attrs)
  612. return func(block_jsons, input_tensor_name, output_tensor_name, alloc_map_list, reuse_map_list, \
  613. clean_op_map_list, attrs_list, poly, target)
  614. def _build_to_gpu_func(desc_s, desc_d, attrs=None, poly=False):
  615. """
  616. build kernel with compute description in json format
  617. Args:
  618. desc_s : str of compute description
  619. desc_d : dict of compute description
  620. attrs : dict of build attributes
  621. Returns:
  622. Module.
  623. """
  624. if os.getenv('MS_GRAPH_KERNEL_TILING'):
  625. repository_gpu = read_repo_file(str(os.getenv('MS_GRAPH_KERNEL_TILING')))
  626. elif 'buffer_stitch' in desc_d:
  627. repository_gpu = {}
  628. else:
  629. file_path = _get_repository_file_path("repository_gpu.json")
  630. repository_gpu = read_repo_file(file_path)
  631. def get_repo(keys, default=None):
  632. repo = repository_gpu
  633. for key in keys:
  634. repo = repo.get(key)
  635. if not repo:
  636. return default
  637. return repo
  638. if attrs is None:
  639. attrs = {'dim': ''}
  640. compute, shape, dtype = generate_trait(desc_d)
  641. batchmatmul = _is_batchmatmul(desc_d)
  642. if batchmatmul:
  643. shape = "any_shape"
  644. repo_attr = get_repo([compute, shape, dtype, 'metadata', 'attrs'], {})
  645. if repo_attr and batchmatmul:
  646. repo_attr = _set_tiling_attrs(desc_d['output_desc'][0]['shape'], repo_attr)
  647. if not repo_attr:
  648. repo_attr = get_repo([compute, 'metadata', 'attrs'], {})
  649. for a in repo_attr:
  650. if not attrs.get(a):
  651. attrs[a] = repo_attr[a]
  652. attr_list = ['dim', 'bind_block', 'bind_thread']
  653. for item in attr_list:
  654. if attrs.get(item) in (None, ''):
  655. value = get_repo([compute, shape, dtype, item])
  656. if value:
  657. attrs[item] = value
  658. if 'parallel_fusion' in desc_d or 'buffer_stitch' in desc_d:
  659. return _build_json_list_func(desc_d, attrs, poly, 'cuda')
  660. func = tvm.get_global_func("composite_with_json")
  661. return func(desc_s, attrs, poly)
  662. def _build(desc_s, desc_d, attrs=None, poly=False, use_repo=True):
  663. if desc_d['process'] == 'cuda':
  664. return _build_to_gpu_func(desc_s, desc_d, attrs, poly)
  665. rst = _build_to_func(desc_s, desc_d, attrs, use_repo)
  666. return _api_internal._BuildToModule(rst)
  667. def build(kernel_desc, attrs=None, poly=False, use_repo=True):
  668. """
  669. build kernel with compute description in json format
  670. Args:
  671. kernel_desc : str or dict of compute description
  672. attrs : dict of build attributes
  673. Returns:
  674. Module.
  675. """
  676. if isinstance(kernel_desc, str):
  677. desc_s = kernel_desc
  678. desc_d = json.loads(kernel_desc)
  679. else:
  680. assert isinstance(kernel_desc, dict)
  681. desc_s = json.dumps(kernel_desc)
  682. desc_d = kernel_desc
  683. return _build(desc_s, desc_d, attrs, poly, use_repo)
  684. def get_tiling_space(kernel_desc, level=1, attr=None):
  685. """
  686. get tiling space of composite kernel
  687. Args:
  688. kernel_desc : str of compute description
  689. level : info level
  690. attr : dict of build attributes
  691. Returns:
  692. Module.
  693. """
  694. if attr is None:
  695. attr = {}
  696. attr['help_tiling'] = level
  697. attr['tuning'] = 'on'
  698. if 'enable_auto_inline' not in attr:
  699. attr['enable_auto_inline'] = False
  700. attr['pragma_reschedule'] = 1
  701. func = tvm.get_global_func('composite_lower')
  702. ret = func(kernel_desc, attr)
  703. spaces = {}
  704. spaces['index'] = ret.index_table.asnumpy().tolist()
  705. spaces['c1_range'] = ret.c1_tile_range_table.asnumpy().tolist()
  706. spaces['c0_range'] = ret.c0_tile_range_table.asnumpy().tolist()
  707. spaces['c1_mod'] = ret.c1_tile_mod_table.asnumpy().tolist()
  708. spaces['c0_mod'] = ret.c0_tile_mod_table.asnumpy().tolist()
  709. if level >= 2:
  710. spaces['tuning_space'] = ret.tiling_candidate.asnumpy().tolist()
  711. return spaces
  712. @tvm.register_func("akg_build_gpu_module")
  713. def build_cuda(outputs, args, sch_name, kernel_name, attrs = False, poly = False, binds = None):
  714. s = select_cuda_scheduler(outputs, sch_name, poly)
  715. if attrs:
  716. attrs_t = dict(attrs.items())
  717. else:
  718. attrs_t = None
  719. dump_ir = os.getenv('MS_AKG_DUMP_IR') == "on"
  720. with tvm.build_config(dump_pass_ir = dump_ir):
  721. mod = akg.build(s, list(args), "cuda", name = kernel_name, binds = binds, attrs = attrs_t, polyhedral=bool(poly))
  722. return mod
  723. @tvm.register_func("select_cuda_scheduler")
  724. def select_cuda_scheduler(outputs, sch_name, poly = False, grid_dims=0, block_dims=0, buffer_stitch=False):
  725. scheduler = {
  726. "injective" : topi.cuda.injective_single_kernel.schedule_injective,
  727. "reduce" : topi.cuda.reduce_opt.schedule_reduce,
  728. }
  729. with tvm.target.cuda():
  730. if bool(poly):
  731. s = akg.tvm.create_schedule([x.op for x in list(outputs)])
  732. else:
  733. if grid_dims and block_dims and sch_name == "injective":
  734. s = scheduler[sch_name](outputs, grid_dims, block_dims, buffer_stitch=buffer_stitch)
  735. else:
  736. s = scheduler[sch_name](outputs, grid_dims, block_dims)
  737. return s

AKG(Auto Kernel Generator)对深度神经网络中的算子进行优化,并提供特定模式下的算子自动融合功能。AKG与MindSpore的图算融合功能协同工作,可提升在不同硬件后端上运行网络的性能。