| @@ -562,6 +562,7 @@ def _enable_auto_inline(kernel_info): | |||
| # For the Ascend, turn 'enable_auto_inline' off for composite op by default. | |||
| return False | |||
| def _build_to_module(desc_s_in, desc_d_in, attr=None, use_repo=True): | |||
| """ | |||
| build kernel with compute description in json format | |||
| @@ -578,8 +579,49 @@ def _build_to_module(desc_s_in, desc_d_in, attr=None, use_repo=True): | |||
| else: | |||
| file_path = _get_repository_file_path("repository.json") | |||
| repository = read_repo_file(file_path) | |||
| def get_repo(keys, default=None): | |||
| repo = repository | |||
| def get_matmul_desc(desc): | |||
| for op in desc['op_desc'] if desc['op_desc'] is not None else []: | |||
| if op['name'] in ["MatMul", "BatchMatMul"]: | |||
| return op | |||
| return None | |||
| def get_attr(repo): | |||
| if not isinstance(repo, dict) or len(repo) == 0: | |||
| return {} | |||
| for key, value in repo.items(): | |||
| if key == "attrs": | |||
| return value | |||
| else: | |||
| return get_attr(value) | |||
| def get_vague_repo(vague_key, vague_shape, repo, default=None, fuzzy_level=0): | |||
| for compute, s_repo in repo.items(): | |||
| if compute.find(vague_key) != -1: | |||
| for shape, t_repo in s_repo.items(): | |||
| if fuzzy_level == 0 and shape.startswith(vague_shape): | |||
| return get_attr(t_repo) | |||
| if fuzzy_level == 1 and shape.find(vague_shape) != -1: | |||
| return get_attr(t_repo) | |||
| return default | |||
| def get_input_shape_str(desc): | |||
| def append_trait(traits, data): | |||
| if traits and traits[-1].rstrip('-') == data: | |||
| traits[-1] += '-' | |||
| else: | |||
| traits.append(data) | |||
| inputs = desc['input_desc'] | |||
| res = [] | |||
| for input_desc in inputs: | |||
| shape = input_desc[0]['shape'] | |||
| shape_str = '_'.join([str(i) for i in shape]) | |||
| append_trait(res, shape_str) | |||
| return '.'.join(res) | |||
| def get_repo(keys, repo, default=None): | |||
| for key in keys: | |||
| repo = repo.get(key) | |||
| if not repo: | |||
| @@ -593,14 +635,20 @@ def _build_to_module(desc_s_in, desc_d_in, attr=None, use_repo=True): | |||
| attr["enable_auto_inline"] = _enable_auto_inline(desc_d) | |||
| if use_repo: | |||
| compute, shape, dtype = generate_trait(desc_d) | |||
| repo_attr = get_repo([compute, shape, dtype, 'metadata', 'attrs'], {}) | |||
| repo_attr = get_repo([compute, shape, dtype, 'metadata', 'attrs'], repository, {}) | |||
| if not repo_attr: | |||
| repo_attr = get_repo([compute, 'metadata', 'attrs'], {}) | |||
| matmul_desc = get_matmul_desc(desc_d) | |||
| if matmul_desc != None: | |||
| input_shape_str = get_input_shape_str(matmul_desc) | |||
| repo_attr = get_vague_repo("MatMul", input_shape_str, repository, {}) | |||
| print("use fuzzy attr: " ,repo_attr) | |||
| else: | |||
| repo_attr = get_repo([compute, 'metadata', 'attrs'], repository, {}) | |||
| for a in repo_attr: | |||
| if not attr.get(a): | |||
| attr[a] = repo_attr[a] | |||
| if attr.get('dim') in (None, ''): | |||
| tiling = get_repo([compute, shape, dtype, 'dim']) | |||
| tiling = get_repo([compute, shape, dtype, 'dim'], repository) | |||
| if tiling: | |||
| attr['dim'] = tiling | |||
| elif 'online_tuning' in attr: | |||