|
|
@@ -209,6 +209,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in |
|
|
#endif |
|
|
#endif |
|
|
#if SUPPORT_NPU |
|
|
#if SUPPORT_NPU |
|
|
if (context_->IsNpuEnabled()) { |
|
|
if (context_->IsNpuEnabled()) { |
|
|
|
|
|
if (desc.data_type == kNumberTypeFloat16) { |
|
|
|
|
|
desc.data_type = kNumberTypeFloat32; |
|
|
|
|
|
} |
|
|
kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; |
|
|
kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; |
|
|
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc); |
|
|
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc); |
|
|
if (kernel != nullptr) { |
|
|
if (kernel != nullptr) { |
|
|
|