From 86199d522e344069a8c788d0225e67e3af5d0ed7 Mon Sep 17 00:00:00 2001 From: zhangxuetong Date: Sun, 16 Aug 2020 15:37:49 +0800 Subject: [PATCH] float16 priority --- mindspore/lite/include/context.h | 5 ++-- mindspore/lite/src/lite_session.cc | 1 + mindspore/lite/src/scheduler.cc | 39 +++++++++++++++++++++++++++--- mindspore/lite/src/scheduler.h | 1 + 4 files changed, 39 insertions(+), 7 deletions(-) diff --git a/mindspore/lite/include/context.h b/mindspore/lite/include/context.h index 5171de4c67..d709dba358 100644 --- a/mindspore/lite/include/context.h +++ b/mindspore/lite/include/context.h @@ -64,11 +64,10 @@ class MS_API Context { /// \brief Destructor of MindSpore Lite Context. virtual ~Context(); - void InferShapeInterrupt() { - infer_shape_interrupt_ = true; - } + void InferShapeInterrupt() { infer_shape_interrupt_ = true; } public: + bool float16_priority = true; /**< allow priority select float16 kernel */ DeviceContext device_ctx_{DT_CPU}; int thread_num_ = 2; /**< thread number config for thread pool */ std::shared_ptr allocator = nullptr; diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 333623101e..4366655cd7 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -238,6 +238,7 @@ int LiteSession::Init(Context *context) { MS_LOG(ERROR) << "new context failed"; return RET_MEMORY_FAILED; } + this->context_->float16_priority = context->float16_priority; this->context_->cpu_bind_mode_ = context->cpu_bind_mode_; ConfigThreadPool(context->cpu_bind_mode_, context->thread_num_); auto ret = KernelRegistry::GetInstance()->Init(); diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 1e4570f326..a264e43915 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -95,6 +95,7 @@ int Scheduler::InitOp2Kernel(const lite::Model *model, std::vectorprimitive()->value_type()); return RET_ERROR; } + SetKernelTensorDataType(kernel); kernel->set_name(cNode->name()->str()); kernel->set_is_model_output(IsContain(graph_output_node_indexes, size_t(i))); kernels->emplace_back(kernel); @@ -129,6 +130,9 @@ void Scheduler::ConstructSubgraphs(std::vector *kernels) { for (auto kernel : temp_kernels) { for (auto tensor : kernel->out_tensors()) { tensor->set_allocator(context_->allocator.get()); + if (context_->float16_priority && tensor->data_type() == kNumberTypeFloat16) { + tensor->set_data_type(kNumberTypeFloat32); + } } } std::copy(temp_kernels.begin(), temp_kernels.end(), std::back_inserter(subgraph_kernels)); @@ -196,24 +200,51 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector desc.arch = kernel::KERNEL_ARCH::kCPU; kernel::LiteKernel *kernel = nullptr; - if (data_type == kNumberTypeFloat32) { + if ((context_->float16_priority && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16) { // check if support fp16 kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type}; kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key); if (kernel != nullptr) { MS_LOG(DEBUG) << "Get fp16 op success."; + desc.data_type = kNumberTypeFloat16; kernel->set_desc(desc); return kernel; } MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; - kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); - } else { - kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); } + if (data_type == kNumberTypeFloat16) { + desc.data_type = kNumberTypeFloat32; + } + kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); if (kernel != nullptr) { kernel->set_desc(desc); return kernel; } return nullptr; } + +void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) { + if (kernel->desc().arch != kernel::KERNEL_ARCH::kCPU) { + return; + } + if (kernel->desc().data_type == kNumberTypeFloat16) { + for (auto tensor : kernel->out_tensors()) { + if (tensor->data_type() == kNumberTypeFloat32) { + tensor->set_data_type(kNumberTypeFloat16); + } + } + } else if (kernel->desc().data_type == kNumberTypeFloat32) { + for (auto tensor : kernel->in_tensors()) { + if (tensor->TensorType() != schema::NodeType_ValueNode && tensor->data_type() == kNumberTypeFloat16) { + tensor->set_data_type(kNumberTypeFloat32); + } + } + for (auto tensor : kernel->out_tensors()) { + if (tensor->data_type() == kNumberTypeFloat16) { + tensor->set_data_type(kNumberTypeFloat32); + } + } + } +} + } // namespace mindspore::lite diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index 4aed217c99..880135252f 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -41,6 +41,7 @@ class Scheduler { void ConstructSubgraphs(std::vector *kernels); kernel::LiteKernel *CreateSubKernel(const std::vector &kernels, kernel::KERNEL_ARCH arch); + void SetKernelTensorDataType(kernel::LiteKernel *kernel); protected: Context *context_ = nullptr;