Browse Source

float16 priority

tags/v0.7.0-beta
zhangxuetong 5 years ago
parent
commit
86199d522e
4 changed files with 39 additions and 7 deletions
  1. +2
    -3
      mindspore/lite/include/context.h
  2. +1
    -0
      mindspore/lite/src/lite_session.cc
  3. +35
    -4
      mindspore/lite/src/scheduler.cc
  4. +1
    -0
      mindspore/lite/src/scheduler.h

+ 2
- 3
mindspore/lite/include/context.h View File

@@ -64,11 +64,10 @@ class MS_API Context {
/// \brief Destructor of MindSpore Lite Context.
virtual ~Context();

void InferShapeInterrupt() {
infer_shape_interrupt_ = true;
}
void InferShapeInterrupt() { infer_shape_interrupt_ = true; }

public:
bool float16_priority = true; /**< allow priority select float16 kernel */
DeviceContext device_ctx_{DT_CPU};
int thread_num_ = 2; /**< thread number config for thread pool */
std::shared_ptr<Allocator> allocator = nullptr;


+ 1
- 0
mindspore/lite/src/lite_session.cc View File

@@ -238,6 +238,7 @@ int LiteSession::Init(Context *context) {
MS_LOG(ERROR) << "new context failed";
return RET_MEMORY_FAILED;
}
this->context_->float16_priority = context->float16_priority;
this->context_->cpu_bind_mode_ = context->cpu_bind_mode_;
ConfigThreadPool(context->cpu_bind_mode_, context->thread_num_);
auto ret = KernelRegistry::GetInstance()->Init();


+ 35
- 4
mindspore/lite/src/scheduler.cc View File

@@ -95,6 +95,7 @@ int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tenso
<< ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type());
return RET_ERROR;
}
SetKernelTensorDataType(kernel);
kernel->set_name(cNode->name()->str());
kernel->set_is_model_output(IsContain(graph_output_node_indexes, size_t(i)));
kernels->emplace_back(kernel);
@@ -129,6 +130,9 @@ void Scheduler::ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels) {
for (auto kernel : temp_kernels) {
for (auto tensor : kernel->out_tensors()) {
tensor->set_allocator(context_->allocator.get());
if (context_->float16_priority && tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
}
}
}
std::copy(temp_kernels.begin(), temp_kernels.end(), std::back_inserter(subgraph_kernels));
@@ -196,24 +200,51 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *>

desc.arch = kernel::KERNEL_ARCH::kCPU;
kernel::LiteKernel *kernel = nullptr;
if (data_type == kNumberTypeFloat32) {
if ((context_->float16_priority && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16) {
// check if support fp16
kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type};
kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key);
if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get fp16 op success.";
desc.data_type = kNumberTypeFloat16;
kernel->set_desc(desc);
return kernel;
}
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
} else {
kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
}
if (data_type == kNumberTypeFloat16) {
desc.data_type = kNumberTypeFloat32;
}
kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
if (kernel != nullptr) {
kernel->set_desc(desc);
return kernel;
}
return nullptr;
}

void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) {
if (kernel->desc().arch != kernel::KERNEL_ARCH::kCPU) {
return;
}
if (kernel->desc().data_type == kNumberTypeFloat16) {
for (auto tensor : kernel->out_tensors()) {
if (tensor->data_type() == kNumberTypeFloat32) {
tensor->set_data_type(kNumberTypeFloat16);
}
}
} else if (kernel->desc().data_type == kNumberTypeFloat32) {
for (auto tensor : kernel->in_tensors()) {
if (tensor->TensorType() != schema::NodeType_ValueNode && tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
}
}
for (auto tensor : kernel->out_tensors()) {
if (tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
}
}
}
}

} // namespace mindspore::lite

+ 1
- 0
mindspore/lite/src/scheduler.h View File

@@ -41,6 +41,7 @@ class Scheduler {
void ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels);

kernel::LiteKernel *CreateSubKernel(const std::vector<kernel::LiteKernel *> &kernels, kernel::KERNEL_ARCH arch);
void SetKernelTensorDataType(kernel::LiteKernel *kernel);

protected:
Context *context_ = nullptr;


Loading…
Cancel
Save