| @@ -204,6 +204,11 @@ int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) { | |||||
| if (trans_kernel->out_kernels().empty()) { | if (trans_kernel->out_kernels().empty()) { | ||||
| // kernel is a trans kernel, it's input kernel num and input tensor num must be 1 | // kernel is a trans kernel, it's input kernel num and input tensor num must be 1 | ||||
| kernel->in_kernels()[0]->set_out_tensors({trans_kernel->out_tensors()[0]}); | kernel->in_kernels()[0]->set_out_tensors({trans_kernel->out_tensors()[0]}); | ||||
| // in fp16 mode, tensor data type fp16 need to be changed back. | |||||
| auto tensor = kernel->in_kernels()[0]->out_tensors()[0]; | |||||
| if (tensor->data_type() == kNumberTypeFloat16) { | |||||
| tensor->set_data_type(kNumberTypeFloat32); | |||||
| } | |||||
| } | } | ||||
| for (const auto &post_kernel : trans_kernel->out_kernels()) { | for (const auto &post_kernel : trans_kernel->out_kernels()) { | ||||
| // update tensor | // update tensor | ||||
| @@ -212,6 +212,11 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||||
| if (desc.data_type == kNumberTypeFloat16) { | if (desc.data_type == kNumberTypeFloat16) { | ||||
| desc.data_type = kNumberTypeFloat32; | desc.data_type = kNumberTypeFloat32; | ||||
| } | } | ||||
| for (auto tensor : in_tensors) { | |||||
| if (tensor->data_type() == kNumberTypeFloat16) { | |||||
| tensor->set_data_type(kNumberTypeFloat32); | |||||
| } | |||||
| } | |||||
| kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; | kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; | ||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc); | auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc); | ||||
| if (kernel != nullptr) { | if (kernel != nullptr) { | ||||