diff --git a/mindspore/lite/src/cxx_api/model/model_parallel_runner.cc b/mindspore/lite/src/cxx_api/model/model_parallel_runner.cc index b51b9f383e..a42d0fd3b4 100644 --- a/mindspore/lite/src/cxx_api/model/model_parallel_runner.cc +++ b/mindspore/lite/src/cxx_api/model/model_parallel_runner.cc @@ -30,8 +30,12 @@ Status ModelParallelRunner::Init(const std::string &model_path, const std::strin } std::vector ModelParallelRunner::GetInputs() { - std::vector model_inputs = {}; - return model_inputs; + auto inputs = ModelPool::GetInstance()->GetInputs(); + if (inputs.empty()) { + MS_LOG(ERROR) << "model pool input is empty."; + return {}; + } + return inputs; } Status ModelParallelRunner::Predict(const std::vector &inputs, std::vector *outputs, diff --git a/mindspore/lite/src/cxx_api/model/model_pool.cc b/mindspore/lite/src/cxx_api/model/model_pool.cc index 457fc7b501..c34e2172da 100644 --- a/mindspore/lite/src/cxx_api/model/model_pool.cc +++ b/mindspore/lite/src/cxx_api/model/model_pool.cc @@ -133,83 +133,43 @@ ModelPoolContex ModelPool::CreateModelContext(const std::string &config_path) { return model_pool_context; } -void ModelPool::Run(std::shared_ptr model) { - while (!model_pool_task_done_) { - std::unique_lock data_lock(mtx_model_queue_); - while (model_data_queue_.empty() && !model_pool_task_done_) { - cv_in_data_.wait(data_lock); - } - if (model_pool_task_done_) { - cv_in_data_.notify_all(); - break; - } - auto &model_data = model_data_queue_.front(); - model_data_queue_.pop(); - auto inputs = model_data->inputs; - auto *outputs = model_data->outputs; - auto before = model_data->before; - auto after = model_data->after; - cv_in_data_.notify_one(); - data_lock.unlock(); - auto status = model->Predict(*inputs, outputs, before, after); - if (status != kSuccess) { - MS_LOG(ERROR) << "model predict failed."; - return; - } - auto output_size = outputs->size(); - for (size_t i = 0; i < output_size; i++) { - auto copy_tensor = - mindspore::MSTensor::CreateTensor(outputs->at(i).Name(), outputs->at(i).DataType(), outputs->at(i).Shape(), - outputs->at(i).MutableData(), outputs->at(i).DataSize()); - outputs->erase(outputs->begin()); - outputs->push_back(*copy_tensor); - } - cv_in_data_.notify_one(); - cv_out_data_.notify_all(); +std::vector ModelPool::GetInputs() { + if (model_inputs_.empty()) { + MS_LOG(ERROR) << "model input is empty."; + return {}; } + return model_inputs_; } Status ModelPool::Init(const std::string &model_path, const std::string &config_path, const Key &dec_key, const std::string &dec_mode) { auto model_pool_context = CreateModelContext(config_path); for (size_t i = 0; i < num_models_; i++) { - auto model = std::make_shared(); - auto status = model->Init(model_path, model_pool_context[i], dec_key, dec_mode); - model_thread_vec_.push_back(std::thread(&ModelPool::Run, this, model)); + auto model_thread = std::make_shared(); + auto status = model_thread->Init(model_path, model_pool_context[i], dec_key, dec_mode); + if (model_inputs_.empty()) { + model_inputs_ = model_thread->GetInputs(); + } + model_thread_vec_.push_back(std::thread(&ModelThread::Run, model_thread)); } return kSuccess; } Status ModelPool::Predict(const std::vector &inputs, std::vector *outputs, const MSKernelCallBack &before, const MSKernelCallBack &after) { - { - std::unique_lock data_lock(mtx_data_queue_); - auto model_data = std::make_shared(); - model_data->inputs = &inputs; - model_data->outputs = outputs; - model_data->before = before; - model_data->after = after; - model_data_queue_.push(model_data); - cv_in_data_.notify_one(); - } - { - std::unique_lock result_loack(mtx_data_queue_); - while (outputs->empty()) { - cv_out_data_.wait(result_loack); - } - } + outputs->clear(); + auto predict_task = std::make_shared(&inputs, outputs, before, after); + PredictTaskQueue::GetInstance()->PushPredictTask(predict_task); + PredictTaskQueue::GetInstance()->WaitUntilPredictActive(outputs); return kSuccess; } ModelPool::~ModelPool() { - model_pool_task_done_ = true; - cv_in_data_.notify_all(); for (auto &th : model_thread_vec_) { if (th.joinable()) { th.join(); } } - cv_in_data_.notify_one(); } } // namespace mindspore #endif diff --git a/mindspore/lite/src/cxx_api/model/model_pool.h b/mindspore/lite/src/cxx_api/model/model_pool.h index 7b3c87274a..874afbfd35 100644 --- a/mindspore/lite/src/cxx_api/model/model_pool.h +++ b/mindspore/lite/src/cxx_api/model/model_pool.h @@ -25,6 +25,7 @@ #include "include/api/status.h" #include "include/api/context.h" #include "src/cxx_api/model/model_thread.h" +#include "src/cxx_api/model/predict_task_queue.h" namespace mindspore { class ModelPool { public: @@ -34,6 +35,8 @@ class ModelPool { Status Init(const std::string &model_path, const std::string &config_path, const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm); + std::vector GetInputs(); + Status Predict(const std::vector &inputs, std::vector *outputs, const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); @@ -45,15 +48,8 @@ class ModelPool { void SetBindStrategy(std::vector> *all_model_bind_list, int thread_num); ModelPoolContex CreateModelContext(const std::string &config_path); - std::mutex mtx_data_queue_; - std::mutex mtx_model_queue_; - std::condition_variable cv_out_data_; - std::condition_variable cv_in_data_; - std::condition_variable cv_model_; - std::vector model_thread_vec_; - std::queue> model_data_queue_; - bool model_pool_task_done_ = false; + std::vector model_inputs_; size_t num_models_ = 5; }; } // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/model/model_thread.cc b/mindspore/lite/src/cxx_api/model/model_thread.cc index b60953ea43..f44761d3e4 100644 --- a/mindspore/lite/src/cxx_api/model/model_thread.cc +++ b/mindspore/lite/src/cxx_api/model/model_thread.cc @@ -18,6 +18,33 @@ #include "src/common/log.h" #include "src/common/utils.h" namespace mindspore { +void ModelThread::Run() { + while (!PredictTaskQueue::GetInstance()->IsPredictTaskDone()) { + auto task = PredictTaskQueue::GetInstance()->GetPreDictTask(); + if (task == nullptr) { + break; + } + auto inputs = task->inputs; + auto *outputs = task->outputs; + auto before = task->before; + auto after = task->after; + auto status = Predict(*inputs, outputs, before, after); + if (status != kSuccess) { + MS_LOG(ERROR) << "model predict failed."; + return; + } + auto output_size = outputs->size(); + for (size_t i = 0; i < output_size; i++) { + auto copy_tensor = + mindspore::MSTensor::CreateTensor(outputs->at(i).Name(), outputs->at(i).DataType(), outputs->at(i).Shape(), + outputs->at(i).MutableData(), outputs->at(i).DataSize()); + outputs->erase(outputs->begin()); + outputs->push_back(*copy_tensor); + } + PredictTaskQueue::GetInstance()->ActiveTask(); + } +} + Status ModelThread::Init(const std::string &model_path, const std::shared_ptr &model_context, const Key &dec_key, const std::string &dec_mode) { model_ = std::make_shared(); @@ -30,14 +57,13 @@ Status ModelThread::Init(const std::string &model_path, const std::shared_ptr &inputs, std::vector *outputs, - const MSKernelCallBack &before, const MSKernelCallBack &after) { - auto status = model_->Predict(inputs, outputs, before, after); - if (status != kSuccess) { - MS_LOG(ERROR) << "model predict failed."; - return status; +std::vector ModelThread::GetInputs() { + if (model_ == nullptr) { + MS_LOG(ERROR) << "model is nullptr in ModelThread."; + return {}; } - return kSuccess; + auto inputs = model_->GetInputs(); + return inputs; } std::pair>, bool> ModelThread::GetModelResize( @@ -73,10 +99,9 @@ Status ModelThread::Predict(const std::vector &inputs, std::vectorPredict(inputs, outputs, before, after); if (status != kSuccess) { - MS_LOG(ERROR) << "model predict failed in ModelPool."; + MS_LOG(ERROR) << "model predict failed."; return status; } return kSuccess; diff --git a/mindspore/lite/src/cxx_api/model/model_thread.h b/mindspore/lite/src/cxx_api/model/model_thread.h index 78b13367de..2ae65a7927 100644 --- a/mindspore/lite/src/cxx_api/model/model_thread.h +++ b/mindspore/lite/src/cxx_api/model/model_thread.h @@ -25,14 +25,9 @@ #include #include #include "include/api/model.h" +#include "src/cxx_api/model/predict_task_queue.h" namespace mindspore { using ModelPoolContex = std::vector>; -struct ModelData { - const std::vector *inputs; - std::vector *outputs; - MSKernelCallBack before; - MSKernelCallBack after; -}; class ModelThread { public: @@ -44,16 +39,17 @@ class ModelThread { Status Init(const std::string &model_path, const std::shared_ptr &model_context, const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm); + std::vector GetInputs(); + Status Predict(const std::vector &inputs, std::vector *outputs, const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); + void Run(); + private: std::pair>, bool> GetModelResize(const std::vector &model_inputs, const std::vector &inputs); - Status ModelRun(const std::vector &inputs, std::vector *outputs, - const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); - private: std::shared_ptr model_ = nullptr; std::mutex mtx_model_; diff --git a/mindspore/lite/src/cxx_api/model/predict_task_queue.cc b/mindspore/lite/src/cxx_api/model/predict_task_queue.cc new file mode 100644 index 0000000000..1b2bec2255 --- /dev/null +++ b/mindspore/lite/src/cxx_api/model/predict_task_queue.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/cxx_api/model/predict_task_queue.h" +namespace mindspore { +PredictTaskQueue::~PredictTaskQueue() { + predict_task_done_ = true; + task_push_cond_.notify_all(); +} + +void PredictTaskQueue::WaitUntilPredictActive(std::vector *outputs) { + std::unique_lock result_lock(mtx_predict_task_); + while (outputs->empty()) { + task_pop_cond_.wait(result_lock); + } + return; +} + +void PredictTaskQueue::ActiveTask() { + task_push_cond_.notify_all(); + task_pop_cond_.notify_all(); +} + +PredictTaskQueue *PredictTaskQueue::GetInstance() { + static PredictTaskQueue instance; + return &instance; +} + +void PredictTaskQueue::PushPredictTask(std::shared_ptr task) { + std::unique_lock data_lock(mtx_predict_task_); + predict_task_.push(task); + task_push_cond_.notify_all(); +} + +std::shared_ptr PredictTaskQueue::GetPreDictTask() { + std::unique_lock task_lock(mtx_model_queue_); + while (predict_task_.empty() && !predict_task_done_) { + task_push_cond_.wait(task_lock); + } + if (predict_task_done_) { + task_push_cond_.notify_all(); + return nullptr; + } + auto predict_task = predict_task_.front(); + predict_task_.pop(); + task_push_cond_.notify_all(); + return predict_task; +} +} // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/model/predict_task_queue.h b/mindspore/lite/src/cxx_api/model/predict_task_queue.h new file mode 100644 index 0000000000..8dfe117c0c --- /dev/null +++ b/mindspore/lite/src/cxx_api/model/predict_task_queue.h @@ -0,0 +1,59 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_CXX_API_PREDICT_TASK_QUEUE_H_ +#define MINDSPORE_LITE_SRC_CXX_API_PREDICT_TASK_QUEUE_H_ + +#include +#include +#include +#include +#include +#include "include/api/types.h" +#include "include/api/status.h" +namespace mindspore { +struct PredictTask { + PredictTask(const std::vector *in, std::vector *out, MSKernelCallBack before, + MSKernelCallBack after) + : inputs(in), outputs(out), before(before), after(after) {} + const std::vector *inputs; + std::vector *outputs; + MSKernelCallBack before; + MSKernelCallBack after; +}; + +class PredictTaskQueue { + public: + static PredictTaskQueue *GetInstance(); + ~PredictTaskQueue(); + + void PushPredictTask(std::shared_ptr task); + void WaitUntilPredictActive(std::vector *outputs); + std::shared_ptr GetPreDictTask(); + void ActiveTask(); + bool IsPredictTaskDone() { return predict_task_done_; } + + private: + PredictTaskQueue() = default; + std::queue> predict_task_; + + std::mutex mtx_predict_task_; + std::mutex mtx_model_queue_; + std::condition_variable task_pop_cond_; + std::condition_variable task_push_cond_; + bool predict_task_done_ = false; +}; +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_CXX_API_PREDICT_TASK_QUEUE_H_