Browse Source

tune support dynamic shape

tags/v1.2.0-rc1
liubuyu 5 years ago
parent
commit
bc18078a0e
3 changed files with 26 additions and 21 deletions
  1. +13
    -3
      mindspore/_extends/parallel_compile/tbe_compiler/compiler.py
  2. +6
    -5
      mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py
  3. +7
    -13
      mindspore/_extends/parallel_compile/tbe_compiler/tuner.py

+ 13
- 3
mindspore/_extends/parallel_compile/tbe_compiler/compiler.py View File

@@ -19,6 +19,8 @@ import sys
from te.platform.cce_conf import te_set_version from te.platform.cce_conf import te_set_version
from te.platform.fusion_util import fusion_op from te.platform.fusion_util import fusion_op
import te import te
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
# pylint: disable=wrong-import-position
from tbe_common import check_kernel_info, get_args, get_built_in_impl_path from tbe_common import check_kernel_info, get_args, get_built_in_impl_path


build_in_impl_path = get_built_in_impl_path() build_in_impl_path = get_built_in_impl_path()
@@ -50,13 +52,14 @@ def _replace_range(args):
range_item[index] = None range_item[index] = None




def build_op(build_type, json_str):
def build_op(build_type, json_str, tune_mode=None):
""" """
call op functions with function name and input args json_str call op functions with function name and input args json_str


Args: Args:
build_type : op function name build_type : op function name
json_str (str): op function input args json_str (str): op function input args
tune_mode (str): if use auto_tune


Raises: Raises:
Exception: If specific keyword is not found. Exception: If specific keyword is not found.
@@ -93,8 +96,10 @@ def build_op(build_type, json_str):
else: else:
if is_dynamic_shape: if is_dynamic_shape:
op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0) op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0)
op_module_name = "impl.dynamic." + op_name
else: else:
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0) op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
op_module_name = "impl." + op_name
# get function # get function
if build_type == op_build: if build_type == op_build:
if custom_flag: if custom_flag:
@@ -111,9 +116,14 @@ def build_op(build_type, json_str):
if is_dynamic_shape: if is_dynamic_shape:
with te.op.dynamic(): with te.op.dynamic():
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
if tune_mode is not None:
return (te.op.get_compile_info()), (inputs_args, outputs_args, attrs_args), op_module_name
return te.op.get_compile_info() return te.op.get_compile_info()
else: else:
return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
res = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
if tune_mode is not None:
return res, (inputs_args, outputs_args, attrs_args), op_module_name
return res


except Exception as e: except Exception as e:
raise RuntimeError(e) raise RuntimeError(e)
@@ -149,7 +159,7 @@ def compile_with_json(json_str):
if "fusion_op" in json_info: if "fusion_op" in json_info:
ret = compile_fusion_op(json_str) ret = compile_fusion_op(json_str)
else: else:
ret = build_op(op_build, json_str)
ret = build_op(op_build, json_str, None)
return ret return ret






+ 6
- 5
mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py View File

@@ -326,16 +326,16 @@ class TbeProcess:
self.__running_tune_tasks.append(task_id) self.__running_tune_tasks.append(task_id)


if tune_mode == RL_TUNE: if tune_mode == RL_TUNE:
ret, job_type = self.__tuner.rl_tune(task_id, op_json)
ret, job_type, compile_info = self.__tuner.rl_tune(task_id, op_json)
if job_type is RL_OFFLINE or job_type is RL_ONLINE: if job_type is RL_OFFLINE or job_type is RL_ONLINE:
if not ret: if not ret:
# offline and online hit will return false # offline and online hit will return false
res = task_id, "Success", "Success"
res = task_id, "Success", compile_info
self.__finish_tune_task.append(res) self.__finish_tune_task.append(res)
self.__running_tune_tasks.remove(task_id) self.__running_tune_tasks.remove(task_id)
elif job_type is RL_COMPILE: elif job_type is RL_COMPILE:
if not ret: if not ret:
res = task_id, "Fail", "Fail"
res = task_id, "Fail", compile_info
self.__finish_tune_task.append(res) self.__finish_tune_task.append(res)
self.__running_tune_tasks.remove(task_id) self.__running_tune_tasks.remove(task_id)
elif tune_mode == GA_TUNE: elif tune_mode == GA_TUNE:
@@ -384,13 +384,14 @@ class TbeProcess:
for item in ret: for item in ret:
task_id = item['task_id'] task_id = item['task_id']
status_code = item['status_code'] status_code = item['status_code']
compile_info = item["op_res"] if "op_res" in item else "{}"
res = None res = None
if status_code == 0: if status_code == 0:
res = task_id, "Success", "Success"
res = task_id, "Success", compile_info
else: else:
self.__failed_tune_task.append(task_id) self.__failed_tune_task.append(task_id)
log.info("task_id:{}, json:{}".format(task_id, self.__task_info[task_id])) log.info("task_id:{}, json:{}".format(task_id, self.__task_info[task_id]))
res = task_id, "Failed", "Failed"
res = task_id, "Failed", compile_info
self.__finish_tune_task.append(res) self.__finish_tune_task.append(res)
self.__running_tune_tasks.remove(task_id) self.__running_tune_tasks.remove(task_id)
ret = self.__finish_tune_task.pop() ret = self.__finish_tune_task.pop()


+ 7
- 13
mindspore/_extends/parallel_compile/tbe_compiler/tuner.py View File

@@ -27,13 +27,14 @@ import auto_tune
from schedule_search.rl_online_tune import rl_tune_init, dispatch_fusion_tune_task, dispatch_single_tune_task, \ from schedule_search.rl_online_tune import rl_tune_init, dispatch_fusion_tune_task, dispatch_single_tune_task, \
rl_tune_deinit rl_tune_deinit
from mindspore import log from mindspore import log
from .tbe_common import get_args
from .compiler import build_op
from .re_construct_json import single_to_fusion, fusion_to_fusion from .re_construct_json import single_to_fusion, fusion_to_fusion


TE_LOG_LEVEL = ["DEBUG", "INFO", "WARNING", "ERROR"] TE_LOG_LEVEL = ["DEBUG", "INFO", "WARNING", "ERROR"]
RL_COMPILE = "RL_COMPILE" RL_COMPILE = "RL_COMPILE"
RL_OFFLINE = "RL_OFFLINE" RL_OFFLINE = "RL_OFFLINE"
RL_ONLINE = "RL_ONLINE" RL_ONLINE = "RL_ONLINE"
OP_BUILD = "compile"


PLATFORM_FLAG = ["ascend310", "ascend910", "Hi3796CV300ES", "ascend710", "ascend610", "Hi3796CV300CS", "SD3403"] PLATFORM_FLAG = ["ascend310", "ascend910", "Hi3796CV300ES", "ascend710", "ascend610", "Hi3796CV300CS", "SD3403"]


@@ -285,27 +286,20 @@ class TbeTuner:
converted_json = single_to_fusion(json.dumps(json_info), tune_mode="RL") converted_json = single_to_fusion(json.dumps(json_info), tune_mode="RL")
op_type = json_info['op_info']['name'] op_type = json_info['op_info']['name']
kernel_name = json_info['op_info']['kernel_name'] kernel_name = json_info['op_info']['kernel_name']
op_module = __import__("impl." + op_type, globals(), locals(), [op_type], 0)
op_module_name = "impl." + op_type
py_fn_name = json_info['op_info']['name']
op_func = getattr(op_module, py_fn_name, None)

tune_mode = "RL"
set_current_op_name(kernel_name) set_current_op_name(kernel_name)
inputs_args = get_args(json_info['op_info'], 'inputs')
outputs_args = get_args(json_info['op_info'], 'outputs')
attrs_args = get_args(json_info['op_info'], 'attrs')
op_args = inputs_args, outputs_args, attrs_args
# todo build with build_single_op_from_c # todo build with build_single_op_from_c
base_kernel = './kernel_meta/' + kernel_name + '.o' base_kernel = './kernel_meta/' + kernel_name + '.o'
job_type = RL_COMPILE job_type = RL_COMPILE
compile_info = "{}"
try: try:
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
compile_info, op_args, op_module_name = build_op(OP_BUILD, json.dumps(json_info), tune_mode)
# pylint: disable=broad-except # pylint: disable=broad-except
except Exception: except Exception:
exc_type, exc_value, _ = sys.exc_info() exc_type, exc_value, _ = sys.exc_info()
log.error( log.error(
"exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc())) "exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc()))
return False, job_type
return False, job_type, compile_info
if self.offline_tune: if self.offline_tune:
job_type = RL_OFFLINE job_type = RL_OFFLINE
dump_fusion_json(converted_json, self.offline_dump_path) dump_fusion_json(converted_json, self.offline_dump_path)
@@ -318,7 +312,7 @@ class TbeTuner:


self.module_list[op_module_name] = 1 self.module_list[op_module_name] = 1
self.fusion_need_sync += 1 self.fusion_need_sync += 1
return ret, job_type
return ret, job_type, json.dumps(compile_info)


def get_op_module_names(self, json_info): def get_op_module_names(self, json_info):
""" """


Loading…
Cancel
Save