| @@ -380,13 +380,25 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||||
| } | } | ||||
| } | } | ||||
| kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; | kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; | ||||
| // weight quant | |||||
| std::map<Tensor *, Tensor *> restored_origin_tensors; | |||||
| for (auto &tensor : in_tensors) { | |||||
| auto channel_first = IsChannelFirst(in_tensors, op_parameter); | |||||
| auto *restore_tensor = DequantUtil::DequantTensor(tensor, desc.data_type, channel_first, kNumberTypeFloat32); | |||||
| if (restore_tensor != nullptr) { | |||||
| restored_origin_tensors[tensor] = restore_tensor; | |||||
| } | |||||
| } | |||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter); | auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter); | ||||
| if (kernel != nullptr) { | if (kernel != nullptr) { | ||||
| MS_LOG(DEBUG) << "Get npu op success: " << PrimitiveCurVersionTypeName(npu_desc.type) << " " << node->name_; | MS_LOG(DEBUG) << "Get npu op success: " << PrimitiveCurVersionTypeName(npu_desc.type) << " " << node->name_; | ||||
| FreeRestoreTensors(&restored_origin_tensors); | |||||
| return kernel; | return kernel; | ||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(npu_desc.type) << " " | MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(npu_desc.type) << " " | ||||
| << node->name_; | << node->name_; | ||||
| RestoreTensorData(restored_origin_tensors); | |||||
| auto ret = InferNodeShape(node, &infer_shape_interrupt); | auto ret = InferNodeShape(node, &infer_shape_interrupt); | ||||
| if (ret == RET_INFER_INVALID || ret == RET_OK) { | if (ret == RET_INFER_INVALID || ret == RET_OK) { | ||||
| op_parameter = op_parameters_[node->output_indices_.at(0)]; | op_parameter = op_parameters_[node->output_indices_.at(0)]; | ||||