diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py b/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py index f935087be2..ffcc2766a0 100644 --- a/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py @@ -106,23 +106,7 @@ class TbeProcess: def __init__(self): self.__processe_num = multiprocessing.cpu_count() - # max_processes_num: Set the maximum number of concurrent processes for compiler - self.max_processes_num = 24 - process_num = os.getenv("MS_BUILD_PROCESS_NUM") - if process_num is None: - self.max_processes_num = 24 - logger.info(f"Using default compile process num {self.max_processes_num}") - elif process_num.isdigit(): - if int(process_num) in range(1, 25): - self.max_processes_num = int(process_num) - logger.info(f"Using custom compile process num {self.max_processes_num}") - else: - raise EnvironmentError( - f"Env ERROR, [MS_BUILD_PROCESS_NUM] should be in range(1, 25), but: {process_num}") - elif not process_num.isdigit(): - raise EnvironmentError(f"Env ERROR, [MS_BUILD_PROCESS_NUM] should be a digit, but: {process_num}") - if self.__processe_num > self.max_processes_num: - self.__processe_num = self.max_processes_num + self.default_num = 24 self.__pool = None self.__next_task_id = 1 self.__running_tasks = [] @@ -133,6 +117,27 @@ class TbeProcess: self.__pool.join() del self.__pool + def init_process_num(self): + """ + init compile process num + :return: str Success or other string info + """ + # max_processes_num: Set the maximum number of concurrent processes for compiler + process_num = os.getenv("MS_BUILD_PROCESS_NUM") + res = "Success" + if process_num is None: + logger.info(f"Using default compile process num {self.default_num}") + elif process_num.isdigit(): + if int(process_num) in range(1, 25): + self.default_num = int(process_num) + logger.info(f"Using custom compile process num {self.default_num}") + else: + res = "TBEException",\ + "ERROR: [MS_BUILD_PROCESS_NUM] should be in range(1, 25), but got : " + str(process_num) + elif not process_num.isdigit(): + res = "TBEException", "ERROR: [MS_BUILD_PROCESS_NUM] type should be a int num, but got :" + process_num + return res + def exit(self): if self.__pool is not None: self.__pool.terminate() @@ -149,6 +154,8 @@ class TbeProcess: Returns: int, task id(>0). -1 if error """ + if self.__processe_num > self.default_num: + self.__processe_num = self.default_num task_id = self.__next_task_id self.__next_task_id = self.__next_task_id + 1 if self.__pool is None: diff --git a/mindspore/_extends/remote/kernel_build_server_ascend.py b/mindspore/_extends/remote/kernel_build_server_ascend.py index e77beda2f3..60f6421ac6 100644 --- a/mindspore/_extends/remote/kernel_build_server_ascend.py +++ b/mindspore/_extends/remote/kernel_build_server_ascend.py @@ -24,6 +24,9 @@ class TbeBuilder: def __init__(self): self.tbe_builder = create_tbe_parallel_process() + def create(self): + return self.tbe_builder.init_process_num() + def start(self, json): return self.tbe_builder.start_compile_op(json) @@ -69,7 +72,10 @@ class AscendMessager(Messager): Reference protocol between them at PR#3821 and PR#3935 """ arg = self.get_message() - if arg == 'TBE/START': + if arg == 'TBE/PRE': + ans = self.tbe_builder.create() + self.send_res(ans) + elif arg == 'TBE/START': self.send_ack() json = self.get_message() res = self.tbe_builder.start(json) diff --git a/mindspore/ccsrc/backend/session/kernel_build_client.cc b/mindspore/ccsrc/backend/session/kernel_build_client.cc index db76b9f020..dbd06bfbf9 100644 --- a/mindspore/ccsrc/backend/session/kernel_build_client.cc +++ b/mindspore/ccsrc/backend/session/kernel_build_client.cc @@ -19,15 +19,30 @@ namespace mindspore { namespace kernel { +inline static bool init_flag = false; void ReplaceStr(std::string *dest, const std::string &replace, char new_char) { std::string::size_type start = 0; while ((start = (*dest).find(replace, start)) != std::string::npos) { (*dest).replace(start, replace.size(), 1, new_char); - start++; // Replaced 1 charactor. + start++; // Replaced 1 character. } } +bool AscendKernelBuildClient::TbePre() { + auto res = SendRequest(kTbePre); + if (res != kSuccess) { + MS_LOG(EXCEPTION) << "PRE failed, res: " << res; + } + return true; +} + int AscendKernelBuildClient::TbeStart(const std::string &json) { + if (!init_flag) { + if (!TbePre()) { + MS_LOG(EXCEPTION) << "START failed"; + } + init_flag = true; + } // Start compiling.. auto res = SendRequest(kTbeStart); if (res != kAck) { @@ -53,7 +68,7 @@ bool AscendKernelBuildClient::TbeWait(int *task_id, std::string *task_result, st } // Request task id. *task_id = std::stoi(SendRequest(kContinue)); - // Requst task result. + // Request task result. *task_result = SendRequest(kContinue); // Request prebuild result. *pre_build_result = SendRequest(kContinue); diff --git a/mindspore/ccsrc/backend/session/kernel_build_client.h b/mindspore/ccsrc/backend/session/kernel_build_client.h index 0dd6106dd3..e10932e6f4 100644 --- a/mindspore/ccsrc/backend/session/kernel_build_client.h +++ b/mindspore/ccsrc/backend/session/kernel_build_client.h @@ -193,6 +193,7 @@ class AscendKernelBuildClient : public KernelBuildClient { // Send building request to server constexpr inline static auto kContinue = "CONTINUE"; // More transactions to be continued + constexpr inline static auto kTbePre = "TBE/PRE"; constexpr inline static auto kTbeStart = "TBE/START"; constexpr inline static auto kTbeWait = "TBE/WAIT"; constexpr inline static auto kTbeReset = "TBE/RESET"; @@ -238,6 +239,7 @@ class AscendKernelBuildClient : public KernelBuildClient { AscendKernelBuildClient &operator=(AscendKernelBuildClient &&) = delete; private: + bool TbePre(); AscendKernelBuildClient() { Open(); } ~AscendKernelBuildClient() override { Close(); } };