|
|
|
@@ -158,16 +158,16 @@ STATUS OnnxModelParser::ConvertNodes() { |
|
|
|
status = RET_ERROR; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (IsSpecialOnnxNode(onnx_node)) { |
|
|
|
auto status_node = ConvertSpecialOnnxNode(onnx_node, primitive_c); |
|
|
|
status = status == RET_OK ? status_node : status; |
|
|
|
continue; |
|
|
|
} |
|
|
|
status = ConvertOpQuantParams(onnx_node, primitive_c); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed."; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (IsSpecialOnnxNode(onnx_node)) { |
|
|
|
auto status_node = ConvertSpecialOnnxNode(onnx_node, primitive_c); |
|
|
|
status = status == RET_OK ? status_node : status; |
|
|
|
continue; |
|
|
|
} |
|
|
|
// build CNode |
|
|
|
status = BuildCNode(onnx_node, primitive_c); |
|
|
|
if (status != RET_OK) { |
|
|
|
@@ -512,8 +512,10 @@ STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite |
|
|
|
return RET_ERROR; |
|
|
|
} else { |
|
|
|
op_inputs.push_back(nodes_[onnx_node.input(i)]); |
|
|
|
prim_ptr->AddInputQuantParam(primitive_c->input_quant_params().at(i)); |
|
|
|
} |
|
|
|
} |
|
|
|
prim_ptr->AddOutputQuantParam(std::vector<schema::QuantParamT>(1)); |
|
|
|
auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs); |
|
|
|
new_cnode->set_fullname_with_scope("Gemm_MatMul_" + onnx_node.output(0)); |
|
|
|
new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); |
|
|
|
@@ -526,6 +528,9 @@ STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite |
|
|
|
} |
|
|
|
op_inputs.push_back(nodes_["Gemm_MatMul_" + onnx_node.output(0)]); |
|
|
|
op_inputs.push_back(nodes_[onnx_node.input(2)]); |
|
|
|
prim_ptr->AddInputQuantParam(std::vector<schema::QuantParamT>(1)); |
|
|
|
prim_ptr->AddInputQuantParam(primitive_c->input_quant_params().at(2)); |
|
|
|
prim_ptr->AddOutputQuantParam(primitive_c->output_quant_params().front()); |
|
|
|
auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs); |
|
|
|
new_cnode->set_fullname_with_scope("Gemm_BiasAdd_" + onnx_node.output(0)); |
|
|
|
new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); |
|
|
|
|