diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py b/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py index dc0f62a65b..5acdfc30ee 100755 --- a/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py @@ -19,6 +19,8 @@ import sys from te.platform.cce_conf import te_set_version from te.platform.fusion_util import fusion_op import te +sys.path.append(os.path.abspath(os.path.dirname(__file__))) +# pylint: disable=wrong-import-position from tbe_common import check_kernel_info, get_args, get_built_in_impl_path build_in_impl_path = get_built_in_impl_path() @@ -50,13 +52,14 @@ def _replace_range(args): range_item[index] = None -def build_op(build_type, json_str): +def build_op(build_type, json_str, tune_mode=None): """ call op functions with function name and input args json_str Args: build_type : op function name json_str (str): op function input args + tune_mode (str): if use auto_tune Raises: Exception: If specific keyword is not found. @@ -93,8 +96,10 @@ def build_op(build_type, json_str): else: if is_dynamic_shape: op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0) + op_module_name = "impl.dynamic." + op_name else: op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0) + op_module_name = "impl." + op_name # get function if build_type == op_build: if custom_flag: @@ -111,9 +116,14 @@ def build_op(build_type, json_str): if is_dynamic_shape: with te.op.dynamic(): op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) + if tune_mode is not None: + return (te.op.get_compile_info()), (inputs_args, outputs_args, attrs_args), op_module_name return te.op.get_compile_info() else: - return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) + res = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) + if tune_mode is not None: + return res, (inputs_args, outputs_args, attrs_args), op_module_name + return res except Exception as e: raise RuntimeError(e) @@ -149,7 +159,7 @@ def compile_with_json(json_str): if "fusion_op" in json_info: ret = compile_fusion_op(json_str) else: - ret = build_op(op_build, json_str) + ret = build_op(op_build, json_str, None) return ret diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py b/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py index a59858c4a8..daaf4b2881 100644 --- a/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py @@ -326,16 +326,16 @@ class TbeProcess: self.__running_tune_tasks.append(task_id) if tune_mode == RL_TUNE: - ret, job_type = self.__tuner.rl_tune(task_id, op_json) + ret, job_type, compile_info = self.__tuner.rl_tune(task_id, op_json) if job_type is RL_OFFLINE or job_type is RL_ONLINE: if not ret: # offline and online hit will return false - res = task_id, "Success", "Success" + res = task_id, "Success", compile_info self.__finish_tune_task.append(res) self.__running_tune_tasks.remove(task_id) elif job_type is RL_COMPILE: if not ret: - res = task_id, "Fail", "Fail" + res = task_id, "Fail", compile_info self.__finish_tune_task.append(res) self.__running_tune_tasks.remove(task_id) elif tune_mode == GA_TUNE: @@ -384,13 +384,14 @@ class TbeProcess: for item in ret: task_id = item['task_id'] status_code = item['status_code'] + compile_info = item["op_res"] if "op_res" in item else "{}" res = None if status_code == 0: - res = task_id, "Success", "Success" + res = task_id, "Success", compile_info else: self.__failed_tune_task.append(task_id) log.info("task_id:{}, json:{}".format(task_id, self.__task_info[task_id])) - res = task_id, "Failed", "Failed" + res = task_id, "Failed", compile_info self.__finish_tune_task.append(res) self.__running_tune_tasks.remove(task_id) ret = self.__finish_tune_task.pop() diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/tuner.py b/mindspore/_extends/parallel_compile/tbe_compiler/tuner.py index 8e57a201db..1bb5980cd1 100644 --- a/mindspore/_extends/parallel_compile/tbe_compiler/tuner.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/tuner.py @@ -27,13 +27,14 @@ import auto_tune from schedule_search.rl_online_tune import rl_tune_init, dispatch_fusion_tune_task, dispatch_single_tune_task, \ rl_tune_deinit from mindspore import log -from .tbe_common import get_args +from .compiler import build_op from .re_construct_json import single_to_fusion, fusion_to_fusion TE_LOG_LEVEL = ["DEBUG", "INFO", "WARNING", "ERROR"] RL_COMPILE = "RL_COMPILE" RL_OFFLINE = "RL_OFFLINE" RL_ONLINE = "RL_ONLINE" +OP_BUILD = "compile" PLATFORM_FLAG = ["ascend310", "ascend910", "Hi3796CV300ES", "ascend710", "ascend610", "Hi3796CV300CS", "SD3403"] @@ -285,27 +286,20 @@ class TbeTuner: converted_json = single_to_fusion(json.dumps(json_info), tune_mode="RL") op_type = json_info['op_info']['name'] kernel_name = json_info['op_info']['kernel_name'] - op_module = __import__("impl." + op_type, globals(), locals(), [op_type], 0) - op_module_name = "impl." + op_type - py_fn_name = json_info['op_info']['name'] - op_func = getattr(op_module, py_fn_name, None) - + tune_mode = "RL" set_current_op_name(kernel_name) - inputs_args = get_args(json_info['op_info'], 'inputs') - outputs_args = get_args(json_info['op_info'], 'outputs') - attrs_args = get_args(json_info['op_info'], 'attrs') - op_args = inputs_args, outputs_args, attrs_args # todo build with build_single_op_from_c base_kernel = './kernel_meta/' + kernel_name + '.o' job_type = RL_COMPILE + compile_info = "{}" try: - op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) + compile_info, op_args, op_module_name = build_op(OP_BUILD, json.dumps(json_info), tune_mode) # pylint: disable=broad-except except Exception: exc_type, exc_value, _ = sys.exc_info() log.error( "exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc())) - return False, job_type + return False, job_type, compile_info if self.offline_tune: job_type = RL_OFFLINE dump_fusion_json(converted_json, self.offline_dump_path) @@ -318,7 +312,7 @@ class TbeTuner: self.module_list[op_module_name] = 1 self.fusion_need_sync += 1 - return ret, job_type + return ret, job_type, json.dumps(compile_info) def get_op_module_names(self, json_info): """