diff --git a/mindspore/lite/src/executor.cc b/mindspore/lite/src/executor.cc index 0ea9b0d442..a1de30c098 100644 --- a/mindspore/lite/src/executor.cc +++ b/mindspore/lite/src/executor.cc @@ -51,7 +51,7 @@ int Executor::Run(std::vector &in_tensors, std::vector &out_ } std::queue kernel_queue; for (auto kernel : kernels) { - if (kernel->IsReady()) { + if (kernel->IsReady(kernel->in_tensors())) { kernel_queue.push(kernel); } } @@ -75,7 +75,7 @@ int Executor::Run(std::vector &in_tensors, std::vector &out_ return ret; } for (auto &out_kernel : cur_kernel->out_kernels()) { - if (out_kernel->IsReady()) { + if (out_kernel->IsReady(out_kernel->in_tensors())) { kernel_queue.push(out_kernel); } } diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc index f553f71a39..b3104fc505 100644 --- a/mindspore/lite/src/lite_kernel.cc +++ b/mindspore/lite/src/lite_kernel.cc @@ -41,9 +41,13 @@ void LiteKernel::FreeWorkspace() { workspace_ = nullptr; } -bool LiteKernel::IsReady() { +bool LiteKernel::IsReady(const std::vector &scope_tensors) { return std::all_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) { - return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1; + if (IsContain(scope_tensors, kernel_in_tensor)) { + return (kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1); + } else { + return true; + } }); } @@ -200,12 +204,16 @@ void LiteKernel::FindInoutKernels(const std::vector &scope } for (auto *tensor : this->in_tensors_) { if (lite::IsContain(scope_kernel->out_tensors(), tensor)) { - this->AddInKernel(scope_kernel); + if (!lite::IsContain(this->in_kernels(), scope_kernel)) { + this->AddInKernel(scope_kernel); + } } } for (auto *tensor : this->out_tensors_) { if (lite::IsContain(scope_kernel->in_tensors(), tensor)) { - this->AddOutKernel(scope_kernel); + if (!lite::IsContain(this->out_kernels(), scope_kernel)) { + this->AddOutKernel(scope_kernel); + } } } } diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 46a1ce791c..d1466c4d27 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -156,7 +156,7 @@ class LiteKernel { const std::vector &out_kernels() const { return this->out_kernels_; } - virtual bool IsReady(); + virtual bool IsReady(const std::vector &scope_tensors); virtual void InitOutTensorInitRefCount(); diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index e5d419bdcd..0ed533d902 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -318,7 +318,6 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) { InitGraphOutputNodeMap(model); InitGraphOutputTensorNames(model); InitGraphOutputTensorMap(model); - AdjustModelOutputTensorInitRefCount(model); } int LiteSession::CompileGraph(Model *model) { @@ -373,7 +372,7 @@ int LiteSession::CompileGraph(Model *model) { is_running_.store(false); return ret; } - ret = PrepareKernels(); + ret = PrepareKernels(model); if (ret != RET_OK) { MS_LOG(ERROR) << "Prepare kernels failed: " << ret; is_running_.store(false); @@ -383,14 +382,30 @@ int LiteSession::CompileGraph(Model *model) { return RET_OK; } -int LiteSession::PrepareKernels() { +int LiteSession::PrepareKernels(Model *model) { + std::vector all_kernels; + // find in_kernels and out_kernels for subgraphs for (auto kernel : this->kernels_) { + kernel->FindInoutKernels(this->kernels_); auto ret = kernel->Prepare(); if (ret != RET_OK) { MS_LOG(ERROR) << "Prepare kernel " << kernel->name() << " failed: " << ret; return ret; } + auto sub_graph = reinterpret_cast(kernel); + MS_ASSERT(sub_graph != nullptr); + auto kernel_in_subgraph = sub_graph->nodes(); + all_kernels.insert(all_kernels.end(), kernel_in_subgraph.begin(), kernel_in_subgraph.end()); + } + // find in_kernels and out_kernels for kernels + for (auto *kernel : all_kernels) { + kernel->FindInoutKernels(all_kernels); } + // init init_ref_count for subgraphs and kernels + for (auto *kernel : this->kernels_) { + kernel->InitOutTensorInitRefCount(); + } + AdjustModelOutputTensorInitRefCount(model); return RET_OK; } diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h index d0d3ec1bfd..002c3588eb 100644 --- a/mindspore/lite/src/lite_session.h +++ b/mindspore/lite/src/lite_session.h @@ -96,7 +96,7 @@ class LiteSession : public session::LiteSession { int ResizeInputs(const std::vector &inputs, const std::vector> &dims); - int PrepareKernels(); + int PrepareKernels(Model *model); static int ReSizeKernels(const std::vector &kernels); diff --git a/mindspore/lite/src/ops/merge.cc b/mindspore/lite/src/ops/merge.cc index e9895164cb..1c7e7bc91d 100644 --- a/mindspore/lite/src/ops/merge.cc +++ b/mindspore/lite/src/ops/merge.cc @@ -67,10 +67,10 @@ Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator); #endif int Merge::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(outputs_.size() == 1); - MS_ASSERT(inputs_.size() == 2); - outputs_[0]->set_data_type(inputs_[0]->data_type()); - + MS_ASSERT(inputs_.size() == 2 * outputs_.size()); + for (size_t i = 0; i < inputs_.size() / 2; i++) { + outputs_[i]->set_data_type(inputs_[i]->data_type()); + } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/merge.cc b/mindspore/lite/src/runtime/kernel/arm/base/merge.cc index 153428b941..ae3a25477d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/merge.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/merge.cc @@ -24,12 +24,38 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Merge; namespace mindspore::kernel { + +int MergeCPUKernel::FreeInWorkTensor() const { + for (auto &in_tensor : this->in_tensors_) { + MS_ASSERT(in_tensor != nullptr); + if (in_tensor->IsConst()) { + continue; + } + if (in_tensor->ref_count() > 0) { + in_tensor->set_ref_count(in_tensor->ref_count() - 1); + if (in_tensor->ref_count() <= 0) { + auto ret = in_tensor->FreeData(); + if (0 != ret) { + MS_LOG(ERROR) << "Free tensor data failed"; + return ret; + } + } + } + } + return RET_OK; +} + // if one of input of merge is const-tensor, merge is always ready, this will cause error. -bool MergeCPUKernel::IsReady() { - MS_ASSERT(in_tensors().size() == 2); - return std::any_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) { - return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1; - }); +bool MergeCPUKernel::IsReady(const std::vector &scope_tensors) { + MS_ASSERT(in_tensors().size() == 2 * out_tensors().size()); + return std::all_of(this->in_tensors().begin(), this->in_tensors().begin() + in_tensors().size() / 2, + [&](lite::Tensor *kernel_in_tensor) { + return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1; + }) || + std::all_of(this->in_tensors().begin() + in_tensors().size() / 2, this->in_tensors().end(), + [&](lite::Tensor *kernel_in_tensor) { + return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1; + }); } int MergeCPUKernel::Init() { return RET_OK; } @@ -37,14 +63,24 @@ int MergeCPUKernel::Init() { return RET_OK; } int MergeCPUKernel::ReSize() { return RET_ERROR; } int MergeCPUKernel::Run() { - MS_ASSERT(in_tensors_.size() == 2); - MS_ASSERT(out_tensors_.size() == 1); - auto out_data = out_tensors_.front()->data_c(); - MS_ASSERT(out_data != nullptr); - for (size_t i = 0; i < in_tensors().size(); i++) { - if (in_tensors()[i]->data_c() != nullptr) { + MS_ASSERT(in_tensors_.size() == 2 * out_tensors_.size()); + int in_tesnor_part_one = 0; + int in_tensor_part_two = out_tensors().size(); + if (in_tensors_[in_tesnor_part_one]->data_c() != nullptr) { + for (size_t i = 0; i < out_tensors().size(); i++) { + auto out_data = out_tensors_[i]->data_c(); auto in_data = in_tensors_[i]->data_c(); MS_ASSERT(in_data != nullptr); + MS_ASSERT(out_data != nullptr); + memcpy(out_data, in_data, in_tensors_[i]->Size()); + } + } + if (in_tensors_[in_tensor_part_two]->data_c() != nullptr) { + for (size_t i = 0; i < out_tensors().size(); i++) { + auto out_data = out_tensors_[i]->data_c(); + auto in_data = in_tensors_[i + in_tensor_part_two]->data_c(); + MS_ASSERT(in_data != nullptr); + MS_ASSERT(out_data != nullptr); memcpy(out_data, in_data, in_tensors_[i]->Size()); } } @@ -81,4 +117,5 @@ kernel::LiteKernel *CpuMergeKernelCreator(const std::vector &inp REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Merge, CpuMergeKernelCreator) REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Merge, CpuMergeKernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Merge, CpuMergeKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/merge.h b/mindspore/lite/src/runtime/kernel/arm/base/merge.h index be81359768..726a3ae243 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/merge.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/merge.h @@ -34,7 +34,8 @@ class MergeCPUKernel : public LiteKernel { merge_param_ = reinterpret_cast(op_parameter_); } ~MergeCPUKernel() override {} - bool IsReady() override; + int FreeInWorkTensor() const override; + bool IsReady(const std::vector &scope_tensors) override; int Init() override; int ReSize() override; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/switch.cc b/mindspore/lite/src/runtime/kernel/arm/base/switch.cc index c56583a172..1a16bdaaaa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/switch.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/switch.cc @@ -53,7 +53,6 @@ int SwitchCPUKernel::ReSize() { return RET_ERROR; } // output: true-data*n, false-data*n int SwitchCPUKernel::Run() { MS_ASSERT(in_tensors_.size() >= 2); - MS_ASSERT(out_tensors_.size() == 2 * in_tensors_.size()); auto bool_tensor = in_tensors_.front(); MS_ASSERT(bool_tensor != nullptr); MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool); @@ -71,8 +70,8 @@ int SwitchCPUKernel::Run() { auto out_tensor = out_tensors_.at(out_index++); MS_ASSERT(in_tensor != nullptr); MS_ASSERT(out_tensor != nullptr); - auto input = reinterpret_cast(in_tensor->data_c()); - auto output = reinterpret_cast(out_tensor->data_c()); + auto input = in_tensor->data_c(); + auto output = out_tensor->data_c(); MS_ASSERT(in_tensor->Size() == out_tensor->Size()); if (input == nullptr || output == nullptr) { MS_LOG(ERROR) << "input tensor or output tensor have not been malloced"; diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc index 8b5cce77a8..cad53923c5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc @@ -221,7 +221,7 @@ int OpenCLSubGraph::Init() { return ret; } nodes_.insert(nodes_.end(), out_convert_ops_.begin(), out_convert_ops_.end()); - + GetInOutNodes(); UpdateTensorDataType(); ret = SubGraphKernel::Prepare(); @@ -283,6 +283,8 @@ void OpenCLSubGraph::GetKernelFromToTensor(const std::vector &in } void OpenCLSubGraph::GetInOutNodes() { + this->in_nodes_.clear(); + this->out_nodes_.clear(); for (auto *node : nodes_) { for (auto *tensor : node->in_tensors()) { if (std::find(in_tensors_.begin(), in_tensors_.end(), tensor) != in_tensors_.end()) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h index f48760e1f1..bcaf85bf34 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h @@ -74,8 +74,6 @@ class OpenCLSubGraph : public SubGraphKernel { std::vector out_parameters_; std::vector in_convert_ops_; std::vector out_convert_ops_; - std::vector in_nodes_; - std::vector out_nodes_; std::set nodes_set_; lite::opencl::OpenCLRuntimeWrapper ocl_runtime_wrap_; lite::opencl::OpenCLRuntime *ocl_runtime_{nullptr}; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 9ec55d58a5..4d2f514672 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -67,8 +67,6 @@ int Scheduler::Schedule(std::vector *dst_kernels) { MS_LOG(ERROR) << "ConstructSubGraphs failed."; return ret; } - FindAllInoutKernels(*dst_kernels); - kernel::LiteKernelUtil::InitTensorInitRefCount(*dst_kernels); MS_LOG(DEBUG) << "schedule kernels success."; return RET_OK; } diff --git a/mindspore/lite/src/sub_graph_kernel.cc b/mindspore/lite/src/sub_graph_kernel.cc index 4f2de9a1f4..d3888ef2bd 100644 --- a/mindspore/lite/src/sub_graph_kernel.cc +++ b/mindspore/lite/src/sub_graph_kernel.cc @@ -53,11 +53,11 @@ std::string SubGraphKernel::ToString() const { for (auto tensor : out_tensors_) { oss << " " << tensor; } - oss << std::endl << "Subgraph input kernels :" << std::endl; + oss << std::endl << "Subgraph input nodes :" << std::endl; for (auto kernel : this->in_nodes_) { oss << " " << kernel->ToString() << std::endl; } - oss << std::endl << "Subgraph output kernels :" << std::endl; + oss << std::endl << "Subgraph output nodes :" << std::endl; for (auto kernel : this->out_nodes_) { oss << " " << kernel->ToString() << std::endl; } diff --git a/mindspore/lite/src/sub_graph_kernel.h b/mindspore/lite/src/sub_graph_kernel.h index 8bcbd4d289..9f595c9151 100644 --- a/mindspore/lite/src/sub_graph_kernel.h +++ b/mindspore/lite/src/sub_graph_kernel.h @@ -69,6 +69,22 @@ class SubGraphKernel : public LiteKernel { nodes_.clear(); } + void FindInoutKernels(const std::vector &scope_kernels) override { + LiteKernel::FindInoutKernels(scope_kernels); + std::vector new_scope_kernels = {}; + new_scope_kernels.insert(new_scope_kernels.end(), this->in_kernels().begin(), this->in_kernels().end()); + new_scope_kernels.insert(new_scope_kernels.end(), this->out_kernels().begin(), this->out_kernels().end()); + new_scope_kernels.insert(new_scope_kernels.end(), nodes_.begin(), nodes_.end()); + for (auto *node : nodes_) { + node->FindInoutKernels(new_scope_kernels); + } + } + + bool IsReady(const std::vector &scope_tensors) override { + return std::all_of(this->in_nodes_.begin(), this->in_nodes_.end(), + [&](LiteKernel *kernel) { return kernel->IsReady(scope_tensors); }); + } + // called while compiling graph. Call node->Prepare() by default. int Prepare() override; // called before Run