Browse Source

!9169 reorder quant params

From: @cjh9368
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @zhanghaibo5
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
7c67bd24ed
2 changed files with 36 additions and 11 deletions
  1. +23
    -11
      mindspore/lite/tools/anf_exporter/anf_exporter.cc
  2. +13
    -0
      mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc

+ 23
- 11
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -123,33 +123,45 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty";
}
// output
auto output_index = dst_node->outputIndex[0];
auto tensor_output = meta_graph->allTensors[output_index].get();

auto output_quant_params = primitive->output_quant_params();
if (output_quant_params.empty()) {
if (node_type != schema::PrimitiveType_QuantDTypeCast) {
MS_LOG(DEBUG) << "node: " << dst_node->name << " output quant params is empty";
}
} else {
for (auto output_quant_param : output_quant_params[0]) {
if (tensor_output->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) {
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
std::make_unique<schema::QuantParamT>(output_quant_param);
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
<< " zp: " << output_quant_param_ptr->zeroPoint;
tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr));
if (dst_node->outputIndex.size() != output_quant_params.size()) {
MS_LOG(INFO) << "node: " << dst_node->name << " output has " << output_quant_params.size()
<< " quant_params; but only " << dst_node->outputIndex.size() << " output";
return RET_ERROR;
}
int output_idx = 0;
for (const auto &output_quant_param : output_quant_params) {
auto output_tensor = meta_graph->allTensors[dst_node->outputIndex[output_idx]].get();
output_idx++;
for (const auto &channel_quant_param : output_quant_param) {
if (output_tensor->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) {
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
std::make_unique<schema::QuantParamT>(channel_quant_param);
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
<< " zp: " << output_quant_param_ptr->zeroPoint;
output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr));
}
}
}
}

auto first_output_index = dst_node->outputIndex[0];
auto first_tensor_output = meta_graph->allTensors[first_output_index].get();
if (dst_node->quantType == schema::QuantType_PostTraining) {
if (node_type != schema::PrimitiveType_QuantDTypeCast) {
tensor_output->dataType = kNumberTypeInt8;
first_tensor_output->dataType = kNumberTypeInt8;
} else {
MS_ASSERT(utils::isa<std::shared_ptr<QuantDTypeCast>>(primitive));
auto primc = utils::cast<std::shared_ptr<QuantDTypeCast>>(primitive);
MS_ASSERT(primc != nullptr);
if (primc->GetDstT() != kNumberTypeFloat32) {
tensor_output->dataType = kNumberTypeInt8;
first_tensor_output->dataType = kNumberTypeInt8;
}
}
}


+ 13
- 0
mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc View File

@@ -38,11 +38,24 @@ bool TfliteInputsOrderExchangePass::Run(const FuncGraphPtr &graph) {
auto inputs = cnode->inputs();
inputs.pop_back();
cnode->set_inputs(inputs);

auto input_quant_params = primitive_c->input_quant_params();
input_quant_params[0] = input_quant_params.at(2);
input_quant_params.pop_back();
primitive_c->set_input_quant_params(input_quant_params);
continue;
}

if (opt::GetCNodeType(node) == schema::PrimitiveType_Split && cnode->inputs().size() == split_inputs_size) {
cnode->set_input(1, cnode->input(2));
auto inputs = cnode->inputs();
inputs.pop_back();
cnode->set_inputs(inputs);

auto input_quant_params = primitive_c->input_quant_params();
input_quant_params[0] = input_quant_params.at(1);
input_quant_params.pop_back();
primitive_c->set_input_quant_params(input_quant_params);
continue;
}



Loading…
Cancel
Save