From: @liubuyu Reviewed-by: @zhoufeng54,@jjfeing Signed-off-by: @jjfeingtags/v1.2.0-rc1
| @@ -106,23 +106,7 @@ class TbeProcess: | |||
| def __init__(self): | |||
| self.__processe_num = multiprocessing.cpu_count() | |||
| # max_processes_num: Set the maximum number of concurrent processes for compiler | |||
| self.max_processes_num = 24 | |||
| process_num = os.getenv("MS_BUILD_PROCESS_NUM") | |||
| if process_num is None: | |||
| self.max_processes_num = 24 | |||
| logger.info(f"Using default compile process num {self.max_processes_num}") | |||
| elif process_num.isdigit(): | |||
| if int(process_num) in range(1, 25): | |||
| self.max_processes_num = int(process_num) | |||
| logger.info(f"Using custom compile process num {self.max_processes_num}") | |||
| else: | |||
| raise EnvironmentError( | |||
| f"Env ERROR, [MS_BUILD_PROCESS_NUM] should be in range(1, 25), but: {process_num}") | |||
| elif not process_num.isdigit(): | |||
| raise EnvironmentError(f"Env ERROR, [MS_BUILD_PROCESS_NUM] should be a digit, but: {process_num}") | |||
| if self.__processe_num > self.max_processes_num: | |||
| self.__processe_num = self.max_processes_num | |||
| self.default_num = 24 | |||
| self.__pool = None | |||
| self.__next_task_id = 1 | |||
| self.__running_tasks = [] | |||
| @@ -133,6 +117,27 @@ class TbeProcess: | |||
| self.__pool.join() | |||
| del self.__pool | |||
| def init_process_num(self): | |||
| """ | |||
| init compile process num | |||
| :return: str Success or other string info | |||
| """ | |||
| # max_processes_num: Set the maximum number of concurrent processes for compiler | |||
| process_num = os.getenv("MS_BUILD_PROCESS_NUM") | |||
| res = "Success" | |||
| if process_num is None: | |||
| logger.info(f"Using default compile process num {self.default_num}") | |||
| elif process_num.isdigit(): | |||
| if int(process_num) in range(1, 25): | |||
| self.default_num = int(process_num) | |||
| logger.info(f"Using custom compile process num {self.default_num}") | |||
| else: | |||
| res = "TBEException",\ | |||
| "ERROR: [MS_BUILD_PROCESS_NUM] should be in range(1, 25), but got : " + str(process_num) | |||
| elif not process_num.isdigit(): | |||
| res = "TBEException", "ERROR: [MS_BUILD_PROCESS_NUM] type should be a int num, but got :" + process_num | |||
| return res | |||
| def exit(self): | |||
| if self.__pool is not None: | |||
| self.__pool.terminate() | |||
| @@ -149,6 +154,8 @@ class TbeProcess: | |||
| Returns: | |||
| int, task id(>0). -1 if error | |||
| """ | |||
| if self.__processe_num > self.default_num: | |||
| self.__processe_num = self.default_num | |||
| task_id = self.__next_task_id | |||
| self.__next_task_id = self.__next_task_id + 1 | |||
| if self.__pool is None: | |||
| @@ -24,6 +24,9 @@ class TbeBuilder: | |||
| def __init__(self): | |||
| self.tbe_builder = create_tbe_parallel_process() | |||
| def create(self): | |||
| return self.tbe_builder.init_process_num() | |||
| def start(self, json): | |||
| return self.tbe_builder.start_compile_op(json) | |||
| @@ -69,7 +72,10 @@ class AscendMessager(Messager): | |||
| Reference protocol between them at PR#3821 and PR#3935 | |||
| """ | |||
| arg = self.get_message() | |||
| if arg == 'TBE/START': | |||
| if arg == 'TBE/PRE': | |||
| ans = self.tbe_builder.create() | |||
| self.send_res(ans) | |||
| elif arg == 'TBE/START': | |||
| self.send_ack() | |||
| json = self.get_message() | |||
| res = self.tbe_builder.start(json) | |||
| @@ -19,15 +19,30 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| inline static bool init_flag = false; | |||
| void ReplaceStr(std::string *dest, const std::string &replace, char new_char) { | |||
| std::string::size_type start = 0; | |||
| while ((start = (*dest).find(replace, start)) != std::string::npos) { | |||
| (*dest).replace(start, replace.size(), 1, new_char); | |||
| start++; // Replaced 1 charactor. | |||
| start++; // Replaced 1 character. | |||
| } | |||
| } | |||
| bool AscendKernelBuildClient::TbePre() { | |||
| auto res = SendRequest(kTbePre); | |||
| if (res != kSuccess) { | |||
| MS_LOG(EXCEPTION) << "PRE failed, res: " << res; | |||
| } | |||
| return true; | |||
| } | |||
| int AscendKernelBuildClient::TbeStart(const std::string &json) { | |||
| if (!init_flag) { | |||
| if (!TbePre()) { | |||
| MS_LOG(EXCEPTION) << "START failed"; | |||
| } | |||
| init_flag = true; | |||
| } | |||
| // Start compiling.. | |||
| auto res = SendRequest(kTbeStart); | |||
| if (res != kAck) { | |||
| @@ -53,7 +68,7 @@ bool AscendKernelBuildClient::TbeWait(int *task_id, std::string *task_result, st | |||
| } | |||
| // Request task id. | |||
| *task_id = std::stoi(SendRequest(kContinue)); | |||
| // Requst task result. | |||
| // Request task result. | |||
| *task_result = SendRequest(kContinue); | |||
| // Request prebuild result. | |||
| *pre_build_result = SendRequest(kContinue); | |||
| @@ -193,6 +193,7 @@ class AscendKernelBuildClient : public KernelBuildClient { | |||
| // Send building request to server | |||
| constexpr inline static auto kContinue = "CONTINUE"; // More transactions to be continued | |||
| constexpr inline static auto kTbePre = "TBE/PRE"; | |||
| constexpr inline static auto kTbeStart = "TBE/START"; | |||
| constexpr inline static auto kTbeWait = "TBE/WAIT"; | |||
| constexpr inline static auto kTbeReset = "TBE/RESET"; | |||
| @@ -238,6 +239,7 @@ class AscendKernelBuildClient : public KernelBuildClient { | |||
| AscendKernelBuildClient &operator=(AscendKernelBuildClient &&) = delete; | |||
| private: | |||
| bool TbePre(); | |||
| AscendKernelBuildClient() { Open(); } | |||
| ~AscendKernelBuildClient() override { Close(); } | |||
| }; | |||