|
|
|
@@ -95,6 +95,7 @@ int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tenso |
|
|
|
<< ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type()); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
SetKernelTensorDataType(kernel); |
|
|
|
kernel->set_name(cNode->name()->str()); |
|
|
|
kernel->set_is_model_output(IsContain(graph_output_node_indexes, size_t(i))); |
|
|
|
kernels->emplace_back(kernel); |
|
|
|
@@ -129,6 +130,9 @@ void Scheduler::ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels) { |
|
|
|
for (auto kernel : temp_kernels) { |
|
|
|
for (auto tensor : kernel->out_tensors()) { |
|
|
|
tensor->set_allocator(context_->allocator.get()); |
|
|
|
if (context_->float16_priority && tensor->data_type() == kNumberTypeFloat16) { |
|
|
|
tensor->set_data_type(kNumberTypeFloat32); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
std::copy(temp_kernels.begin(), temp_kernels.end(), std::back_inserter(subgraph_kernels)); |
|
|
|
@@ -196,24 +200,51 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *> |
|
|
|
|
|
|
|
desc.arch = kernel::KERNEL_ARCH::kCPU; |
|
|
|
kernel::LiteKernel *kernel = nullptr; |
|
|
|
if (data_type == kNumberTypeFloat32) { |
|
|
|
if ((context_->float16_priority && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16) { |
|
|
|
// check if support fp16 |
|
|
|
kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type}; |
|
|
|
kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key); |
|
|
|
if (kernel != nullptr) { |
|
|
|
MS_LOG(DEBUG) << "Get fp16 op success."; |
|
|
|
desc.data_type = kNumberTypeFloat16; |
|
|
|
kernel->set_desc(desc); |
|
|
|
return kernel; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; |
|
|
|
kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); |
|
|
|
} else { |
|
|
|
kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); |
|
|
|
} |
|
|
|
if (data_type == kNumberTypeFloat16) { |
|
|
|
desc.data_type = kNumberTypeFloat32; |
|
|
|
} |
|
|
|
kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); |
|
|
|
if (kernel != nullptr) { |
|
|
|
kernel->set_desc(desc); |
|
|
|
return kernel; |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) { |
|
|
|
if (kernel->desc().arch != kernel::KERNEL_ARCH::kCPU) { |
|
|
|
return; |
|
|
|
} |
|
|
|
if (kernel->desc().data_type == kNumberTypeFloat16) { |
|
|
|
for (auto tensor : kernel->out_tensors()) { |
|
|
|
if (tensor->data_type() == kNumberTypeFloat32) { |
|
|
|
tensor->set_data_type(kNumberTypeFloat16); |
|
|
|
} |
|
|
|
} |
|
|
|
} else if (kernel->desc().data_type == kNumberTypeFloat32) { |
|
|
|
for (auto tensor : kernel->in_tensors()) { |
|
|
|
if (tensor->TensorType() != schema::NodeType_ValueNode && tensor->data_type() == kNumberTypeFloat16) { |
|
|
|
tensor->set_data_type(kNumberTypeFloat32); |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto tensor : kernel->out_tensors()) { |
|
|
|
if (tensor->data_type() == kNumberTypeFloat16) { |
|
|
|
tensor->set_data_type(kNumberTypeFloat32); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace mindspore::lite |