| @@ -158,16 +158,16 @@ STATUS OnnxModelParser::ConvertNodes() { | |||||
| status = RET_ERROR; | status = RET_ERROR; | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsSpecialOnnxNode(onnx_node)) { | |||||
| auto status_node = ConvertSpecialOnnxNode(onnx_node, primitive_c); | |||||
| status = status == RET_OK ? status_node : status; | |||||
| continue; | |||||
| } | |||||
| status = ConvertOpQuantParams(onnx_node, primitive_c); | status = ConvertOpQuantParams(onnx_node, primitive_c); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed."; | MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed."; | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsSpecialOnnxNode(onnx_node)) { | |||||
| auto status_node = ConvertSpecialOnnxNode(onnx_node, primitive_c); | |||||
| status = status == RET_OK ? status_node : status; | |||||
| continue; | |||||
| } | |||||
| // build CNode | // build CNode | ||||
| status = BuildCNode(onnx_node, primitive_c); | status = BuildCNode(onnx_node, primitive_c); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -512,8 +512,10 @@ STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } else { | } else { | ||||
| op_inputs.push_back(nodes_[onnx_node.input(i)]); | op_inputs.push_back(nodes_[onnx_node.input(i)]); | ||||
| prim_ptr->AddInputQuantParam(primitive_c->input_quant_params().at(i)); | |||||
| } | } | ||||
| } | } | ||||
| prim_ptr->AddOutputQuantParam(std::vector<schema::QuantParamT>(1)); | |||||
| auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs); | auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs); | ||||
| new_cnode->set_fullname_with_scope("Gemm_MatMul_" + onnx_node.output(0)); | new_cnode->set_fullname_with_scope("Gemm_MatMul_" + onnx_node.output(0)); | ||||
| new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | ||||
| @@ -526,6 +528,9 @@ STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite | |||||
| } | } | ||||
| op_inputs.push_back(nodes_["Gemm_MatMul_" + onnx_node.output(0)]); | op_inputs.push_back(nodes_["Gemm_MatMul_" + onnx_node.output(0)]); | ||||
| op_inputs.push_back(nodes_[onnx_node.input(2)]); | op_inputs.push_back(nodes_[onnx_node.input(2)]); | ||||
| prim_ptr->AddInputQuantParam(std::vector<schema::QuantParamT>(1)); | |||||
| prim_ptr->AddInputQuantParam(primitive_c->input_quant_params().at(2)); | |||||
| prim_ptr->AddOutputQuantParam(primitive_c->output_quant_params().front()); | |||||
| auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs); | auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs); | ||||
| new_cnode->set_fullname_with_scope("Gemm_BiasAdd_" + onnx_node.output(0)); | new_cnode->set_fullname_with_scope("Gemm_BiasAdd_" + onnx_node.output(0)); | ||||
| new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | ||||