#!/usr/bin/env python3 # coding: utf-8 # Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """build module""" import os import json from functools import reduce import logging import akg from akg import tvm from akg.tvm import _api_internal from akg.topi.cuda.injective_single_kernel import schedule_injective import topi from akg.global_configs import get_dump_ir_flag def should_enable_atomic_add(kernel_info): for op in kernel_info["op_desc"]: if not op["attr"]: continue for attr in op["attr"]: if attr["name"] == "enable_atomic_add" and attr["value"]: return True return False class Graph(): def __init__(self, output): self.tensors = set(output) self.ops = [] self.output_name = output self.input_name = [] self.input = [] self.core_num = 0 self.output = [] self.op_name = 'Fused' class Liveness(): def __init__(self): self.start = -1 self.end = -1 self.is_reduce = False def __str__(self): return "live_" + str(self.start) + "_" + str(self.end) + "_" + str(self.is_reduce) def __repr__(self): return "live_" + str(self.start) + "_" + str(self.end) + "_" + str(self.is_reduce) def liveness_analysis(desc_d, req_map): req_liveness = dict((k, Liveness()) for k in req_map.keys()) idx = len(desc_d['op_desc']) for i in range(len(desc_d['op_desc']) - 1, -1, -1): idx -= 1 op_info = desc_d['op_desc'][i] for out_desc in op_info['output_desc']: out_name = out_desc['tensor_name'] if out_name in req_liveness: if is_reduce(op_info['name']): req_liveness[out_name].is_reduce = True if req_liveness[out_name].end == -1: req_liveness[out_name].end = idx req_liveness[out_name].start = idx else: req_liveness[out_name].start = idx for input_desc in op_info['input_desc']: for sub_input_desc in input_desc: inp_name = sub_input_desc['tensor_name'] if inp_name in req_liveness and req_liveness[inp_name].end == -1: req_liveness[inp_name].end = idx if inp_name in req_liveness and req_liveness[inp_name].end > -1: req_liveness[inp_name].start = idx # sort req_liveness by Liveness.end. sort_req_liveness = dict(sorted(req_liveness.items(), key=lambda x: x[1].end, reverse=True)) return sort_req_liveness def is_reduce(tensor_name): return tensor_name.startswith('Reduce') def shared_memory_optimization(desc_d, req_map, outputs): sort_req_liveness = liveness_analysis(desc_d, req_map) sort_req_buf = list(sort_req_liveness.keys()) alloc_map = dict() reuse_map = dict() reverse_reuse_map = dict() for i in range(len(sort_req_liveness)): reuse = False find_conflit = False ### TODO: the check is used due to the initialization clause position of reduce computation. if sort_req_liveness[sort_req_buf[i]].is_reduce: alloc_map[sort_req_buf[i]] = ['ALLOC', req_map[sort_req_buf[i]]] continue for j in range(len(sort_req_liveness) - 1, i, -1): # whether reuseable. # rule1: one buffer start larger equal to the reused buffer end. if sort_req_liveness[sort_req_buf[i]].start >= sort_req_liveness[sort_req_buf[j]].end: # rule2: sizes are compatiable. if req_map[sort_req_buf[i]] <= req_map[sort_req_buf[j]] and sort_req_buf[j] not in outputs: # rule3: make sure the candidate reused buffer is not using by other conflict variable. for item in reverse_reuse_map.get(sort_req_buf[j], []): if (sort_req_liveness[item].end >= sort_req_liveness[sort_req_buf[i]].end) or (sort_req_liveness[item].end >= sort_req_liveness[sort_req_buf[i]].start): find_conflit = True break if not find_conflit: if sort_req_buf[j] not in reverse_reuse_map: reverse_reuse_map[sort_req_buf[j]] = [sort_req_buf[i]] else: reverse_reuse_map[sort_req_buf[j]].append(sort_req_buf[i]) # rule4: prefer to reuse buffer with same size. if req_map[sort_req_buf[i]] == req_map[sort_req_buf[j]]: reuse_map[sort_req_buf[i]] = [sort_req_buf[j], req_map[sort_req_buf[i]]] reuse = True break else: reuse_map[sort_req_buf[i]] = [sort_req_buf[j], req_map[sort_req_buf[i]]] reuse = True if not reuse: alloc_map[sort_req_buf[i]] = ['ALLOC', req_map[sort_req_buf[i]]] return alloc_map, reuse_map def is_tensor(op_info): return 'value' not in op_info def parse_merged_json(desc_d, stitch_tensor_name, input_tensor_name, output_tensor_name): ''' Parse merged json to get subgraph splitted by stitch nodes and input-output relationship of merged graph. Args: desc_d (dict): The dict of compute description. stitch_tensor_name (list[string]): The list of stitch node tensors. stitch nodes are regarded as edges of sub_graphs. The smallest number of sub_graph is the length of stitch_tensor_name + 1. input_tensor_name (list[string]): The list of input tensors. output_tensor_name (list[string]): The list of output tensors. output tensors would be regarded as inter_output_tensor and final_output_tensor. The main difference of the two kinds of tensors is whether out-degree is zero, in which final_output_tensor is the tensor with zero out-degree in merged graph and otherwise, it is inter_output_tensor. Returns: extra_subgraph_output (dict): The dict of extra output tensors for each sub_graph. final_output_list (list[string]): The list of final output tensors. output tensors in this list are are final_output_tensor and the subgraph they belong to doesn't include stitch nodes. final_output_within_graph (list[string]): The list of final output tensors. output tensors in this list are final_output_tensor and the subgraph they belong to also includes stitch node. ''' # Initialize sub_graph number as the smallest possible number of sub graph. # sub graphs number might increase based on graph structure. sub_graph_length = len(stitch_tensor_name) sub_graph_node = [set() for _ in range(sub_graph_length)] # use dict to save extra outputs for each sub_graph. extra_subgraph_output = dict(zip(stitch_tensor_name, [[] for _ in range(sub_graph_length)])) in_out_dict = {} inter_output_list = set() final_output_list = set() final_output_within_graph = [] idx = 0 final_output_graph = False for i in range(len(desc_d['op_desc']) - 1, -1, -1): op_info = desc_d['op_desc'][i] for out_desc in op_info['output_desc']: # switch to next subgraph if find stitch node. if out_desc['tensor_name'] in stitch_tensor_name: idx += 1 cur_stitch_node = out_desc['tensor_name'] # when current subgraph concludes final output and encounters with stitch node, increase number of subgraph. if final_output_graph: final_output_list.add(cur_final_node) final_output_within_graph.remove(cur_final_node) sub_graph_length += 1 sub_graph_node += [set()] final_output_graph = False # out_desc not in in_out_dict means out-degree is zero. if out_desc['tensor_name'] not in in_out_dict: final_output_graph = True cur_final_node = out_desc['tensor_name'] final_output_within_graph.append(cur_final_node) sub_graph_node[idx].add(out_desc['tensor_name']) for input_desc in op_info['input_desc']: for sub_input_desc in input_desc: sub_graph_node[idx].add(sub_input_desc['tensor_name']) tmp_name = sub_input_desc['tensor_name'] if tmp_name in output_tensor_name: inter_output_list.add(sub_input_desc['tensor_name']) for subgraph in sub_graph_node[0: idx]: extra_output = is_tensor(sub_input_desc) and tmp_name not in stitch_tensor_name and tmp_name not in input_tensor_name used_by_other_sg = tmp_name in subgraph used_as_output = tmp_name in output_tensor_name extra_output = extra_output and (used_by_other_sg or used_as_output) if extra_output and cur_stitch_node and not final_output_graph: extra_subgraph_output[cur_stitch_node].insert(0, tmp_name) break if sub_input_desc['tensor_name'] not in in_out_dict: in_out_dict[sub_input_desc['tensor_name']] = [out_desc['tensor_name']] else: in_out_dict[sub_input_desc['tensor_name']].append(out_desc['tensor_name']) return extra_subgraph_output, list(final_output_list), final_output_within_graph def collect_subgraph_info(desc_d, sub_stitch_graphs, req_map, input_tensor_name, output_tensor_name, stitch_node_list): inplace_assign_map = {} fake_output_list = [] # traversal desc_d by reverse topologically order. for i in range(len(desc_d['op_desc']) - 1, -1, -1): op_info = desc_d['op_desc'][i] if (op_info['name'] == "InplaceAssign"): inplace_assign_map[op_info['output_desc'][0]['tensor_name']] = op_info['input_desc'][0][0]['tensor_name'] if (op_info['attr'][0]['name'] == 'fake_output' and op_info['attr'][0]['value'] == 1): fake_output_list.append(op_info['output_desc'][0]['tensor_name']) for sg in sub_stitch_graphs: added_output = [] for out_desc in op_info['output_desc']: out_tensor_name = out_desc['tensor_name'] if out_tensor_name in sg.tensors: sg.ops.append(op_info) if out_tensor_name in req_map: if out_desc['shape']: req_map[out_tensor_name] = reduce(lambda x, y: x * y, out_desc['shape']) else: req_map[out_tensor_name] = 1 if out_tensor_name in sg.output_name and out_tensor_name not in added_output: sg.output.append(out_desc) added_output.append(out_tensor_name) for input_desc in op_info['input_desc']: for sub_input_desc in input_desc: if is_tensor(sub_input_desc): input_name = sub_input_desc['tensor_name'] if input_name in output_tensor_name and input_name not in added_output: sg.output.insert(0, sub_input_desc) added_output.append(input_name) if input_name in input_tensor_name and input_name not in sg.input_name: sg.input_name.append(sub_input_desc['tensor_name']) sg.input.append([sub_input_desc]) # stop expand subgraph when encounter with stitch node. if input_name not in stitch_node_list: sg.tensors.add(sub_input_desc['tensor_name']) # add extra input into subgraph. elif input_name not in sg.output_name and input_name not in sg.input_name: sg.input_name.append(input_name) sg.input.append([sub_input_desc]) return sub_stitch_graphs, inplace_assign_map, fake_output_list def sub_graph_info(sub_graph, desc_d): # gather info for sub graph. op_json_str = {} op_json_str['composite'] = True op_json_str['composite_graph'] = desc_d['composite_graph'] op_json_str['id'] = desc_d['id'] op_json_str['op'] = sub_graph.op_name op_json_str['input_desc'] = sub_graph.input op_json_str['op_desc'] = sub_graph.ops op_json_str['output_desc'] = sub_graph.output op_json_str['platform'] = "AKG" op_json_str['process'] = desc_d['process'] if 'sub_block_size' in desc_d['buffer_stitch']: op_json_str['blocksize'] = desc_d['buffer_stitch']['sub_block_size'] json_str = json.dumps(op_json_str) return json_str def stitch_json_split(desc_d): """ split sub graph from merged json file. Using 'buffer_stitch' to store stitch info from graph kernel. Args: desc_d: dict of compute description Returns: List of spilted json info. List of original input. Dict of dominance info. """ stitch_jsons = [] input_tensor_name = [tensor[0]['tensor_name'] for tensor in desc_d['input_desc']] output_tensor_name = [tensor['tensor_name'] for tensor in desc_d['output_desc']] stitch_node = desc_d['buffer_stitch']['stitch_op'] stitch_node_name = [node for stitchnode in stitch_node for node in stitchnode] extra_subgraph_output, final_output_list, final_output_within_graph = parse_merged_json(desc_d, stitch_node_name, input_tensor_name, output_tensor_name) # traverse extra_subgraph_output to save extra output into subgraph. stitch_node = [] extra_list = [] for item in extra_subgraph_output: cur_list = [item] for node in extra_subgraph_output[item]: if node not in extra_list: extra_list.append(node) cur_list.append(node) stitch_node.append(cur_list) stitch_node_name = [node for stitchnode in stitch_node for node in stitchnode] # initialize req_map req_op_size = [0] * len(stitch_node_name) req_map = dict(zip(stitch_node_name, req_op_size)) # add final output within subgraph into the last initialized stitch sub_graph. stitch_node = stitch_node[:-1] + [stitch_node[-1] + final_output_within_graph] # add final output into stitch_op. stitch_node += [[op] for op in final_output_list if op not in stitch_node_name] stitch_node_list = [node for stitchnode in stitch_node for node in stitchnode] # each output tensor can only be parsed as output once in all subgraphs. # All tensors in stitch_node_list will be put into output_name. # Save other output tensors which are not in stitch_node_name for the output collection of subgraphs. complement_output = [tensor for tensor in output_tensor_name if tensor not in stitch_node_list] # initialize sub_stitch_graphs. sub_stitch_graphs = [] for i, stitch_op in enumerate(stitch_node): sub_stitch_graphs.append(Graph(stitch_op)) sub_stitch_graphs, inplace_assign_map, fake_output_list = collect_subgraph_info(desc_d, sub_stitch_graphs, req_map, input_tensor_name, complement_output, stitch_node_list) # reverse op order to generate topological subgraph for i, sg in enumerate(sub_stitch_graphs): sg.ops = list(reversed(sg.ops)) sg.op_name = desc_d['op'] stitch_json_str = sub_graph_info(sg, desc_d) if (os.getenv(get_dump_ir_flag()) == "on"): if not os.path.exists("stitch_info"): try: os.mkdir("stitch_info") except OSError as err: # 17, OSError: [Errno 17] File exists if err.errno == 17: pass else: raise err with open('stitch_info/' + sg.op_name + '_stitch_' + str(i + 1) + '.json', 'w+') as f: f.write(stitch_json_str) with open('stitch_info/' + sg.op_name + '_stitch.json', 'w+') as f: f.write(json.dumps(desc_d)) stitch_jsons.append(stitch_json_str) clean_op_list = [fake_op for fake_op in fake_output_list if fake_op in stitch_node_name] # add fake outputs into output_tensor_name output_tensor_name += clean_op_list # start node for dominance tree is final_output_list + final_output_within_graph. start_node = final_output_list + final_output_within_graph alloc_map, reuse_map = shared_memory_optimization(desc_d, req_map, output_tensor_name) # remove fake output from alloc_map and store them into clean_op_map clean_op_map = dict() for fake_op in clean_op_list: clean_info = alloc_map[fake_op] if fake_op in alloc_map else reuse_map[fake_op] clean_op_map[inplace_assign_map[fake_op]] = clean_info alloc_map.pop(fake_op) if fake_op in alloc_map else reuse_map.pop(fake_op) if not alloc_map: alloc_map['EMPTY'] = [] if not clean_op_map: clean_op_map['EMPTY'] = [] if not reuse_map: reuse_map['EMPTY'] = [] return stitch_jsons, input_tensor_name, output_tensor_name, alloc_map, reuse_map, clean_op_map def parallel_json_split(desc_d): """ spilt merge_json to single graph json. Args: desc_d : dict of compute desciption Returns: List of subgraph json. List of input names. Dict of output names. """ op_jsons = [] # get some basic info to init subgraph composite_graph_id = desc_d['composite_graph'] composite_id = desc_d['id'] final_output_name = desc_d['parallel_fusion']['sub_graph'] sub_graphs = [] for i in range(len(final_output_name)): sub_graphs.append(Graph(final_output_name[i])) # traversal desc_d by reverse topological order to construct subgraph for i in range(len(desc_d['op_desc']) - 1, -1, -1): op_info = desc_d['op_desc'][i] for g in sub_graphs: for j in range(len(op_info['output_desc'])): if op_info['output_desc'][j]['tensor_name'] in g.tensors: g.ops.append(op_info) for input_info in op_info['input_desc']: for sub_input_info in input_info: g.tensors.add(sub_input_info['tensor_name']) # get subgraph original input if desc_d['input_desc']: for op_input in desc_d['input_desc']: for g in sub_graphs: if op_input[0]['tensor_name'] in g.tensors: g.input.append(op_input) # get subgraph original output for op_output in desc_d['output_desc']: for g in sub_graphs: if op_output['tensor_name'] in g.tensors: g.output.append(op_output) # get subgraph core num info core_num_info = desc_d['parallel_fusion']['core_num'] for idx in range(len(sub_graphs)): g = sub_graphs[idx] g.core_num = core_num_info[idx] # reverse ops order to generate a topology order subgraph for g in sub_graphs: g.ops = list(reversed(g.ops)) g.op_name = desc_d['op'] # get the original input of all subgraphs in order # suppose all original json input_args info satisfies this order input_tensor_names = [tensor[0]['tensor_name'] for tensor in desc_d['input_desc']] if desc_d['input_desc'] else [] output_tensor_names = [tensor['tensor_name'] for tensor in desc_d['output_desc']] if desc_d['output_desc'] else [] # construct subgraph json info op_result = [] for g in sub_graphs: op_json_str = {} op_json_str['composite'] = True op_json_str['composite_graph'] = composite_graph_id op_json_str['id'] = composite_id op_json_str['op'] = g.op_name op_json_str['input_desc'] = g.input op_json_str['op_desc'] = g.ops op_json_str['output_desc'] = g.output op_json_str['core_num'] = g.core_num op_json_str['platform'] = "AKG" op_json_str['process'] = desc_d['process'] op_result.append(op_json_str) # all sub json info saved in op_jsons list for idx in range(len(op_result)): single_op = op_result[idx] json_str = json.dumps(single_op, indent=4) op_jsons.append(json_str) return op_jsons, input_tensor_names, output_tensor_names def generate_trait(desc): """ generate trait of kernel description """ def generate_compute_trait(): tensor_idx = {} counter = 0 traits = [] if desc['input_desc'] is not None: for in_desc in desc['input_desc']: tensor_idx[in_desc[0]['tensor_name']] = counter counter += 1 traits = [str(len(desc['input_desc']))] for op in desc['op_desc'] if desc['op_desc'] is not None else []: input_idx = [] for input_desc in op['input_desc']: if input_desc[0].get('value', None) is None: input_idx.append(counter - tensor_idx[input_desc[0]['tensor_name']]) input_idx.sort() input_idx_str = ''.join([str(i) for i in input_idx]) op_trait = op['name'] + input_idx_str if op['name'] == "MatMul": for attr in op['attr']: if attr['name'] == "transpose_a": transpose_a = str(int(attr['value'])) if attr['name'] == "transpose_b": transpose_b = str(int(attr['value'])) op_trait += '_' + transpose_a + '_' + transpose_b traits.append(op_trait) tensor_idx[op['output_desc'][0]['tensor_name']] = counter counter += 1 output_idx = [] for out_desc in desc['output_desc'] if desc['output_desc'] is not None else []: output_idx.append(tensor_idx[out_desc['tensor_name']]) output_idx.sort() traits.append(''.join([str(i) for i in output_idx])) return '.'.join(traits) def append_trait(traits, data): if traits and traits[-1].rstrip('-') == data: traits[-1] += '-' else: traits.append(data) def generate_shape_trait(): traits = [] for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []: shape_s = '_'.join([str(i) for i in in_desc[0]['shape']]) append_trait(traits, shape_s) for out_desc in desc['output_desc'] if desc['output_desc'] is not None else []: shape_s = '_'.join([str(i) for i in out_desc['shape']]) append_trait(traits, shape_s) return '.'.join(traits) def generate_dtype_trait(): traits = [] for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []: dtype = in_desc[0]['data_type'] append_trait(traits, dtype) for out_desc in desc['output_desc'] if desc['output_desc'] is not None else []: dtype = out_desc['data_type'] append_trait(traits, dtype) return '.'.join(traits) compute = generate_compute_trait() shape = generate_shape_trait() dtype = generate_dtype_trait() return compute, shape, dtype def read_repo_file(repo_file): with open(repo_file, 'r') as f: repo = json.loads(f.read()) return repo def _get_repository_file_path(file): pwd = os.path.dirname(os.path.abspath(__file__)) path = pwd + "/" + file if not os.path.exists(path): path = pwd + "/../config/" + file if not os.path.exists(path): raise FileNotFoundError("Can not find {} in directory {} and {}".format(file, pwd, pwd + "/../config")) return path def _set_compute_attrs(desc_d_in, attr): desc_d = desc_d_in for i, op in enumerate(desc_d.get('op_desc')): if op.get('name') == "MatMul" and attr.get('bypass') not in (None, ''): desc_d['op_desc'][i]['attr'].append({'data_type': 'int32', 'name': 'bypass', 'value': attr['bypass']}) desc_s = json.dumps(desc_d) return desc_d, desc_s def _pragma_rmselfdep(kernel_info): for op in kernel_info["op_desc"]: if op['name'] == "MatMul": return False return True def _enable_auto_inline(kernel_info): for op in kernel_info["op_desc"]: # For the MatMul/BatchMatMul with bias, the inline is necessary if op['name'] in ["MatMul", "BatchMatMul"]: return True # For the Ascend, turn 'enable_auto_inline' off for composite op by default. return False def _build_to_module(desc_s_in, desc_d_in, attr=None, use_repo=True): """ build kernel with compute description in json format Args: desc_s_in : str of compute description desc_d_in : dict of compute description attr : dict of build attributes Returns: Module. """ if os.getenv('MS_GRAPH_KERNEL_TILING'): repository = read_repo_file(str(os.getenv('MS_GRAPH_KERNEL_TILING'))) else: file_path = _get_repository_file_path("repository.json") repository = read_repo_file(file_path) def get_repo(keys, default=None): repo = repository for key in keys: repo = repo.get(key) if not repo: return default return repo if attr is None: attr = {'dim': ''} desc_d = desc_d_in desc_s = desc_s_in attr["pragma_rmselfdep"] = _pragma_rmselfdep(desc_d) attr["enable_auto_inline"] = _enable_auto_inline(desc_d) if use_repo: compute, shape, dtype = generate_trait(desc_d) repo_attr = get_repo([compute, shape, dtype, 'metadata', 'attrs'], {}) if not repo_attr: repo_attr = get_repo([compute, 'metadata', 'attrs'], {}) for a in repo_attr: if not attr.get(a): attr[a] = repo_attr[a] if attr.get('dim') in (None, ''): tiling = get_repo([compute, shape, dtype, 'dim']) if tiling: attr['dim'] = tiling elif 'online_tuning' in attr: from akg.auto_tune.composite_tuner import tune_composite best_config = tune_composite(desc_s_in, tune_level=attr["online_tuning"], repo_path=_get_repository_file_path("repository.json"), skip_exist=True) attr.update(best_config) desc_d, desc_s = _set_compute_attrs(desc_d, attr) if 'parallel_fusion' in desc_d or 'buffer_stitch' in desc_d: return _build_json_list_to_module(desc_d, attr, True, 'cce') func = tvm.get_global_func("composite_with_json") return func(desc_s, attr, True) def _reducemax_pattern(kernel_info): for op in kernel_info['op_desc']: if op['name'] == 'ReduceMax': input_shape = op['input_desc'][0][0]['shape'] batch_size = input_shape[0] reduce_size = batch_size * input_shape[1] * input_shape[2] return (True, reduce_size) return (False, 0) def _is_batchmatmul(kernel_info): for op in kernel_info['op_desc']: if op['name'] == 'BatchMatMul': return True return False def _set_tiling_attrs(out_shape, attrs): axis_len = len(out_shape) if axis_len < 3: return attrs if all(map(lambda x:x == 1, [out_shape[x] for x in range(axis_len - 2)])): return attrs if attrs.get('bind_block') in (None, ''): i = 0 while out_shape[i] == 1: i += 1 block_y = out_shape[i] block_x = out_shape[i + 1] if i < axis_len - 3 else 1 attrs['bind_block'] = str(block_x) + ' ' + str(block_y) if attrs.get('dim') in (None, ''): batch_axis = 0 for i in range(axis_len - 2): if out_shape[i] != 1: batch_axis += 1 dim_list = [0, 0, 64, 64, 0, 0, 64, 64, 0, 0, 64, 4] dim_list = [0, 0, 1, 1] * batch_axis + dim_list i = 0 while i < (len(dim_list) // 4): dim_list[i * 4 + 1] = i i += 1 attrs['dim'] = ' '.join(str(x) for x in dim_list) return attrs def _set_reducemax_attrs(desc_d, attrs): if _reducemax_pattern(desc_d)[0]: attrs['enable_tile_c0'] = True elem_per_thread = 4 blockdim_x = 64 blockdim_y = 16 griddim_x = 1 griddim_y = _reducemax_pattern(desc_d)[1] / (blockdim_y * elem_per_thread) attrs['dim'] = ' 0 0 128 64 0 1 128 128' attrs['bind_block'] = str(griddim_x) + ' ' + str(griddim_y) attrs['bind_thread'] = str(blockdim_x) + ' ' + str(blockdim_y) return attrs def _json_need_split(desc_d, attrs): block_jsons = [] input_tensor_name = [] output_tensor_name = [] attrs_list = [] alloc_map_list = [] reuse_map_list = [] clean_op_map_list = [] if 'parallel_fusion' in desc_d: block_jsons, input_tensor_name, output_tensor_name = parallel_json_split(desc_d) if desc_d["parallel_fusion"]["fusion_type"] == "block_pipeline_fusion": attrs["pipeline_groups"] = desc_d["parallel_fusion"]['type_info'] for i, _ in enumerate(block_jsons): if 'buffer_stitch' in block_jsons[i]: stitch_jsons, _, _, alloc_map, reuse_map, clean_op_map = stitch_json_split(block_jsons[i]) block_jsons[i] = stitch_jsons cur_attrs = _set_reducemax_attrs(json.loads(stitch_jsons), attrs.copy()) else: alloc_map, reuse_map, clean_op_map = dict(), dict(), dict() cur_attrs = attrs.copy() cur_attrs["enable_atomic_add"] = should_enable_atomic_add(json.loads(block_jsons[i])) attrs_list.append(cur_attrs) alloc_map_list.append(alloc_map) reuse_map_list.append(reuse_map) clean_op_map_list.append(clean_op_map) elif 'buffer_stitch' in desc_d: stitch_jsons, input_tensor_name, output_tensor_name, alloc_map, reuse_map, clean_op_map = stitch_json_split(desc_d) block_jsons.append(stitch_jsons) attrs = _set_reducemax_attrs(desc_d, attrs) attrs_list.append(attrs) alloc_map_list.append(alloc_map) reuse_map_list.append(reuse_map) clean_op_map_list.append(clean_op_map) return block_jsons, input_tensor_name, output_tensor_name, attrs_list, alloc_map_list, reuse_map_list, clean_op_map_list def _build_json_list_to_module(desc_d, attrs, poly, target): func = tvm.get_global_func("composite_with_json_list") block_jsons, input_tensor_name, output_tensor_name, attrs_list, alloc_map_list, reuse_map_list, \ clean_op_map_list = _json_need_split(desc_d, attrs) return func(block_jsons, input_tensor_name, output_tensor_name, alloc_map_list, reuse_map_list, \ clean_op_map_list, attrs_list, poly, target) def _build_to_module_gpu(desc_s, desc_d, attrs=None, poly=False): """ build kernel with compute description in json format Args: desc_s : str of compute description desc_d : dict of compute description attrs : dict of build attributes Returns: Module. """ if os.getenv('MS_GRAPH_KERNEL_TILING'): repository_gpu = read_repo_file(str(os.getenv('MS_GRAPH_KERNEL_TILING'))) elif 'buffer_stitch' in desc_d: repository_gpu = {} else: file_path = _get_repository_file_path("repository_gpu.json") repository_gpu = read_repo_file(file_path) def get_repo(keys, default=None): repo = repository_gpu for key in keys: repo = repo.get(key) if not repo: return default return repo if attrs is None: attrs = {'dim': ''} compute, shape, dtype = generate_trait(desc_d) batchmatmul = _is_batchmatmul(desc_d) if batchmatmul: shape = "any_shape" repo_attr = get_repo([compute, shape, dtype, 'metadata', 'attrs'], {}) if repo_attr and batchmatmul: repo_attr = _set_tiling_attrs(desc_d['output_desc'][0]['shape'], repo_attr) if not repo_attr: repo_attr = get_repo([compute, 'metadata', 'attrs'], {}) for a in repo_attr: if not attrs.get(a): attrs[a] = repo_attr[a] attr_list = ['dim', 'bind_block', 'bind_thread'] for item in attr_list: if attrs.get(item) in (None, ''): value = get_repo([compute, shape, dtype, item]) if value: attrs[item] = value if 'parallel_fusion' in desc_d or 'buffer_stitch' in desc_d: return _build_json_list_to_module(desc_d, attrs, poly, 'cuda') func = tvm.get_global_func("composite_with_json") return func(desc_s, attrs, poly) def _build(desc_s, desc_d, attrs=None, poly=True, use_repo=True): if attrs is None: attrs = dict() backend = desc_d['process'] if "enable_atomic_add" not in attrs.keys(): attrs["enable_atomic_add"] = should_enable_atomic_add(desc_d) if not poly: attrs["enable_atomic_add"] = False if backend == 'cuda': if poly: attrs["enable_akg_reduce_lib"] = True return _build_to_module_gpu(desc_s, desc_d, attrs, poly) else: return _build_to_module(desc_s, desc_d, attrs, use_repo) def build(kernel_desc, attrs=None, poly=True, use_repo=True): """ build kernel with compute description in json format Args: kernel_desc : str or dict of compute description attrs : dict of build attributes Returns: Module. """ if isinstance(kernel_desc, str): desc_s = kernel_desc desc_d = json.loads(kernel_desc) else: assert isinstance(kernel_desc, dict) desc_s = json.dumps(kernel_desc) desc_d = kernel_desc return _build(desc_s, desc_d, attrs, poly, use_repo) def get_tiling_space(kernel_desc, level=1, attr=None): """ get tiling space of composite kernel Args: kernel_desc : str of compute description level : info level attr : dict of build attributes Returns: Module. """ if attr is None: attr = {} attr['help_tiling'] = level attr['tuning'] = 'on' if 'enable_auto_inline' not in attr: attr['enable_auto_inline'] = False attr['pragma_reschedule'] = 1 func = tvm.get_global_func('composite_lower') ret = func(kernel_desc, attr) spaces = {} spaces['index'] = ret.index_table.asnumpy().tolist() spaces['c1_range'] = ret.c1_tile_range_table.asnumpy().tolist() spaces['c0_range'] = ret.c0_tile_range_table.asnumpy().tolist() spaces['c1_mod'] = ret.c1_tile_mod_table.asnumpy().tolist() spaces['c0_mod'] = ret.c0_tile_mod_table.asnumpy().tolist() if level >= 2: spaces['tuning_space'] = ret.tiling_candidate.asnumpy().tolist() return spaces @tvm.register_func("akg_build_gpu_module") def build_cuda(outputs, args, sch_name, kernel_name, attrs = False, poly = False, binds = None): s = select_cuda_scheduler(outputs, sch_name, poly) if attrs: attrs_t = dict(attrs.items()) else: attrs_t = None dump_ir = os.getenv(get_dump_ir_flag()) == "on" with tvm.build_config(dump_pass_ir = dump_ir): mod = akg.build(s, list(args), "cuda", name = kernel_name, binds = binds, attrs = attrs_t, polyhedral=bool(poly)) return mod @tvm.register_func("select_cuda_scheduler") def select_cuda_scheduler(outputs, sch_name, poly = False, grid_dims=0, block_dims=0, buffer_stitch=False): scheduler = { "injective" : topi.cuda.injective_single_kernel.schedule_injective, "reduce" : topi.cuda.reduce_opt.schedule_reduce, } with tvm.target.cuda(): if bool(poly): s = akg.tvm.create_schedule([x.op for x in list(outputs)]) else: if grid_dims and block_dims and sch_name == "injective": s = scheduler[sch_name](outputs, grid_dims, block_dims, buffer_stitch=buffer_stitch) else: s = scheduler[sch_name](outputs, grid_dims, block_dims) return s