Browse Source

fix onnx models convert to quantized bug

tags/v1.1.0
xuanyue 5 years ago
parent
commit
c1dd0b8e3d
1 changed files with 10 additions and 5 deletions
  1. +10
    -5
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc

+ 10
- 5
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc View File

@@ -158,16 +158,16 @@ STATUS OnnxModelParser::ConvertNodes() {
status = RET_ERROR; status = RET_ERROR;
continue; continue;
} }
if (IsSpecialOnnxNode(onnx_node)) {
auto status_node = ConvertSpecialOnnxNode(onnx_node, primitive_c);
status = status == RET_OK ? status_node : status;
continue;
}
status = ConvertOpQuantParams(onnx_node, primitive_c); status = ConvertOpQuantParams(onnx_node, primitive_c);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed."; MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed.";
continue; continue;
} }
if (IsSpecialOnnxNode(onnx_node)) {
auto status_node = ConvertSpecialOnnxNode(onnx_node, primitive_c);
status = status == RET_OK ? status_node : status;
continue;
}
// build CNode // build CNode
status = BuildCNode(onnx_node, primitive_c); status = BuildCNode(onnx_node, primitive_c);
if (status != RET_OK) { if (status != RET_OK) {
@@ -512,8 +512,10 @@ STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite
return RET_ERROR; return RET_ERROR;
} else { } else {
op_inputs.push_back(nodes_[onnx_node.input(i)]); op_inputs.push_back(nodes_[onnx_node.input(i)]);
prim_ptr->AddInputQuantParam(primitive_c->input_quant_params().at(i));
} }
} }
prim_ptr->AddOutputQuantParam(std::vector<schema::QuantParamT>(1));
auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs); auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs);
new_cnode->set_fullname_with_scope("Gemm_MatMul_" + onnx_node.output(0)); new_cnode->set_fullname_with_scope("Gemm_MatMul_" + onnx_node.output(0));
new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
@@ -526,6 +528,9 @@ STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite
} }
op_inputs.push_back(nodes_["Gemm_MatMul_" + onnx_node.output(0)]); op_inputs.push_back(nodes_["Gemm_MatMul_" + onnx_node.output(0)]);
op_inputs.push_back(nodes_[onnx_node.input(2)]); op_inputs.push_back(nodes_[onnx_node.input(2)]);
prim_ptr->AddInputQuantParam(std::vector<schema::QuantParamT>(1));
prim_ptr->AddInputQuantParam(primitive_c->input_quant_params().at(2));
prim_ptr->AddOutputQuantParam(primitive_c->output_quant_params().front());
auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs); auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs);
new_cnode->set_fullname_with_scope("Gemm_BiasAdd_" + onnx_node.output(0)); new_cnode->set_fullname_with_scope("Gemm_BiasAdd_" + onnx_node.output(0));
new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));


Loading…
Cancel
Save