|
|
|
@@ -41,6 +41,7 @@ PLATFORM_FLAG = ["ascend310", "ascend910", "Hi3796CV300ES", "ascend710", "ascend |
|
|
|
|
|
|
|
class TbeTuner: |
|
|
|
"""tbe tuner for ga tune or rl tune""" |
|
|
|
|
|
|
|
def __init__(self, offline_tune, tune_mode): |
|
|
|
self.offline_tune = offline_tune |
|
|
|
self.tune_init = False |
|
|
|
@@ -286,6 +287,7 @@ class TbeTuner: |
|
|
|
converted_json = single_to_fusion(json.dumps(json_info), tune_mode="RL") |
|
|
|
op_type = json_info['op_info']['name'] |
|
|
|
kernel_name = json_info['op_info']['kernel_name'] |
|
|
|
full_name = json_info['op_info']['full_name'] |
|
|
|
tune_mode = "RL" |
|
|
|
set_current_op_name(kernel_name) |
|
|
|
# todo build with build_single_op_from_c |
|
|
|
@@ -307,25 +309,13 @@ class TbeTuner: |
|
|
|
job_type = RL_ONLINE |
|
|
|
graph_id = 0 |
|
|
|
l1size = 0 # todo need to verify |
|
|
|
ret = dispatch_single_tune_task(graph_id, task_id, l1size, base_kernel, kernel_name, op_module_name, |
|
|
|
ret = dispatch_single_tune_task(graph_id, task_id, l1size, base_kernel, kernel_name, full_name, |
|
|
|
op_module_name + "@" + op_module_name, op_type, op_type, op_args) |
|
|
|
|
|
|
|
self.module_list[op_module_name] = 1 |
|
|
|
self.fusion_need_sync += 1 |
|
|
|
return ret, job_type, json.dumps(compile_info) |
|
|
|
|
|
|
|
def get_op_module_names(self, json_info): |
|
|
|
""" |
|
|
|
Get op module names from op info json |
|
|
|
:param json_info: op's info |
|
|
|
:return: op module names |
|
|
|
""" |
|
|
|
op_module_name = "" |
|
|
|
for op in json_info["fusion_op"]["op_list"]: |
|
|
|
if "module_name" in op: |
|
|
|
op_module_name = op_module_name + op["module_name"] + "," |
|
|
|
return op_module_name[:-1] |
|
|
|
|
|
|
|
def fusion_rl_tune(self, task_id, json_info): |
|
|
|
""" |
|
|
|
RL tune for fusion op |
|
|
|
@@ -336,6 +326,7 @@ class TbeTuner: |
|
|
|
if 'fusion_op' not in json_info or not json_info['fusion_op']: |
|
|
|
raise ValueError("Json string Errors, key:fusion_op not found.") |
|
|
|
kernel_name = json_info["fusion_op"]["fusion_op_name"] |
|
|
|
full_name = json_info["fusion_op"]["full_name"] |
|
|
|
set_current_op_name(kernel_name) |
|
|
|
converted_json = fusion_to_fusion(json.dumps(json_info), tune_mode="RL") |
|
|
|
job_type = RL_COMPILE |
|
|
|
@@ -355,8 +346,7 @@ class TbeTuner: |
|
|
|
job_type = RL_ONLINE |
|
|
|
graph_id = 0 |
|
|
|
l1size = 0 |
|
|
|
op_model_name = self.get_op_module_names(json_info) |
|
|
|
ret = dispatch_fusion_tune_task(graph_id, task_id, l1size, base_kernel, kernel_name, op_model_name, |
|
|
|
ret = dispatch_fusion_tune_task(graph_id, task_id, l1size, base_kernel, kernel_name, full_name, |
|
|
|
converted_json) |
|
|
|
return ret, job_type |
|
|
|
|
|
|
|
|