From: @liubuyu Reviewed-by: @zhoufeng54,@jjfeing Signed-off-by: @jjfeingtags/v1.2.0-rc1
| @@ -106,23 +106,7 @@ class TbeProcess: | |||||
| def __init__(self): | def __init__(self): | ||||
| self.__processe_num = multiprocessing.cpu_count() | self.__processe_num = multiprocessing.cpu_count() | ||||
| # max_processes_num: Set the maximum number of concurrent processes for compiler | |||||
| self.max_processes_num = 24 | |||||
| process_num = os.getenv("MS_BUILD_PROCESS_NUM") | |||||
| if process_num is None: | |||||
| self.max_processes_num = 24 | |||||
| logger.info(f"Using default compile process num {self.max_processes_num}") | |||||
| elif process_num.isdigit(): | |||||
| if int(process_num) in range(1, 25): | |||||
| self.max_processes_num = int(process_num) | |||||
| logger.info(f"Using custom compile process num {self.max_processes_num}") | |||||
| else: | |||||
| raise EnvironmentError( | |||||
| f"Env ERROR, [MS_BUILD_PROCESS_NUM] should be in range(1, 25), but: {process_num}") | |||||
| elif not process_num.isdigit(): | |||||
| raise EnvironmentError(f"Env ERROR, [MS_BUILD_PROCESS_NUM] should be a digit, but: {process_num}") | |||||
| if self.__processe_num > self.max_processes_num: | |||||
| self.__processe_num = self.max_processes_num | |||||
| self.default_num = 24 | |||||
| self.__pool = None | self.__pool = None | ||||
| self.__next_task_id = 1 | self.__next_task_id = 1 | ||||
| self.__running_tasks = [] | self.__running_tasks = [] | ||||
| @@ -133,6 +117,27 @@ class TbeProcess: | |||||
| self.__pool.join() | self.__pool.join() | ||||
| del self.__pool | del self.__pool | ||||
| def init_process_num(self): | |||||
| """ | |||||
| init compile process num | |||||
| :return: str Success or other string info | |||||
| """ | |||||
| # max_processes_num: Set the maximum number of concurrent processes for compiler | |||||
| process_num = os.getenv("MS_BUILD_PROCESS_NUM") | |||||
| res = "Success" | |||||
| if process_num is None: | |||||
| logger.info(f"Using default compile process num {self.default_num}") | |||||
| elif process_num.isdigit(): | |||||
| if int(process_num) in range(1, 25): | |||||
| self.default_num = int(process_num) | |||||
| logger.info(f"Using custom compile process num {self.default_num}") | |||||
| else: | |||||
| res = "TBEException",\ | |||||
| "ERROR: [MS_BUILD_PROCESS_NUM] should be in range(1, 25), but got : " + str(process_num) | |||||
| elif not process_num.isdigit(): | |||||
| res = "TBEException", "ERROR: [MS_BUILD_PROCESS_NUM] type should be a int num, but got :" + process_num | |||||
| return res | |||||
| def exit(self): | def exit(self): | ||||
| if self.__pool is not None: | if self.__pool is not None: | ||||
| self.__pool.terminate() | self.__pool.terminate() | ||||
| @@ -149,6 +154,8 @@ class TbeProcess: | |||||
| Returns: | Returns: | ||||
| int, task id(>0). -1 if error | int, task id(>0). -1 if error | ||||
| """ | """ | ||||
| if self.__processe_num > self.default_num: | |||||
| self.__processe_num = self.default_num | |||||
| task_id = self.__next_task_id | task_id = self.__next_task_id | ||||
| self.__next_task_id = self.__next_task_id + 1 | self.__next_task_id = self.__next_task_id + 1 | ||||
| if self.__pool is None: | if self.__pool is None: | ||||
| @@ -24,6 +24,9 @@ class TbeBuilder: | |||||
| def __init__(self): | def __init__(self): | ||||
| self.tbe_builder = create_tbe_parallel_process() | self.tbe_builder = create_tbe_parallel_process() | ||||
| def create(self): | |||||
| return self.tbe_builder.init_process_num() | |||||
| def start(self, json): | def start(self, json): | ||||
| return self.tbe_builder.start_compile_op(json) | return self.tbe_builder.start_compile_op(json) | ||||
| @@ -69,7 +72,10 @@ class AscendMessager(Messager): | |||||
| Reference protocol between them at PR#3821 and PR#3935 | Reference protocol between them at PR#3821 and PR#3935 | ||||
| """ | """ | ||||
| arg = self.get_message() | arg = self.get_message() | ||||
| if arg == 'TBE/START': | |||||
| if arg == 'TBE/PRE': | |||||
| ans = self.tbe_builder.create() | |||||
| self.send_res(ans) | |||||
| elif arg == 'TBE/START': | |||||
| self.send_ack() | self.send_ack() | ||||
| json = self.get_message() | json = self.get_message() | ||||
| res = self.tbe_builder.start(json) | res = self.tbe_builder.start(json) | ||||
| @@ -19,15 +19,30 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| inline static bool init_flag = false; | |||||
| void ReplaceStr(std::string *dest, const std::string &replace, char new_char) { | void ReplaceStr(std::string *dest, const std::string &replace, char new_char) { | ||||
| std::string::size_type start = 0; | std::string::size_type start = 0; | ||||
| while ((start = (*dest).find(replace, start)) != std::string::npos) { | while ((start = (*dest).find(replace, start)) != std::string::npos) { | ||||
| (*dest).replace(start, replace.size(), 1, new_char); | (*dest).replace(start, replace.size(), 1, new_char); | ||||
| start++; // Replaced 1 charactor. | |||||
| start++; // Replaced 1 character. | |||||
| } | } | ||||
| } | } | ||||
| bool AscendKernelBuildClient::TbePre() { | |||||
| auto res = SendRequest(kTbePre); | |||||
| if (res != kSuccess) { | |||||
| MS_LOG(EXCEPTION) << "PRE failed, res: " << res; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| int AscendKernelBuildClient::TbeStart(const std::string &json) { | int AscendKernelBuildClient::TbeStart(const std::string &json) { | ||||
| if (!init_flag) { | |||||
| if (!TbePre()) { | |||||
| MS_LOG(EXCEPTION) << "START failed"; | |||||
| } | |||||
| init_flag = true; | |||||
| } | |||||
| // Start compiling.. | // Start compiling.. | ||||
| auto res = SendRequest(kTbeStart); | auto res = SendRequest(kTbeStart); | ||||
| if (res != kAck) { | if (res != kAck) { | ||||
| @@ -53,7 +68,7 @@ bool AscendKernelBuildClient::TbeWait(int *task_id, std::string *task_result, st | |||||
| } | } | ||||
| // Request task id. | // Request task id. | ||||
| *task_id = std::stoi(SendRequest(kContinue)); | *task_id = std::stoi(SendRequest(kContinue)); | ||||
| // Requst task result. | |||||
| // Request task result. | |||||
| *task_result = SendRequest(kContinue); | *task_result = SendRequest(kContinue); | ||||
| // Request prebuild result. | // Request prebuild result. | ||||
| *pre_build_result = SendRequest(kContinue); | *pre_build_result = SendRequest(kContinue); | ||||
| @@ -193,6 +193,7 @@ class AscendKernelBuildClient : public KernelBuildClient { | |||||
| // Send building request to server | // Send building request to server | ||||
| constexpr inline static auto kContinue = "CONTINUE"; // More transactions to be continued | constexpr inline static auto kContinue = "CONTINUE"; // More transactions to be continued | ||||
| constexpr inline static auto kTbePre = "TBE/PRE"; | |||||
| constexpr inline static auto kTbeStart = "TBE/START"; | constexpr inline static auto kTbeStart = "TBE/START"; | ||||
| constexpr inline static auto kTbeWait = "TBE/WAIT"; | constexpr inline static auto kTbeWait = "TBE/WAIT"; | ||||
| constexpr inline static auto kTbeReset = "TBE/RESET"; | constexpr inline static auto kTbeReset = "TBE/RESET"; | ||||
| @@ -238,6 +239,7 @@ class AscendKernelBuildClient : public KernelBuildClient { | |||||
| AscendKernelBuildClient &operator=(AscendKernelBuildClient &&) = delete; | AscendKernelBuildClient &operator=(AscendKernelBuildClient &&) = delete; | ||||
| private: | private: | ||||
| bool TbePre(); | |||||
| AscendKernelBuildClient() { Open(); } | AscendKernelBuildClient() { Open(); } | ||||
| ~AscendKernelBuildClient() override { Close(); } | ~AscendKernelBuildClient() override { Close(); } | ||||
| }; | }; | ||||