From: @ling_qiao_min Reviewed-by: Signed-off-by:pull/14542/MERGE
| @@ -28,6 +28,7 @@ option(ENABLE_VERBOSE "" off) | |||
| option(ENABLE_SSE "if x86_64 support SSE instruction set" off) | |||
| option(ENABLE_AVX "if x86_64 support SSE instruction set" off) | |||
| option(ENABLE_MINDRT "if support mindrt" on) | |||
| option(SUBGRAPH_SPLIT "if support sub graph split" off) | |||
| set(DIR_PREFIX mindspore-lite) | |||
| set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION}) | |||
| @@ -57,6 +58,9 @@ else() | |||
| set(PROCESS_UNIT cpu) | |||
| endif() | |||
| if(SUBGRAPH_SPLIT) | |||
| add_compile_definitions(SUBGRAPH_SPLIT) | |||
| endif() | |||
| if(SUPPORT_NPU) | |||
| set(DDK_PATH "$ENV{HWHIAI_DDK}/ddk/ai_ddk_lib") | |||
| @@ -132,6 +132,7 @@ set(LITE_SRC | |||
| ${LITE_DIR}/src/common/tensor_util.cc | |||
| ${LITE_DIR}/src/runtime/infer_manager.cc | |||
| ${LITE_DIR}/src/lite_model.cc | |||
| ${LITE_DIR}/src/sub_graph_split.cc | |||
| ${LITE_DIR}/src/tensorlist.cc | |||
| ${LITE_DIR}/src/tensor.cc | |||
| ${LITE_DIR}/src/dequant.cc | |||
| @@ -59,6 +59,7 @@ set(LITE_SRC | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_split.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc | |||
| @@ -43,6 +43,16 @@ void LiteKernel::FreeWorkspace() { | |||
| free(workspace_); | |||
| workspace_ = nullptr; | |||
| } | |||
| int LiteKernel::DecOutTensorRefCount() { | |||
| for (auto *tensor : this->out_tensors_) { | |||
| tensor->set_ref_count(tensor->ref_count() - 1); | |||
| if (0 >= tensor->ref_count()) { | |||
| tensor->FreeData(); | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| #endif | |||
| bool LiteKernel::IsReady(const std::vector<lite::Tensor *> &scope_tensors) { | |||
| return std::all_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *in_tensor) { | |||
| @@ -66,16 +76,6 @@ void LiteKernel::InitOutTensorInitRefCount() { | |||
| } | |||
| } | |||
| int LiteKernel::DecOutTensorRefCount() { | |||
| for (auto *tensor : this->out_tensors_) { | |||
| tensor->set_ref_count(tensor->ref_count() - 1); | |||
| if (0 >= tensor->ref_count()) { | |||
| tensor->FreeData(); | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| int LiteKernel::FreeInWorkTensor() const { | |||
| for (auto &in_tensor : this->in_tensors_) { | |||
| MS_ASSERT(in_tensor != nullptr); | |||
| @@ -35,7 +35,16 @@ static constexpr int kPerTensor = 1; | |||
| static constexpr size_t kPerBatch = 3; | |||
| namespace mindspore::kernel { | |||
| enum KERNEL_ARCH { kCPU, kGPU, kAPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU }; | |||
| enum KERNEL_ARCH { | |||
| kCPU, | |||
| kGPU, | |||
| kAPU, | |||
| kNPU, | |||
| kALL, /* Support GPU NPU CPU */ | |||
| kKernelArch_MIN = kCPU, | |||
| kKernelArch_MAX = kALL | |||
| }; | |||
| struct KernelKey { | |||
| KERNEL_ARCH arch; | |||
| TypeId data_type; | |||
| @@ -161,8 +170,6 @@ class LiteKernel { | |||
| virtual void InitOutTensorInitRefCount(); | |||
| int DecOutTensorRefCount(); | |||
| virtual int FreeInWorkTensor() const; | |||
| KernelKey desc() const { return desc_; } | |||
| @@ -171,6 +178,8 @@ class LiteKernel { | |||
| SubGraphType subgraph_type() const { return this->subgraph_type_; } | |||
| const lite::InnerContext *context() const { return this->context_; } | |||
| virtual std::string ToString() const; | |||
| #ifdef SUPPORT_TRAIN | |||
| @@ -179,6 +188,7 @@ class LiteKernel { | |||
| static void AllocWorkspace(size_t size); | |||
| static void FreeWorkspace(); | |||
| void *workspace() { return workspace_; } | |||
| int DecOutTensorRefCount(); | |||
| #endif | |||
| protected: | |||
| @@ -32,7 +32,7 @@ int LiteOpActor::CompileArrow() { | |||
| } | |||
| } | |||
| if (to_input_index == -1) { | |||
| break; | |||
| continue; | |||
| } | |||
| auto id = out->name() + this->GetAID().Url(); | |||
| auto arrow = std::make_shared<OpArrow>(i, id, to_input_index); | |||
| @@ -41,12 +41,19 @@ int LiteOpActor::CompileArrow() { | |||
| return RET_ERROR; | |||
| } | |||
| output_op_arrows_.emplace_back(std::move(arrow)); | |||
| break; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void LiteOpActor::AsyncOutput(OpContext<Tensor> *context) { | |||
| for (auto op_arrow : output_op_arrows_) { | |||
| auto data = context->outputData_->at(op_arrow->from_output_index_); | |||
| Async(op_arrow->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data, context); | |||
| } | |||
| return; | |||
| } | |||
| void LiteOpActor::SetOutputData(OpContext<Tensor> *context) { | |||
| auto size = context->outputData_->size(); | |||
| MS_ASSERT(size == context->results_->size()); | |||
| @@ -50,6 +50,7 @@ class LiteOpActor : public OpActor<lite::Tensor> { | |||
| return; | |||
| } | |||
| input_op_datas_.erase(op_uuid); | |||
| AsyncOutput(context); | |||
| SetOutputData(context); | |||
| } | |||
| void Init() { | |||
| @@ -83,6 +84,7 @@ class LiteOpActor : public OpActor<lite::Tensor> { | |||
| private: | |||
| void SetOutputData(OpContext<Tensor> *context); | |||
| void AsyncOutput(OpContext<Tensor> *context); | |||
| kernel::LiteKernel *kernel_; | |||
| }; | |||
| @@ -108,6 +108,12 @@ void LiteModel::Free() { | |||
| tensor_buf = nullptr; | |||
| } | |||
| attr_tensor_bufs_.resize(0); | |||
| for (auto &node_buf : node_bufs_) { | |||
| free(node_buf); | |||
| node_buf = nullptr; | |||
| } | |||
| node_bufs_.resize(0); | |||
| } | |||
| void LiteModel::Destroy() { | |||
| @@ -192,6 +192,7 @@ class LiteModel : public Model { | |||
| public: | |||
| size_t buf_size_ = 0; | |||
| std::vector<char *> node_bufs_; | |||
| protected: | |||
| std::vector<char *> attr_tensor_bufs_; | |||
| @@ -399,8 +399,15 @@ int LiteSession::CompileGraph(Model *model) { | |||
| #endif | |||
| InitGraphInOutTensors(model); | |||
| ret = PrepareKernels(model); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare kernels failed: " << ret; | |||
| is_running_.store(false); | |||
| return ret; | |||
| } | |||
| #ifdef ENABLE_MINDRT | |||
| if (context_->IsCpuEnabled() && !context_->IsGpuEnabled() && !context_->IsNpuEnabled() && kernels_.size() == 1) { | |||
| if (kernels_.size() == 1) { | |||
| executor_ = new (std::nothrow) MindrtExecutor(); | |||
| } else { | |||
| executor_ = new (std::nothrow) Executor(); | |||
| @@ -420,16 +427,10 @@ int LiteSession::CompileGraph(Model *model) { | |||
| is_running_.store(false); | |||
| return ret; | |||
| } | |||
| ret = PrepareKernels(model); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare kernels failed: " << ret; | |||
| is_running_.store(false); | |||
| return ret; | |||
| } | |||
| is_running_.store(false); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| int LiteSession::PrepareKernels(Model *model) { | |||
| std::vector<kernel::LiteKernel *> all_kernels; | |||
| @@ -102,10 +102,7 @@ int NPUExecutor::Run(const std::vector<Tensor *> &in_tensors, const std::vector< | |||
| memcpy(npu_input_tensors_[i]->GetBuffer(), data, in_tensors[index]->Size()); | |||
| inputs_visited[index] = true; | |||
| in_tensors[index]->set_ref_count(in_tensors[index]->ref_count() - 1); | |||
| if (in_tensors[index]->ref_count() <= 0) { | |||
| in_tensors[index]->FreeData(); | |||
| } | |||
| in_tensors[index]->DecRefCount(); | |||
| break; | |||
| } | |||
| } | |||
| @@ -38,6 +38,7 @@ class SubGraphNpuKernel : public SubGraphKernel { | |||
| const lite::InnerContext *ctx = nullptr, lite::NPUManager *npu_manager = nullptr) | |||
| : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx), npu_manager_(npu_manager) { | |||
| subgraph_type_ = kNpuSubGraph; | |||
| desc_.arch = kernel::KERNEL_ARCH::kNPU; | |||
| } | |||
| ~SubGraphNpuKernel() override; | |||
| @@ -70,7 +70,7 @@ class DefaultAllocator : public Allocator { | |||
| std::multimap<size_t, MemBuf *> freeList_; | |||
| // 6 is empirical value | |||
| int shiftFactor_ = 6; | |||
| bool lockFlag_ = false; | |||
| bool lockFlag_ = true; | |||
| }; | |||
| constexpr int64_t MAX_MALLOC_SIZE = static_cast<size_t>(2000) * 1024 * 1024; | |||
| @@ -34,6 +34,7 @@ class OpenCLSubGraph : public SubGraphKernel { | |||
| : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx) { | |||
| ocl_runtime_ = ocl_runtime_wrap_.GetInstance(); | |||
| subgraph_type_ = kGpuSubGraph; | |||
| desc_.arch = kernel::KERNEL_ARCH::kGPU; | |||
| this->name_ = "GpuSubGraph"; | |||
| nodes_set_.insert(nodes.begin(), nodes.end()); | |||
| all_kernels_infer_done_ = std::all_of(nodes_.begin(), nodes_.end(), [](const kernel::LiteKernel *kernel) { | |||
| @@ -30,6 +30,7 @@ | |||
| #include "src/common/version_manager.h" | |||
| #include "src/common/prim_util.h" | |||
| #include "src/runtime/infer_manager.h" | |||
| #include "src/sub_graph_split.h" | |||
| #include "src/dequant.h" | |||
| #include "nnacl/matmul_parameter.h" | |||
| #if GPU_OPENCL | |||
| @@ -71,6 +72,12 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) { | |||
| } | |||
| this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_); | |||
| #ifdef SUBGRAPH_SPLIT | |||
| auto search_sub_graph = SearchSubGraph(src_model_, this->graph_output_node_indexes_); | |||
| search_sub_graph.SubGraphSplitByOutput(); | |||
| #endif | |||
| bool infer_shape_interrupt = false; | |||
| auto ret = InferSubGraphShape(kMainSubGraphIndex, &infer_shape_interrupt); | |||
| if (ret != RET_OK) { | |||
| @@ -89,7 +96,11 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) { | |||
| MS_LOG(ERROR) << "Schedule run pass failed."; | |||
| return ret; | |||
| } | |||
| ret = ConstructSubGraphs(dst_kernels); | |||
| auto src_kernel = *dst_kernels; | |||
| dst_kernels->clear(); | |||
| std::map<const kernel::LiteKernel *, bool> is_kernel_finish; | |||
| ret = ConstructSubGraphs(src_kernel, dst_kernels, &is_kernel_finish); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConstructSubGraphs failed."; | |||
| return ret; | |||
| @@ -473,6 +484,14 @@ kernel::LiteKernel *Scheduler::SchedulePartialToKernel(const lite::Model::Node * | |||
| MS_LOG(ERROR) << "Schedule partial failed, name: " << src_node->name_; | |||
| return nullptr; | |||
| } | |||
| FindAllInoutKernels(sub_kernels); | |||
| ret = RunPass(&sub_kernels); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SchedulePartialToKernel run pass failed."; | |||
| return nullptr; | |||
| } | |||
| auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(sub_kernels.front()); | |||
| auto subgraph = CreateSubGraphKernel(sub_kernels, &in_tensors, &out_tensors, cur_sub_graph_type); | |||
| subgraph->set_name("subgraph_" + src_node->name_); | |||
| @@ -602,35 +621,33 @@ std::vector<kernel::LiteKernel *> Scheduler::FindAllSubGraphKernels( | |||
| return sub_kernels; | |||
| } | |||
| int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) { | |||
| auto old_kernels = *kernels; | |||
| kernels->clear(); | |||
| std::map<const kernel::LiteKernel *, bool> is_kernel_finish; | |||
| for (auto kernel : old_kernels) { | |||
| is_kernel_finish[kernel] = false; | |||
| int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> src_kernel, | |||
| std::vector<kernel::LiteKernel *> *dst_kernel, | |||
| std::map<const kernel::LiteKernel *, bool> *is_kernel_finish) { | |||
| for (auto kernel : src_kernel) { | |||
| (*is_kernel_finish)[kernel] = false; | |||
| } | |||
| while (true) { | |||
| auto head_kernel_iter = std::find_if(old_kernels.begin(), old_kernels.end(), [&](const kernel::LiteKernel *kernel) { | |||
| auto head_kernel_iter = std::find_if(src_kernel.begin(), src_kernel.end(), [&](const kernel::LiteKernel *kernel) { | |||
| auto kernel_inputs = kernel->in_kernels(); | |||
| if (is_kernel_finish[kernel]) { | |||
| if ((*is_kernel_finish)[kernel]) { | |||
| return false; | |||
| } | |||
| // when merge is removed, this if is removed automatically | |||
| if (kernel->Type() == schema::PrimitiveType_Merge) { | |||
| return MergeOpIsReady(kernel, is_kernel_finish); | |||
| return MergeOpIsReady(kernel, (*is_kernel_finish)); | |||
| } else { | |||
| return std::all_of(kernel_inputs.begin(), kernel_inputs.end(), | |||
| [&](kernel::LiteKernel *kernel) { return is_kernel_finish[kernel]; }); | |||
| [&](kernel::LiteKernel *kernel) { return (*is_kernel_finish)[kernel]; }); | |||
| } | |||
| }); | |||
| if (head_kernel_iter == old_kernels.end()) { | |||
| if (head_kernel_iter == src_kernel.end()) { | |||
| break; | |||
| } | |||
| auto head_kernel = *head_kernel_iter; | |||
| if (head_kernel->subgraph_type() != kernel::kNotSubGraph) { | |||
| is_kernel_finish[head_kernel] = true; | |||
| kernels->emplace_back(head_kernel); | |||
| (*is_kernel_finish)[head_kernel] = true; | |||
| dst_kernel->push_back(head_kernel); | |||
| continue; | |||
| } | |||
| if (head_kernel->desc().arch == mindspore::kernel::kAPU) { | |||
| @@ -638,15 +655,15 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) { | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel); | |||
| auto sub_kernels = FindAllSubGraphKernels(head_kernel, &is_kernel_finish); | |||
| auto sub_kernels = FindAllSubGraphKernels(head_kernel, is_kernel_finish); | |||
| auto subgraph = CreateSubGraphKernel(sub_kernels, nullptr, nullptr, cur_sub_graph_type); | |||
| if (subgraph == nullptr) { | |||
| MS_LOG(ERROR) << "Create SubGraphKernel failed"; | |||
| return RET_ERROR; | |||
| } | |||
| kernels->emplace_back(subgraph); | |||
| dst_kernel->emplace_back(subgraph); | |||
| } | |||
| for (auto *subgraph : *kernels) { | |||
| for (auto *subgraph : *dst_kernel) { | |||
| auto ret = subgraph->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init SubGraph failed: " << ret; | |||
| @@ -840,6 +857,7 @@ int Scheduler::RunPass(std::vector<kernel::LiteKernel *> *dst_kernels) { | |||
| npu_pass_manager_->AddPass(fusion_pass); | |||
| ret = npu_pass_manager_->Run(); | |||
| npu_pass_manager_->Clear(); | |||
| #endif | |||
| return ret; | |||
| } | |||
| @@ -74,7 +74,8 @@ class Scheduler { | |||
| static void FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels); | |||
| // vector<LiteKernel/SubGraphKernel> --> vector<SubGraphKernel> | |||
| int ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels); | |||
| int ConstructSubGraphs(std::vector<kernel::LiteKernel *> src_kernel, std::vector<kernel::LiteKernel *> *dst_kernel, | |||
| std::map<const kernel::LiteKernel *, bool> *sinked_kernel_map); | |||
| // create subgraph_kernel from a vector of kernel | |||
| kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels, | |||
| @@ -128,6 +128,7 @@ class CpuSubGraph : public SubGraphKernel { | |||
| std::vector<LiteKernel *> nodes, const lite::InnerContext *ctx) | |||
| : SubGraphKernel(inputs, outputs, std::move(in_kernels), std::move(out_kernels), std::move(nodes), ctx) { | |||
| subgraph_type_ = kCpuFP32SubGraph; | |||
| desc_.arch = kernel::KERNEL_ARCH::kCPU; | |||
| } | |||
| ~CpuSubGraph() override { delete this->executor_; } | |||
| @@ -0,0 +1,269 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/sub_graph_split.h" | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "src/tensor.h" | |||
| #include "schema/inner/ops_generated.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore::lite { | |||
| #ifdef SUBGRAPH_SPLIT | |||
| const schema::Primitive *SearchSubGraph::CreatePartialPrimitive(int64_t subgraph_index) { | |||
| flatbuffers::FlatBufferBuilder fbb(1024); | |||
| auto val_offset = schema::CreatePartialFusion(fbb, subgraph_index); | |||
| auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_PartialFusion, val_offset.o); | |||
| fbb.Finish(prim_offset); | |||
| auto tmp_buf = fbb.GetBufferPointer(); | |||
| auto prim_buf = reinterpret_cast<char *>(malloc(fbb.GetSize())); | |||
| memcpy(prim_buf, tmp_buf, fbb.GetSize()); | |||
| auto primitive = flatbuffers::GetRoot<schema::Primitive>(prim_buf); | |||
| fbb.Clear(); | |||
| model_->node_bufs_.push_back(prim_buf); | |||
| return std::move(primitive); | |||
| } | |||
| void SearchSubGraph::ConvertSubGraphToModel() { | |||
| Model::SubGraph *main_graphs = model_->sub_graphs_.front(); | |||
| for (Subgraph &subgraph : sub_graphs_) { | |||
| if (subgraph.nodes_.empty()) { | |||
| continue; | |||
| } | |||
| mindspore::kernel::KERNEL_ARCH device = subgraph.device_; | |||
| int new_sub_index = model_->sub_graphs_.size(); | |||
| int partial_index = model_->all_nodes_.size(); | |||
| Model::SubGraph *new_sub_graph = new (std::nothrow) Model::SubGraph(); | |||
| if (new_sub_graph == nullptr) { | |||
| MS_LOG(ERROR) << "New sub graph failed!"; | |||
| return; | |||
| } | |||
| new_sub_graph->name_ = "Subgraph-split-" + std::to_string(new_sub_index); | |||
| Model::Node *new_partial_node = new (std::nothrow) Model::Node(); | |||
| if (new_partial_node == nullptr) { | |||
| MS_LOG(ERROR) << "New partial node failed!"; | |||
| return; | |||
| } | |||
| new_partial_node->name_ = "Partial-subgraph-split-" + std::to_string(new_sub_index); | |||
| new_partial_node->node_type_ = mindspore::lite::NodeType_ValueNode; | |||
| new_partial_node->primitive_ = CreatePartialPrimitive(new_sub_index); | |||
| while (!subgraph.nodes_.empty()) { | |||
| uint32_t node_index = subgraph.nodes_.front(); | |||
| new_sub_graph->node_indices_.push_back(node_index); | |||
| VectorErase(&main_graphs->node_indices_, node_index); | |||
| VectorErase(&subgraph.nodes_, node_index); | |||
| model_->all_nodes_[node_index]->device_type_ = device; | |||
| } | |||
| for (uint32_t head_index : subgraph.heads_) { | |||
| Model::Node *head_node = model_->all_nodes_[head_index]; | |||
| std::vector<uint32_t> inputs = head_node->input_indices_; | |||
| for (auto input : inputs) { | |||
| if (tensors_[input].type_ == CONST) { | |||
| continue; | |||
| } | |||
| if (std::find(new_sub_graph->input_indices_.begin(), new_sub_graph->input_indices_.end(), input) != | |||
| new_sub_graph->input_indices_.end()) { | |||
| continue; | |||
| } | |||
| new_sub_graph->input_indices_.insert(new_sub_graph->input_indices_.end(), input); | |||
| new_partial_node->input_indices_.insert(new_partial_node->input_indices_.end(), input); | |||
| } | |||
| } | |||
| for (uint32_t end_index : subgraph.ends_) { | |||
| Model::Node *end_node = model_->all_nodes_[end_index]; | |||
| std::vector<uint32_t> outputs = end_node->output_indices_; | |||
| new_sub_graph->output_indices_.insert(new_sub_graph->output_indices_.end(), outputs.begin(), outputs.end()); | |||
| new_partial_node->output_indices_.insert(new_partial_node->output_indices_.end(), outputs.begin(), outputs.end()); | |||
| } | |||
| main_graphs->node_indices_.push_back(partial_index); | |||
| model_->all_nodes_.push_back(std::move(new_partial_node)); | |||
| model_->sub_graphs_.push_back(std::move(new_sub_graph)); | |||
| } | |||
| return; | |||
| } | |||
| bool SearchSubGraph::IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes) { | |||
| std::vector<uint32_t> output_indexes = node_list_[node_index]->output_indices_; | |||
| std::vector<uint32_t> output_nodes; | |||
| for (uint32_t out_t : output_indexes) { | |||
| std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_; | |||
| output_nodes.insert(output_nodes.end(), cur_nodes.begin(), cur_nodes.end()); | |||
| } | |||
| for (uint32_t out_n : output_nodes) { | |||
| if (find(ready_nodes.begin(), ready_nodes.end(), out_n) == ready_nodes.end()) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| void SearchSubGraph::InsertNode(uint32_t index, Subgraph *subgraph) { | |||
| if (subgraph->search_terminate_) { | |||
| return; | |||
| } | |||
| Model::Node *node = node_list_[index]; | |||
| if (node == nullptr) { | |||
| return; | |||
| } | |||
| std::vector<uint32_t> input = node->input_indices_; | |||
| /* remove const node */ | |||
| for (int i = input.size() - 1; i >= 0; i--) { | |||
| if (tensors_[input[i]].type_ == CONST) { | |||
| input.erase(input.begin() + i); | |||
| } | |||
| } | |||
| /* all node_input is graph_input */ | |||
| for (size_t i = 0; i < input.size(); i++) { | |||
| if (tensors_[input[i]].type_ != INPUT) { | |||
| break; | |||
| } | |||
| subgraph->heads_.clear(); | |||
| subgraph->ends_.clear(); | |||
| subgraph->nodes_.clear(); | |||
| subgraph->search_terminate_ = true; | |||
| return; | |||
| } | |||
| /* split in graph */ | |||
| if (IsNodeSubGraphHead(index, subgraph->nodes_)) { | |||
| if (subgraph->nodes_.empty()) { | |||
| subgraph->search_terminate_ = true; | |||
| return; | |||
| } | |||
| subgraph->heads_.push_back(subgraph->nodes_.front()); | |||
| return; | |||
| } | |||
| if (find(output_nodes_.begin(), output_nodes_.end(), index) != output_nodes_.end()) { | |||
| subgraph->ends_.push_back(index); | |||
| } | |||
| /* node insert in current subgraph */ | |||
| subgraph->nodes_.insert(subgraph->nodes_.begin(), index); | |||
| node_list_[index] = nullptr; | |||
| /* search for next node */ | |||
| for (uint32_t in : input) { | |||
| auto next_nodes = tensors_[in].out_nodes_; | |||
| for (uint32_t next_node : next_nodes) { | |||
| InsertNode(next_node, subgraph); | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void SearchSubGraph::InitSearchSubGraph() { | |||
| for (uint32_t out : output_nodes_) { | |||
| Subgraph subgraph; | |||
| InsertNode(out, &subgraph); | |||
| sub_graphs_.push_back(std::move(subgraph)); | |||
| } | |||
| return; | |||
| } | |||
| void SearchSubGraph::InitSearchTensor() { | |||
| tensors_.resize(model_->all_tensors_.size()); | |||
| /* Set Tensor Type */ | |||
| for (size_t i = 0; i < tensors_.size(); i++) { | |||
| tensors_[i].type_ = NORMAL; | |||
| mindspore::schema::Tensor *src_tensor = model_->all_tensors_[i]; | |||
| auto category = TensorCategory(src_tensor); | |||
| if (category == mindspore::lite::Tensor::Category::CONST_TENSOR || | |||
| category == mindspore::lite::Tensor::Category::CONST_SCALAR) { | |||
| tensors_[i].type_ = CONST; | |||
| } | |||
| } | |||
| std::vector<uint32_t> graph_input = model_->sub_graphs_[0]->input_indices_; | |||
| for (auto in : graph_input) { | |||
| tensors_[in].type_ = INPUT; | |||
| } | |||
| /* Set Tensor In and out Node */ | |||
| for (size_t index = 0; index < model_->all_nodes_.size(); index++) { | |||
| Model::Node *node = model_->all_nodes_[index]; | |||
| std::vector<uint32_t> input = node->input_indices_; | |||
| for (uint32_t in : input) { | |||
| tensors_[in].in_nodes_.push_back(index); | |||
| } | |||
| std::vector<uint32_t> output = node->output_indices_; | |||
| for (uint32_t out : output) { | |||
| tensors_[out].out_nodes_.push_back(index); | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void SearchSubGraph::InitSubgraphDevice() { | |||
| sub_graphs_[0].device_ = kernel::KERNEL_ARCH::kCPU; | |||
| sub_graphs_[1].device_ = kernel::KERNEL_ARCH::kALL; | |||
| } | |||
| void SearchSubGraph::InitMainGraphDevice() { | |||
| kernel::KERNEL_ARCH main_device = kernel::KERNEL_ARCH::kALL; | |||
| Model::SubGraph *main_graph = model_->sub_graphs_.front(); | |||
| for (uint32_t node_index : main_graph->node_indices_) { | |||
| Model::Node *node = model_->all_nodes_[node_index]; | |||
| node->device_type_ = main_device; | |||
| } | |||
| } | |||
| void SearchSubGraph::SubgraphFusion() { | |||
| Subgraph new_npu_sub; | |||
| Subgraph &npu_sub1 = sub_graphs_[1]; | |||
| Subgraph &npu_sub2 = sub_graphs_[2]; | |||
| new_npu_sub.nodes_.insert(new_npu_sub.nodes_.end(), npu_sub1.nodes_.begin(), npu_sub1.nodes_.end()); | |||
| new_npu_sub.nodes_.insert(new_npu_sub.nodes_.end(), npu_sub2.nodes_.begin(), npu_sub2.nodes_.end()); | |||
| new_npu_sub.heads_.insert(new_npu_sub.heads_.end(), npu_sub1.heads_.begin(), npu_sub1.heads_.end()); | |||
| new_npu_sub.heads_.insert(new_npu_sub.heads_.end(), npu_sub2.heads_.begin(), npu_sub2.heads_.end()); | |||
| new_npu_sub.ends_.insert(new_npu_sub.ends_.end(), npu_sub1.ends_.begin(), npu_sub1.ends_.end()); | |||
| new_npu_sub.ends_.insert(new_npu_sub.ends_.end(), npu_sub2.ends_.begin(), npu_sub2.ends_.end()); | |||
| sub_graphs_.erase(sub_graphs_.begin() + 2); | |||
| sub_graphs_.erase(sub_graphs_.begin() + 1); | |||
| sub_graphs_.insert(sub_graphs_.end(), std::move(new_npu_sub)); | |||
| return; | |||
| } | |||
| void SearchSubGraph::SubGraphSplitByOutput() { | |||
| InitSearchTensor(); | |||
| InitSearchSubGraph(); | |||
| SubgraphFusion(); | |||
| InitSubgraphDevice(); | |||
| ConvertSubGraphToModel(); | |||
| InitMainGraphDevice(); | |||
| } | |||
| #endif | |||
| } // namespace mindspore::lite | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_SUB_GRAPH_SPLIT_H_ | |||
| #define MINDSPORE_LITE_SRC_SUB_GRAPH_SPLIT_H_ | |||
| #include <stack> | |||
| #include <vector> | |||
| #include "include/model.h" | |||
| #include "src/lite_kernel.h" | |||
| #include "src/lite_model.h" | |||
| namespace mindspore::lite { | |||
| #ifdef SUBGRAPH_SPLIT | |||
| class SearchSubGraph { | |||
| enum TensorType { NORMAL, CONST, INPUT }; | |||
| struct Tensor { | |||
| std::vector<uint32_t> in_nodes_; /* used current tensor as input */ | |||
| std::vector<uint32_t> out_nodes_; | |||
| TensorType type_; | |||
| }; | |||
| struct Subgraph { | |||
| std::vector<uint32_t> nodes_; | |||
| std::vector<uint32_t> heads_; | |||
| std::vector<uint32_t> ends_; | |||
| bool search_terminate_ = false; | |||
| mindspore::kernel::KERNEL_ARCH device_; | |||
| }; | |||
| public: | |||
| SearchSubGraph(Model *model, std::vector<size_t> output_nodes) { | |||
| output_nodes_.insert(output_nodes_.end(), output_nodes.begin(), output_nodes.end()); | |||
| node_list_ = model->all_nodes_; | |||
| model_ = reinterpret_cast<LiteModel *>(model); | |||
| } | |||
| ~SearchSubGraph() = default; | |||
| public: | |||
| void SubGraphSplitByOutput(); | |||
| private: | |||
| void InitSearchTensor(); | |||
| void InitSearchSubGraph(); | |||
| void ConvertSubGraphToModel(); | |||
| void InsertNode(uint32_t index, Subgraph *subgraph); | |||
| bool IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes); | |||
| const schema::Primitive *CreatePartialPrimitive(int64_t subgraph_index); | |||
| void InitSubgraphDevice(); | |||
| void SubgraphFusion(); | |||
| void InitMainGraphDevice(); | |||
| private: | |||
| LiteModel *model_ = nullptr; | |||
| std::vector<Tensor> tensors_; | |||
| std::vector<Subgraph> sub_graphs_; | |||
| std::vector<size_t> output_nodes_; | |||
| std::vector<Model::Node *> node_list_; | |||
| }; | |||
| #endif | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_SRC_SUB_GRAPH_SPLIT_H_ | |||
| @@ -352,8 +352,8 @@ void Tensor::DecRefCount() { | |||
| if (this->IsConst() || this->IsGraphInput()) { | |||
| return; | |||
| } | |||
| this->ref_count_--; | |||
| if (this->ref_count_ <= 0) { | |||
| bool free_data = --ref_count_ <= 0; | |||
| if (free_data) { | |||
| FreeData(); | |||
| this->ref_count_ = 0; | |||
| } | |||
| @@ -22,6 +22,7 @@ | |||
| #include <string> | |||
| #include <numeric> | |||
| #include <functional> | |||
| #include <atomic> | |||
| #include "include/ms_tensor.h" | |||
| #include "src/runtime/allocator.h" | |||
| @@ -205,7 +206,7 @@ class Tensor : public mindspore::tensor::MSTensor { | |||
| std::vector<int> shape_; | |||
| schema::Format format_; | |||
| Category category_; | |||
| size_t ref_count_ = 0; | |||
| std::atomic_int ref_count_ = 0; | |||
| size_t init_ref_count_ = 0; | |||
| std::vector<QuantArg> quant_params_; | |||
| std::vector<float> quant_clusters_; | |||
| @@ -144,6 +144,7 @@ set(TEST_LITE_SRC | |||
| ${LITE_DIR}/src/dequant.cc | |||
| ${LITE_DIR}/src/huffman_decode.cc | |||
| ${LITE_DIR}/src/sub_graph_kernel.cc | |||
| ${LITE_DIR}/src/sub_graph_split.cc | |||
| ${LITE_DIR}/src/lite_model.cc | |||
| ${LITE_DIR}/src/scheduler.cc | |||
| ${LITE_DIR}/src/common/graph_util.cc | |||
| @@ -109,6 +109,7 @@ set(LITE_SRC | |||
| ${SRC_DIR}/lite_kernel.cc | |||
| ${SRC_DIR}/scheduler.cc | |||
| ${SRC_DIR}/sub_graph_kernel.cc | |||
| ${SRC_DIR}/sub_graph_split.cc | |||
| ${SRC_DIR}/lite_session.cc | |||
| ${SRC_DIR}/executor.cc | |||
| ${SRC_DIR}/lite_model.cc | |||