| @@ -19,6 +19,8 @@ import sys | |||
| from te.platform.cce_conf import te_set_version | |||
| from te.platform.fusion_util import fusion_op | |||
| import te | |||
| sys.path.append(os.path.abspath(os.path.dirname(__file__))) | |||
| # pylint: disable=wrong-import-position | |||
| from tbe_common import check_kernel_info, get_args, get_built_in_impl_path | |||
| build_in_impl_path = get_built_in_impl_path() | |||
| @@ -50,13 +52,14 @@ def _replace_range(args): | |||
| range_item[index] = None | |||
| def build_op(build_type, json_str): | |||
| def build_op(build_type, json_str, tune_mode=None): | |||
| """ | |||
| call op functions with function name and input args json_str | |||
| Args: | |||
| build_type : op function name | |||
| json_str (str): op function input args | |||
| tune_mode (str): if use auto_tune | |||
| Raises: | |||
| Exception: If specific keyword is not found. | |||
| @@ -93,8 +96,10 @@ def build_op(build_type, json_str): | |||
| else: | |||
| if is_dynamic_shape: | |||
| op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0) | |||
| op_module_name = "impl.dynamic." + op_name | |||
| else: | |||
| op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0) | |||
| op_module_name = "impl." + op_name | |||
| # get function | |||
| if build_type == op_build: | |||
| if custom_flag: | |||
| @@ -111,9 +116,14 @@ def build_op(build_type, json_str): | |||
| if is_dynamic_shape: | |||
| with te.op.dynamic(): | |||
| op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) | |||
| if tune_mode is not None: | |||
| return (te.op.get_compile_info()), (inputs_args, outputs_args, attrs_args), op_module_name | |||
| return te.op.get_compile_info() | |||
| else: | |||
| return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) | |||
| res = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) | |||
| if tune_mode is not None: | |||
| return res, (inputs_args, outputs_args, attrs_args), op_module_name | |||
| return res | |||
| except Exception as e: | |||
| raise RuntimeError(e) | |||
| @@ -149,7 +159,7 @@ def compile_with_json(json_str): | |||
| if "fusion_op" in json_info: | |||
| ret = compile_fusion_op(json_str) | |||
| else: | |||
| ret = build_op(op_build, json_str) | |||
| ret = build_op(op_build, json_str, None) | |||
| return ret | |||
| @@ -326,16 +326,16 @@ class TbeProcess: | |||
| self.__running_tune_tasks.append(task_id) | |||
| if tune_mode == RL_TUNE: | |||
| ret, job_type = self.__tuner.rl_tune(task_id, op_json) | |||
| ret, job_type, compile_info = self.__tuner.rl_tune(task_id, op_json) | |||
| if job_type is RL_OFFLINE or job_type is RL_ONLINE: | |||
| if not ret: | |||
| # offline and online hit will return false | |||
| res = task_id, "Success", "Success" | |||
| res = task_id, "Success", compile_info | |||
| self.__finish_tune_task.append(res) | |||
| self.__running_tune_tasks.remove(task_id) | |||
| elif job_type is RL_COMPILE: | |||
| if not ret: | |||
| res = task_id, "Fail", "Fail" | |||
| res = task_id, "Fail", compile_info | |||
| self.__finish_tune_task.append(res) | |||
| self.__running_tune_tasks.remove(task_id) | |||
| elif tune_mode == GA_TUNE: | |||
| @@ -384,13 +384,14 @@ class TbeProcess: | |||
| for item in ret: | |||
| task_id = item['task_id'] | |||
| status_code = item['status_code'] | |||
| compile_info = item["op_res"] if "op_res" in item else "{}" | |||
| res = None | |||
| if status_code == 0: | |||
| res = task_id, "Success", "Success" | |||
| res = task_id, "Success", compile_info | |||
| else: | |||
| self.__failed_tune_task.append(task_id) | |||
| log.info("task_id:{}, json:{}".format(task_id, self.__task_info[task_id])) | |||
| res = task_id, "Failed", "Failed" | |||
| res = task_id, "Failed", compile_info | |||
| self.__finish_tune_task.append(res) | |||
| self.__running_tune_tasks.remove(task_id) | |||
| ret = self.__finish_tune_task.pop() | |||
| @@ -27,13 +27,14 @@ import auto_tune | |||
| from schedule_search.rl_online_tune import rl_tune_init, dispatch_fusion_tune_task, dispatch_single_tune_task, \ | |||
| rl_tune_deinit | |||
| from mindspore import log | |||
| from .tbe_common import get_args | |||
| from .compiler import build_op | |||
| from .re_construct_json import single_to_fusion, fusion_to_fusion | |||
| TE_LOG_LEVEL = ["DEBUG", "INFO", "WARNING", "ERROR"] | |||
| RL_COMPILE = "RL_COMPILE" | |||
| RL_OFFLINE = "RL_OFFLINE" | |||
| RL_ONLINE = "RL_ONLINE" | |||
| OP_BUILD = "compile" | |||
| PLATFORM_FLAG = ["ascend310", "ascend910", "Hi3796CV300ES", "ascend710", "ascend610", "Hi3796CV300CS", "SD3403"] | |||
| @@ -285,27 +286,20 @@ class TbeTuner: | |||
| converted_json = single_to_fusion(json.dumps(json_info), tune_mode="RL") | |||
| op_type = json_info['op_info']['name'] | |||
| kernel_name = json_info['op_info']['kernel_name'] | |||
| op_module = __import__("impl." + op_type, globals(), locals(), [op_type], 0) | |||
| op_module_name = "impl." + op_type | |||
| py_fn_name = json_info['op_info']['name'] | |||
| op_func = getattr(op_module, py_fn_name, None) | |||
| tune_mode = "RL" | |||
| set_current_op_name(kernel_name) | |||
| inputs_args = get_args(json_info['op_info'], 'inputs') | |||
| outputs_args = get_args(json_info['op_info'], 'outputs') | |||
| attrs_args = get_args(json_info['op_info'], 'attrs') | |||
| op_args = inputs_args, outputs_args, attrs_args | |||
| # todo build with build_single_op_from_c | |||
| base_kernel = './kernel_meta/' + kernel_name + '.o' | |||
| job_type = RL_COMPILE | |||
| compile_info = "{}" | |||
| try: | |||
| op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) | |||
| compile_info, op_args, op_module_name = build_op(OP_BUILD, json.dumps(json_info), tune_mode) | |||
| # pylint: disable=broad-except | |||
| except Exception: | |||
| exc_type, exc_value, _ = sys.exc_info() | |||
| log.error( | |||
| "exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc())) | |||
| return False, job_type | |||
| return False, job_type, compile_info | |||
| if self.offline_tune: | |||
| job_type = RL_OFFLINE | |||
| dump_fusion_json(converted_json, self.offline_dump_path) | |||
| @@ -318,7 +312,7 @@ class TbeTuner: | |||
| self.module_list[op_module_name] = 1 | |||
| self.fusion_need_sync += 1 | |||
| return ret, job_type | |||
| return ret, job_type, json.dumps(compile_info) | |||
| def get_op_module_names(self, json_info): | |||
| """ | |||