| @@ -35,6 +35,7 @@ set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/graph/graph.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/graph/graph_data.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model/model_impl.cc | |||
| ${API_MS_INFER_SRC} | |||
| ${API_ACL_SRC} | |||
| ${API_OPS_SRC} | |||
| @@ -156,32 +156,6 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s | |||
| return kSuccess; | |||
| } | |||
| Status AclModel::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| if (graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid data, graph_ is null."; | |||
| return kMCFailed; | |||
| } | |||
| if (graph_cell_ == nullptr) { | |||
| MS_LOG(WARNING) << "Model has not been built, it will be built with default options"; | |||
| Status ret = Build(); | |||
| if (ret != kSuccess) { | |||
| MS_LOG(ERROR) << "Build model failed."; | |||
| return ret; | |||
| } | |||
| } | |||
| MS_EXCEPTION_IF_NULL(graph_cell_); | |||
| Status ret = graph_cell_->Run(inputs, outputs); | |||
| if (ret != kSuccess) { | |||
| MS_LOG(ERROR) << "Run graph failed."; | |||
| return ret; | |||
| } | |||
| return kSuccess; | |||
| } | |||
| std::vector<MSTensor> AclModel::GetInputs() { | |||
| MS_EXCEPTION_IF_NULL(graph_cell_); | |||
| return graph_cell_->GetInputs(); | |||
| @@ -40,13 +40,10 @@ class AclModel : public ModelImpl { | |||
| Status Build() override; | |||
| Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) override; | |||
| Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override; | |||
| std::vector<MSTensor> GetInputs() override; | |||
| std::vector<MSTensor> GetOutputs() override; | |||
| private: | |||
| std::shared_ptr<GraphCell> graph_cell_; | |||
| ModelConverter model_converter_; | |||
| std::unique_ptr<AclModelOptions> options_; | |||
| std::map<std::string, std::shared_ptr<Graph>> dynamic_size_graph_map_; | |||
| @@ -208,86 +208,6 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) { | |||
| return buffer_ret; | |||
| } | |||
| Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) { | |||
| MultiProcess multi_process; | |||
| Buffer buffer_ret; | |||
| auto parent_process = [&model_data, &buffer_ret](MultiProcess *multi_process) -> Status { | |||
| MS_EXCEPTION_IF_NULL(multi_process); | |||
| // send original model to child | |||
| auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize()); | |||
| if (status != kSuccess) { | |||
| MS_LOG_ERROR << "Send original model to child process failed"; | |||
| return status; | |||
| } | |||
| // receive convert model result from child | |||
| CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * { | |||
| buffer_ret.ResizeData(msg_len); | |||
| return reinterpret_cast<uint8_t *>(buffer_ret.MutableData()); | |||
| }; | |||
| status = multi_process->ReceiveMsg(call); | |||
| if (status != kSuccess) { | |||
| MS_LOG_ERROR << "Receive result model from child process failed"; | |||
| return status; | |||
| } | |||
| return kSuccess; | |||
| }; | |||
| auto child_process = [this](MultiProcess *multi_process) -> Status { | |||
| MS_EXCEPTION_IF_NULL(multi_process); | |||
| // receive original model from parent | |||
| Buffer model; | |||
| CreateBufferCall call = [&model](size_t msg_len) -> uint8_t * { | |||
| model.ResizeData(msg_len); | |||
| return reinterpret_cast<uint8_t *>(model.MutableData()); | |||
| }; | |||
| auto status = multi_process->ReceiveMsg(call); | |||
| if (status != kSuccess) { | |||
| MS_LOG_ERROR << "Receive original model from parent process failed"; | |||
| return status; | |||
| } | |||
| Buffer model_result = LoadAscendIRInner(model); | |||
| if (model_result.DataSize() == 0) { | |||
| MS_LOG_ERROR << "Convert model from AIR to OM failed"; | |||
| return kMCFailed; | |||
| } | |||
| // send result model to parent | |||
| status = multi_process->SendMsg(model_result.Data(), model_result.DataSize()); | |||
| if (status != kSuccess) { | |||
| MS_LOG_ERROR << "Send result model to parent process failed"; | |||
| return status; | |||
| } | |||
| return kSuccess; | |||
| }; | |||
| auto status = multi_process.MainProcess(parent_process, child_process); | |||
| if (status != kSuccess) { | |||
| MS_LOG_ERROR << "Convert AIR model to OM model failed"; | |||
| } else { | |||
| MS_LOG_INFO << "Convert AIR model to OM model success"; | |||
| } | |||
| return buffer_ret; | |||
| } | |||
| Buffer ModelConverter::LoadMindIRInner(const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed."; | |||
| return Buffer(); | |||
| } | |||
| auto df_graph = ConvertFuncGraphToAIR(func_graph); | |||
| if (df_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed."; | |||
| return Buffer(); | |||
| } | |||
| std::map<std::string, std::string> init_options; | |||
| std::map<std::string, std::string> build_options; | |||
| if (options_ != nullptr) { | |||
| std::tie(init_options, build_options) = options_->GenAclOptions(); | |||
| } | |||
| auto om_data = BuildAirModel(df_graph, init_options, build_options); | |||
| return om_data; | |||
| } | |||
| Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) { | |||
| ge::Model load_model = ge::Model("loadmodel", "version2"); | |||
| ge::Status ret = | |||
| @@ -33,7 +33,6 @@ class ModelConverter { | |||
| ModelConverter() : options_(nullptr) {} | |||
| Buffer LoadMindIR(const FuncGraphPtr &func_graph); | |||
| Buffer LoadAscendIR(const Buffer &model_data); | |||
| void set_options(AclModelOptions *options) { options_ = options; } | |||
| @@ -43,7 +42,6 @@ class ModelConverter { | |||
| const std::map<std::string, std::string> &build_options); | |||
| AclModelOptions *options_; | |||
| Buffer LoadMindIRInner(const FuncGraphPtr &func_graph); | |||
| Buffer LoadAscendIRInner(const Buffer &model_data); | |||
| }; | |||
| } // namespace mindspore | |||
| @@ -59,7 +59,8 @@ Status MultiProcess::MainProcess(ProcessFuncCall parent_process, ProcessFuncCall | |||
| MS_LOG_ERROR << "Get shared memory failed"; | |||
| return ret; | |||
| } | |||
| shmat_data_addr_ = shmat_addr_ + sizeof(MessageFlag) * 2; | |||
| constexpr size_t kMsgStructNum = 2; | |||
| shmat_data_addr_ = shmat_addr_ + sizeof(MessageFlag) * kMsgStructNum; | |||
| shmat_data_max_size_ = memory_size_ - (shmat_data_addr_ - shmat_addr_); | |||
| MS_LOG_INFO << "Shm addr " << (uint64_t)shmat_addr_; | |||
| if (pid == 0) { | |||
| @@ -192,6 +193,7 @@ Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) { | |||
| void MultiProcess::HeartbeatThreadFunc(MultiProcess *multi_process) { multi_process->HeartbeatThreadFuncInner(); } | |||
| void MultiProcess::HeartbeatThreadFuncInner() { | |||
| constexpr uint64_t kOvertime = 1024; | |||
| uint64_t last_beat_cnt = 0; | |||
| uint64_t repeat_cnt = 0; | |||
| while (!stopped_) { | |||
| @@ -201,7 +203,7 @@ void MultiProcess::HeartbeatThreadFuncInner() { | |||
| break; | |||
| } | |||
| uint64_t heartbeat_gap = receive_msg_->heartbeat - last_beat_cnt; | |||
| if (heartbeat_gap > 0 && heartbeat_gap < 1024) { | |||
| if (heartbeat_gap > 0 && heartbeat_gap < kOvertime) { | |||
| last_beat_cnt = receive_msg_->heartbeat; | |||
| repeat_cnt = 0; | |||
| } else { | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "cxx_api/model/model_impl.h" | |||
| namespace mindspore { | |||
| Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| if (graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid data, graph_ is null."; | |||
| return kMCFailed; | |||
| } | |||
| if (graph_cell_ == nullptr) { | |||
| MS_LOG(WARNING) << "Model has not been built, it will be built with default options"; | |||
| Status ret = Build(); | |||
| if (ret != kSuccess) { | |||
| MS_LOG(ERROR) << "Build model failed."; | |||
| return ret; | |||
| } | |||
| } | |||
| MS_EXCEPTION_IF_NULL(graph_cell_); | |||
| Status ret = graph_cell_->Run(inputs, outputs); | |||
| if (ret != kSuccess) { | |||
| MS_LOG(ERROR) << "Run graph failed."; | |||
| return ret; | |||
| } | |||
| return kSuccess; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -37,7 +37,7 @@ class ModelImpl { | |||
| virtual Status Build() = 0; | |||
| virtual Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) = 0; | |||
| virtual Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0; | |||
| virtual Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs); | |||
| virtual std::vector<MSTensor> GetInputs() = 0; | |||
| virtual std::vector<MSTensor> GetOutputs() = 0; | |||
| @@ -58,8 +58,9 @@ class ModelImpl { | |||
| return graph_data->GetFuncGraph(); | |||
| } | |||
| std::shared_ptr<Graph> graph_; | |||
| std::shared_ptr<Context> model_context_; | |||
| std::shared_ptr<Graph> graph_ = nullptr; | |||
| std::shared_ptr<GraphCell> graph_cell_ = nullptr; | |||
| std::shared_ptr<Context> model_context_ = nullptr; | |||
| private: | |||
| friend class Model; | |||
| @@ -137,32 +137,6 @@ Status MsModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<st | |||
| return kSuccess; | |||
| } | |||
| Status MsModel::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| if (graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid data, graph_ is null."; | |||
| return kMCFailed; | |||
| } | |||
| if (graph_cell_ == nullptr) { | |||
| MS_LOG(INFO) << "Model has not been built, it will be built with default options"; | |||
| Status ret = Build(); | |||
| if (ret != kSuccess) { | |||
| MS_LOG(ERROR) << "Build model failed."; | |||
| return ret; | |||
| } | |||
| } | |||
| MS_EXCEPTION_IF_NULL(graph_cell_); | |||
| Status ret = graph_cell_->Run(inputs, outputs); | |||
| if (ret != kSuccess) { | |||
| MS_LOG(ERROR) << "Run graph failed."; | |||
| return ret; | |||
| } | |||
| return kSuccess; | |||
| } | |||
| std::vector<MSTensor> MsModel::GetInputs() { | |||
| MS_EXCEPTION_IF_NULL(graph_cell_); | |||
| return graph_cell_->GetInputs(); | |||
| @@ -41,8 +41,6 @@ class MsModel : public ModelImpl { | |||
| Status Build() override; | |||
| Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) override; | |||
| Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override; | |||
| std::vector<MSTensor> GetInputs() override; | |||
| std::vector<MSTensor> GetOutputs() override; | |||
| @@ -50,7 +48,6 @@ class MsModel : public ModelImpl { | |||
| std::shared_ptr<GraphCell> GenerateGraphCell(const std::vector<std::vector<int64_t>> &dims); | |||
| uint32_t GetDeviceID() const; | |||
| std::shared_ptr<GraphCell> graph_cell_; | |||
| std::map<std::string, std::shared_ptr<GraphCell>> dynamic_size_graph_map_; | |||
| }; | |||
| } // namespace mindspore | |||
| @@ -422,7 +422,7 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) { | |||
| FuncGraphManagerPtr manager_ptr = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager_ptr); | |||
| MapPrimTypeFuncGraph prim_graphs; | |||
| auto get_prim_graph = [&](const PrimitivePtr &prim, const AbstractFunctionPtr &type) { | |||
| auto get_prim_graph = [&prim_graphs](const PrimitivePtr &prim, const AbstractFunctionPtr &type) { | |||
| PrimTypePair prim_type = std::make_pair(prim, type); | |||
| if (prim_graphs.end() == prim_graphs.find(prim_type)) { | |||
| FuncGraphPtr g = std::make_shared<FuncGraph>(); | |||