From 9428ffe860237c6b251ecbd88a68bb41d7c6c44f Mon Sep 17 00:00:00 2001 From: xutianchun Date: Wed, 20 Jan 2021 11:15:54 +0800 Subject: [PATCH] fix gather weight quant bug --- mindspore/lite/src/dequant.cc | 11 ++++-- mindspore/lite/src/dequant.h | 4 +- mindspore/lite/src/lite_session.cc | 37 ++++++++++--------- .../src/runtime/kernel/arm/fp32/lstm_fp32.cc | 26 +++++++------ mindspore/lite/src/scheduler.cc | 11 +++++- mindspore/lite/src/scheduler.h | 6 +++ .../converter/quantizer/weight_quantizer.cc | 32 ++++++++-------- .../converter/quantizer/weight_quantizer.h | 2 +- 8 files changed, 76 insertions(+), 53 deletions(-) diff --git a/mindspore/lite/src/dequant.cc b/mindspore/lite/src/dequant.cc index f13c403b54..22064b37e1 100644 --- a/mindspore/lite/src/dequant.cc +++ b/mindspore/lite/src/dequant.cc @@ -51,7 +51,7 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i } std::map> DequantUtil::DequantTensor(const std::vector &in_tensors, - TypeId data_type) { + TypeId data_type, bool need_restore) { std::map> tensor_origin_data; if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) { for (auto weight_tensor : in_tensors) { @@ -59,16 +59,21 @@ std::map> DequantUtil::DequantTensor(const s auto *restore_data = weight_tensor->data_c(); auto restore_type = weight_tensor->data_type(); bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && - restore_data != nullptr; + restore_data != nullptr && + (restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16); if (dequant_flag) { auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { MS_LOG(ERROR) << "dequant data is nullptr."; return tensor_origin_data; } + if (need_restore) { + tensor_origin_data[weight_tensor] = {restore_type, restore_data}; + } else { + weight_tensor->FreeData(); + } weight_tensor->set_data(dequant_weight); weight_tensor->set_data_type(kNumberTypeFloat32); - tensor_origin_data[weight_tensor] = {restore_type, restore_data}; } } } diff --git a/mindspore/lite/src/dequant.h b/mindspore/lite/src/dequant.h index b052515103..094b8468ef 100644 --- a/mindspore/lite/src/dequant.h +++ b/mindspore/lite/src/dequant.h @@ -34,7 +34,7 @@ class DequantUtil { static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); static std::map> DequantTensor(const std::vector &in_tensors, - TypeId data_type); + TypeId data_type, bool need_restore = true); static void RestoreTensorData(const std::map> &tensor_origin_data_map); @@ -79,7 +79,7 @@ class DequantUtil { auto var_corr = param.var_corr; auto mean_corr = param.mean_corr; if (var_corr < 0 || var_corr > 10) { - MS_LOG(WARNING) << "unexpeted var_corr: " << var_corr; + MS_LOG(WARNING) << "unexpected var_corr: " << var_corr; var_corr = 1; } for (size_t j = 0; j < per_channel_size; j++) { diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 2f2e746438..48e415f109 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -38,10 +38,6 @@ namespace mindspore { namespace lite { -static std::vector packed_op = { - schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DepthwiseConv2D, - schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_MatMul}; - // this method will not check whether tensor_idx is a weight tensor index, caller should ensure this. static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor_idx) { #ifdef SUPPORT_TRAIN @@ -92,8 +88,13 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde lite::Tensor *dst_tensor) { MS_ASSERT(src_tensor != nullptr); MS_ASSERT(dst_tensor != nullptr); + auto NeedUnPack = [&src_tensor, &dst_tensor]() -> bool { + auto data_type = src_tensor->dataType(); + int pack_size = src_tensor->data()->size(); + int org_size = dst_tensor->Size(); + return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16); + }; auto src_category = TensorCategory(src_tensor); - auto data_type = src_tensor->dataType(); if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) && src_tensor->data() != nullptr && src_tensor->data()->size() > 0) { if (src_tensor->dataType() == kObjectTypeTensorType) { @@ -112,18 +113,20 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde MS_LOG(ERROR) << "Data from tensor is nullptr"; return RET_NULL_PTR; } - memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); + if (NeedUnPack()) { + DequantUtil::UnPackToInt(src_tensor, dst_data); + } else { + memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); + } copyed_tensor_idxes_.emplace_back(tensor_index); } else { - int pack_size = src_tensor->data()->size(); - int org_size = dst_tensor->Size(); - if (pack_size != org_size && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16)) { - auto ret = dst_tensor->MallocData(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Malloc data for tensor failed "; - return RET_ERROR; + if (NeedUnPack()) { + auto dst_data = dst_tensor->MutableData(); + if (dst_data == nullptr) { + MS_LOG(ERROR) << "Data from tensor is nullptr"; + return RET_NULL_PTR; } - DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData()); + DequantUtil::UnPackToInt(src_tensor, dst_data); copyed_tensor_idxes_.emplace_back(tensor_index); } else { dst_tensor->set_data(const_cast(src_tensor->data()->data())); @@ -713,12 +716,12 @@ int LiteSession::InitGPURuntime() { session::LiteSession *session::LiteSession::CreateSession(const lite::Context *context) { auto session = new (std::nothrow) lite::LiteSession(); if (session == nullptr) { - MS_LOG(ERROR) << "create sesssion failed"; + MS_LOG(ERROR) << "create session failed"; return nullptr; } auto ret = session->Init(context); if (ret != mindspore::lite::RET_OK) { - MS_LOG(ERROR) << "init sesssion failed"; + MS_LOG(ERROR) << "init session failed"; delete session; return nullptr; } @@ -729,7 +732,7 @@ session::LiteSession *session::LiteSession::CreateSession(const char *model_buf, const lite::Context *context) { auto *session = LiteSession::CreateSession(context); if (session == nullptr) { - MS_LOG(ERROR) << "Create sesssion failed"; + MS_LOG(ERROR) << "Create session failed"; return nullptr; } auto *model = lite::ImportFromBuffer(model_buf, size, true); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc index 2ab97225c4..298d535bf5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm_fp32.cc @@ -107,8 +107,10 @@ int LstmCPUKernel::InitWeightBias() { } memcpy(weight_h_ptr_, weight_h->MutableData(), weight_h->ElementsNum() * sizeof(float)); + std::vector w_shape = weight_i->shape(); + auto hidden_size = w_shape.at(1) / 4; // init bias - int bias_num = lstm_parm_->bidirectional_ ? 2 * 4 * lstm_parm_->hidden_size_ : 4 * lstm_parm_->hidden_size_; + int bias_num = lstm_parm_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size; bias_ptr_ = reinterpret_cast(malloc(bias_num * sizeof(float))); if (bias_ptr_ == nullptr) { MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error."; @@ -116,13 +118,13 @@ int LstmCPUKernel::InitWeightBias() { } auto bias_data = reinterpret_cast(in_tensors_.at(3)->MutableData()); - const int state_bias_offset = 4 * lstm_parm_->hidden_size_; + const int state_bias_offset = 4 * hidden_size; for (int i = 0; i < state_bias_offset; i++) { bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset]; } if (lstm_parm_->bidirectional_) { - bias_data += 4 * lstm_parm_->hidden_size_ * 2; - auto backward_bias = bias_ptr_ + 4 * lstm_parm_->hidden_size_; + bias_data += 4 * hidden_size * 2; + auto backward_bias = bias_ptr_ + 4 * hidden_size; for (int i = 0; i < state_bias_offset; i++) { backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset]; } @@ -131,6 +133,14 @@ int LstmCPUKernel::InitWeightBias() { } int LstmCPUKernel::Init() { + FreeTmpBuffer(); + auto ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error."; + FreeTmpBuffer(); + return RET_ERROR; + } + if (!InferShapeDone()) { return RET_OK; } @@ -138,20 +148,12 @@ int LstmCPUKernel::Init() { } int LstmCPUKernel::ReSize() { - FreeTmpBuffer(); auto ret = InitParam(); if (ret != RET_OK) { MS_LOG(ERROR) << "LstmCPUKernel InitParam error."; return RET_ERROR; } - ret = InitWeightBias(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error."; - FreeTmpBuffer(); - return RET_ERROR; - } - ret = InitBuffer(); if (ret != RET_OK) { MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error."; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 551b720b28..d3dbbbe9db 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -184,6 +184,13 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in const Model::Node *node) { MS_ASSERT(primitive != nullptr); TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); + bool need_restore = true; + if (primitive->quant_type() == schema::QuantType_WeightQuant) { + data_type = kNumberTypeFloat32; + } + if (!IsContain(packed_op, (schema::PrimitiveType)primitive->Type())) { + need_restore = false; + } kernel::KernelKey desc{kCPU, data_type, static_cast(primitive->Type())}; #if SUPPORT_GPU if (context_->IsGpuEnabled()) { @@ -216,7 +223,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in if (mindspore::lite::IsSupportFloat16() && ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; - auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type); + auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type, need_restore); auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); DequantUtil::RestoreTensorData(tensor_origin_data_map); @@ -230,7 +237,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; desc.data_type = kNumberTypeFloat32; } - auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type); + auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type, need_restore); auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); DequantUtil::RestoreTensorData(tensor_origin_data_map); if (kernel != nullptr) { diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index bd5c9fac17..4866502838 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -26,6 +26,12 @@ #include "src/ops/primitive_c.h" namespace mindspore::lite { + +static std::vector packed_op = { + schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, + schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, + schema::PrimitiveType_MatMul, schema::PrimitiveType_Lstm}; + class Scheduler { public: Scheduler(const InnerContext *ctx, Model *src_model, std::vector *src_tensors) diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index f079a3913e..09a037986b 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -253,11 +253,11 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { } auto status = RET_ERROR; if (type_id_ == kNumberTypeInt8) { - status = - QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); + status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, + false, 1); } else if (type_id_ == kNumberTypeInt16) { - status = - QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); + status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, + false, 1); } if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; @@ -316,11 +316,11 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { } auto status = RET_ERROR; if (type_id_ == kNumberTypeInt8) { - status = - QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); + status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, + false, 3); } else if (type_id_ == kNumberTypeInt16) { status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, - false); + false, 3); } if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; @@ -340,10 +340,10 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { auto primitive_c = GetValueNode>(cnode->input(0)); MS_ASSERT(primitive_c != nullptr); - auto weight_h = cnode->input(1); + auto first_input = cnode->input(1); ParameterPtr param_node; ParamValueLitePtr param_value; - GetLiteParameter(weight_h, ¶m_node, ¶m_value); + GetLiteParameter(first_input, ¶m_node, ¶m_value); if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) { MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight"; return RET_OK; @@ -358,10 +358,10 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { auto status = RET_ERROR; if (type_id_ == kNumberTypeInt8) { status = - QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, 0); } else if (type_id_ == kNumberTypeInt16) { status = - QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false); + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, 0); } if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; @@ -510,7 +510,7 @@ STATUS WeightQuantizer::RunFp32Graph(FuncGraphPtr func_graph) { return RET_OK; } -STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { +STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { // 0.2 Parse input calib files auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); if (status != RET_OK) { @@ -652,7 +652,7 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { delete quant_sm.model; return RET_ERROR; } - // 3. compare betwen quant and fp32 + // 3. compare between quant and fp32 auto quant_outputs = quant_session->GetOutputs(); mean_error += CompareOutputData(fp32_output_tensors_[i], quant_outputs); } // end_for: calib data loop @@ -690,8 +690,8 @@ STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) { for (auto &cnode : func_graph->GetOrderedCnodes()) { auto primitive_c = GetValueNode>(cnode->input(0)); if (primitive_c == nullptr) { - MS_LOG(ERROR) << "primitive_c is nullptr"; - return RET_ERROR; + MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive_c is nullptr"; + continue; } auto op_name = cnode->fullname_with_scope(); auto op_type = (schema::PrimitiveType)primitive_c->Type(); @@ -744,7 +744,7 @@ STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); type_id_ = kNumberTypeInt8; MS_LOG(INFO) << "Do mixed bit quantization"; - return DoMiexedQuant(func_graph); + return DoMixedQuant(func_graph); } return DoFixedQuant(func_graph); diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h index 2fc8e0199a..0d749b3aaf 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -62,7 +62,7 @@ class WeightQuantizer : public Quantizer { std::vector> images_; // multi_input, [[mode_input_0], [model_input_1]...] std::vector> fp32_output_tensors_; - STATUS DoMiexedQuant(FuncGraphPtr); + STATUS DoMixedQuant(FuncGraphPtr); STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr primitive_c); STATUS DoFixedQuant(FuncGraphPtr); STATUS RunFp32Graph(FuncGraphPtr);