Browse Source

!14203 npu do weight quant

From: @zhaozhenlong
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
pull/14203/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
b31f102e47
1 changed files with 12 additions and 0 deletions
  1. +12
    -0
      mindspore/lite/src/scheduler.cc

+ 12
- 0
mindspore/lite/src/scheduler.cc View File

@@ -380,13 +380,25 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
}
}
kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type};

// weight quant
std::map<Tensor *, Tensor *> restored_origin_tensors;
for (auto &tensor : in_tensors) {
auto channel_first = IsChannelFirst(in_tensors, op_parameter);
auto *restore_tensor = DequantUtil::DequantTensor(tensor, desc.data_type, channel_first, kNumberTypeFloat32);
if (restore_tensor != nullptr) {
restored_origin_tensors[tensor] = restore_tensor;
}
}
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter);
if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get npu op success: " << PrimitiveCurVersionTypeName(npu_desc.type) << " " << node->name_;
FreeRestoreTensors(&restored_origin_tensors);
return kernel;
} else {
MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(npu_desc.type) << " "
<< node->name_;
RestoreTensorData(restored_origin_tensors);
auto ret = InferNodeShape(node, &infer_shape_interrupt);
if (ret == RET_INFER_INVALID || ret == RET_OK) {
op_parameter = op_parameters_[node->output_indices_.at(0)];


Loading…
Cancel
Save