From 6b667016a70b9db445a40f24369fd1094104aa17 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Wed, 28 Oct 2020 15:36:10 +0800 Subject: [PATCH] fix bug in sub_graph_kernel --- mindspore/lite/src/lite_kernel.cc | 49 ++++++++++++++++++++++--- mindspore/lite/src/lite_kernel.h | 2 + mindspore/lite/src/scheduler.cc | 20 ++++++---- mindspore/lite/src/scheduler.h | 4 +- mindspore/lite/src/sub_graph_kernel.cc | 51 +++++++++----------------- 5 files changed, 76 insertions(+), 50 deletions(-) diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc index 997878b598..a0805dd2cd 100644 --- a/mindspore/lite/src/lite_kernel.cc +++ b/mindspore/lite/src/lite_kernel.cc @@ -16,6 +16,7 @@ #include "src/lite_kernel.h" #include +#include #include "src/tensor.h" namespace mindspore::kernel { @@ -120,19 +121,19 @@ std::string LiteKernel::ToString() const { std::ostringstream oss; oss << "LiteKernel: " << this->name_; oss << ", Type: " << this->type_str(); - oss << std::endl << this->in_tensors_.size() << " InputTensors:"; + oss << ", " << this->in_tensors_.size() << " InputTensors:"; for (auto tensor : in_tensors_) { - oss << " " << tensor << ":" << tensor->ToString(); + oss << " " << tensor; } - oss << std::endl << this->out_tensors_.size() << " OutputTensors:"; + oss << ", " << this->out_tensors_.size() << " OutputTensors:"; for (auto tensor : out_tensors_) { - oss << " " << tensor << ":" << tensor->ToString(); + oss << " " << tensor; } - oss << std::endl << this->in_kernels_.size() << " InputKernels:"; + oss << ", " << this->in_kernels_.size() << " InputKernels:"; for (auto in_kernel : in_kernels_) { oss << " " << in_kernel->name_; } - oss << std::endl << this->out_kernels_.size() << " OutputKernels:"; + oss << ", " << this->out_kernels_.size() << " OutputKernels:"; for (auto out_kernel : out_kernels_) { oss << " " << out_kernel->name_; } @@ -239,6 +240,42 @@ std::vector LiteKernelUtil::SubgraphOutputTensors(const std::vec return output_tensors; } +int LiteKernelUtil::TopologicalSortKernels(std::vector *kernels) { + auto old_kernels = *kernels; + kernels->clear(); + std::queue kernel_queue; + for (auto kernel : old_kernels) { + if (kernel->in_kernels().empty()) { + kernel_queue.push(kernel); + kernels->emplace_back(kernel); + } + } + while (!kernel_queue.empty()) { + auto cur_kernel = kernel_queue.front(); + kernel_queue.pop(); + MS_ASSERT(cur_kernel != nullptr); + auto next_kernels = cur_kernel->out_kernels(); + for (auto next_kernel : next_kernels) { + auto in_kernels = next_kernel->in_kernels(); + if (lite::IsContain(*kernels, const_cast(next_kernel))) { + MS_LOG(ERROR) << "TopologicalSortKernels failed, loop exist"; + return RET_ERROR; + } + if (std::all_of(in_kernels.begin(), in_kernels.end(), [&](const kernel::LiteKernel *in_kernel) { + return lite::IsContain(*kernels, const_cast(in_kernel)); + })) { + kernel_queue.push(next_kernel); + } + } + } + if (kernels->size() != old_kernels.size()) { + MS_LOG(ERROR) << "TopologicalSortKernels failed, kernels size before sort: " << old_kernels.size() + << ", kernels size after sort: " << kernels->size(); + return RET_ERROR; + } + return RET_OK; +} + void LiteKernelUtil::InitIOKernels(std::vector &kernels) { for (auto *kernel : kernels) { // clean io kernels diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 64571c263d..8defe1be0d 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -202,6 +202,8 @@ class LiteKernelUtil { static std::vector SubgraphOutputTensors(const std::vector &kernels); + static int TopologicalSortKernels(std::vector *kernels); + static void InitTensorRefCount(std::vector &kernels); static int SetInput(LiteKernel &kernelMod, std::vector inputs); diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 8bfdcb5c69..9e25581bf0 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -38,17 +38,21 @@ int Scheduler::Schedule(const lite::Model *model, std::vector *tensors int ret = InferShape(model, tensors); if (ret != RET_OK) { MS_LOG(ERROR) << "op infer shape failed."; - return RET_ERROR; + return ret; } - ret = InitOp2Kernel(model, tensors, kernels); + ret = BuildKernels(model, tensors, kernels); if (ret != RET_OK) { MS_LOG(ERROR) << "init op to kernel failed."; - return RET_ERROR; + return ret; } kernel::LiteKernelUtil::InitIOKernels(*kernels); - ConstructSubGraphs(kernels); + ret = ConstructSubGraphs(kernels); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConstructSubGraphs failed."; + return ret; + } kernel::LiteKernelUtil::InitIOKernels(*kernels); @@ -129,8 +133,8 @@ int Scheduler::InferShape(const lite::Model *model, std::vector *tenso return RET_OK; } -int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector *tensors, - std::vector *kernels) { +int Scheduler::BuildKernels(const lite::Model *model, std::vector *tensors, + std::vector *kernels) { MS_ASSERT(model != nullptr); MS_ASSERT(tensors != nullptr); uint32_t kernelCount = model->nodes_.size(); @@ -194,7 +198,7 @@ int Scheduler::ConstructSubGraphs(std::vector *kernels) { std::vector sub_kernels; std::queue kernel_queue; kernel_queue.emplace(head_kernel); - auto cur_sub_graph_type = this->GetKernelSubGraphType(head_kernel); + auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel); while (!kernel_queue.empty()) { auto cur_kernel = kernel_queue.front(); kernel_queue.pop(); @@ -202,7 +206,7 @@ int Scheduler::ConstructSubGraphs(std::vector *kernels) { sub_kernels.emplace_back(cur_kernel); auto post_kernels = cur_kernel->out_kernels(); for (auto post_kernel : post_kernels) { - if (cur_sub_graph_type == this->GetKernelSubGraphType(post_kernel)) { + if (cur_sub_graph_type == mindspore::lite::Scheduler::GetKernelSubGraphType(post_kernel)) { auto post_kernel_inputs = post_kernel->in_kernels(); if (std::all_of(post_kernel_inputs.begin(), post_kernel_inputs.end(), [&](kernel::LiteKernel *kernel) { return is_kernel_sinked[kernel]; })) { diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index e219258018..b2e8d3fefb 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -37,8 +37,8 @@ class Scheduler { kernel::LiteKernel *ScheduleNode(const std::vector &in_tensors, const std::vector &out_tensors, const mindspore::lite::PrimitiveC *primitive, const Model::Node *cnode); - int InitOp2Kernel(const lite::Model *model, std::vector *tensors, - std::vector *kernels); + int BuildKernels(const lite::Model *model, std::vector *tensors, + std::vector *kernels); static int InferShape(const lite::Model *model, std::vector *tensors); diff --git a/mindspore/lite/src/sub_graph_kernel.cc b/mindspore/lite/src/sub_graph_kernel.cc index 50d66aa57f..0d941f7e2c 100644 --- a/mindspore/lite/src/sub_graph_kernel.cc +++ b/mindspore/lite/src/sub_graph_kernel.cc @@ -44,23 +44,23 @@ int SubGraphKernel::Prepare() { std::string SubGraphKernel::ToString() const { std::ostringstream oss; oss << "===============================================" << std::endl << "Subgraph type : " << this->subgraph_type_; - oss << std::endl << this->in_tensors_.size() << " InputTensors:"; + oss << std::endl << this->in_tensors_.size() << "Subgraph inputTensors:"; for (auto tensor : in_tensors_) { - oss << " " << tensor << ":" << tensor->ToString(); + oss << " " << tensor; } - oss << std::endl << this->out_tensors_.size() << " OutputTensors:"; + oss << std::endl << this->out_tensors_.size() << "Subgraph outputTensors:"; for (auto tensor : out_tensors_) { - oss << " " << tensor << ":" << tensor->ToString(); + oss << " " << tensor; } - oss << std::endl << "input kernels :"; + oss << std::endl << "Subgraph input kernels :" << std::endl; for (auto kernel : this->in_kernels_) { - oss << " " << kernel->ToString(); + oss << " " << kernel->ToString() << std::endl; } - oss << std::endl << "output kernels :"; + oss << std::endl << "Subgraph output kernels :" << std::endl; for (auto kernel : this->out_kernels_) { - oss << " " << kernel->ToString(); + oss << " " << kernel->ToString() << std::endl; } - oss << std::endl << nodes_.size() << " nodes :"; + oss << std::endl << nodes_.size() << " nodes in subgraph :"; for (auto kernel : this->nodes_) { oss << " " << kernel->name(); } @@ -178,36 +178,18 @@ int CpuFp16SubGraph::PreProcess() { } int CpuFp16SubGraph::PostProcess() { - auto fp16_to_fp32_cast_func = kernel::Float16CastUtil::GetInstance()->float16_to_float32_func_; + auto fp16_to_fp32_cast_func = Float16CastUtil::GetInstance()->float16_to_float32_func_; if (fp16_to_fp32_cast_func == nullptr) { MS_LOG(ERROR) << "Can not find cast fp16 to fp32 func"; return RET_ERROR; } for (auto tensor : this->out_tensors_) { if (tensor->data_type() == kNumberTypeFloat16) { - void *float16_data = nullptr; - if (this->context_ != nullptr && this->context_->allocator != nullptr) { - float16_data = this->context_->allocator->Malloc(tensor->Size()); - } else { - float16_data = malloc(tensor->Size()); - } - if (float16_data == nullptr) { - MS_LOG(ERROR) << "malloc data failed"; - return RET_ERROR; - } - memcpy(float16_data, tensor->data_c(), tensor->Size()); - auto ret = tensor->FreeData(); - if (RET_OK != ret) { - MS_LOG(ERROR) << "free data failed"; - if (this->context_ != nullptr && this->context_->allocator != nullptr) { - this->context_->allocator->Free(float16_data); - } else { - free(float16_data); - } - return RET_ERROR; - } + auto float16_data = tensor->data_c(); + MS_ASSERT(float16_data != nullptr); + tensor->set_data(nullptr); tensor->set_data_type(TypeId::kNumberTypeFloat32); - ret = tensor->MallocData(); + auto ret = tensor->MallocData(); if (RET_OK != ret) { MS_LOG(ERROR) << "malloc data failed"; if (this->context_ != nullptr && this->context_->allocator != nullptr) { @@ -217,9 +199,10 @@ int CpuFp16SubGraph::PostProcess() { } return RET_ERROR; } + MS_ASSERT(tensor->data_c() != nullptr); fp16_to_fp32_cast_func(float16_data, tensor->data_c(), tensor->ElementsNum()); - if (this->context_ != nullptr && this->context_->allocator != nullptr) { - this->context_->allocator->Free(float16_data); + if (tensor->allocator() != nullptr) { + tensor->allocator()->Free(float16_data); } else { free(float16_data); }