Browse Source

fuzzy matching of repo json

pull/78/head
wYann 5 years ago
parent
commit
97e667ed7b
1 changed files with 53 additions and 5 deletions
  1. +53
    -5
      python/akg/composite/build_module.py

+ 53
- 5
python/akg/composite/build_module.py View File

@@ -562,6 +562,7 @@ def _enable_auto_inline(kernel_info):
# For the Ascend, turn 'enable_auto_inline' off for composite op by default.
return False


def _build_to_module(desc_s_in, desc_d_in, attr=None, use_repo=True):
"""
build kernel with compute description in json format
@@ -578,8 +579,49 @@ def _build_to_module(desc_s_in, desc_d_in, attr=None, use_repo=True):
else:
file_path = _get_repository_file_path("repository.json")
repository = read_repo_file(file_path)
def get_repo(keys, default=None):
repo = repository

def get_matmul_desc(desc):
for op in desc['op_desc'] if desc['op_desc'] is not None else []:
if op['name'] in ["MatMul", "BatchMatMul"]:
return op
return None
def get_attr(repo):
if not isinstance(repo, dict) or len(repo) == 0:
return {}
for key, value in repo.items():
if key == "attrs":
return value
else:
return get_attr(value)
def get_vague_repo(vague_key, vague_shape, repo, default=None, fuzzy_level=0):
for compute, s_repo in repo.items():
if compute.find(vague_key) != -1:
for shape, t_repo in s_repo.items():
if fuzzy_level == 0 and shape.startswith(vague_shape):
return get_attr(t_repo)
if fuzzy_level == 1 and shape.find(vague_shape) != -1:
return get_attr(t_repo)
return default
def get_input_shape_str(desc):

def append_trait(traits, data):
if traits and traits[-1].rstrip('-') == data:
traits[-1] += '-'
else:
traits.append(data)
inputs = desc['input_desc']
res = []
for input_desc in inputs:
shape = input_desc[0]['shape']
shape_str = '_'.join([str(i) for i in shape])
append_trait(res, shape_str)
return '.'.join(res)

def get_repo(keys, repo, default=None):
for key in keys:
repo = repo.get(key)
if not repo:
@@ -593,14 +635,20 @@ def _build_to_module(desc_s_in, desc_d_in, attr=None, use_repo=True):
attr["enable_auto_inline"] = _enable_auto_inline(desc_d)
if use_repo:
compute, shape, dtype = generate_trait(desc_d)
repo_attr = get_repo([compute, shape, dtype, 'metadata', 'attrs'], {})
repo_attr = get_repo([compute, shape, dtype, 'metadata', 'attrs'], repository, {})
if not repo_attr:
repo_attr = get_repo([compute, 'metadata', 'attrs'], {})
matmul_desc = get_matmul_desc(desc_d)
if matmul_desc != None:
input_shape_str = get_input_shape_str(matmul_desc)
repo_attr = get_vague_repo("MatMul", input_shape_str, repository, {})
print("use fuzzy attr: " ,repo_attr)
else:
repo_attr = get_repo([compute, 'metadata', 'attrs'], repository, {})
for a in repo_attr:
if not attr.get(a):
attr[a] = repo_attr[a]
if attr.get('dim') in (None, ''):
tiling = get_repo([compute, shape, dtype, 'dim'])
tiling = get_repo([compute, shape, dtype, 'dim'], repository)
if tiling:
attr['dim'] = tiling
elif 'online_tuning' in attr:


Loading…
Cancel
Save