From: @xu-yfei Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -50,9 +50,15 @@ class MS_API Model { | |||
| Status GetInputsInfo(std::vector<Tensor> *tensor_list) const; | |||
| Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const; | |||
| static bool CheckModelSupport(const std::string& device_type, ModelType model_type); | |||
| private: | |||
| std::shared_ptr<ModelImpl> impl_; | |||
| }; | |||
| extern MS_API const char* kDeviceTypeAscendCL; | |||
| extern MS_API const char* kDeviceTypeAscendMS; | |||
| } // namespace api | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_INCLUDE_API_MODEL_H | |||
| @@ -213,30 +213,5 @@ std::string AscendInferenceSession::InputsInfo(const std::vector<ParameterPtr> & | |||
| return graph + " " + actual; | |||
| } | |||
| void AscendInferenceSession::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const { | |||
| MS_LOG(INFO) << "Start get model inputs, graph id : " << graph_id; | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto kernel_graph_inputs = kernel_graph->inputs(); | |||
| vector<ParameterPtr> paras; | |||
| // find parameters of graph inputs | |||
| for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) { | |||
| if (!kernel_graph_inputs[i]->isa<Parameter>()) { | |||
| MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter."; | |||
| continue; | |||
| } | |||
| auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>(); | |||
| if (!AnfAlgo::IsParameterWeight(parameter)) { | |||
| vector<int64_t> input_shape; | |||
| auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0); | |||
| (void)std::transform(parameter_shape.begin(), parameter_shape.end(), std::back_inserter(input_shape), | |||
| [](const size_t dim) { return SizeToLong(dim); }); | |||
| auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter); | |||
| auto data_type = kernel_build_info->GetOutputDeviceType(0); | |||
| auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, input_shape); | |||
| inputs->push_back(ms_tensor); | |||
| } | |||
| } | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -44,7 +44,6 @@ class AscendInferenceSession : public AscendSession { | |||
| template <typename T> | |||
| std::string PrintInputShape(std::vector<T> shape) const; | |||
| std::string InputsInfo(const std::vector<ParameterPtr> ¶s, const std::vector<tensor::TensorPtr> &inputs) const; | |||
| void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const override; | |||
| protected: | |||
| GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; | |||
| @@ -370,7 +370,8 @@ Status MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<ten | |||
| Status MSInferSession::GetModelInputsInfo(uint32_t model_id, std::vector<inference::InferTensor> *tensor_list) const { | |||
| vector<tensor::TensorPtr> inputs; | |||
| session_impl_->GetModelInputsInfo(model_id, &inputs); | |||
| vector<std::string> input_names; | |||
| session_impl_->GetModelInputsInfo(model_id, &inputs, &input_names); | |||
| if (inputs.size() == 0) { | |||
| MS_LOG(ERROR) << "The model inputs is NULL"; | |||
| return FAILED; | |||
| @@ -34,6 +34,8 @@ | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "utils/utils.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "mindspore/core/base/base_ref_utils.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/worker.h" | |||
| #include "ps/util.h" | |||
| @@ -1089,6 +1091,61 @@ void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vecto | |||
| } | |||
| } | |||
| void SessionBasic::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs, | |||
| std::vector<std::string> *inputs_name) const { | |||
| MS_LOG(INFO) << "Start get model inputs, graph id : " << graph_id; | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_EXCEPTION_IF_NULL(inputs); | |||
| MS_EXCEPTION_IF_NULL(inputs_name); | |||
| auto kernel_graph_inputs = kernel_graph->inputs(); | |||
| vector<ParameterPtr> paras; | |||
| // find parameters of graph inputs | |||
| for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) { | |||
| if (!kernel_graph_inputs[i]->isa<Parameter>()) { | |||
| MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter."; | |||
| continue; | |||
| } | |||
| auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>(); | |||
| if (!AnfAlgo::IsParameterWeight(parameter)) { | |||
| vector<int64_t> input_shape; | |||
| auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0); | |||
| (void)std::transform(parameter_shape.begin(), parameter_shape.end(), std::back_inserter(input_shape), | |||
| [](const size_t dim) { return SizeToLong(dim); }); | |||
| auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter); | |||
| auto data_type = kernel_build_info->GetOutputDeviceType(0); | |||
| auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, input_shape); | |||
| inputs->push_back(ms_tensor); | |||
| inputs_name->push_back(parameter->name()); | |||
| } | |||
| } | |||
| } | |||
| void SessionBasic::GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs, | |||
| std::vector<std::string> *output_names) const { | |||
| std::vector<tensor::TensorPtr> inputs; | |||
| std::vector<std::string> input_names; | |||
| GetModelInputsInfo(graph_id, &inputs, &input_names); | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| MS_EXCEPTION_IF_NULL(output_names); | |||
| VectorRef vector_outputs; | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node; | |||
| auto anf_outputs = kernel_graph->outputs(); | |||
| for (auto &item : anf_outputs) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]"; | |||
| vector_outputs.emplace_back(CreateNodeOutputTensors(item, kernel_graph, inputs, &tensor_to_node)); | |||
| } | |||
| *outputs = TransformVectorRefToMultiTensor(vector_outputs); | |||
| for (size_t i = 0; i < outputs->size(); i++) { | |||
| output_names->push_back("output" + std::to_string(i)); | |||
| } | |||
| } | |||
| void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) { | |||
| MS_EXCEPTION_IF_NULL(callback); | |||
| summary_callback_ = callback; | |||
| @@ -102,7 +102,10 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| std::string *error_msg) const { | |||
| return true; | |||
| } | |||
| virtual void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const {} | |||
| void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs, | |||
| std::vector<std::string> *inputs_name) const; | |||
| void GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs, | |||
| std::vector<std::string> *outputs_name) const; | |||
| std::vector<tensor::TensorPtr> GetInputNeedLockTensors(const GraphId &graph_id, | |||
| const std::vector<tensor::TensorPtr> &inputs); | |||
| // Get graph by graph id, if not exist return null ptr | |||
| @@ -6,22 +6,25 @@ set(LOAD_ONNX_SRC | |||
| file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc") | |||
| if (ENABLE_ACL) | |||
| file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc") | |||
| file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc" "model/model_converter_utils/*.cc") | |||
| elseif (ENABLE_D) | |||
| file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc") | |||
| endif () | |||
| set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/cell.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc | |||
| ${API_ACL_SRC} | |||
| ${API_OPS_SRC} | |||
| ${LOAD_ONNX_SRC}) | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/cell.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc | |||
| ${API_MS_INFER_SRC} | |||
| ${API_ACL_SRC} | |||
| ${API_OPS_SRC} | |||
| ${LOAD_ONNX_SRC}) | |||
| add_library(mindspore_shared_lib SHARED ${MSLIB_SRC}) | |||
| set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}") | |||
| target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} | |||
| -Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf) | |||
| -Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf) | |||
| if (ENABLE_CPU) | |||
| target_link_libraries(mindspore_shared_lib PRIVATE mindspore::dnnl mindspore::mkldnn) | |||
| @@ -58,5 +61,13 @@ if (ENABLE_ACL) | |||
| find_library(acl_runtime libruntime.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64) | |||
| find_library(ge_compiler libge_compiler.so ${ATC_DIR}/lib64 ${ATLAS_ATC_DIR}/lib64) | |||
| target_link_libraries(mindspore_shared_lib PRIVATE ${acl} ${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime} | |||
| ${ge_compiler} mindspore::jpeg_turbo) | |||
| ${ge_compiler} mindspore::jpeg_turbo) | |||
| endif () | |||
| # Before build inference | |||
| if (ENABLE_D) | |||
| find_library(adump_server libadump_server.a ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) | |||
| target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server}) | |||
| endif () | |||
| @@ -91,7 +91,6 @@ Status AclModel::InitEnv() { | |||
| MS_LOG(ERROR) << "DVPP init resource failed"; | |||
| return FAILED; | |||
| } | |||
| ModelConverter::RegAllOp(); | |||
| MS_LOG(INFO) << "Init acl success, device id " << device_id_; | |||
| init_flag_ = true; | |||
| @@ -24,6 +24,7 @@ | |||
| #include "backend/kernel_compiler/oplib/oplib.h" | |||
| #include "graph/model.h" | |||
| #include "cxx_api/model/model_converter_utils/multi_process.h" | |||
| namespace py = pybind11; | |||
| @@ -238,6 +239,131 @@ Buffer ModelConverter::ReadFile(const std::string &file) { | |||
| } | |||
| Buffer ModelConverter::LoadMindIR(const Buffer &model_data) { | |||
| if (Py_IsInitialized() == 0) { | |||
| MS_LOG_INFO << "Call LoadMindIRInner directly"; | |||
| return LoadMindIRInner(model_data); | |||
| } | |||
| MultiProcess multi_process; | |||
| Buffer buffer_ret; | |||
| auto parent_process = [&model_data, &buffer_ret](MultiProcess *multi_process) -> Status { | |||
| MS_EXCEPTION_IF_NULL(multi_process); | |||
| // send original model to child | |||
| auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize()); | |||
| if (!status.IsSuccess()) { | |||
| MS_LOG_ERROR << "Send original model to child process failed"; | |||
| return FAILED; | |||
| } | |||
| // receive convert model result from child | |||
| CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * { | |||
| buffer_ret.ResizeData(msg_len); | |||
| return reinterpret_cast<uint8_t *>(buffer_ret.MutableData()); | |||
| }; | |||
| status = multi_process->ReceiveMsg(call); | |||
| if (!status.IsSuccess()) { | |||
| MS_LOG_ERROR << "Receive result model from child process failed"; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| }; | |||
| auto child_process = [this](MultiProcess *multi_process) -> Status { | |||
| MS_EXCEPTION_IF_NULL(multi_process); | |||
| // receive original model from parent | |||
| Buffer model; | |||
| CreateBufferCall call = [&model](size_t msg_len) -> uint8_t * { | |||
| model.ResizeData(msg_len); | |||
| return reinterpret_cast<uint8_t *>(model.MutableData()); | |||
| }; | |||
| auto status = multi_process->ReceiveMsg(call); | |||
| if (!status.IsSuccess()) { | |||
| MS_LOG_ERROR << "Receive original model from parent process failed"; | |||
| return FAILED; | |||
| } | |||
| Buffer model_result = LoadMindIRInner(model); | |||
| if (model_result.DataSize() == 0) { | |||
| MS_LOG_ERROR << "Convert model from MindIR to OM failed"; | |||
| return FAILED; | |||
| } | |||
| // send result model to parent | |||
| status = multi_process->SendMsg(model_result.Data(), model_result.DataSize()); | |||
| if (!status.IsSuccess()) { | |||
| MS_LOG_ERROR << "Send result model to parent process failed"; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| }; | |||
| auto status = multi_process.MainProcess(parent_process, child_process); | |||
| if (!status.IsSuccess()) { | |||
| MS_LOG_ERROR << "Convert MindIR model to OM model failed"; | |||
| } else { | |||
| MS_LOG_INFO << "Convert MindIR model to OM model success"; | |||
| } | |||
| return buffer_ret; | |||
| } | |||
| Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) { | |||
| if (Py_IsInitialized() == 0) { | |||
| MS_LOG_INFO << "Call LoadAscendIRInner directly"; | |||
| return LoadAscendIRInner(model_data); | |||
| } | |||
| MultiProcess multi_process; | |||
| Buffer buffer_ret; | |||
| auto parent_process = [&model_data, &buffer_ret](MultiProcess *multi_process) -> Status { | |||
| MS_EXCEPTION_IF_NULL(multi_process); | |||
| // send original model to child | |||
| auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize()); | |||
| if (!status.IsSuccess()) { | |||
| MS_LOG_ERROR << "Send original model to child process failed"; | |||
| return FAILED; | |||
| } | |||
| // receive convert model result from child | |||
| CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * { | |||
| buffer_ret.ResizeData(msg_len); | |||
| return reinterpret_cast<uint8_t *>(buffer_ret.MutableData()); | |||
| }; | |||
| status = multi_process->ReceiveMsg(call); | |||
| if (!status.IsSuccess()) { | |||
| MS_LOG_ERROR << "Receive result model from child process failed"; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| }; | |||
| auto child_process = [this](MultiProcess *multi_process) -> Status { | |||
| MS_EXCEPTION_IF_NULL(multi_process); | |||
| // receive original model from parent | |||
| Buffer model; | |||
| CreateBufferCall call = [&model](size_t msg_len) -> uint8_t * { | |||
| model.ResizeData(msg_len); | |||
| return reinterpret_cast<uint8_t *>(model.MutableData()); | |||
| }; | |||
| auto status = multi_process->ReceiveMsg(call); | |||
| if (!status.IsSuccess()) { | |||
| MS_LOG_ERROR << "Receive original model from parent process failed"; | |||
| return FAILED; | |||
| } | |||
| Buffer model_result = LoadAscendIRInner(model); | |||
| if (model_result.DataSize() == 0) { | |||
| MS_LOG_ERROR << "Convert model from AIR to OM failed"; | |||
| return FAILED; | |||
| } | |||
| // send result model to parent | |||
| status = multi_process->SendMsg(model_result.Data(), model_result.DataSize()); | |||
| if (!status.IsSuccess()) { | |||
| MS_LOG_ERROR << "Send result model to parent process failed"; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| }; | |||
| auto status = multi_process.MainProcess(parent_process, child_process); | |||
| if (!status.IsSuccess()) { | |||
| MS_LOG_ERROR << "Convert AIR model to OM model failed"; | |||
| } else { | |||
| MS_LOG_INFO << "Convert AIR model to OM model success"; | |||
| } | |||
| return buffer_ret; | |||
| } | |||
| Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) { | |||
| RegAllOp(); | |||
| auto func_graph = ConvertMindIrToFuncGraph(model_data); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed."; | |||
| @@ -259,7 +385,8 @@ Buffer ModelConverter::LoadMindIR(const Buffer &model_data) { | |||
| return om_data; | |||
| } | |||
| Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) { | |||
| Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) { | |||
| RegAllOp(); | |||
| ge::Model load_model = ge::Model("loadmodel", "version2"); | |||
| ge::Status ret = | |||
| ge::Model::Load(reinterpret_cast<const uint8_t *>(model_data.Data()), model_data.DataSize(), load_model); | |||
| @@ -45,6 +45,9 @@ class ModelConverter { | |||
| transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph); | |||
| Buffer BuildAirModel(const transform::DfGraphPtr &graph, const std::map<std::string, std::string> &acl_options); | |||
| AclModelOptions *options_; | |||
| Buffer LoadMindIRInner(const Buffer &model_data); | |||
| Buffer LoadAscendIRInner(const Buffer &model_data); | |||
| }; | |||
| } // namespace mindspore::api | |||
| @@ -18,6 +18,9 @@ | |||
| #include "utils/utils.h" | |||
| namespace mindspore::api { | |||
| const char *kDeviceTypeAscendCL = "AscendCL"; | |||
| const char *kDeviceTypeAscendMS = "AscendMS"; | |||
| Status Model::LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options) { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| return impl_->LoadModel(model_data, type, options); | |||
| @@ -95,4 +98,9 @@ Model::Model(NetWork network, const std::string &device_type, uint32_t device_id | |||
| } | |||
| Model::~Model() {} | |||
| bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) { | |||
| return ModelFactory::Instance().CheckModelSupport(device_type, model_type); | |||
| } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,207 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "cxx_api/model/model_converter_utils/multi_process.h" | |||
| #include <unistd.h> | |||
| #include <sys/wait.h> | |||
| #include <algorithm> | |||
| #include <vector> | |||
| #include <thread> | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| #include "cxx_api/model/model_converter_utils/shared_memory.h" | |||
| namespace mindspore { | |||
| namespace api { | |||
| namespace { | |||
| uint64_t kSharedMemorySize = 100ull << 20; // 100 MB | |||
| } | |||
| MultiProcess::MultiProcess() = default; | |||
| MultiProcess::~MultiProcess() = default; | |||
| Status MultiProcess::MainProcess(ProcessFuncCall parent_process, ProcessFuncCall child_process) { | |||
| MS_EXCEPTION_IF_NULL(parent_process); | |||
| MS_EXCEPTION_IF_NULL(child_process); | |||
| Status ret; | |||
| memory_size_ = kSharedMemorySize; // 100 MB | |||
| SharedMemory shared_memory; | |||
| ret = shared_memory.Create(memory_size_); | |||
| if (!ret.IsSuccess()) { | |||
| MS_LOG_ERROR << "Create shared memory failed"; | |||
| return ret; | |||
| } | |||
| pid_t pid = fork(); | |||
| if (pid < 0) { | |||
| shared_memory.Destroy(); | |||
| MS_LOG_ERROR << "Fork process to convert model failed"; | |||
| return FAILED; | |||
| } | |||
| ret = shared_memory.Attach(); | |||
| if (!ret.IsSuccess()) { | |||
| MS_LOG_ERROR << "Process attach shared memory failed, pid " << pid; | |||
| return ret; | |||
| } | |||
| shmat_addr_ = shared_memory.GetSharedMemoryAddr(); | |||
| if (shmat_addr_ == nullptr) { | |||
| MS_LOG_ERROR << "Get shared memory failed"; | |||
| return ret; | |||
| } | |||
| shmat_data_addr_ = shmat_addr_ + sizeof(MessageFlag) * 2; | |||
| shmat_data_max_size_ = memory_size_ - (shmat_data_addr_ - shmat_addr_); | |||
| MS_LOG_INFO << "Shm addr " << (uint64_t)shmat_addr_; | |||
| if (pid == 0) { | |||
| ChildProcess(child_process); | |||
| shared_memory.Detach(); | |||
| MS_LOG_INFO << "Model converter: child process exit"; | |||
| exit(0); | |||
| } else { // parent process | |||
| ret = ParentProcess(parent_process); | |||
| shared_memory.Detach(); | |||
| int status; | |||
| wait(&status); | |||
| shared_memory.Destroy(); | |||
| } | |||
| return ret; | |||
| } | |||
| Status MultiProcess::ParentProcess(ProcessFuncCall parent_process) { | |||
| auto parent_msg = reinterpret_cast<MessageFlag *>(shmat_addr_); | |||
| auto child_msg = reinterpret_cast<MessageFlag *>(shmat_addr_ + sizeof(MessageFlag)); | |||
| send_msg_ = parent_msg; | |||
| receive_msg_ = child_msg; | |||
| std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this); | |||
| Status ret; | |||
| try { | |||
| ret = parent_process(this); | |||
| if (!ret.IsSuccess()) { | |||
| MS_LOG_ERROR << "Parent process process failed"; | |||
| } | |||
| } catch (const std::runtime_error &ex) { | |||
| MS_LOG_ERROR << "Catch parent process runtime error: " << ex.what(); | |||
| ret = FAILED; | |||
| } | |||
| stopped_ = true; | |||
| send_msg_->stop = true; | |||
| heartbeat_thread.join(); | |||
| return ret; | |||
| } | |||
| void MultiProcess::ChildProcess(ProcessFuncCall child_process) { | |||
| auto parent_msg = reinterpret_cast<MessageFlag *>(shmat_addr_); | |||
| auto child_msg = reinterpret_cast<MessageFlag *>(shmat_addr_ + sizeof(MessageFlag)); | |||
| send_msg_ = child_msg; | |||
| receive_msg_ = parent_msg; | |||
| std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this); | |||
| try { | |||
| auto ret = child_process(this); | |||
| if (!ret.IsSuccess()) { | |||
| MS_LOG_ERROR << "Child process process failed"; | |||
| } | |||
| } catch (const std::runtime_error &ex) { | |||
| MS_LOG_ERROR << "Catch child process runtime error: " << ex.what(); | |||
| } | |||
| stopped_ = true; | |||
| send_msg_->stop = true; | |||
| heartbeat_thread.join(); | |||
| } | |||
| Status MultiProcess::SendMsg(const void *buffer, uint64_t msg_len) { | |||
| MS_LOG_INFO << "Start to send message to peer process, msg len " << msg_len; | |||
| send_msg_->msg_total_len = msg_len; | |||
| uint64_t cur_offset = 0; | |||
| while (msg_len > cur_offset) { | |||
| uint64_t sub_msg_len = std::min(msg_len - cur_offset, shmat_data_max_size_); | |||
| memcpy_s(shmat_data_addr_, shmat_data_max_size_, static_cast<const uint8_t *>(buffer) + cur_offset, sub_msg_len); | |||
| cur_offset += sub_msg_len; | |||
| send_msg_->msg_len = sub_msg_len; | |||
| send_msg_->read_finish_flag = false; | |||
| send_msg_->read_ready_flag = true; | |||
| MS_LOG_INFO << "Send start " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len; | |||
| while (!send_msg_->read_finish_flag && !peer_stopped_) { | |||
| usleep(1000); // 1ms | |||
| } | |||
| if (peer_stopped_) { | |||
| if (!send_msg_->read_finish_flag) { | |||
| return FAILED; | |||
| } | |||
| break; | |||
| } | |||
| MS_LOG_INFO << "Send end " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len; | |||
| } | |||
| MS_LOG_INFO << "End to send message to peer process, msg len " << msg_len; | |||
| return SUCCESS; | |||
| } | |||
| Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) { | |||
| uint64_t cur_offset = 0; | |||
| uint8_t *msg_buffer = nullptr; | |||
| uint64_t msg_len = 0; | |||
| do { | |||
| MS_LOG_INFO << "Receive start from " << cur_offset; | |||
| while (!receive_msg_->read_ready_flag && !peer_stopped_) { | |||
| usleep(1000); // 1ms | |||
| } | |||
| if (peer_stopped_) { | |||
| return FAILED; | |||
| } | |||
| if (msg_buffer == nullptr) { | |||
| msg_len = receive_msg_->msg_total_len; | |||
| msg_buffer = create_buffer_call(msg_len); | |||
| } | |||
| memcpy_s(msg_buffer + cur_offset, msg_len - cur_offset, shmat_data_addr_, receive_msg_->msg_len); | |||
| cur_offset += receive_msg_->msg_len; | |||
| receive_msg_->read_ready_flag = false; | |||
| receive_msg_->read_finish_flag = true; | |||
| MS_LOG_INFO << "Receive end, current length " << cur_offset << ", total length " << msg_len << std::endl; | |||
| } while (msg_len > cur_offset); | |||
| return SUCCESS; | |||
| } | |||
| void MultiProcess::HeartbeatThreadFunc(MultiProcess *multi_process) { multi_process->HeartbeatThreadFuncInner(); } | |||
| void MultiProcess::HeartbeatThreadFuncInner() { | |||
| uint64_t last_beat_cnt = 0; | |||
| uint64_t repeat_cnt = 0; | |||
| while (!stopped_) { | |||
| if (receive_msg_->stop) { | |||
| peer_stopped_ = true; | |||
| MS_LOG_WARNING << "Peer stopped"; | |||
| break; | |||
| } | |||
| uint64_t heartbeat_gap = receive_msg_->heartbeat - last_beat_cnt; | |||
| if (heartbeat_gap > 0 && heartbeat_gap < 1024) { | |||
| last_beat_cnt = receive_msg_->heartbeat; | |||
| repeat_cnt = 0; | |||
| } else { | |||
| repeat_cnt++; | |||
| if (repeat_cnt > 30) { // 30*100ms = 3s no reply | |||
| peer_stopped_ = true; | |||
| MS_LOG_WARNING << "Peer stopped"; | |||
| break; | |||
| } | |||
| } | |||
| send_msg_->heartbeat += 1; | |||
| usleep(100000); // sleep 100 ms | |||
| } | |||
| } | |||
| } // namespace api | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H | |||
| #define MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H | |||
| #include <iostream> | |||
| #include <functional> | |||
| #include "include/api/status.h" | |||
| namespace mindspore { | |||
| namespace api { | |||
| struct MessageFlag { | |||
| uint64_t heartbeat = 0; | |||
| uint64_t stop = false; | |||
| uint64_t msg_len = 0; | |||
| uint64_t msg_total_len = 0; | |||
| uint64_t read_ready_flag = false; | |||
| uint64_t read_finish_flag = false; | |||
| }; | |||
| class MultiProcess; | |||
| using ProcessFuncCall = std::function<Status(MultiProcess *multi_process)>; | |||
| using CreateBufferCall = std::function<uint8_t *(size_t msg_len)>; | |||
| class MultiProcess { | |||
| public: | |||
| MultiProcess(); | |||
| ~MultiProcess(); | |||
| Status MainProcess(ProcessFuncCall parent_process, ProcessFuncCall child_process); | |||
| Status SendMsg(const void *buffer, uint64_t msg_len); | |||
| Status ReceiveMsg(CreateBufferCall create_buffer_call); | |||
| private: | |||
| uint8_t *shmat_addr_ = nullptr; | |||
| uint8_t *shmat_data_addr_ = nullptr; | |||
| uint64_t shmat_data_max_size_ = 0; | |||
| uint64_t memory_size_ = 0; | |||
| bool peer_stopped_ = false; | |||
| bool stopped_ = false; | |||
| MessageFlag *send_msg_ = nullptr; | |||
| MessageFlag *receive_msg_ = nullptr; | |||
| static void HeartbeatThreadFunc(MultiProcess *multi_process); | |||
| void HeartbeatThreadFuncInner(); | |||
| Status ParentProcess(ProcessFuncCall parent_process); | |||
| void ChildProcess(ProcessFuncCall child_process); | |||
| }; | |||
| } // namespace api | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "cxx_api/model/model_converter_utils/shared_memory.h" | |||
| #include <sys/shm.h> | |||
| #include <sys/stat.h> | |||
| #include <string> | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace api { | |||
| Status SharedMemory::Create(uint64_t memory_size) { | |||
| auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; | |||
| shm_id_ = shmget(IPC_PRIVATE, memory_size, IPC_CREAT | IPC_EXCL | access_mode); | |||
| if (shm_id_ == -1) { | |||
| MS_LOG_ERROR << "Shared memory creation failed. Errno " + std::to_string(errno); | |||
| return FAILED; | |||
| } | |||
| MS_LOG_INFO << "shmget success, shm id " << shm_id_; | |||
| return SUCCESS; | |||
| } | |||
| Status SharedMemory::Attach() { | |||
| void *shmat_addr = shmat(shm_id_, nullptr, 0); | |||
| if (shmat_addr == reinterpret_cast<void *>(-1)) { | |||
| MS_LOG_ERROR << "Shared memory attach failed. Errno " + std::to_string(errno); | |||
| return FAILED; | |||
| } | |||
| shmat_addr_ = reinterpret_cast<uint8_t *>(shmat_addr); | |||
| return SUCCESS; | |||
| } | |||
| void SharedMemory::Detach() { | |||
| if (shmat_addr_) { | |||
| auto err = shmdt(shmat_addr_); | |||
| if (err == -1) { | |||
| MS_LOG_ERROR << "Shared memory detach failed. Errno " + std::to_string(errno); | |||
| return; | |||
| } | |||
| } | |||
| shmat_addr_ = nullptr; | |||
| } | |||
| void SharedMemory::Destroy() { | |||
| // Remove the shared memory and never mind about the return code. | |||
| auto err = shmctl(shm_id_, IPC_RMID, nullptr); | |||
| if (err == -1) { | |||
| std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id_); | |||
| errMsg += ". Errno :" + std::to_string(errno); | |||
| errMsg += "\nPlesae remove it manually using ipcrm -m command"; | |||
| MS_LOG_ERROR << errMsg; | |||
| } | |||
| } | |||
| } // namespace api | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H | |||
| #define MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H | |||
| #include <iostream> | |||
| #include "include/api/status.h" | |||
| namespace mindspore { | |||
| namespace api { | |||
| class SharedMemory { | |||
| public: | |||
| Status Create(uint64_t memory_size); | |||
| Status Attach(); | |||
| void Detach(); | |||
| void Destroy(); | |||
| uint8_t *GetSharedMemoryAddr() { return shmat_addr_; } | |||
| private: | |||
| int shm_id_ = -1; | |||
| uint8_t *shmat_addr_ = nullptr; | |||
| }; | |||
| } // namespace api | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H | |||
| @@ -70,6 +70,12 @@ class ModelFactory { | |||
| return nullptr; | |||
| } | |||
| bool CheckModelSupport(const std::string &device_type, ModelType /*model_type*/) { | |||
| return std::any_of( | |||
| model_creators_.begin(), model_creators_.end(), | |||
| [&device_type](const std::pair<std::string, ModelCreator> &item) { return item.first == device_type; }); | |||
| } | |||
| private: | |||
| ModelFactory() = default; | |||
| ~ModelFactory() = default; | |||
| @@ -86,7 +92,7 @@ class ModelRegistrar { | |||
| #define API_REG_MODEL(DEVICE_NAME, MODEL_CLASS) \ | |||
| static const ModelRegistrar g_api_model_registrar__##DEVICE_NAME##_##_reg( \ | |||
| #DEVICE_NAME, [](uint32_t device_id) { return std::make_shared<MODEL_CLASS>(device_id); }); | |||
| kDeviceType##DEVICE_NAME, [](uint32_t device_id) { return std::make_shared<MODEL_CLASS>(device_id); }); | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,418 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "cxx_api/model/ms/ms_model.h" | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <fstream> | |||
| #include "utils/load_onnx/anf_converter.h" | |||
| #include "backend/session/session_basic.h" | |||
| #include "backend/session/session_factory.h" | |||
| #include "backend/session/executor_manager.h" | |||
| #include "base/base_ref_utils.h" | |||
| #include "backend/kernel_compiler/oplib/oplib.h" | |||
| #include "utils/context/context_extends.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/embed.h" | |||
| #ifdef ENABLE_D | |||
| #include "utils/ms_context.h" | |||
| #endif | |||
| using std::string; | |||
| using std::vector; | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace api { | |||
| MsModel::MsModel(uint32_t device_id) : device_id_(device_id) {} | |||
| MsModel::~MsModel() = default; | |||
| TypeId TransInferDataType2TypeId(DataType data_type) { | |||
| const std::map<api::DataType, TypeId> type2id_map{ | |||
| {api::kMsUnknown, TypeId::kNumberTypeBegin}, {api::kMsBool, TypeId::kNumberTypeBool}, | |||
| {api::kMsInt8, TypeId::kNumberTypeInt8}, {api::kMsUint8, TypeId::kNumberTypeUInt8}, | |||
| {api::kMsInt16, TypeId::kNumberTypeInt16}, {api::kMsUint16, TypeId::kNumberTypeUInt16}, | |||
| {api::kMsInt32, TypeId::kNumberTypeInt32}, {api::kMsUint32, TypeId::kNumberTypeUInt32}, | |||
| {api::kMsInt64, TypeId::kNumberTypeInt64}, {api::kMsUint64, TypeId::kNumberTypeUInt64}, | |||
| {api::kMsFloat16, TypeId::kNumberTypeFloat16}, {api::kMsFloat32, TypeId::kNumberTypeFloat32}, | |||
| {api::kMsFloat64, TypeId::kNumberTypeFloat64}, | |||
| }; | |||
| auto it = type2id_map.find(data_type); | |||
| if (it == type2id_map.end()) { | |||
| MS_LOG_WARNING << "Unsupported MSI data type " << data_type; | |||
| return TypeId::kNumberTypeBegin; | |||
| } else { | |||
| return it->second; | |||
| } | |||
| } | |||
| DataType TransTypeId2InferDataType(TypeId type_id) { | |||
| const std::map<TypeId, api::DataType> id2type_map{ | |||
| {TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool}, | |||
| {TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8}, | |||
| {TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16}, | |||
| {TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32}, | |||
| {TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64}, | |||
| {TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16}, | |||
| {TypeId::kNumberTypeFloat32, api::kMsFloat32}, | |||
| }; | |||
| auto it = id2type_map.find(type_id); | |||
| if (it == id2type_map.end()) { | |||
| MS_LOG_WARNING << "Unsupported data id " << type_id; | |||
| return api::kMsUnknown; | |||
| } else { | |||
| return it->second; | |||
| } | |||
| } | |||
| Buffer MsModel::ReadFile(const std::string &file) { | |||
| if (file.empty()) { | |||
| MS_LOG(ERROR) << "file is nullptr"; | |||
| return Buffer(); | |||
| } | |||
| std::ifstream ifs(file); | |||
| if (!ifs.good()) { | |||
| MS_LOG(ERROR) << "file: " << file << " is not exist"; | |||
| return Buffer(); | |||
| } | |||
| if (!ifs.is_open()) { | |||
| MS_LOG(ERROR) << "file: " << file << "open failed"; | |||
| return Buffer(); | |||
| } | |||
| ifs.seekg(0, std::ios::end); | |||
| size_t size = ifs.tellg(); | |||
| Buffer buffer; | |||
| buffer.ResizeData(size); | |||
| ifs.seekg(0, std::ios::beg); | |||
| ifs.read(static_cast<char *>(buffer.MutableData()), size); | |||
| ifs.close(); | |||
| return buffer; | |||
| } | |||
| Status MsModel::LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options) { | |||
| auto status = InitEnv({}); | |||
| if (status != SUCCESS) { | |||
| MS_LOG(ERROR) << "Init env failed"; | |||
| return FAILED; | |||
| } | |||
| std::shared_ptr<FuncGraph> anf_graph; | |||
| try { | |||
| anf_graph = | |||
| lite::AnfConverter::RunAnfConverter(static_cast<const char *>(model_data.Data()), model_data.DataSize()); | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference LoadModel failed"; | |||
| return FAILED; | |||
| } | |||
| Status ret = CompileGraph(anf_graph); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Compile graph model failed"; | |||
| return FAILED; | |||
| } | |||
| session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_); | |||
| session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_); | |||
| if (inputs_.empty() || inputs_.size() != input_names_.size()) { | |||
| MS_LOG_ERROR << "Get model inputs info failed"; | |||
| return FAILED; | |||
| } | |||
| if (outputs_.empty() || outputs_.size() != output_names_.size()) { | |||
| MS_LOG_ERROR << "Get model outputs info failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Load model success"; | |||
| #ifdef ENABLE_D | |||
| // set d context | |||
| rtError_t rt_ret = rtCtxGetCurrent(&context_); | |||
| if (rt_ret != RT_ERROR_NONE || context_ == nullptr) { | |||
| MS_LOG(ERROR) << "the ascend device context is null"; | |||
| return FAILED; | |||
| } | |||
| #endif | |||
| load_flag_ = true; | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::LoadModel(const std::string &file_name, ModelType type, | |||
| const std::map<std::string, std::string> &options) { | |||
| auto graphBuf = ReadFile(file_name); | |||
| if (graphBuf.DataSize() == 0) { | |||
| MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); | |||
| return FAILED; | |||
| } | |||
| auto status = LoadModel(graphBuf, type, options); | |||
| if (status != SUCCESS) { | |||
| MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::UnloadModel() { | |||
| if (!load_flag_) { | |||
| MS_LOG_ERROR << "Model has not been loaded"; | |||
| return FAILED; | |||
| } | |||
| FinalizeEnv(); | |||
| load_flag_ = false; | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::Train(const DataSet &, std::map<std::string, Buffer> *) { | |||
| MS_LOG(ERROR) << "Unsupported feature."; | |||
| return FAILED; | |||
| } | |||
| Status MsModel::Eval(const DataSet &, std::map<std::string, Buffer> *) { | |||
| MS_LOG(ERROR) << "Unsupported feature."; | |||
| return FAILED; | |||
| } | |||
| Status MsModel::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| if (!load_flag_) { | |||
| MS_LOG(ERROR) << "No model is loaded, predict failed."; | |||
| return FAILED; | |||
| } | |||
| if (inputs.size() != inputs_.size()) { | |||
| MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size(); | |||
| return INVALID_INPUTS; | |||
| } | |||
| std::vector<Buffer> request; | |||
| std::vector<Buffer> reply; | |||
| for (size_t i = 0; i < inputs_.size(); ++i) { | |||
| const auto &input_name = input_names_[i]; | |||
| auto iter = inputs.find(input_name); | |||
| if (iter == inputs.end()) { | |||
| MS_LOG(ERROR) << "Model missing input " << input_name; | |||
| return INVALID_INPUTS; | |||
| } | |||
| if (iter->second.DataSize() != inputs_[i]->Size()) { | |||
| MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count " | |||
| << iter->second.DataSize(); | |||
| return INVALID_INPUTS; | |||
| } | |||
| request.push_back(iter->second); | |||
| } | |||
| if (ExecuteModel(request, &reply) != SUCCESS) { | |||
| MS_LOG(ERROR) << "Execute Model Failed"; | |||
| return FAILED; | |||
| } | |||
| if (outputs_.size() != reply.size()) { | |||
| MS_LOG(ERROR) << "Predict output size " << reply.size() << " not match output size got from model info " | |||
| << outputs_.size(); | |||
| return FAILED; | |||
| } | |||
| outputs->clear(); | |||
| for (size_t i = 0; i < reply.size(); i++) { | |||
| outputs->emplace(output_names_[i], reply[i]); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) { | |||
| MS_EXCEPTION_IF_NULL(reply); | |||
| #ifdef ENABLE_D | |||
| if (context_ == nullptr) { | |||
| MS_LOG(ERROR) << "rtCtx is nullptr"; | |||
| return FAILED; | |||
| } | |||
| rtError_t rt_ret = rtCtxSetCurrent(context_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "set Ascend rtCtx failed"; | |||
| return FAILED; | |||
| } | |||
| #endif | |||
| vector<tensor::TensorPtr> inputs; | |||
| for (size_t i = 0; i < request.size(); i++) { | |||
| auto &item = request[i]; | |||
| auto input = inputs_[i]; | |||
| if (input->Size() != item.DataSize()) { | |||
| MS_LOG(ERROR) << "Predict input " << i << " data size " << item.DataSize() << " not match model input data size " | |||
| << input->Size(); | |||
| return FAILED; | |||
| } | |||
| auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize()); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Tensor copy failed"; | |||
| return FAILED; | |||
| } | |||
| inputs.push_back(input); | |||
| } | |||
| vector<tensor::TensorPtr> outputs = RunGraph(inputs); | |||
| if (outputs.empty()) { | |||
| MS_LOG(ERROR) << "Execute Model Failed"; | |||
| return FAILED; | |||
| } | |||
| reply->clear(); | |||
| std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply), | |||
| [](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); }); | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::FinalizeEnv() { | |||
| MS_LOG_INFO << "Start finalize env"; | |||
| py::gil_scoped_acquire acquire; | |||
| session::ExecutorManager::Instance().Clear(); | |||
| device::KernelRuntimeManager::Instance().ClearRuntimeResource(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| if (ms_context == nullptr) { | |||
| MS_LOG(ERROR) << "Get Context failed!"; | |||
| return FAILED; | |||
| } | |||
| if (!context::CloseTsd(ms_context)) { | |||
| MS_LOG(ERROR) << "Inference CloseTsd failed!"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG_INFO << "End finalize env"; | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<FuncGraph> MsModel::LoadModel(const char *model_buf, size_t size, const std::string &device) { | |||
| MS_EXCEPTION_IF_NULL(model_buf); | |||
| try { | |||
| auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); | |||
| return anf_graph; | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference LoadModel failed: " << e.what(); | |||
| return nullptr; | |||
| } | |||
| } | |||
| void MsModel::RegAllOp() { | |||
| static std::mutex init_mutex; | |||
| static bool Initialized = false; | |||
| std::lock_guard<std::mutex> lock(init_mutex); | |||
| if (Initialized) { | |||
| return; | |||
| } | |||
| Initialized = true; | |||
| auto ms_context_instance = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context_instance); | |||
| ms_context_instance->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| try { | |||
| std::shared_ptr<py::scoped_interpreter> guard; | |||
| if (Py_IsInitialized() == 0) { | |||
| guard = std::make_shared<py::scoped_interpreter>(); | |||
| } | |||
| py::module c_expression = py::module::import("mindspore._c_expression"); | |||
| size_t ops_info_long = c_expression.attr("OpInfoLoaderPy")().attr("get_all_ops_info")().cast<size_t>(); | |||
| auto all_ops_info = reinterpret_cast<std::vector<kernel::OpInfo *> *>(ops_info_long); | |||
| for (auto op_info : *all_ops_info) { | |||
| kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info)); | |||
| } | |||
| all_ops_info->clear(); | |||
| delete all_ops_info; | |||
| } catch (const std::runtime_error &ex) { | |||
| MS_LOG_EXCEPTION << ex.what(); | |||
| } | |||
| } | |||
| Status MsModel::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) { | |||
| MS_ASSERT(session_impl_ != nullptr); | |||
| try { | |||
| graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||
| py::gil_scoped_release gil_release; | |||
| return SUCCESS; | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference CompileGraph failed: " << e.what(); | |||
| return FAILED; | |||
| } | |||
| } | |||
| std::vector<tensor::TensorPtr> MsModel::RunGraph(const std::vector<tensor::TensorPtr> &inputs) { | |||
| try { | |||
| VectorRef outputs; | |||
| session_impl_->RunGraph(graph_id_, inputs, &outputs); | |||
| return TransformVectorRefToMultiTensor(outputs); | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference Rungraph failed: " << e.what(); | |||
| return std::vector<tensor::TensorPtr>(); | |||
| } | |||
| } | |||
| Status MsModel::InitEnv(const std::unordered_map<std::string, std::string> &other_options) { | |||
| RegAllOp(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| if (ms_context == nullptr) { | |||
| MS_LOG(ERROR) << "Get Context failed!"; | |||
| return FAILED; | |||
| } | |||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_); | |||
| ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice); | |||
| if (!context::OpenTsd(ms_context)) { | |||
| MS_LOG(ERROR) << "Session init OpenTsd failed!"; | |||
| return FAILED; | |||
| } | |||
| session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice); | |||
| if (session_impl_ == nullptr) { | |||
| MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice | |||
| << " is available."; | |||
| return FAILED; | |||
| } | |||
| session_impl_->Init(device_id_); | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const { | |||
| MS_ASSERT(session_impl_ != nullptr); | |||
| std::string error_msg; | |||
| if (!session_impl_->CheckModelInputs(graph_id, inputs, &error_msg)) { | |||
| return Status(INVALID_INPUTS, error_msg); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::GetInputsInfo(std::vector<Tensor> *tensor_list) const { | |||
| MS_EXCEPTION_IF_NULL(tensor_list); | |||
| tensor_list->clear(); | |||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||
| auto &tensor = inputs_[i]; | |||
| Tensor infer_tensor; | |||
| infer_tensor.SetName(input_names_[i]); | |||
| infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type())); | |||
| infer_tensor.SetShape(tensor->shape()); | |||
| tensor_list->push_back(infer_tensor); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MsModel::GetOutputsInfo(std::vector<Tensor> *tensor_list) const { | |||
| MS_EXCEPTION_IF_NULL(tensor_list); | |||
| tensor_list->clear(); | |||
| for (size_t i = 0; i < outputs_.size(); i++) { | |||
| auto &tensor = outputs_[i]; | |||
| Tensor infer_tensor; | |||
| infer_tensor.SetName(output_names_[i]); | |||
| infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type())); | |||
| infer_tensor.SetShape(tensor->shape()); | |||
| tensor_list->push_back(infer_tensor); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace api | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_SESSION_SESSION_H | |||
| #define MINDSPORE_CCSRC_SESSION_SESSION_H | |||
| #include <vector> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <map> | |||
| #include "backend/session/session_basic.h" | |||
| #include "ir/anf.h" | |||
| #include "include/api/status.h" | |||
| #include "cxx_api/model/model_impl.h" | |||
| #ifdef ENABLE_D | |||
| #include "runtime/context.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace api { | |||
| class MsModel : public ModelImpl { | |||
| public: | |||
| explicit MsModel(uint32_t device_id); | |||
| ~MsModel(); | |||
| Status LoadModel(const Buffer &model_data, ModelType type, | |||
| const std::map<std::string, std::string> &options) override; | |||
| Status LoadModel(const std::string &file_name, ModelType type, | |||
| const std::map<std::string, std::string> &options) override; | |||
| Status UnloadModel() override; | |||
| Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override; | |||
| Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override; | |||
| Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) override; | |||
| Status GetInputsInfo(std::vector<Tensor> *tensor_list) const override; | |||
| Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const override; | |||
| Status InitEnv(const std::unordered_map<std::string, std::string> &other_options); | |||
| Status FinalizeEnv(); | |||
| private: | |||
| std::shared_ptr<session::SessionBasic> session_impl_ = nullptr; | |||
| uint32_t graph_id_; | |||
| std::string device_type_; | |||
| int32_t device_id_ = 0; | |||
| #ifdef ENABLE_D | |||
| rtContext_t context_ = nullptr; | |||
| #endif | |||
| std::vector<tensor::TensorPtr> inputs_; | |||
| std::vector<tensor::TensorPtr> outputs_; | |||
| std::vector<std::string> input_names_; | |||
| std::vector<std::string> output_names_; | |||
| bool load_flag_ = false; | |||
| std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device); | |||
| Buffer ReadFile(const std::string &file); | |||
| static void RegAllOp(); | |||
| Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr); | |||
| Status CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const; | |||
| std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs); | |||
| Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs); | |||
| }; | |||
| API_REG_MODEL(AscendMS, MsModel); | |||
| } // namespace api | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H | |||