From 304664bd091136ba60fde648f28fd249b129109e Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Tue, 13 Apr 2021 17:04:22 +0800 Subject: [PATCH] [MS][LITE] move conv bias quant param to propogator --- .../conv_quant_param_propogator.cc | 31 ++++++++++++++++ .../optimizer/fusion/matmul_add_fusion.cc | 3 +- .../optimizer/graph/mindir_adjust_pass.cc | 35 ------------------- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.cc index 061e31b8aa..94348068cb 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.cc @@ -15,6 +15,7 @@ */ #include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h" #include "mindspore/core/ir/dtype/type_id.h" + namespace mindspore::lite { static constexpr size_t kBiasAdd = 3; @@ -22,6 +23,36 @@ STATUS ConvQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGra const mindspore::schema::CNodeT &node) { if (node.inputIndex.size() == kBiasAdd) { auto &bias_tensor = graph->allTensors.at(node.inputIndex.at(kBiasAdd - 1)); + if (bias_tensor->quantParams.empty() || !bias_tensor->quantParams.front()->inited) { + // check input and weight quant params + auto &input_tensor = graph->allTensors.at(node.inputIndex.at(0)); + auto &weight_tensor = graph->allTensors.at(node.inputIndex.at(1)); + if (input_tensor->quantParams.empty() || !input_tensor->quantParams.front()->inited) { + return RET_OK; + } + + if (weight_tensor->quantParams.empty() || !weight_tensor->quantParams.front()->inited) { + return RET_OK; + } + auto &input_quant_param = input_tensor->quantParams.at(0); + auto &weight_quant_param = weight_tensor->quantParams.at(0); + + if (bias_tensor->quantParams.empty()) { + auto tmp_quant_param = std::make_unique(); + bias_tensor->quantParams.emplace_back(std::move(tmp_quant_param)); + } + auto &bias_quant_param = bias_tensor->quantParams.front(); + bias_quant_param->min = 0.0; + bias_quant_param->max = 0.0; + bias_quant_param->dstDtype = kNumberTypeInt32; + bias_quant_param->inited = input_quant_param->inited && weight_quant_param->inited; + bias_quant_param->zeroPoint = 0; + if (bias_quant_param->inited) { + bias_quant_param->scale = input_quant_param->scale * weight_quant_param->scale; + } + bias_quant_param->roundType = 1; + bias_quant_param->multiplier = 1; + } for (auto &quantParam : bias_tensor->quantParams) { quantParam->dstDtype = TypeId::kNumberTypeInt32; } diff --git a/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc b/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc index a96d8dff16..401f6d40b5 100644 --- a/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc @@ -64,7 +64,8 @@ bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) { } auto matmul_cnode = cnode->input(index)->cast(); auto bias_node = cnode->input(kAddInputSize - index); - if (!utils::isa(bias_node) || !bias_node->cast()->default_param()) { + if (!utils::isa(bias_node) && + (!utils::isa(bias_node) || !bias_node->cast()->default_param())) { continue; } matmul_cnode->add_input(bias_node); diff --git a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc index e183b906ac..98eb45e385 100644 --- a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc @@ -26,40 +26,6 @@ namespace mindspore { namespace opt { namespace { constexpr size_t kDoubleNum = 2; -void FillDefaultInputQuantParamIfNeed(const PrimitivePtr &prim, const size_t &input_size) { - auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); - if (quant_tensor_info_ptr == nullptr) { - prim->AddAttr("quant_params", std::make_shared()); - } - auto quant_param_holder = prim->GetAttr("quant_params")->cast(); - std::vector quants; - schema::QuantParamT quant_param; - auto input_quant_params = quant_param_holder->input_quant_params(); - if (input_quant_params.size() == kDoubleNum) { - quants.clear(); - quant_param.min = 0.0; - quant_param.max = 0.0; - quant_param.dstDtype = kNumberTypeInt32; - quant_param.inited = input_quant_params.at(0).at(0).inited && input_quant_params.at(1).at(0).inited; - quant_param.inited = false; - quant_param.zeroPoint = 0; - if (quant_param.inited) { - quant_param.scale = input_quant_params.at(0).at(0).scale * input_quant_params.at(1).at(0).scale; - } - quant_param.roundType = 1; - quant_param.multiplier = 1; - quants.emplace_back(quant_param); - input_quant_params.emplace_back(quants); - } - // fill input_quant_param_ by not inited quant_parm - if (input_quant_params.size() < input_size) { - schema::QuantParamT tmpQuantParam; - quants.emplace_back(tmpQuantParam); - input_quant_params.insert(input_quant_params.end(), input_size - input_quant_params.size(), quants); - } - quant_param_holder->set_input_quant_params(input_quant_params); -} - int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) { auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); if (quant_tensor_info_ptr == nullptr) { @@ -212,7 +178,6 @@ int ConvertQuantParam(const PrimitivePtr &prim, const std::vector &i MS_LOG(ERROR) << "compute int quant param failed."; return status; } - FillDefaultInputQuantParamIfNeed(prim, inputs.size()); status = ConvertOutputQuantParam(prim, narrow_range_param, num_bits_param); if (status != lite::RET_OK) { MS_LOG(ERROR) << "compute output quant param failed.";