|
|
|
@@ -380,13 +380,25 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in |
|
|
|
} |
|
|
|
} |
|
|
|
kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; |
|
|
|
|
|
|
|
// weight quant |
|
|
|
std::map<Tensor *, Tensor *> restored_origin_tensors; |
|
|
|
for (auto &tensor : in_tensors) { |
|
|
|
auto channel_first = IsChannelFirst(in_tensors, op_parameter); |
|
|
|
auto *restore_tensor = DequantUtil::DequantTensor(tensor, desc.data_type, channel_first, kNumberTypeFloat32); |
|
|
|
if (restore_tensor != nullptr) { |
|
|
|
restored_origin_tensors[tensor] = restore_tensor; |
|
|
|
} |
|
|
|
} |
|
|
|
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter); |
|
|
|
if (kernel != nullptr) { |
|
|
|
MS_LOG(DEBUG) << "Get npu op success: " << PrimitiveCurVersionTypeName(npu_desc.type) << " " << node->name_; |
|
|
|
FreeRestoreTensors(&restored_origin_tensors); |
|
|
|
return kernel; |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(npu_desc.type) << " " |
|
|
|
<< node->name_; |
|
|
|
RestoreTensorData(restored_origin_tensors); |
|
|
|
auto ret = InferNodeShape(node, &infer_shape_interrupt); |
|
|
|
if (ret == RET_INFER_INVALID || ret == RET_OK) { |
|
|
|
op_parameter = op_parameters_[node->output_indices_.at(0)]; |
|
|
|
|