From: @ling_qiao_min Reviewed-by: Signed-off-by:pull/14542/MERGE
| @@ -28,6 +28,7 @@ option(ENABLE_VERBOSE "" off) | |||||
| option(ENABLE_SSE "if x86_64 support SSE instruction set" off) | option(ENABLE_SSE "if x86_64 support SSE instruction set" off) | ||||
| option(ENABLE_AVX "if x86_64 support SSE instruction set" off) | option(ENABLE_AVX "if x86_64 support SSE instruction set" off) | ||||
| option(ENABLE_MINDRT "if support mindrt" on) | option(ENABLE_MINDRT "if support mindrt" on) | ||||
| option(SUBGRAPH_SPLIT "if support sub graph split" off) | |||||
| set(DIR_PREFIX mindspore-lite) | set(DIR_PREFIX mindspore-lite) | ||||
| set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION}) | set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION}) | ||||
| @@ -57,6 +58,9 @@ else() | |||||
| set(PROCESS_UNIT cpu) | set(PROCESS_UNIT cpu) | ||||
| endif() | endif() | ||||
| if(SUBGRAPH_SPLIT) | |||||
| add_compile_definitions(SUBGRAPH_SPLIT) | |||||
| endif() | |||||
| if(SUPPORT_NPU) | if(SUPPORT_NPU) | ||||
| set(DDK_PATH "$ENV{HWHIAI_DDK}/ddk/ai_ddk_lib") | set(DDK_PATH "$ENV{HWHIAI_DDK}/ddk/ai_ddk_lib") | ||||
| @@ -132,6 +132,7 @@ set(LITE_SRC | |||||
| ${LITE_DIR}/src/common/tensor_util.cc | ${LITE_DIR}/src/common/tensor_util.cc | ||||
| ${LITE_DIR}/src/runtime/infer_manager.cc | ${LITE_DIR}/src/runtime/infer_manager.cc | ||||
| ${LITE_DIR}/src/lite_model.cc | ${LITE_DIR}/src/lite_model.cc | ||||
| ${LITE_DIR}/src/sub_graph_split.cc | |||||
| ${LITE_DIR}/src/tensorlist.cc | ${LITE_DIR}/src/tensorlist.cc | ||||
| ${LITE_DIR}/src/tensor.cc | ${LITE_DIR}/src/tensor.cc | ||||
| ${LITE_DIR}/src/dequant.cc | ${LITE_DIR}/src/dequant.cc | ||||
| @@ -59,6 +59,7 @@ set(LITE_SRC | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc | ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc | ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc | ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_split.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc | ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc | ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc | ||||
| @@ -43,6 +43,16 @@ void LiteKernel::FreeWorkspace() { | |||||
| free(workspace_); | free(workspace_); | ||||
| workspace_ = nullptr; | workspace_ = nullptr; | ||||
| } | } | ||||
| int LiteKernel::DecOutTensorRefCount() { | |||||
| for (auto *tensor : this->out_tensors_) { | |||||
| tensor->set_ref_count(tensor->ref_count() - 1); | |||||
| if (0 >= tensor->ref_count()) { | |||||
| tensor->FreeData(); | |||||
| } | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| #endif | #endif | ||||
| bool LiteKernel::IsReady(const std::vector<lite::Tensor *> &scope_tensors) { | bool LiteKernel::IsReady(const std::vector<lite::Tensor *> &scope_tensors) { | ||||
| return std::all_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *in_tensor) { | return std::all_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *in_tensor) { | ||||
| @@ -66,16 +76,6 @@ void LiteKernel::InitOutTensorInitRefCount() { | |||||
| } | } | ||||
| } | } | ||||
| int LiteKernel::DecOutTensorRefCount() { | |||||
| for (auto *tensor : this->out_tensors_) { | |||||
| tensor->set_ref_count(tensor->ref_count() - 1); | |||||
| if (0 >= tensor->ref_count()) { | |||||
| tensor->FreeData(); | |||||
| } | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| int LiteKernel::FreeInWorkTensor() const { | int LiteKernel::FreeInWorkTensor() const { | ||||
| for (auto &in_tensor : this->in_tensors_) { | for (auto &in_tensor : this->in_tensors_) { | ||||
| MS_ASSERT(in_tensor != nullptr); | MS_ASSERT(in_tensor != nullptr); | ||||
| @@ -35,7 +35,16 @@ static constexpr int kPerTensor = 1; | |||||
| static constexpr size_t kPerBatch = 3; | static constexpr size_t kPerBatch = 3; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| enum KERNEL_ARCH { kCPU, kGPU, kAPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU }; | |||||
| enum KERNEL_ARCH { | |||||
| kCPU, | |||||
| kGPU, | |||||
| kAPU, | |||||
| kNPU, | |||||
| kALL, /* Support GPU NPU CPU */ | |||||
| kKernelArch_MIN = kCPU, | |||||
| kKernelArch_MAX = kALL | |||||
| }; | |||||
| struct KernelKey { | struct KernelKey { | ||||
| KERNEL_ARCH arch; | KERNEL_ARCH arch; | ||||
| TypeId data_type; | TypeId data_type; | ||||
| @@ -161,8 +170,6 @@ class LiteKernel { | |||||
| virtual void InitOutTensorInitRefCount(); | virtual void InitOutTensorInitRefCount(); | ||||
| int DecOutTensorRefCount(); | |||||
| virtual int FreeInWorkTensor() const; | virtual int FreeInWorkTensor() const; | ||||
| KernelKey desc() const { return desc_; } | KernelKey desc() const { return desc_; } | ||||
| @@ -171,6 +178,8 @@ class LiteKernel { | |||||
| SubGraphType subgraph_type() const { return this->subgraph_type_; } | SubGraphType subgraph_type() const { return this->subgraph_type_; } | ||||
| const lite::InnerContext *context() const { return this->context_; } | |||||
| virtual std::string ToString() const; | virtual std::string ToString() const; | ||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| @@ -179,6 +188,7 @@ class LiteKernel { | |||||
| static void AllocWorkspace(size_t size); | static void AllocWorkspace(size_t size); | ||||
| static void FreeWorkspace(); | static void FreeWorkspace(); | ||||
| void *workspace() { return workspace_; } | void *workspace() { return workspace_; } | ||||
| int DecOutTensorRefCount(); | |||||
| #endif | #endif | ||||
| protected: | protected: | ||||
| @@ -32,7 +32,7 @@ int LiteOpActor::CompileArrow() { | |||||
| } | } | ||||
| } | } | ||||
| if (to_input_index == -1) { | if (to_input_index == -1) { | ||||
| break; | |||||
| continue; | |||||
| } | } | ||||
| auto id = out->name() + this->GetAID().Url(); | auto id = out->name() + this->GetAID().Url(); | ||||
| auto arrow = std::make_shared<OpArrow>(i, id, to_input_index); | auto arrow = std::make_shared<OpArrow>(i, id, to_input_index); | ||||
| @@ -41,12 +41,19 @@ int LiteOpActor::CompileArrow() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| output_op_arrows_.emplace_back(std::move(arrow)); | output_op_arrows_.emplace_back(std::move(arrow)); | ||||
| break; | |||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void LiteOpActor::AsyncOutput(OpContext<Tensor> *context) { | |||||
| for (auto op_arrow : output_op_arrows_) { | |||||
| auto data = context->outputData_->at(op_arrow->from_output_index_); | |||||
| Async(op_arrow->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data, context); | |||||
| } | |||||
| return; | |||||
| } | |||||
| void LiteOpActor::SetOutputData(OpContext<Tensor> *context) { | void LiteOpActor::SetOutputData(OpContext<Tensor> *context) { | ||||
| auto size = context->outputData_->size(); | auto size = context->outputData_->size(); | ||||
| MS_ASSERT(size == context->results_->size()); | MS_ASSERT(size == context->results_->size()); | ||||
| @@ -50,6 +50,7 @@ class LiteOpActor : public OpActor<lite::Tensor> { | |||||
| return; | return; | ||||
| } | } | ||||
| input_op_datas_.erase(op_uuid); | input_op_datas_.erase(op_uuid); | ||||
| AsyncOutput(context); | |||||
| SetOutputData(context); | SetOutputData(context); | ||||
| } | } | ||||
| void Init() { | void Init() { | ||||
| @@ -83,6 +84,7 @@ class LiteOpActor : public OpActor<lite::Tensor> { | |||||
| private: | private: | ||||
| void SetOutputData(OpContext<Tensor> *context); | void SetOutputData(OpContext<Tensor> *context); | ||||
| void AsyncOutput(OpContext<Tensor> *context); | |||||
| kernel::LiteKernel *kernel_; | kernel::LiteKernel *kernel_; | ||||
| }; | }; | ||||
| @@ -108,6 +108,12 @@ void LiteModel::Free() { | |||||
| tensor_buf = nullptr; | tensor_buf = nullptr; | ||||
| } | } | ||||
| attr_tensor_bufs_.resize(0); | attr_tensor_bufs_.resize(0); | ||||
| for (auto &node_buf : node_bufs_) { | |||||
| free(node_buf); | |||||
| node_buf = nullptr; | |||||
| } | |||||
| node_bufs_.resize(0); | |||||
| } | } | ||||
| void LiteModel::Destroy() { | void LiteModel::Destroy() { | ||||
| @@ -192,6 +192,7 @@ class LiteModel : public Model { | |||||
| public: | public: | ||||
| size_t buf_size_ = 0; | size_t buf_size_ = 0; | ||||
| std::vector<char *> node_bufs_; | |||||
| protected: | protected: | ||||
| std::vector<char *> attr_tensor_bufs_; | std::vector<char *> attr_tensor_bufs_; | ||||
| @@ -399,8 +399,15 @@ int LiteSession::CompileGraph(Model *model) { | |||||
| #endif | #endif | ||||
| InitGraphInOutTensors(model); | InitGraphInOutTensors(model); | ||||
| ret = PrepareKernels(model); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare kernels failed: " << ret; | |||||
| is_running_.store(false); | |||||
| return ret; | |||||
| } | |||||
| #ifdef ENABLE_MINDRT | #ifdef ENABLE_MINDRT | ||||
| if (context_->IsCpuEnabled() && !context_->IsGpuEnabled() && !context_->IsNpuEnabled() && kernels_.size() == 1) { | |||||
| if (kernels_.size() == 1) { | |||||
| executor_ = new (std::nothrow) MindrtExecutor(); | executor_ = new (std::nothrow) MindrtExecutor(); | ||||
| } else { | } else { | ||||
| executor_ = new (std::nothrow) Executor(); | executor_ = new (std::nothrow) Executor(); | ||||
| @@ -420,16 +427,10 @@ int LiteSession::CompileGraph(Model *model) { | |||||
| is_running_.store(false); | is_running_.store(false); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = PrepareKernels(model); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare kernels failed: " << ret; | |||||
| is_running_.store(false); | |||||
| return ret; | |||||
| } | |||||
| is_running_.store(false); | is_running_.store(false); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | |||||
| } // namespace lite | |||||
| int LiteSession::PrepareKernels(Model *model) { | int LiteSession::PrepareKernels(Model *model) { | ||||
| std::vector<kernel::LiteKernel *> all_kernels; | std::vector<kernel::LiteKernel *> all_kernels; | ||||
| @@ -102,10 +102,7 @@ int NPUExecutor::Run(const std::vector<Tensor *> &in_tensors, const std::vector< | |||||
| memcpy(npu_input_tensors_[i]->GetBuffer(), data, in_tensors[index]->Size()); | memcpy(npu_input_tensors_[i]->GetBuffer(), data, in_tensors[index]->Size()); | ||||
| inputs_visited[index] = true; | inputs_visited[index] = true; | ||||
| in_tensors[index]->set_ref_count(in_tensors[index]->ref_count() - 1); | |||||
| if (in_tensors[index]->ref_count() <= 0) { | |||||
| in_tensors[index]->FreeData(); | |||||
| } | |||||
| in_tensors[index]->DecRefCount(); | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| @@ -38,6 +38,7 @@ class SubGraphNpuKernel : public SubGraphKernel { | |||||
| const lite::InnerContext *ctx = nullptr, lite::NPUManager *npu_manager = nullptr) | const lite::InnerContext *ctx = nullptr, lite::NPUManager *npu_manager = nullptr) | ||||
| : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx), npu_manager_(npu_manager) { | : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx), npu_manager_(npu_manager) { | ||||
| subgraph_type_ = kNpuSubGraph; | subgraph_type_ = kNpuSubGraph; | ||||
| desc_.arch = kernel::KERNEL_ARCH::kNPU; | |||||
| } | } | ||||
| ~SubGraphNpuKernel() override; | ~SubGraphNpuKernel() override; | ||||
| @@ -70,7 +70,7 @@ class DefaultAllocator : public Allocator { | |||||
| std::multimap<size_t, MemBuf *> freeList_; | std::multimap<size_t, MemBuf *> freeList_; | ||||
| // 6 is empirical value | // 6 is empirical value | ||||
| int shiftFactor_ = 6; | int shiftFactor_ = 6; | ||||
| bool lockFlag_ = false; | |||||
| bool lockFlag_ = true; | |||||
| }; | }; | ||||
| constexpr int64_t MAX_MALLOC_SIZE = static_cast<size_t>(2000) * 1024 * 1024; | constexpr int64_t MAX_MALLOC_SIZE = static_cast<size_t>(2000) * 1024 * 1024; | ||||
| @@ -34,6 +34,7 @@ class OpenCLSubGraph : public SubGraphKernel { | |||||
| : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx) { | : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx) { | ||||
| ocl_runtime_ = ocl_runtime_wrap_.GetInstance(); | ocl_runtime_ = ocl_runtime_wrap_.GetInstance(); | ||||
| subgraph_type_ = kGpuSubGraph; | subgraph_type_ = kGpuSubGraph; | ||||
| desc_.arch = kernel::KERNEL_ARCH::kGPU; | |||||
| this->name_ = "GpuSubGraph"; | this->name_ = "GpuSubGraph"; | ||||
| nodes_set_.insert(nodes.begin(), nodes.end()); | nodes_set_.insert(nodes.begin(), nodes.end()); | ||||
| all_kernels_infer_done_ = std::all_of(nodes_.begin(), nodes_.end(), [](const kernel::LiteKernel *kernel) { | all_kernels_infer_done_ = std::all_of(nodes_.begin(), nodes_.end(), [](const kernel::LiteKernel *kernel) { | ||||
| @@ -30,6 +30,7 @@ | |||||
| #include "src/common/version_manager.h" | #include "src/common/version_manager.h" | ||||
| #include "src/common/prim_util.h" | #include "src/common/prim_util.h" | ||||
| #include "src/runtime/infer_manager.h" | #include "src/runtime/infer_manager.h" | ||||
| #include "src/sub_graph_split.h" | |||||
| #include "src/dequant.h" | #include "src/dequant.h" | ||||
| #include "nnacl/matmul_parameter.h" | #include "nnacl/matmul_parameter.h" | ||||
| #if GPU_OPENCL | #if GPU_OPENCL | ||||
| @@ -71,6 +72,12 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) { | |||||
| } | } | ||||
| this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_); | this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_); | ||||
| #ifdef SUBGRAPH_SPLIT | |||||
| auto search_sub_graph = SearchSubGraph(src_model_, this->graph_output_node_indexes_); | |||||
| search_sub_graph.SubGraphSplitByOutput(); | |||||
| #endif | |||||
| bool infer_shape_interrupt = false; | bool infer_shape_interrupt = false; | ||||
| auto ret = InferSubGraphShape(kMainSubGraphIndex, &infer_shape_interrupt); | auto ret = InferSubGraphShape(kMainSubGraphIndex, &infer_shape_interrupt); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -89,7 +96,11 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) { | |||||
| MS_LOG(ERROR) << "Schedule run pass failed."; | MS_LOG(ERROR) << "Schedule run pass failed."; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = ConstructSubGraphs(dst_kernels); | |||||
| auto src_kernel = *dst_kernels; | |||||
| dst_kernels->clear(); | |||||
| std::map<const kernel::LiteKernel *, bool> is_kernel_finish; | |||||
| ret = ConstructSubGraphs(src_kernel, dst_kernels, &is_kernel_finish); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "ConstructSubGraphs failed."; | MS_LOG(ERROR) << "ConstructSubGraphs failed."; | ||||
| return ret; | return ret; | ||||
| @@ -473,6 +484,14 @@ kernel::LiteKernel *Scheduler::SchedulePartialToKernel(const lite::Model::Node * | |||||
| MS_LOG(ERROR) << "Schedule partial failed, name: " << src_node->name_; | MS_LOG(ERROR) << "Schedule partial failed, name: " << src_node->name_; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| FindAllInoutKernels(sub_kernels); | |||||
| ret = RunPass(&sub_kernels); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "SchedulePartialToKernel run pass failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(sub_kernels.front()); | auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(sub_kernels.front()); | ||||
| auto subgraph = CreateSubGraphKernel(sub_kernels, &in_tensors, &out_tensors, cur_sub_graph_type); | auto subgraph = CreateSubGraphKernel(sub_kernels, &in_tensors, &out_tensors, cur_sub_graph_type); | ||||
| subgraph->set_name("subgraph_" + src_node->name_); | subgraph->set_name("subgraph_" + src_node->name_); | ||||
| @@ -602,35 +621,33 @@ std::vector<kernel::LiteKernel *> Scheduler::FindAllSubGraphKernels( | |||||
| return sub_kernels; | return sub_kernels; | ||||
| } | } | ||||
| int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) { | |||||
| auto old_kernels = *kernels; | |||||
| kernels->clear(); | |||||
| std::map<const kernel::LiteKernel *, bool> is_kernel_finish; | |||||
| for (auto kernel : old_kernels) { | |||||
| is_kernel_finish[kernel] = false; | |||||
| int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> src_kernel, | |||||
| std::vector<kernel::LiteKernel *> *dst_kernel, | |||||
| std::map<const kernel::LiteKernel *, bool> *is_kernel_finish) { | |||||
| for (auto kernel : src_kernel) { | |||||
| (*is_kernel_finish)[kernel] = false; | |||||
| } | } | ||||
| while (true) { | while (true) { | ||||
| auto head_kernel_iter = std::find_if(old_kernels.begin(), old_kernels.end(), [&](const kernel::LiteKernel *kernel) { | |||||
| auto head_kernel_iter = std::find_if(src_kernel.begin(), src_kernel.end(), [&](const kernel::LiteKernel *kernel) { | |||||
| auto kernel_inputs = kernel->in_kernels(); | auto kernel_inputs = kernel->in_kernels(); | ||||
| if (is_kernel_finish[kernel]) { | |||||
| if ((*is_kernel_finish)[kernel]) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| // when merge is removed, this if is removed automatically | // when merge is removed, this if is removed automatically | ||||
| if (kernel->Type() == schema::PrimitiveType_Merge) { | if (kernel->Type() == schema::PrimitiveType_Merge) { | ||||
| return MergeOpIsReady(kernel, is_kernel_finish); | |||||
| return MergeOpIsReady(kernel, (*is_kernel_finish)); | |||||
| } else { | } else { | ||||
| return std::all_of(kernel_inputs.begin(), kernel_inputs.end(), | return std::all_of(kernel_inputs.begin(), kernel_inputs.end(), | ||||
| [&](kernel::LiteKernel *kernel) { return is_kernel_finish[kernel]; }); | |||||
| [&](kernel::LiteKernel *kernel) { return (*is_kernel_finish)[kernel]; }); | |||||
| } | } | ||||
| }); | }); | ||||
| if (head_kernel_iter == old_kernels.end()) { | |||||
| if (head_kernel_iter == src_kernel.end()) { | |||||
| break; | break; | ||||
| } | } | ||||
| auto head_kernel = *head_kernel_iter; | auto head_kernel = *head_kernel_iter; | ||||
| if (head_kernel->subgraph_type() != kernel::kNotSubGraph) { | if (head_kernel->subgraph_type() != kernel::kNotSubGraph) { | ||||
| is_kernel_finish[head_kernel] = true; | |||||
| kernels->emplace_back(head_kernel); | |||||
| (*is_kernel_finish)[head_kernel] = true; | |||||
| dst_kernel->push_back(head_kernel); | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (head_kernel->desc().arch == mindspore::kernel::kAPU) { | if (head_kernel->desc().arch == mindspore::kernel::kAPU) { | ||||
| @@ -638,15 +655,15 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) { | |||||
| return RET_NOT_SUPPORT; | return RET_NOT_SUPPORT; | ||||
| } | } | ||||
| auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel); | auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel); | ||||
| auto sub_kernels = FindAllSubGraphKernels(head_kernel, &is_kernel_finish); | |||||
| auto sub_kernels = FindAllSubGraphKernels(head_kernel, is_kernel_finish); | |||||
| auto subgraph = CreateSubGraphKernel(sub_kernels, nullptr, nullptr, cur_sub_graph_type); | auto subgraph = CreateSubGraphKernel(sub_kernels, nullptr, nullptr, cur_sub_graph_type); | ||||
| if (subgraph == nullptr) { | if (subgraph == nullptr) { | ||||
| MS_LOG(ERROR) << "Create SubGraphKernel failed"; | MS_LOG(ERROR) << "Create SubGraphKernel failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| kernels->emplace_back(subgraph); | |||||
| dst_kernel->emplace_back(subgraph); | |||||
| } | } | ||||
| for (auto *subgraph : *kernels) { | |||||
| for (auto *subgraph : *dst_kernel) { | |||||
| auto ret = subgraph->Init(); | auto ret = subgraph->Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Init SubGraph failed: " << ret; | MS_LOG(ERROR) << "Init SubGraph failed: " << ret; | ||||
| @@ -840,6 +857,7 @@ int Scheduler::RunPass(std::vector<kernel::LiteKernel *> *dst_kernels) { | |||||
| npu_pass_manager_->AddPass(fusion_pass); | npu_pass_manager_->AddPass(fusion_pass); | ||||
| ret = npu_pass_manager_->Run(); | ret = npu_pass_manager_->Run(); | ||||
| npu_pass_manager_->Clear(); | |||||
| #endif | #endif | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -74,7 +74,8 @@ class Scheduler { | |||||
| static void FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels); | static void FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels); | ||||
| // vector<LiteKernel/SubGraphKernel> --> vector<SubGraphKernel> | // vector<LiteKernel/SubGraphKernel> --> vector<SubGraphKernel> | ||||
| int ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels); | |||||
| int ConstructSubGraphs(std::vector<kernel::LiteKernel *> src_kernel, std::vector<kernel::LiteKernel *> *dst_kernel, | |||||
| std::map<const kernel::LiteKernel *, bool> *sinked_kernel_map); | |||||
| // create subgraph_kernel from a vector of kernel | // create subgraph_kernel from a vector of kernel | ||||
| kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels, | kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels, | ||||
| @@ -128,6 +128,7 @@ class CpuSubGraph : public SubGraphKernel { | |||||
| std::vector<LiteKernel *> nodes, const lite::InnerContext *ctx) | std::vector<LiteKernel *> nodes, const lite::InnerContext *ctx) | ||||
| : SubGraphKernel(inputs, outputs, std::move(in_kernels), std::move(out_kernels), std::move(nodes), ctx) { | : SubGraphKernel(inputs, outputs, std::move(in_kernels), std::move(out_kernels), std::move(nodes), ctx) { | ||||
| subgraph_type_ = kCpuFP32SubGraph; | subgraph_type_ = kCpuFP32SubGraph; | ||||
| desc_.arch = kernel::KERNEL_ARCH::kCPU; | |||||
| } | } | ||||
| ~CpuSubGraph() override { delete this->executor_; } | ~CpuSubGraph() override { delete this->executor_; } | ||||
| @@ -0,0 +1,269 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/sub_graph_split.h" | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include "src/tensor.h" | |||||
| #include "schema/inner/ops_generated.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| namespace mindspore::lite { | |||||
| #ifdef SUBGRAPH_SPLIT | |||||
| const schema::Primitive *SearchSubGraph::CreatePartialPrimitive(int64_t subgraph_index) { | |||||
| flatbuffers::FlatBufferBuilder fbb(1024); | |||||
| auto val_offset = schema::CreatePartialFusion(fbb, subgraph_index); | |||||
| auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_PartialFusion, val_offset.o); | |||||
| fbb.Finish(prim_offset); | |||||
| auto tmp_buf = fbb.GetBufferPointer(); | |||||
| auto prim_buf = reinterpret_cast<char *>(malloc(fbb.GetSize())); | |||||
| memcpy(prim_buf, tmp_buf, fbb.GetSize()); | |||||
| auto primitive = flatbuffers::GetRoot<schema::Primitive>(prim_buf); | |||||
| fbb.Clear(); | |||||
| model_->node_bufs_.push_back(prim_buf); | |||||
| return std::move(primitive); | |||||
| } | |||||
| void SearchSubGraph::ConvertSubGraphToModel() { | |||||
| Model::SubGraph *main_graphs = model_->sub_graphs_.front(); | |||||
| for (Subgraph &subgraph : sub_graphs_) { | |||||
| if (subgraph.nodes_.empty()) { | |||||
| continue; | |||||
| } | |||||
| mindspore::kernel::KERNEL_ARCH device = subgraph.device_; | |||||
| int new_sub_index = model_->sub_graphs_.size(); | |||||
| int partial_index = model_->all_nodes_.size(); | |||||
| Model::SubGraph *new_sub_graph = new (std::nothrow) Model::SubGraph(); | |||||
| if (new_sub_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "New sub graph failed!"; | |||||
| return; | |||||
| } | |||||
| new_sub_graph->name_ = "Subgraph-split-" + std::to_string(new_sub_index); | |||||
| Model::Node *new_partial_node = new (std::nothrow) Model::Node(); | |||||
| if (new_partial_node == nullptr) { | |||||
| MS_LOG(ERROR) << "New partial node failed!"; | |||||
| return; | |||||
| } | |||||
| new_partial_node->name_ = "Partial-subgraph-split-" + std::to_string(new_sub_index); | |||||
| new_partial_node->node_type_ = mindspore::lite::NodeType_ValueNode; | |||||
| new_partial_node->primitive_ = CreatePartialPrimitive(new_sub_index); | |||||
| while (!subgraph.nodes_.empty()) { | |||||
| uint32_t node_index = subgraph.nodes_.front(); | |||||
| new_sub_graph->node_indices_.push_back(node_index); | |||||
| VectorErase(&main_graphs->node_indices_, node_index); | |||||
| VectorErase(&subgraph.nodes_, node_index); | |||||
| model_->all_nodes_[node_index]->device_type_ = device; | |||||
| } | |||||
| for (uint32_t head_index : subgraph.heads_) { | |||||
| Model::Node *head_node = model_->all_nodes_[head_index]; | |||||
| std::vector<uint32_t> inputs = head_node->input_indices_; | |||||
| for (auto input : inputs) { | |||||
| if (tensors_[input].type_ == CONST) { | |||||
| continue; | |||||
| } | |||||
| if (std::find(new_sub_graph->input_indices_.begin(), new_sub_graph->input_indices_.end(), input) != | |||||
| new_sub_graph->input_indices_.end()) { | |||||
| continue; | |||||
| } | |||||
| new_sub_graph->input_indices_.insert(new_sub_graph->input_indices_.end(), input); | |||||
| new_partial_node->input_indices_.insert(new_partial_node->input_indices_.end(), input); | |||||
| } | |||||
| } | |||||
| for (uint32_t end_index : subgraph.ends_) { | |||||
| Model::Node *end_node = model_->all_nodes_[end_index]; | |||||
| std::vector<uint32_t> outputs = end_node->output_indices_; | |||||
| new_sub_graph->output_indices_.insert(new_sub_graph->output_indices_.end(), outputs.begin(), outputs.end()); | |||||
| new_partial_node->output_indices_.insert(new_partial_node->output_indices_.end(), outputs.begin(), outputs.end()); | |||||
| } | |||||
| main_graphs->node_indices_.push_back(partial_index); | |||||
| model_->all_nodes_.push_back(std::move(new_partial_node)); | |||||
| model_->sub_graphs_.push_back(std::move(new_sub_graph)); | |||||
| } | |||||
| return; | |||||
| } | |||||
| bool SearchSubGraph::IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes) { | |||||
| std::vector<uint32_t> output_indexes = node_list_[node_index]->output_indices_; | |||||
| std::vector<uint32_t> output_nodes; | |||||
| for (uint32_t out_t : output_indexes) { | |||||
| std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_; | |||||
| output_nodes.insert(output_nodes.end(), cur_nodes.begin(), cur_nodes.end()); | |||||
| } | |||||
| for (uint32_t out_n : output_nodes) { | |||||
| if (find(ready_nodes.begin(), ready_nodes.end(), out_n) == ready_nodes.end()) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void SearchSubGraph::InsertNode(uint32_t index, Subgraph *subgraph) { | |||||
| if (subgraph->search_terminate_) { | |||||
| return; | |||||
| } | |||||
| Model::Node *node = node_list_[index]; | |||||
| if (node == nullptr) { | |||||
| return; | |||||
| } | |||||
| std::vector<uint32_t> input = node->input_indices_; | |||||
| /* remove const node */ | |||||
| for (int i = input.size() - 1; i >= 0; i--) { | |||||
| if (tensors_[input[i]].type_ == CONST) { | |||||
| input.erase(input.begin() + i); | |||||
| } | |||||
| } | |||||
| /* all node_input is graph_input */ | |||||
| for (size_t i = 0; i < input.size(); i++) { | |||||
| if (tensors_[input[i]].type_ != INPUT) { | |||||
| break; | |||||
| } | |||||
| subgraph->heads_.clear(); | |||||
| subgraph->ends_.clear(); | |||||
| subgraph->nodes_.clear(); | |||||
| subgraph->search_terminate_ = true; | |||||
| return; | |||||
| } | |||||
| /* split in graph */ | |||||
| if (IsNodeSubGraphHead(index, subgraph->nodes_)) { | |||||
| if (subgraph->nodes_.empty()) { | |||||
| subgraph->search_terminate_ = true; | |||||
| return; | |||||
| } | |||||
| subgraph->heads_.push_back(subgraph->nodes_.front()); | |||||
| return; | |||||
| } | |||||
| if (find(output_nodes_.begin(), output_nodes_.end(), index) != output_nodes_.end()) { | |||||
| subgraph->ends_.push_back(index); | |||||
| } | |||||
| /* node insert in current subgraph */ | |||||
| subgraph->nodes_.insert(subgraph->nodes_.begin(), index); | |||||
| node_list_[index] = nullptr; | |||||
| /* search for next node */ | |||||
| for (uint32_t in : input) { | |||||
| auto next_nodes = tensors_[in].out_nodes_; | |||||
| for (uint32_t next_node : next_nodes) { | |||||
| InsertNode(next_node, subgraph); | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| void SearchSubGraph::InitSearchSubGraph() { | |||||
| for (uint32_t out : output_nodes_) { | |||||
| Subgraph subgraph; | |||||
| InsertNode(out, &subgraph); | |||||
| sub_graphs_.push_back(std::move(subgraph)); | |||||
| } | |||||
| return; | |||||
| } | |||||
| void SearchSubGraph::InitSearchTensor() { | |||||
| tensors_.resize(model_->all_tensors_.size()); | |||||
| /* Set Tensor Type */ | |||||
| for (size_t i = 0; i < tensors_.size(); i++) { | |||||
| tensors_[i].type_ = NORMAL; | |||||
| mindspore::schema::Tensor *src_tensor = model_->all_tensors_[i]; | |||||
| auto category = TensorCategory(src_tensor); | |||||
| if (category == mindspore::lite::Tensor::Category::CONST_TENSOR || | |||||
| category == mindspore::lite::Tensor::Category::CONST_SCALAR) { | |||||
| tensors_[i].type_ = CONST; | |||||
| } | |||||
| } | |||||
| std::vector<uint32_t> graph_input = model_->sub_graphs_[0]->input_indices_; | |||||
| for (auto in : graph_input) { | |||||
| tensors_[in].type_ = INPUT; | |||||
| } | |||||
| /* Set Tensor In and out Node */ | |||||
| for (size_t index = 0; index < model_->all_nodes_.size(); index++) { | |||||
| Model::Node *node = model_->all_nodes_[index]; | |||||
| std::vector<uint32_t> input = node->input_indices_; | |||||
| for (uint32_t in : input) { | |||||
| tensors_[in].in_nodes_.push_back(index); | |||||
| } | |||||
| std::vector<uint32_t> output = node->output_indices_; | |||||
| for (uint32_t out : output) { | |||||
| tensors_[out].out_nodes_.push_back(index); | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| void SearchSubGraph::InitSubgraphDevice() { | |||||
| sub_graphs_[0].device_ = kernel::KERNEL_ARCH::kCPU; | |||||
| sub_graphs_[1].device_ = kernel::KERNEL_ARCH::kALL; | |||||
| } | |||||
| void SearchSubGraph::InitMainGraphDevice() { | |||||
| kernel::KERNEL_ARCH main_device = kernel::KERNEL_ARCH::kALL; | |||||
| Model::SubGraph *main_graph = model_->sub_graphs_.front(); | |||||
| for (uint32_t node_index : main_graph->node_indices_) { | |||||
| Model::Node *node = model_->all_nodes_[node_index]; | |||||
| node->device_type_ = main_device; | |||||
| } | |||||
| } | |||||
| void SearchSubGraph::SubgraphFusion() { | |||||
| Subgraph new_npu_sub; | |||||
| Subgraph &npu_sub1 = sub_graphs_[1]; | |||||
| Subgraph &npu_sub2 = sub_graphs_[2]; | |||||
| new_npu_sub.nodes_.insert(new_npu_sub.nodes_.end(), npu_sub1.nodes_.begin(), npu_sub1.nodes_.end()); | |||||
| new_npu_sub.nodes_.insert(new_npu_sub.nodes_.end(), npu_sub2.nodes_.begin(), npu_sub2.nodes_.end()); | |||||
| new_npu_sub.heads_.insert(new_npu_sub.heads_.end(), npu_sub1.heads_.begin(), npu_sub1.heads_.end()); | |||||
| new_npu_sub.heads_.insert(new_npu_sub.heads_.end(), npu_sub2.heads_.begin(), npu_sub2.heads_.end()); | |||||
| new_npu_sub.ends_.insert(new_npu_sub.ends_.end(), npu_sub1.ends_.begin(), npu_sub1.ends_.end()); | |||||
| new_npu_sub.ends_.insert(new_npu_sub.ends_.end(), npu_sub2.ends_.begin(), npu_sub2.ends_.end()); | |||||
| sub_graphs_.erase(sub_graphs_.begin() + 2); | |||||
| sub_graphs_.erase(sub_graphs_.begin() + 1); | |||||
| sub_graphs_.insert(sub_graphs_.end(), std::move(new_npu_sub)); | |||||
| return; | |||||
| } | |||||
| void SearchSubGraph::SubGraphSplitByOutput() { | |||||
| InitSearchTensor(); | |||||
| InitSearchSubGraph(); | |||||
| SubgraphFusion(); | |||||
| InitSubgraphDevice(); | |||||
| ConvertSubGraphToModel(); | |||||
| InitMainGraphDevice(); | |||||
| } | |||||
| #endif | |||||
| } // namespace mindspore::lite | |||||
| @@ -0,0 +1,78 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_SUB_GRAPH_SPLIT_H_ | |||||
| #define MINDSPORE_LITE_SRC_SUB_GRAPH_SPLIT_H_ | |||||
| #include <stack> | |||||
| #include <vector> | |||||
| #include "include/model.h" | |||||
| #include "src/lite_kernel.h" | |||||
| #include "src/lite_model.h" | |||||
| namespace mindspore::lite { | |||||
| #ifdef SUBGRAPH_SPLIT | |||||
| class SearchSubGraph { | |||||
| enum TensorType { NORMAL, CONST, INPUT }; | |||||
| struct Tensor { | |||||
| std::vector<uint32_t> in_nodes_; /* used current tensor as input */ | |||||
| std::vector<uint32_t> out_nodes_; | |||||
| TensorType type_; | |||||
| }; | |||||
| struct Subgraph { | |||||
| std::vector<uint32_t> nodes_; | |||||
| std::vector<uint32_t> heads_; | |||||
| std::vector<uint32_t> ends_; | |||||
| bool search_terminate_ = false; | |||||
| mindspore::kernel::KERNEL_ARCH device_; | |||||
| }; | |||||
| public: | |||||
| SearchSubGraph(Model *model, std::vector<size_t> output_nodes) { | |||||
| output_nodes_.insert(output_nodes_.end(), output_nodes.begin(), output_nodes.end()); | |||||
| node_list_ = model->all_nodes_; | |||||
| model_ = reinterpret_cast<LiteModel *>(model); | |||||
| } | |||||
| ~SearchSubGraph() = default; | |||||
| public: | |||||
| void SubGraphSplitByOutput(); | |||||
| private: | |||||
| void InitSearchTensor(); | |||||
| void InitSearchSubGraph(); | |||||
| void ConvertSubGraphToModel(); | |||||
| void InsertNode(uint32_t index, Subgraph *subgraph); | |||||
| bool IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes); | |||||
| const schema::Primitive *CreatePartialPrimitive(int64_t subgraph_index); | |||||
| void InitSubgraphDevice(); | |||||
| void SubgraphFusion(); | |||||
| void InitMainGraphDevice(); | |||||
| private: | |||||
| LiteModel *model_ = nullptr; | |||||
| std::vector<Tensor> tensors_; | |||||
| std::vector<Subgraph> sub_graphs_; | |||||
| std::vector<size_t> output_nodes_; | |||||
| std::vector<Model::Node *> node_list_; | |||||
| }; | |||||
| #endif | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_SRC_SUB_GRAPH_SPLIT_H_ | |||||
| @@ -352,8 +352,8 @@ void Tensor::DecRefCount() { | |||||
| if (this->IsConst() || this->IsGraphInput()) { | if (this->IsConst() || this->IsGraphInput()) { | ||||
| return; | return; | ||||
| } | } | ||||
| this->ref_count_--; | |||||
| if (this->ref_count_ <= 0) { | |||||
| bool free_data = --ref_count_ <= 0; | |||||
| if (free_data) { | |||||
| FreeData(); | FreeData(); | ||||
| this->ref_count_ = 0; | this->ref_count_ = 0; | ||||
| } | } | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <numeric> | #include <numeric> | ||||
| #include <functional> | #include <functional> | ||||
| #include <atomic> | |||||
| #include "include/ms_tensor.h" | #include "include/ms_tensor.h" | ||||
| #include "src/runtime/allocator.h" | #include "src/runtime/allocator.h" | ||||
| @@ -205,7 +206,7 @@ class Tensor : public mindspore::tensor::MSTensor { | |||||
| std::vector<int> shape_; | std::vector<int> shape_; | ||||
| schema::Format format_; | schema::Format format_; | ||||
| Category category_; | Category category_; | ||||
| size_t ref_count_ = 0; | |||||
| std::atomic_int ref_count_ = 0; | |||||
| size_t init_ref_count_ = 0; | size_t init_ref_count_ = 0; | ||||
| std::vector<QuantArg> quant_params_; | std::vector<QuantArg> quant_params_; | ||||
| std::vector<float> quant_clusters_; | std::vector<float> quant_clusters_; | ||||
| @@ -144,6 +144,7 @@ set(TEST_LITE_SRC | |||||
| ${LITE_DIR}/src/dequant.cc | ${LITE_DIR}/src/dequant.cc | ||||
| ${LITE_DIR}/src/huffman_decode.cc | ${LITE_DIR}/src/huffman_decode.cc | ||||
| ${LITE_DIR}/src/sub_graph_kernel.cc | ${LITE_DIR}/src/sub_graph_kernel.cc | ||||
| ${LITE_DIR}/src/sub_graph_split.cc | |||||
| ${LITE_DIR}/src/lite_model.cc | ${LITE_DIR}/src/lite_model.cc | ||||
| ${LITE_DIR}/src/scheduler.cc | ${LITE_DIR}/src/scheduler.cc | ||||
| ${LITE_DIR}/src/common/graph_util.cc | ${LITE_DIR}/src/common/graph_util.cc | ||||
| @@ -109,6 +109,7 @@ set(LITE_SRC | |||||
| ${SRC_DIR}/lite_kernel.cc | ${SRC_DIR}/lite_kernel.cc | ||||
| ${SRC_DIR}/scheduler.cc | ${SRC_DIR}/scheduler.cc | ||||
| ${SRC_DIR}/sub_graph_kernel.cc | ${SRC_DIR}/sub_graph_kernel.cc | ||||
| ${SRC_DIR}/sub_graph_split.cc | |||||
| ${SRC_DIR}/lite_session.cc | ${SRC_DIR}/lite_session.cc | ||||
| ${SRC_DIR}/executor.cc | ${SRC_DIR}/executor.cc | ||||
| ${SRC_DIR}/lite_model.cc | ${SRC_DIR}/lite_model.cc | ||||