diff --git a/mindspore/_extends/remote/kernel_build_server.py b/mindspore/_extends/remote/kernel_build_server.py index c3c07beb4f..48ad7b3e96 100644 --- a/mindspore/_extends/remote/kernel_build_server.py +++ b/mindspore/_extends/remote/kernel_build_server.py @@ -14,52 +14,21 @@ # ============================================================================ """kernel build server""" import os -import sys import time -from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_process, op_select_format, check_supported -from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process - -class TbeBuilder: - """Tbe building wrapper""" - - def __init__(self): - self.tbe_builder = create_tbe_parallel_process() - - def start(self, json): - return self.tbe_builder.start_compile_op(json) - - def wait(self): - return self.tbe_builder.wait_one() - - def reset(self): - self.tbe_builder.reset_task_info() - - def exit(self): - self.tbe_builder.exit() - -class AkgBuilder: - """Akg building wrapper""" - - def __init__(self): - pass - - def create(self, process_num, waitime): - self.akg_builder = create_akg_parallel_process(process_num, waitime) - - def accept_json(self, json): - return self.akg_builder.accept_json(json) - - def compile(self): - return self.akg_builder.compile() class Messager: '''Messager''' - def __init__(self): - logger.info('[TRACE]', 'Messager init...') + def __init__(self, fdin, fdout): + self.fdin = fdin + self.fdout = fdout + self.fin = os.fdopen(fdin, "r") + self.fout = os.fdopen(fdout, "w") self.message = '' - self.tbe_builder = TbeBuilder() - self.akg_builder = AkgBuilder() + + def __del__(self): + os.close(self.fdin) + os.close(self.fdout) def get_message(self): """ @@ -72,7 +41,7 @@ class Messager: # Not read by input() anymore res = self.fin.readline() if not res: - logger.info('[TRACE]', "read ") + logger.debug('[TRACE]', "read nothing...") self.exit() if res[len(res) - 1] == '\n': res = res[0:len(res)-1] @@ -82,7 +51,7 @@ class Messager: self.exit() finally: pass - if self.message == '' or self.message == 'FIN': + if self.message == '' or self.message == 'FINISH': self.send_ack() self.exit() return self.message @@ -123,76 +92,6 @@ class Messager: else: self.send_res('ERR') - def handle(self): - """ - Communicate with remote - """ - arg = self.get_message() - if arg == 'TBE/START': - self.send_ack() - json = self.get_message() - res = self.tbe_builder.start(json) - self.send_res(res) - elif arg == 'TBE/WAIT': - self.send_ack() - task_id, res, pre = self.tbe_builder.wait() - logger.debug('[TRACE]', str(task_id) + '/' + str(res) + '/' + str(pre)) - if self.get_message() != 'CONT': - self.send_ack(False) - self.exit() - self.send_res(task_id) - if self.get_message() != 'CONT': - self.send_ack(False) - self.exit() - self.send_res(res) - if self.get_message() != 'CONT': - self.send_ack(False) - self.exit() - self.send_res(pre) - elif arg == 'TBE/RESET': - self.tbe_builder.reset() - self.send_ack() - elif arg == 'AKG/START': - self.send_ack() - process_num_str = self.get_message() - self.send_ack() - wait_time_str = self.get_message() - self.akg_builder.create(int(process_num_str), int(wait_time_str)) - self.send_ack() - elif arg == 'AKG/DATA': - self.send_ack() - while True: - req = self.get_message() - if req.startswith('{'): - self.akg_builder.accept_json(req) - self.send_ack() - elif req == 'AKG/WAIT': - res = self.akg_builder.compile() - self.send_res(res) - break - else: - self.send_ack(False) - break - elif arg == 'FORMAT': - self.send_ack() - json = self.get_message() - self.send_res(op_select_format(json)) - elif arg == 'SUPPORT': - self.send_ack() - json = self.get_message() - logger.debug('[SUPPORT]', json) - try: - res = check_supported(json) - except json.decoder.JSONDecodeError: - self.send_ack(False) - self.exit() - finally: - pass - self.send_res(res) - else: - self.send_ack(False) - self.exit() - def loop(self): """ Messaging loop @@ -200,20 +99,26 @@ class Messager: while True: self.handle() + def run(self): + self.loop() + + def handle(self): + """ + A interface communicates with remote. + + Note: + All subclasses should override this interface. + """ + raise NotImplementedError + def exit(self): - os.close(self.fdin) - os.close(self.fdout) - self.tbe_builder.reset() - self.tbe_builder.exit() - logger.info('[TRACE]', 'Messager Exit...') - exit() + """ + A interface handles the procedure before exit. - def run(self, fdin, fdout): - self.fdin = fdin - self.fdout = fdout - self.fin = os.fdopen(fdin, "r") - self.fout = os.fdopen(fdout, "w") - self.loop() + Note: + All subclasses should override this interface. + """ + raise NotImplementedError class Logger: """ @@ -265,9 +170,5 @@ class DummyLogger: logger = DummyLogger() -if __name__ == '__main__': - if len(sys.argv) != 3: - raise Exception('Incorrect argv: {}'.format(sys.argv)) - logger.debug('[TRACE]', 'argv: ' + str(sys.argv)) - messager = Messager() - messager.run(int(sys.argv[1]), int(sys.argv[2])) +def get_logger(): + return logger diff --git a/mindspore/_extends/remote/kernel_build_server_ascend.py b/mindspore/_extends/remote/kernel_build_server_ascend.py new file mode 100644 index 0000000000..e77beda2f3 --- /dev/null +++ b/mindspore/_extends/remote/kernel_build_server_ascend.py @@ -0,0 +1,148 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""kernel build server for ascend""" +import sys +from mindspore._extends.remote.kernel_build_server import Messager, get_logger +from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_process, op_select_format, check_supported +from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process + +class TbeBuilder: + """Tbe building wrapper""" + + def __init__(self): + self.tbe_builder = create_tbe_parallel_process() + + def start(self, json): + return self.tbe_builder.start_compile_op(json) + + def wait(self): + return self.tbe_builder.wait_one() + + def reset(self): + self.tbe_builder.reset_task_info() + + def exit(self): + self.tbe_builder.exit() + +class AkgBuilder: + """Akg building wrapper""" + + def __init__(self): + pass + + def create(self, process_num, waitime): + self.akg_builder = create_akg_parallel_process(process_num, waitime) + + def accept_json(self, json): + return self.akg_builder.accept_json(json) + + def compile(self): + return self.akg_builder.compile() + +class AscendMessager(Messager): + ''' + Ascend Messager + It works as a server, communicating with c++ client. + ''' + + def __init__(self, fdin, fdout): + super().__init__(fdin, fdout) + get_logger().info('[TRACE]', 'Ascend Messager init...') + self.tbe_builder = TbeBuilder() + self.akg_builder = AkgBuilder() + + def handle(self): + """ + Communicate with remote client. + Reference protocol between them at PR#3821 and PR#3935 + """ + arg = self.get_message() + if arg == 'TBE/START': + self.send_ack() + json = self.get_message() + res = self.tbe_builder.start(json) + self.send_res(res) + elif arg == 'TBE/WAIT': + self.send_ack() + task_id, res, pre = self.tbe_builder.wait() + get_logger().debug('[TRACE]', str(task_id) + '/' + str(res) + '/' + str(pre)) + if self.get_message() != 'CONTINUE': + self.send_ack(False) + self.exit() + self.send_res(task_id) + if self.get_message() != 'CONTINUE': + self.send_ack(False) + self.exit() + self.send_res(res) + if self.get_message() != 'CONTINUE': + self.send_ack(False) + self.exit() + self.send_res(pre) + elif arg == 'TBE/RESET': + self.tbe_builder.reset() + self.send_ack() + elif arg == 'AKG/START': + self.send_ack() + process_num_str = self.get_message() + self.send_ack() + wait_time_str = self.get_message() + self.akg_builder.create(int(process_num_str), int(wait_time_str)) + self.send_ack() + elif arg == 'AKG/DATA': + self.send_ack() + while True: + req = self.get_message() + if req.startswith('{'): + self.akg_builder.accept_json(req) + self.send_ack() + elif req == 'AKG/WAIT': + res = self.akg_builder.compile() + self.send_res(res) + break + else: + self.send_ack(False) + break + elif arg == 'FORMAT': + self.send_ack() + json = self.get_message() + self.send_res(op_select_format(json)) + elif arg == 'SUPPORT': + self.send_ack() + json = self.get_message() + get_logger().debug('[SUPPORT]', json) + try: + res = check_supported(json) + except json.decoder.JSONDecodeError: + self.send_ack(False) + self.exit() + finally: + pass + self.send_res(res) + else: + self.send_ack(False) + self.exit() + + def exit(self): + self.tbe_builder.reset() + self.tbe_builder.exit() + get_logger().info('[TRACE]', 'Ascend Messager Exit...') + exit() + +if __name__ == '__main__': + if len(sys.argv) != 3: + raise Exception('Incorrect argv: {}'.format(sys.argv)) + get_logger().debug('[TRACE]', 'argv: ' + str(sys.argv)) + messager = AscendMessager(int(sys.argv[1]), int(sys.argv[2])) + messager.run() diff --git a/mindspore/_extends/remote/kernel_build_server_gpu.py b/mindspore/_extends/remote/kernel_build_server_gpu.py new file mode 100644 index 0000000000..8bdf5805af --- /dev/null +++ b/mindspore/_extends/remote/kernel_build_server_gpu.py @@ -0,0 +1,63 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""kernel build server for gpu""" +import os +import sys +from mindspore._extends.remote.kernel_build_server import Messager, get_logger +from mindspore._extends.parallel_compile.akg_compiler.compiler import run_compiler as akg_compile_single + +class GpuMessager(Messager): + ''' + GPU Messager + It works as a server, communicating with c++ client. + ''' + + def __init__(self, fdin, fdout): + super().__init__(fdin, fdout) + get_logger().info('[TRACE]', 'GPU Messager init...') + + def handle(self): + """ + Communicate with remote client. + Reference protocol between them at PR#4063 + """ + arg = self.get_message() + if arg == 'AKG/PID': + self.send_res(os.getpid()) + elif arg == 'AKG/COMPILE': + self.send_ack() + json = self.get_message() + try: + akg_compile_single(json) + except ValueError: + self.send_ack(False) + self.exit() + finally: + pass + self.send_ack() + else: + self.send_ack(False) + self.exit() + + def exit(self): + get_logger().info('[TRACE]', 'GPU Messager Exit...') + exit() + +if __name__ == '__main__': + if len(sys.argv) != 3: + raise Exception('Incorrect argv: {}'.format(sys.argv)) + get_logger().debug('[TRACE]', 'argv: ' + str(sys.argv)) + messager = GpuMessager(int(sys.argv[1]), int(sys.argv[2])) + messager.run() diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc index 033dfcf3da..1dad3d4e57 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc @@ -15,7 +15,6 @@ */ #include "backend/kernel_compiler/akg/akg_kernel_build.h" -#include #include #include #include @@ -37,13 +36,10 @@ #include "utils/utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h" +#include "backend/session/kernel_build_client.h" namespace mindspore { namespace kernel { -constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200; -constexpr int32_t ARGS_SIZE = 1; -constexpr auto kCompileWithJsonFunc = "compilewithjson"; - // json key constexpr auto kOpDesc = "op_desc"; constexpr auto kInputDesc = "input_desc"; @@ -70,25 +66,6 @@ std::string Vector2Str(const std::vector &inputs) { } } // namespace -std::string AkgKernelBuild::PyObjectToStr(PyObject *const PyObj) { - char *pChar = nullptr; - std::string str_res; - if (PyObj == nullptr) { - MS_LOG(ERROR) << "Input parameter is nullptr."; - return str_res; - } - PyObject *strArgs = PyObject_Str(PyObj); - if (strArgs != nullptr) { - (void)PyArg_Parse(strArgs, "s", &pChar); - } - if (pChar == nullptr) { - MS_LOG(ERROR) << "pChar is nullptr."; - return str_res; - } - str_res = pChar; - return str_res; -} - std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, const std::pair &position) { if (node_json.count(tag) == 0) { @@ -528,32 +505,11 @@ KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNod return cached_kernel_pack; } - PyObject *pModule = nullptr; - PyObject *pFunc = nullptr; - PyObject *pArg = nullptr; - PyObject *pRes = nullptr; - - pModule = PyImport_ImportModule(kAkgModule); - if (pModule == nullptr) { - MS_LOG(ERROR) << "Failed to import [" << kAkgModule << "]."; - return nullptr; - } - - pFunc = PyObject_GetAttrString(pModule, kCompileWithJsonFunc); - pArg = PyTuple_New(ARGS_SIZE); - (void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", node_json.c_str())); - (void)alarm(AUTODIFF_COMPILE_OVERTIME); - pRes = PyEval_CallObject(pFunc, pArg); + auto res = GpuKernelBuildClient::Instance().AkgCompileSingle(node_json); (void)alarm(0); - if (pRes == nullptr) { - MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(pArg) << ")."; - return nullptr; - } - if (PyObject_IsTrue(pRes) != 1) { - MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(pArg) << ")."; + if (!res) { + MS_LOG(ERROR) << "Akg compile failed, json: " << node_json; return nullptr; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc index 8a7c02e790..57217c66b4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc @@ -324,15 +324,15 @@ bool AkgOpParallelBuild(const std::vector fusion_type_maps = { {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE}, }; -void KernelMeta::Initialize() { - kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/"; +void KernelMeta::Initialize(int pid) { + if (pid == -1) { + kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/"; + } else { + kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(pid) + "/"; + } // remove old kernel cache RemoveKernelCache(); diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h index a59b1bf387..d9ebfe3c4c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h @@ -40,7 +40,6 @@ constexpr auto kProcessorCuda = "cuda"; constexpr auto kJsonSuffix = ".json"; constexpr auto kInfoSuffix = ".info"; constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600; -constexpr auto kAkgModule = "akg.ms"; constexpr auto kArgDataformat = "data_format"; const std::vector support_devices = {"aicore", "aicpu", "cuda"}; @@ -54,7 +53,7 @@ using KernelMetaPtr = std::shared_ptr; class KernelMeta { public: KernelMeta() = default; - void Initialize(); + void Initialize(int pid); void RemoveKernelCache(); std::string Search(const std::string &kernel_name) const; bool Insert(const std::string &kernel_name, const std::string &kernel_json); diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc index 4c01167b0e..7a625268d3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc @@ -272,12 +272,12 @@ KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const s } int ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) { - return KernelBuildClient::Instance().TbeStart(kernel_json.dump()); + return AscendKernelBuildClient::Instance().TbeStart(kernel_json.dump()); } bool ParallelBuildManager::WaitOne(int *task_id, std::string *task_result, std::string *pre_build_result) { MS_EXCEPTION_IF_NULL(task_id); - return KernelBuildClient::Instance().TbeWait(task_id, task_result, pre_build_result); + return AscendKernelBuildClient::Instance().TbeWait(task_id, task_result, pre_build_result); } void ParallelBuildManager::ResetTaskInfo() { @@ -287,7 +287,7 @@ void ParallelBuildManager::ResetTaskInfo() { } task_map_.clear(); same_op_list_.clear(); - KernelBuildClient::Instance().TbeReset(); + AscendKernelBuildClient::Instance().TbeReset(); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc index 2c9bf68181..5635811425 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -312,7 +312,7 @@ bool TbeKernelSelect::TbeCheckSupported( if (!ret) { MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed."; } - ret = KernelBuildClient::Instance().CheckSupported(kernel_json.dump()); + ret = AscendKernelBuildClient::Instance().CheckSupported(kernel_json.dump()); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get()); return ret; } @@ -486,7 +486,7 @@ std::string TbeKernelSelect::OpSelectFormat() { if (!ret) { MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed."; } - res_json_str = KernelBuildClient::Instance().SelectFormat(kernel_json.dump()); + res_json_str = AscendKernelBuildClient::Instance().SelectFormat(kernel_json.dump()); if (res_json_str.empty()) { MS_LOG(EXCEPTION) << "op select format error."; } diff --git a/mindspore/ccsrc/backend/session/kernel_build_client.cc b/mindspore/ccsrc/backend/session/kernel_build_client.cc index 847d24096a..e12b6de9bf 100644 --- a/mindspore/ccsrc/backend/session/kernel_build_client.cc +++ b/mindspore/ccsrc/backend/session/kernel_build_client.cc @@ -29,7 +29,7 @@ void ReplaceStr(std::string *dest, const std::string &replace, char new_char) { } } -int KernelBuildClient::TbeStart(const std::string &json) { +int AscendKernelBuildClient::TbeStart(const std::string &json) { // Start compiling.. auto res = SendRequest(kTbeStart); if (res != kAck) { @@ -46,7 +46,7 @@ int KernelBuildClient::TbeStart(const std::string &json) { return std::stoi(res); } -bool KernelBuildClient::TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result) { +bool AscendKernelBuildClient::TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result) { // Start waiting.. auto res = SendRequest(kTbeWait); if (res != kAck) { @@ -54,15 +54,15 @@ bool KernelBuildClient::TbeWait(int *task_id, std::string *task_result, std::str return false; } // Request task id. - *task_id = std::stoi(SendRequest(kCont)); + *task_id = std::stoi(SendRequest(kContinue)); // Requst task result. - *task_result = SendRequest(kCont); + *task_result = SendRequest(kContinue); // Request prebuild result. - *pre_build_result = SendRequest(kCont); + *pre_build_result = SendRequest(kContinue); return true; } -void KernelBuildClient::TbeReset() { +void AscendKernelBuildClient::TbeReset() { // Start compiling.. auto res = SendRequest(kTbeReset); if (res != kAck) { @@ -70,7 +70,7 @@ void KernelBuildClient::TbeReset() { } } -bool KernelBuildClient::AkgStart(int process_num, int wait_time) { +bool AscendKernelBuildClient::AkgStart(int process_num, int wait_time) { // Start compiling.. auto res = SendRequest(kAkgStart); if (res != kAck) { @@ -92,7 +92,7 @@ bool KernelBuildClient::AkgStart(int process_num, int wait_time) { return true; } -bool KernelBuildClient::AkgSendData(const std::vector &jsons) { +bool AscendKernelBuildClient::AkgSendData(const std::vector &jsons) { auto res = SendRequest(kAkgData); if (res != kAck) { MS_LOG(ERROR) << "AKG/DATA failed, res: " << res; @@ -109,7 +109,7 @@ bool KernelBuildClient::AkgSendData(const std::vector &jsons) { } // Fetch the result of AKG compiling. -bool KernelBuildClient::AkgWait() { +bool AscendKernelBuildClient::AkgWait() { auto res = SendRequest(kAkgWait); if (res != kTrue) { MS_LOG(ERROR) << "AKG/WAIT failed, res: " << res; @@ -118,7 +118,7 @@ bool KernelBuildClient::AkgWait() { return true; } -std::string KernelBuildClient::SelectFormat(const std::string &json) { +std::string AscendKernelBuildClient::SelectFormat(const std::string &json) { // Start compiling.. auto res = SendRequest(kFormat); if (res != kAck) { @@ -134,7 +134,7 @@ std::string KernelBuildClient::SelectFormat(const std::string &json) { return res; } -bool KernelBuildClient::CheckSupported(const std::string &json) { +bool AscendKernelBuildClient::CheckSupported(const std::string &json) { // Checking support.. auto res = SendRequest(kSupport); if (res != kAck) { @@ -149,5 +149,29 @@ bool KernelBuildClient::CheckSupported(const std::string &json) { } return true; } + +int GpuKernelBuildClient::AkgGetPid() { + auto res = SendRequest(kAkgPid); + if (res == kErr) { + MS_LOG(ERROR) << "AKG/PID failed, res: " << res; + return -1; + } + return std::stoi(res); +} + +bool GpuKernelBuildClient::AkgCompileSingle(const std::string json) { + auto res = SendRequest(kAkgCompileOp); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/COMPILE failed, res: " << res; + return false; + } + // Send single json data. + res = SendRequest(json); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/COMPILE responds failed, res: " << res; + return false; + } + return true; +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/kernel_build_client.h b/mindspore/ccsrc/backend/session/kernel_build_client.h index 817e285197..f442c8846d 100644 --- a/mindspore/ccsrc/backend/session/kernel_build_client.h +++ b/mindspore/ccsrc/backend/session/kernel_build_client.h @@ -29,97 +29,37 @@ namespace mindspore { namespace kernel { void ReplaceStr(std::string *dest, const std::string &replace, char new_char); + +constexpr inline static int kBufferSize = 4096; +// The TAG as prefix of real command from remote. +constexpr inline static auto kTag = "[~]"; + class KernelBuildClient { public: - // Server configure - constexpr inline static auto kEnv = "python"; - constexpr inline static auto kGetPathScript = - "-c " - "\"" - "import pkgutil;" - "path = pkgutil" - ".get_loader(\\\"mindspore._extends.remote.kernel_build_server\\\")" // Server module name - ".get_filename();" - "print('[~]' + path)" - "\""; - + // Send Finish request to server + constexpr inline static auto kFinish = "FINISH"; // Receive the response from server constexpr inline static auto kAck = "ACK"; constexpr inline static auto kErr = "ERR"; - constexpr inline static auto kFailed = "-1"; - // Send Finish request to server - constexpr inline static auto kFin = "FIN"; - - // Send building request to server - constexpr inline static auto kTbeStart = "TBE/START"; - constexpr inline static auto kTbeWait = "TBE/WAIT"; - constexpr inline static auto kCont = "CONT"; - constexpr inline static auto kSuccess = "Success"; - constexpr inline static auto kTbeReset = "TBE/RESET"; - constexpr inline static auto kAkgStart = "AKG/START"; - constexpr inline static auto kAkgData = "AKG/DATA"; - constexpr inline static auto kAkgWait = "AKG/WAIT"; - - // Send server info. query to server - constexpr inline static auto kFormat = "FORMAT"; - constexpr inline static auto kSupport = "SUPPORT"; constexpr inline static auto kTrue = "True"; + constexpr inline static auto kSuccess = "Success"; // Revert \n, \r, [space]. constexpr inline static auto kLF = "[LF]"; constexpr inline static auto kCR = "[CR]"; constexpr inline static auto kSP = "[SP]"; - // The TAG as prefix of real command from remote. - constexpr inline static auto kTag = "[~]"; - - constexpr inline static int kBufferSize = 4096; constexpr inline static unsigned int kTimeOutSeconds = 350; - static KernelBuildClient &Instance() { - static KernelBuildClient instance; - return instance; - } - std::string GetScriptPath() { - std::string cmd = kEnv; - (void)cmd.append(1, ' ').append(kGetPathScript); - FILE *fpipe = popen(cmd.c_str(), "r"); - if (fpipe == nullptr) { - MS_LOG(EXCEPTION) << "popen failed, " << strerror(errno) << "(" << errno << ")"; - } - bool start = false; - std::string result; - char buf[kBufferSize]; - while (std::fgets(buf, sizeof(buf), fpipe) != nullptr) { - if (std::strncmp(buf, kTag, std::strlen(kTag)) == 0) { - start = true; - } - // Filter with 'kTAG' and '\n' - if (start) { - auto size = std::strlen(buf); - bool line_end = buf[size - 1] == '\n'; - result.append(buf, line_end ? size - 1 : size); - if (line_end) { - break; - } - } - } - pclose(fpipe); - const std::string py_suffix = ".py"; - if (result.empty() || result.rfind(py_suffix) != (result.length() - py_suffix.length())) { - MS_LOG(EXCEPTION) << "py file seems incorrect, result: {" << result << "}"; - } - result = result.substr(strlen(kTag)); - MS_LOG(DEBUG) << "result: " << result; - return result; - } + virtual std::string GetEnv() = 0; + virtual std::string GetScript() = 0; void Open() { if (!init_) { // Exception's thrown if open failed - if (dp_->Open({kEnv, GetScriptPath()}, true) != -1) { + if (dp_->Open({GetEnv(), GetScript()}, true) != -1) { dp_->SetTimeOutSeconds(kTimeOutSeconds); - dp_->SetTimeOutCallback([this]() { SendRequest(kFin); }); + dp_->SetTimeOutCallback([this]() { SendRequest(kFinish); }); init_ = true; } } @@ -164,6 +104,88 @@ class KernelBuildClient { return res; } + protected: + KernelBuildClient() : init_(false), dp_(std::make_shared()) {} + virtual ~KernelBuildClient() = default; + + private: + bool init_; + std::shared_ptr dp_; +}; + +static inline std::string GetScriptFilePath(const std::string cmd_env, const std::string &cmd_script) { + std::string cmd = cmd_env; + (void)cmd.append(1, ' ').append(cmd_script); + FILE *fpipe = popen(cmd.c_str(), "r"); + if (fpipe == nullptr) { + MS_LOG(EXCEPTION) << "popen failed, " << strerror(errno) << "(" << errno << ")"; + } + bool start = false; + std::string result; + char buf[kBufferSize]; + while (std::fgets(buf, sizeof(buf), fpipe) != nullptr) { + if (std::strncmp(buf, kTag, std::strlen(kTag)) == 0) { + start = true; + } + // Filter with 'kTAG' and '\n' + if (start) { + auto size = std::strlen(buf); + bool line_end = buf[size - 1] == '\n'; + result.append(buf, line_end ? size - 1 : size); + if (line_end) { + break; + } + } + } + pclose(fpipe); + const std::string py_suffix = ".py"; + if (result.empty() || result.rfind(py_suffix) != (result.length() - py_suffix.length())) { + MS_LOG(EXCEPTION) << "py file seems incorrect, result: {" << result << "}"; + } + result = result.substr(strlen(kTag)); + MS_LOG(DEBUG) << "result: " << result; + return result; +} + +class AscendKernelBuildClient : public KernelBuildClient { + public: + // Server configure + constexpr inline static auto kEnv = "python"; + constexpr inline static auto kGetPathScript = + "-c " + "\"" + "import pkgutil;" + "path = pkgutil" + ".get_loader(\\\"mindspore._extends.remote.kernel_build_server_ascend\\\")" // Server module name + ".get_filename();" + "print('[~]' + path)" + "\""; + + // Receive the response from server + constexpr inline static auto kFailed = "-1"; + + // Send building request to server + constexpr inline static auto kContinue = "CONTINUE"; // More transactions to be continued + constexpr inline static auto kTbeStart = "TBE/START"; + constexpr inline static auto kTbeWait = "TBE/WAIT"; + constexpr inline static auto kTbeReset = "TBE/RESET"; + constexpr inline static auto kAkgStart = "AKG/START"; + constexpr inline static auto kAkgData = "AKG/DATA"; + constexpr inline static auto kAkgWait = "AKG/WAIT"; + + // Send server info. query to server + constexpr inline static auto kFormat = "FORMAT"; + constexpr inline static auto kSupport = "SUPPORT"; + + static AscendKernelBuildClient &Instance() { + static AscendKernelBuildClient instance; + return instance; + } + + std::string GetEnv() override { return kEnv; } + + std::string GetScript() override { return GetScriptFilePath(kEnv, kGetPathScript); } + // Before building. std::string SelectFormat(const std::string &json); bool CheckSupported(const std::string &json); @@ -177,19 +199,60 @@ class KernelBuildClient { bool AkgStart(int process_num, int wait_time); bool AkgSendData(const std::vector &jsons); bool AkgWait(); + bool AkgCompileSingle(const std::string json); - KernelBuildClient(const KernelBuildClient &) = delete; - KernelBuildClient &operator=(const KernelBuildClient &) = delete; + AscendKernelBuildClient(const AscendKernelBuildClient &) = delete; + AscendKernelBuildClient &operator=(const AscendKernelBuildClient &) = delete; - KernelBuildClient(KernelBuildClient &&) = delete; - KernelBuildClient &operator=(KernelBuildClient &&) = delete; + AscendKernelBuildClient(AscendKernelBuildClient &&) = delete; + AscendKernelBuildClient &operator=(AscendKernelBuildClient &&) = delete; private: - KernelBuildClient() : init_(false), dp_(std::make_shared()) { Open(); } - ~KernelBuildClient() { Close(); } + AscendKernelBuildClient() { Open(); } + ~AscendKernelBuildClient() override { Close(); } +}; - bool init_; - std::shared_ptr dp_; +class GpuKernelBuildClient : public KernelBuildClient { + public: + // Server configure + constexpr inline static auto kEnv = "python"; + constexpr inline static auto kGetPathScript = + "-c " + "\"" + "import pkgutil;" + "path = pkgutil" + ".get_loader(\\\"mindspore._extends.remote.kernel_build_server_gpu\\\")" // Server module name + ".get_filename();" + "print('[~]' + path)" + "\""; + + // Send building request to server + constexpr inline static auto kAkgPid = "AKG/PID"; + constexpr inline static auto kAkgCompileOp = "AKG/COMPILE"; // Compile a single op + + static GpuKernelBuildClient &Instance() { + static GpuKernelBuildClient instance; + return instance; + } + + std::string GetEnv() override { return kEnv; } + + std::string GetScript() override { return GetScriptFilePath(kEnv, kGetPathScript); } + + // Fetch pid(pid_t) from remote. + int AkgGetPid(); + // Run AKG building. + bool AkgCompileSingle(const std::string json); + + GpuKernelBuildClient(const GpuKernelBuildClient &) = delete; + GpuKernelBuildClient &operator=(const GpuKernelBuildClient &) = delete; + + GpuKernelBuildClient(GpuKernelBuildClient &&) = delete; + GpuKernelBuildClient &operator=(GpuKernelBuildClient &&) = delete; + + private: + GpuKernelBuildClient() { Open(); } + ~GpuKernelBuildClient() override { Close(); } }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc index 9d88a205bc..e4b054cb5b 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc @@ -21,13 +21,16 @@ #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "frontend/operator/ops.h" #include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_build_client.h" + namespace mindspore { namespace device { namespace gpu { void GpuBuild(const KernelGraphPtr &kernel_graph) { kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance(); MS_EXCEPTION_IF_NULL(bin_map); - bin_map->Initialize(); + auto pid = mindspore::kernel::GpuKernelBuildClient::Instance().AkgGetPid(); + bin_map->Initialize(pid); MS_EXCEPTION_IF_NULL(kernel_graph); auto kernels = kernel_graph->execution_order(); for (const auto &kernel : kernels) {