Previously function op_select_format and check_supported raise an exception directly on the tbe_process python side, but we don't deal with the exception, and raise an exeception on c++ side to frontend ME, that will cause some conflict when recycle resource on ME and tbe_process python interpreter. This changes adding try...catch in function op_select_format and check_supported on the python side, and return the Exception string to c++ side, so that we can raise an exception to frontend ME and ME will deal with resouce clearning and exit.tags/v0.3.0-alpha
| @@ -0,0 +1,114 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """tbe process""" | |||||
| import sys | |||||
| import os | |||||
| from .common import get_args, get_build_in_impl_path, TBEException | |||||
| build_in_impl_path = get_build_in_impl_path() | |||||
| def _op_select_format(kernel_info): | |||||
| """ | |||||
| call op's op_select_format to get op supported format | |||||
| Args: | |||||
| kernel_info (dict): kernel info load by json string | |||||
| Returns: | |||||
| op supported format | |||||
| """ | |||||
| try: | |||||
| # import module | |||||
| op_name = kernel_info['op_info']['name'] | |||||
| impl_path = build_in_impl_path | |||||
| custom_flag = False | |||||
| if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None: | |||||
| op_impl_path = os.path.realpath(kernel_info['impl_path']) | |||||
| if os.path.isfile(op_impl_path): | |||||
| path, file_name = os.path.split(op_impl_path) | |||||
| op_name, _ = os.path.splitext(file_name) | |||||
| impl_path = path | |||||
| custom_flag = True | |||||
| if impl_path not in sys.path: | |||||
| sys.path.insert(0, impl_path) | |||||
| if custom_flag: | |||||
| op_module = __import__(op_name) | |||||
| else: | |||||
| op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0) | |||||
| # get function | |||||
| if not hasattr(op_module, "op_select_format"): | |||||
| return "" | |||||
| op_func = getattr(op_module, "op_select_format", None) | |||||
| # call function | |||||
| inputs_args = get_args(kernel_info['op_info'], 'inputs') | |||||
| outputs_args = get_args(kernel_info['op_info'], 'outputs') | |||||
| attrs_args = get_args(kernel_info['op_info'], 'attrs') | |||||
| kernel_name = kernel_info['op_info']['kernel_name'] | |||||
| ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) | |||||
| except Exception as e: | |||||
| raise TBEException(str(e)) | |||||
| return ret | |||||
| def _check_supported(kernel_info): | |||||
| """ | |||||
| call op's check_supported to check supported or not | |||||
| Args: | |||||
| kernel_info (dict): kernel info load by json string | |||||
| Returns: | |||||
| bool: check result, true or false | |||||
| """ | |||||
| try: | |||||
| # import module | |||||
| op_name = kernel_info['op_info']['name'] | |||||
| impl_path = build_in_impl_path | |||||
| custom_flag = False | |||||
| if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None: | |||||
| op_impl_path = os.path.realpath(kernel_info['impl_path']) | |||||
| if os.path.isfile(op_impl_path): | |||||
| path, file_name = os.path.split(op_impl_path) | |||||
| op_name, _ = os.path.splitext(file_name) | |||||
| impl_path = path | |||||
| custom_flag = True | |||||
| if impl_path not in sys.path: | |||||
| sys.path.insert(0, impl_path) | |||||
| if custom_flag: | |||||
| op_module = __import__(op_name) | |||||
| else: | |||||
| op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0) | |||||
| # get function | |||||
| if not hasattr(op_module, "check_supported"): | |||||
| return "" | |||||
| op_func = getattr(op_module, "check_supported", None) | |||||
| # call function | |||||
| inputs_args = get_args(kernel_info['op_info'], 'inputs') | |||||
| outputs_args = get_args(kernel_info['op_info'], 'outputs') | |||||
| attrs_args = get_args(kernel_info['op_info'], 'attrs') | |||||
| kernel_name = kernel_info['op_info']['kernel_name'] | |||||
| ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) | |||||
| except Exception as e: | |||||
| raise TBEException(str(e)) | |||||
| return ret | |||||
| @@ -19,10 +19,8 @@ import subprocess | |||||
| import sys | import sys | ||||
| import os | import os | ||||
| import json | import json | ||||
| from .common import check_kernel_info, get_args, get_build_in_impl_path | |||||
| build_in_impl_path = get_build_in_impl_path() | |||||
| from .common import check_kernel_info, TBEException | |||||
| from .helper import _op_select_format, _check_supported | |||||
| def create_tbe_parallel_compiler(): | def create_tbe_parallel_compiler(): | ||||
| """ | """ | ||||
| @@ -41,40 +39,17 @@ def op_select_format(op_json: str): | |||||
| op_json (str): json string of the op | op_json (str): json string of the op | ||||
| Returns: | Returns: | ||||
| op supported format | |||||
| op supported format or exception message | |||||
| """ | """ | ||||
| ret = "" | ret = "" | ||||
| kernel_info = json.loads(op_json) | |||||
| check_kernel_info(kernel_info) | |||||
| # import module | |||||
| op_name = kernel_info['op_info']['name'] | |||||
| impl_path = build_in_impl_path | |||||
| custom_flag = False | |||||
| if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None: | |||||
| op_impl_path = os.path.realpath(kernel_info['impl_path']) | |||||
| if os.path.isfile(op_impl_path): | |||||
| path, file_name = os.path.split(op_impl_path) | |||||
| op_name, _ = os.path.splitext(file_name) | |||||
| impl_path = path | |||||
| custom_flag = True | |||||
| sys.path.insert(0, impl_path) | |||||
| if custom_flag: | |||||
| op_module = __import__(op_name) | |||||
| else: | |||||
| op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0) | |||||
| # get function | |||||
| if not hasattr(op_module, "op_select_format"): | |||||
| return "" | |||||
| op_func = getattr(op_module, "op_select_format", None) | |||||
| # call function | |||||
| inputs_args = get_args(kernel_info['op_info'], 'inputs') | |||||
| outputs_args = get_args(kernel_info['op_info'], 'outputs') | |||||
| attrs_args = get_args(kernel_info['op_info'], 'attrs') | |||||
| kernel_name = kernel_info['op_info']['kernel_name'] | |||||
| ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) | |||||
| try: | |||||
| kernel_info = json.loads(op_json) | |||||
| check_kernel_info(kernel_info) | |||||
| ret = _op_select_format(kernel_info) | |||||
| except TBEException as e: | |||||
| return "TBEException: " + str(e) | |||||
| return ret | return ret | ||||
| @@ -86,40 +61,18 @@ def check_supported(op_json: str): | |||||
| op_json (str): json string of the op | op_json (str): json string of the op | ||||
| Returns: | Returns: | ||||
| true or false | |||||
| bool: check result, true or false | |||||
| str: exception message when catch an Exception | |||||
| """ | """ | ||||
| ret = "" | ret = "" | ||||
| kernel_info = json.loads(op_json) | |||||
| check_kernel_info(kernel_info) | |||||
| # import module | |||||
| op_name = kernel_info['op_info']['name'] | |||||
| impl_path = build_in_impl_path | |||||
| custom_flag = False | |||||
| if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None: | |||||
| op_impl_path = os.path.realpath(kernel_info['impl_path']) | |||||
| if os.path.isfile(op_impl_path): | |||||
| path, file_name = os.path.split(op_impl_path) | |||||
| op_name, _ = os.path.splitext(file_name) | |||||
| impl_path = path | |||||
| custom_flag = True | |||||
| sys.path.insert(0, impl_path) | |||||
| if custom_flag: | |||||
| op_module = __import__(op_name) | |||||
| else: | |||||
| op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0) | |||||
| # get function | |||||
| if not hasattr(op_module, "check_supported"): | |||||
| return "" | |||||
| op_func = getattr(op_module, "check_supported", None) | |||||
| # call function | |||||
| inputs_args = get_args(kernel_info['op_info'], 'inputs') | |||||
| outputs_args = get_args(kernel_info['op_info'], 'outputs') | |||||
| attrs_args = get_args(kernel_info['op_info'], 'attrs') | |||||
| kernel_name = kernel_info['op_info']['kernel_name'] | |||||
| ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) | |||||
| try: | |||||
| kernel_info = json.loads(op_json) | |||||
| check_kernel_info(kernel_info) | |||||
| ret = _check_supported(kernel_info) | |||||
| except TBEException as e: | |||||
| return "TBEException: " + str(e) | |||||
| return ret | return ret | ||||
| @@ -149,12 +102,12 @@ class CompilerPool: | |||||
| """compiler pool""" | """compiler pool""" | ||||
| def __init__(self): | def __init__(self): | ||||
| processes = multiprocessing.cpu_count() | |||||
| self.__processe_num = multiprocessing.cpu_count() | |||||
| # max_processes_num: Set the maximum number of concurrent processes for compiler | # max_processes_num: Set the maximum number of concurrent processes for compiler | ||||
| max_processes_num = 16 | max_processes_num = 16 | ||||
| if processes > max_processes_num: | |||||
| processes = max_processes_num | |||||
| self.__pool = multiprocessing.Pool(processes=processes) | |||||
| if self.__processe_num > max_processes_num: | |||||
| self.__processe_num = max_processes_num | |||||
| self.__pool = None | |||||
| self.__next_task_id = 1 | self.__next_task_id = 1 | ||||
| self.__running_tasks = [] | self.__running_tasks = [] | ||||
| @@ -165,11 +118,10 @@ class CompilerPool: | |||||
| del self.__pool | del self.__pool | ||||
| def exit(self): | def exit(self): | ||||
| return | |||||
| # self.__pool.terminate() | |||||
| # self.__pool.join() | |||||
| # if self.__pool is not None: | |||||
| # del self.__pool | |||||
| if self.__pool is not None: | |||||
| self.__pool.terminate() | |||||
| self.__pool.join() | |||||
| del self.__pool | |||||
| def start_compile_op(self, op_json): | def start_compile_op(self, op_json): | ||||
| """ | """ | ||||
| @@ -183,6 +135,8 @@ class CompilerPool: | |||||
| """ | """ | ||||
| task_id = self.__next_task_id | task_id = self.__next_task_id | ||||
| self.__next_task_id = self.__next_task_id + 1 | self.__next_task_id = self.__next_task_id + 1 | ||||
| if self.__pool is None: | |||||
| self.__pool = multiprocessing.Pool(processes=self.__processe_num) | |||||
| task_future = self.__pool.apply_async(func=run_compiler, args=(op_json,)) | task_future = self.__pool.apply_async(func=run_compiler, args=(op_json,)) | ||||
| self.__running_tasks.append((task_id, task_future)) | self.__running_tasks.append((task_id, task_future)) | ||||
| return task_id | return task_id | ||||
| @@ -98,7 +98,7 @@ void TbeAdapter::NormalizeFuncName(std::string *func_name) { | |||||
| *func_name = name_tmp; | *func_name = name_tmp; | ||||
| auto iter = tbe_func_adapter_map.find(*func_name); | auto iter = tbe_func_adapter_map.find(*func_name); | ||||
| if (iter != tbe_func_adapter_map.end()) { | if (iter != tbe_func_adapter_map.end()) { | ||||
| MS_LOG(INFO) << "map actual op fron me " << func_name << "to tbe op" << iter->second; | |||||
| MS_LOG(INFO) << "map actual op from me " << func_name << "to tbe op" << iter->second; | |||||
| *func_name = iter->second; | *func_name = iter->second; | ||||
| } | } | ||||
| } | } | ||||
| @@ -35,6 +35,8 @@ namespace kernel { | |||||
| constexpr auto kName = "name"; | constexpr auto kName = "name"; | ||||
| constexpr auto kDtype = "dtype"; | constexpr auto kDtype = "dtype"; | ||||
| constexpr auto kFormat = "format"; | constexpr auto kFormat = "format"; | ||||
| constexpr auto kPrefixInput = "input"; | |||||
| constexpr auto kPrefixOutput = "output"; | |||||
| const std::map<std::string, std::string> DYNAMIC_FORMAT_MAP = {{"NCHW", "DefaultFormat"}, | const std::map<std::string, std::string> DYNAMIC_FORMAT_MAP = {{"NCHW", "DefaultFormat"}, | ||||
| {"NHWC", "DefaultFormat"}, | {"NHWC", "DefaultFormat"}, | ||||
| {"ND", "DefaultFormat"}, | {"ND", "DefaultFormat"}, | ||||
| @@ -146,13 +148,13 @@ bool ParseDynamicFormatJson(const std::string &jsonStr, std::vector<std::shared_ | |||||
| if (!CheckJsonItemValidity(json_obj, key_name, keys)) { | if (!CheckJsonItemValidity(json_obj, key_name, keys)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (key_name.find("input", 0) != std::string::npos) { | |||||
| if (key_name.compare(0, strlen(kPrefixInput), kPrefixInput) == 0) { | |||||
| std::shared_ptr<OpIOInfo> input = std::make_shared<OpIOInfo>(); | std::shared_ptr<OpIOInfo> input = std::make_shared<OpIOInfo>(); | ||||
| MS_EXCEPTION_IF_NULL(input); | MS_EXCEPTION_IF_NULL(input); | ||||
| input->set_name(json_obj[key_name].at(kName)); | input->set_name(json_obj[key_name].at(kName)); | ||||
| ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), input); | ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), input); | ||||
| inputs->emplace_back(input); | inputs->emplace_back(input); | ||||
| } else if (key_name.find("output", 0) != std::string::npos) { | |||||
| } else if (key_name.compare(0, strlen(kPrefixOutput), kPrefixOutput) == 0) { | |||||
| std::shared_ptr<OpIOInfo> output = std::make_shared<OpIOInfo>(); | std::shared_ptr<OpIOInfo> output = std::make_shared<OpIOInfo>(); | ||||
| MS_EXCEPTION_IF_NULL(output); | MS_EXCEPTION_IF_NULL(output); | ||||
| output->set_name(json_obj[key_name].at(kName)); | output->set_name(json_obj[key_name].at(kName)); | ||||
| @@ -26,6 +26,7 @@ constexpr auto kTbeProcessModule = "mindspore._extends.parallel_compile.tbe_comp | |||||
| constexpr auto kCreateTbeParallelCompilerFunc = "create_tbe_parallel_compiler"; | constexpr auto kCreateTbeParallelCompilerFunc = "create_tbe_parallel_compiler"; | ||||
| constexpr auto kOpSelectFormatFunc = "op_select_format"; | constexpr auto kOpSelectFormatFunc = "op_select_format"; | ||||
| constexpr auto kCheckSupportedFunc = "check_supported"; | constexpr auto kCheckSupportedFunc = "check_supported"; | ||||
| constexpr auto kTBEException = "TBEException"; | |||||
| PyObject *TbePythonFuncs::pCreateTbeParallelCompilerFunc_ = nullptr; | PyObject *TbePythonFuncs::pCreateTbeParallelCompilerFunc_ = nullptr; | ||||
| PyObject *TbePythonFuncs::pTbeCompiler_ = nullptr; | PyObject *TbePythonFuncs::pTbeCompiler_ = nullptr; | ||||
| @@ -133,6 +134,10 @@ std::string TbePythonFuncs::OpSelectFormat(const nlohmann::json &kernel_json) { | |||||
| char *pstr = nullptr; | char *pstr = nullptr; | ||||
| (void)PyArg_Parse(pRet, "s", &pstr); | (void)PyArg_Parse(pRet, "s", &pstr); | ||||
| res_json_str = pstr; | res_json_str = pstr; | ||||
| if (res_json_str.compare(0, strlen(kTBEException), kTBEException) == 0) { | |||||
| MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc << "], " << res_json_str | |||||
| << " ,function args:" << PyObjectToStr(pArg); | |||||
| } | |||||
| return res_json_str; | return res_json_str; | ||||
| } | } | ||||
| @@ -167,7 +172,18 @@ bool TbePythonFuncs::CheckSupported(const nlohmann::json &kernel_json) { | |||||
| MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc | MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc | ||||
| << "], function args: " << PyObjectToStr(pArg); | << "], function args: " << PyObjectToStr(pArg); | ||||
| } | } | ||||
| ret = PyObject_IsTrue(pRes) != 0; | |||||
| if (PyBool_Check(pRes)) { | |||||
| ret = PyObject_IsTrue(pRes) != 0; | |||||
| } else { | |||||
| char *pstr = nullptr; | |||||
| (void)PyArg_Parse(pRes, "s", &pstr); | |||||
| std::string res_str = pstr; | |||||
| if (res_str.compare(0, strlen(kTBEException), kTBEException) == 0) { | |||||
| MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc << "], " << res_str | |||||
| << ", function args: " << PyObjectToStr(pArg); | |||||
| } | |||||
| } | |||||
| return ret; | return ret; | ||||
| } | } | ||||