Browse Source

!29317 [MS][LITE] model pool: define predict tast queue

Merge pull request !29317 from yefeng/202_model_pool_api
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
546f054d11
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 186 additions and 84 deletions
  1. +6
    -2
      mindspore/lite/src/cxx_api/model/model_parallel_runner.cc
  2. +15
    -55
      mindspore/lite/src/cxx_api/model/model_pool.cc
  3. +4
    -8
      mindspore/lite/src/cxx_api/model/model_pool.h
  4. +35
    -10
      mindspore/lite/src/cxx_api/model/model_thread.cc
  5. +5
    -9
      mindspore/lite/src/cxx_api/model/model_thread.h
  6. +62
    -0
      mindspore/lite/src/cxx_api/model/predict_task_queue.cc
  7. +59
    -0
      mindspore/lite/src/cxx_api/model/predict_task_queue.h

+ 6
- 2
mindspore/lite/src/cxx_api/model/model_parallel_runner.cc View File

@@ -30,8 +30,12 @@ Status ModelParallelRunner::Init(const std::string &model_path, const std::strin
}

std::vector<MSTensor> ModelParallelRunner::GetInputs() {
std::vector<MSTensor> model_inputs = {};
return model_inputs;
auto inputs = ModelPool::GetInstance()->GetInputs();
if (inputs.empty()) {
MS_LOG(ERROR) << "model pool input is empty.";
return {};
}
return inputs;
}

Status ModelParallelRunner::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,


+ 15
- 55
mindspore/lite/src/cxx_api/model/model_pool.cc View File

@@ -133,83 +133,43 @@ ModelPoolContex ModelPool::CreateModelContext(const std::string &config_path) {
return model_pool_context;
}

void ModelPool::Run(std::shared_ptr<ModelThread> model) {
while (!model_pool_task_done_) {
std::unique_lock<std::mutex> data_lock(mtx_model_queue_);
while (model_data_queue_.empty() && !model_pool_task_done_) {
cv_in_data_.wait(data_lock);
}
if (model_pool_task_done_) {
cv_in_data_.notify_all();
break;
}
auto &model_data = model_data_queue_.front();
model_data_queue_.pop();
auto inputs = model_data->inputs;
auto *outputs = model_data->outputs;
auto before = model_data->before;
auto after = model_data->after;
cv_in_data_.notify_one();
data_lock.unlock();
auto status = model->Predict(*inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model predict failed.";
return;
}
auto output_size = outputs->size();
for (size_t i = 0; i < output_size; i++) {
auto copy_tensor =
mindspore::MSTensor::CreateTensor(outputs->at(i).Name(), outputs->at(i).DataType(), outputs->at(i).Shape(),
outputs->at(i).MutableData(), outputs->at(i).DataSize());
outputs->erase(outputs->begin());
outputs->push_back(*copy_tensor);
}
cv_in_data_.notify_one();
cv_out_data_.notify_all();
std::vector<MSTensor> ModelPool::GetInputs() {
if (model_inputs_.empty()) {
MS_LOG(ERROR) << "model input is empty.";
return {};
}
return model_inputs_;
}

Status ModelPool::Init(const std::string &model_path, const std::string &config_path, const Key &dec_key,
const std::string &dec_mode) {
auto model_pool_context = CreateModelContext(config_path);
for (size_t i = 0; i < num_models_; i++) {
auto model = std::make_shared<ModelThread>();
auto status = model->Init(model_path, model_pool_context[i], dec_key, dec_mode);
model_thread_vec_.push_back(std::thread(&ModelPool::Run, this, model));
auto model_thread = std::make_shared<ModelThread>();
auto status = model_thread->Init(model_path, model_pool_context[i], dec_key, dec_mode);
if (model_inputs_.empty()) {
model_inputs_ = model_thread->GetInputs();
}
model_thread_vec_.push_back(std::thread(&ModelThread::Run, model_thread));
}
return kSuccess;
}

Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
{
std::unique_lock<std::mutex> data_lock(mtx_data_queue_);
auto model_data = std::make_shared<ModelData>();
model_data->inputs = &inputs;
model_data->outputs = outputs;
model_data->before = before;
model_data->after = after;
model_data_queue_.push(model_data);
cv_in_data_.notify_one();
}
{
std::unique_lock<std::mutex> result_loack(mtx_data_queue_);
while (outputs->empty()) {
cv_out_data_.wait(result_loack);
}
}
outputs->clear();
auto predict_task = std::make_shared<PredictTask>(&inputs, outputs, before, after);
PredictTaskQueue::GetInstance()->PushPredictTask(predict_task);
PredictTaskQueue::GetInstance()->WaitUntilPredictActive(outputs);
return kSuccess;
}

ModelPool::~ModelPool() {
model_pool_task_done_ = true;
cv_in_data_.notify_all();
for (auto &th : model_thread_vec_) {
if (th.joinable()) {
th.join();
}
}
cv_in_data_.notify_one();
}
} // namespace mindspore
#endif

+ 4
- 8
mindspore/lite/src/cxx_api/model/model_pool.h View File

@@ -25,6 +25,7 @@
#include "include/api/status.h"
#include "include/api/context.h"
#include "src/cxx_api/model/model_thread.h"
#include "src/cxx_api/model/predict_task_queue.h"
namespace mindspore {
class ModelPool {
public:
@@ -34,6 +35,8 @@ class ModelPool {
Status Init(const std::string &model_path, const std::string &config_path, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);

std::vector<MSTensor> GetInputs();

Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);

@@ -45,15 +48,8 @@ class ModelPool {
void SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num);
ModelPoolContex CreateModelContext(const std::string &config_path);

std::mutex mtx_data_queue_;
std::mutex mtx_model_queue_;
std::condition_variable cv_out_data_;
std::condition_variable cv_in_data_;
std::condition_variable cv_model_;

std::vector<std::thread> model_thread_vec_;
std::queue<std::shared_ptr<ModelData>> model_data_queue_;
bool model_pool_task_done_ = false;
std::vector<MSTensor> model_inputs_;
size_t num_models_ = 5;
};
} // namespace mindspore


+ 35
- 10
mindspore/lite/src/cxx_api/model/model_thread.cc View File

@@ -18,6 +18,33 @@
#include "src/common/log.h"
#include "src/common/utils.h"
namespace mindspore {
void ModelThread::Run() {
while (!PredictTaskQueue::GetInstance()->IsPredictTaskDone()) {
auto task = PredictTaskQueue::GetInstance()->GetPreDictTask();
if (task == nullptr) {
break;
}
auto inputs = task->inputs;
auto *outputs = task->outputs;
auto before = task->before;
auto after = task->after;
auto status = Predict(*inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model predict failed.";
return;
}
auto output_size = outputs->size();
for (size_t i = 0; i < output_size; i++) {
auto copy_tensor =
mindspore::MSTensor::CreateTensor(outputs->at(i).Name(), outputs->at(i).DataType(), outputs->at(i).Shape(),
outputs->at(i).MutableData(), outputs->at(i).DataSize());
outputs->erase(outputs->begin());
outputs->push_back(*copy_tensor);
}
PredictTaskQueue::GetInstance()->ActiveTask();
}
}

Status ModelThread::Init(const std::string &model_path, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode) {
model_ = std::make_shared<Model>();
@@ -30,14 +57,13 @@ Status ModelThread::Init(const std::string &model_path, const std::shared_ptr<Co
return kSuccess;
}

Status ModelThread::ModelRun(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
auto status = model_->Predict(inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model predict failed.";
return status;
std::vector<MSTensor> ModelThread::GetInputs() {
if (model_ == nullptr) {
MS_LOG(ERROR) << "model is nullptr in ModelThread.";
return {};
}
return kSuccess;
auto inputs = model_->GetInputs();
return inputs;
}

std::pair<std::vector<std::vector<int64_t>>, bool> ModelThread::GetModelResize(
@@ -73,10 +99,9 @@ Status ModelThread::Predict(const std::vector<MSTensor> &inputs, std::vector<MST
return kLiteError;
}
}

auto status = ModelRun(inputs, outputs, before, after);
auto status = model_->Predict(inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model predict failed in ModelPool.";
MS_LOG(ERROR) << "model predict failed.";
return status;
}
return kSuccess;


+ 5
- 9
mindspore/lite/src/cxx_api/model/model_thread.h View File

@@ -25,14 +25,9 @@
#include <utility>
#include <memory>
#include "include/api/model.h"
#include "src/cxx_api/model/predict_task_queue.h"
namespace mindspore {
using ModelPoolContex = std::vector<std::shared_ptr<Context>>;
struct ModelData {
const std::vector<MSTensor> *inputs;
std::vector<MSTensor> *outputs;
MSKernelCallBack before;
MSKernelCallBack after;
};

class ModelThread {
public:
@@ -44,16 +39,17 @@ class ModelThread {
Status Init(const std::string &model_path, const std::shared_ptr<Context> &model_context, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);

std::vector<MSTensor> GetInputs();

Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);

void Run();

private:
std::pair<std::vector<std::vector<int64_t>>, bool> GetModelResize(const std::vector<MSTensor> &model_inputs,
const std::vector<MSTensor> &inputs);

Status ModelRun(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);

private:
std::shared_ptr<mindspore::Model> model_ = nullptr;
std::mutex mtx_model_;


+ 62
- 0
mindspore/lite/src/cxx_api/model/predict_task_queue.cc View File

@@ -0,0 +1,62 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/cxx_api/model/predict_task_queue.h"
namespace mindspore {
PredictTaskQueue::~PredictTaskQueue() {
predict_task_done_ = true;
task_push_cond_.notify_all();
}

void PredictTaskQueue::WaitUntilPredictActive(std::vector<MSTensor> *outputs) {
std::unique_lock<std::mutex> result_lock(mtx_predict_task_);
while (outputs->empty()) {
task_pop_cond_.wait(result_lock);
}
return;
}

void PredictTaskQueue::ActiveTask() {
task_push_cond_.notify_all();
task_pop_cond_.notify_all();
}

PredictTaskQueue *PredictTaskQueue::GetInstance() {
static PredictTaskQueue instance;
return &instance;
}

void PredictTaskQueue::PushPredictTask(std::shared_ptr<PredictTask> task) {
std::unique_lock<std::mutex> data_lock(mtx_predict_task_);
predict_task_.push(task);
task_push_cond_.notify_all();
}

std::shared_ptr<PredictTask> PredictTaskQueue::GetPreDictTask() {
std::unique_lock<std::mutex> task_lock(mtx_model_queue_);
while (predict_task_.empty() && !predict_task_done_) {
task_push_cond_.wait(task_lock);
}
if (predict_task_done_) {
task_push_cond_.notify_all();
return nullptr;
}
auto predict_task = predict_task_.front();
predict_task_.pop();
task_push_cond_.notify_all();
return predict_task;
}
} // namespace mindspore

+ 59
- 0
mindspore/lite/src/cxx_api/model/predict_task_queue.h View File

@@ -0,0 +1,59 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_CXX_API_PREDICT_TASK_QUEUE_H_
#define MINDSPORE_LITE_SRC_CXX_API_PREDICT_TASK_QUEUE_H_

#include <queue>
#include <mutex>
#include <memory>
#include <vector>
#include <condition_variable>
#include "include/api/types.h"
#include "include/api/status.h"
namespace mindspore {
struct PredictTask {
PredictTask(const std::vector<MSTensor> *in, std::vector<MSTensor> *out, MSKernelCallBack before,
MSKernelCallBack after)
: inputs(in), outputs(out), before(before), after(after) {}
const std::vector<MSTensor> *inputs;
std::vector<MSTensor> *outputs;
MSKernelCallBack before;
MSKernelCallBack after;
};

class PredictTaskQueue {
public:
static PredictTaskQueue *GetInstance();
~PredictTaskQueue();

void PushPredictTask(std::shared_ptr<PredictTask> task);
void WaitUntilPredictActive(std::vector<MSTensor> *outputs);
std::shared_ptr<PredictTask> GetPreDictTask();
void ActiveTask();
bool IsPredictTaskDone() { return predict_task_done_; }

private:
PredictTaskQueue() = default;
std::queue<std::shared_ptr<PredictTask>> predict_task_;

std::mutex mtx_predict_task_;
std::mutex mtx_model_queue_;
std::condition_variable task_pop_cond_;
std::condition_variable task_push_cond_;
bool predict_task_done_ = false;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_CXX_API_PREDICT_TASK_QUEUE_H_

Loading…
Cancel
Save