|
|
|
@@ -26,40 +26,6 @@ namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
constexpr size_t kDoubleNum = 2; |
|
|
|
void FillDefaultInputQuantParamIfNeed(const PrimitivePtr &prim, const size_t &input_size) { |
|
|
|
auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); |
|
|
|
if (quant_tensor_info_ptr == nullptr) { |
|
|
|
prim->AddAttr("quant_params", std::make_shared<lite::QuantParamHolder>()); |
|
|
|
} |
|
|
|
auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>(); |
|
|
|
std::vector<schema::QuantParamT> quants; |
|
|
|
schema::QuantParamT quant_param; |
|
|
|
auto input_quant_params = quant_param_holder->input_quant_params(); |
|
|
|
if (input_quant_params.size() == kDoubleNum) { |
|
|
|
quants.clear(); |
|
|
|
quant_param.min = 0.0; |
|
|
|
quant_param.max = 0.0; |
|
|
|
quant_param.dstDtype = kNumberTypeInt32; |
|
|
|
quant_param.inited = input_quant_params.at(0).at(0).inited && input_quant_params.at(1).at(0).inited; |
|
|
|
quant_param.inited = false; |
|
|
|
quant_param.zeroPoint = 0; |
|
|
|
if (quant_param.inited) { |
|
|
|
quant_param.scale = input_quant_params.at(0).at(0).scale * input_quant_params.at(1).at(0).scale; |
|
|
|
} |
|
|
|
quant_param.roundType = 1; |
|
|
|
quant_param.multiplier = 1; |
|
|
|
quants.emplace_back(quant_param); |
|
|
|
input_quant_params.emplace_back(quants); |
|
|
|
} |
|
|
|
// fill input_quant_param_ by not inited quant_parm |
|
|
|
if (input_quant_params.size() < input_size) { |
|
|
|
schema::QuantParamT tmpQuantParam; |
|
|
|
quants.emplace_back(tmpQuantParam); |
|
|
|
input_quant_params.insert(input_quant_params.end(), input_size - input_quant_params.size(), quants); |
|
|
|
} |
|
|
|
quant_param_holder->set_input_quant_params(input_quant_params); |
|
|
|
} |
|
|
|
|
|
|
|
int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) { |
|
|
|
auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); |
|
|
|
if (quant_tensor_info_ptr == nullptr) { |
|
|
|
@@ -212,7 +178,6 @@ int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &i |
|
|
|
MS_LOG(ERROR) << "compute int quant param failed."; |
|
|
|
return status; |
|
|
|
} |
|
|
|
FillDefaultInputQuantParamIfNeed(prim, inputs.size()); |
|
|
|
status = ConvertOutputQuantParam(prim, narrow_range_param, num_bits_param); |
|
|
|
if (status != lite::RET_OK) { |
|
|
|
MS_LOG(ERROR) << "compute output quant param failed."; |
|
|
|
|