| @@ -29,6 +29,7 @@ | |||
| #include "tools/common/graph_util.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -263,7 +264,15 @@ bool FusionPass::MatchTree(schema::MetaGraphT *graph, size_t nodeIdx, const std: | |||
| return true; | |||
| } | |||
| for (auto preNodeIdx : preNodeIdxes) { | |||
| MS_ASSERT(subGraph->nodes.size() > preNodeIdx); | |||
| MS_ASSERT(graph->nodes.size() > preNodeIdx); | |||
| // Case of multiple outputs is not supported. | |||
| if (GetInputNodeIdx(*graph, preNodeIdx).size() > kDoubleNum || | |||
| GetOutputNodeIdx(*graph, preNodeIdx).size() > kSingleNum) { | |||
| sinkIdes.erase((sinkIdes.end() - 1)); | |||
| pathSinkIdes.erase((pathSinkIdes.end() - 1)); | |||
| target->UnSetPath(); | |||
| return false; | |||
| } | |||
| // match left | |||
| if (MatchTree(graph, preNodeIdx, target->left, sinkIdes, pathSinkIdes)) { | |||
| // match right | |||
| @@ -75,6 +75,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| attr->dilateW = 1; | |||
| attr->group = 1; | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| attr->format = schema::Format::Format_NCHW; | |||
| // set opdef each attr params | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| if (onnx_node_attr.name() == "group") { | |||
| @@ -161,7 +162,6 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| attr->channelOut = dims[0]; | |||
| attr->channelIn = dims[3] * attr->group; | |||
| } | |||
| attr->format = schema::Format::Format_NCHW; | |||
| attr->hasBias = onnx_node.input().size() == 3; | |||
| if (onnx_node.op_type() == "ConvRelu" || onnx_node.op_type() == "Int8ConvRelu") { | |||
| attr->activationType = schema::ActivationType_RELU; | |||
| @@ -244,6 +244,16 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| // set quantParams to Int8GivenTensor. | |||
| std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<schema::QuantParamT>(); | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| if (onnx_node_attr.name() == "Y_scale") { | |||
| quant_param->scale = onnx_node_attr.f(); | |||
| } else if (onnx_node_attr.name() == "Y_zero_point") { | |||
| quant_param->zeroPoint = static_cast<int32_t>(onnx_node_attr.i()); | |||
| } | |||
| } | |||
| tensor->quantParams.emplace_back(std::move(quant_param)); | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported data type " << tensor->dataType; | |||
| return RET_ERROR; | |||
| @@ -256,9 +266,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, | |||
| } | |||
| STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *dst_op, schema::TensorT *dst_tensor, | |||
| TensorCache *tensor_cache, const QuantType &quantType, | |||
| schema::MetaGraphT *dst_graph) { | |||
| schema::CNodeT *dst_op, TensorCache *tensor_cache, | |||
| const QuantType &quantType, schema::MetaGraphT *dst_graph) { | |||
| // change op_type() to name(), that is unique | |||
| static bool interrupt = false; | |||
| dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); | |||
| @@ -267,7 +276,6 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||
| MS_LOG(DEBUG) << "onnx op name " << onnx_node.op_type() << ", dst op name: " << dst_op->name << ", input size " | |||
| << onnx_node.input_size(); | |||
| // get the real op type | |||
| SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache); | |||
| if (onnx_node.op_type() == "Loop") { | |||
| NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); | |||
| interrupt = true; | |||
| @@ -305,6 +313,13 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||
| MS_LOG(ERROR) << "SetOpInputIndex failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (dst_op->primitive->value.type == schema::PrimitiveType_Conv2D) { | |||
| auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); | |||
| weight_tensor->format = dst_op->primitive->value.AsConv2D()->format; | |||
| } else if (dst_op->primitive->value.type == schema::PrimitiveType_DeConv2D) { | |||
| auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); | |||
| weight_tensor->format = dst_op->primitive->value.AsDeConv2D()->format; | |||
| } | |||
| // set op output index | |||
| std::vector<string> node_outputs; | |||
| (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); | |||
| @@ -314,6 +329,13 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||
| MS_LOG(ERROR) << "SetOpOutputIndex failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto &output_tensor = tensor_cache->GetCachedTensor().at(dst_op->outputIndex.front()); | |||
| if (output_tensor == nullptr) { | |||
| interrupt = true; | |||
| MS_LOG(ERROR) << "Output tensor of node " << onnx_node.op_type() << "is nullptr."; | |||
| return RET_ERROR; | |||
| } | |||
| SetOpQuantParams(onnx_graph, onnx_node, dst_op, output_tensor, tensor_cache); | |||
| return RET_OK; | |||
| } | |||
| @@ -572,9 +594,7 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT | |||
| } | |||
| std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>(); | |||
| std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>(); | |||
| status_node = | |||
| ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType, dst_graph); | |||
| status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), &tensor_cache, quantType, dst_graph); | |||
| if (status_node != RET_OK) { | |||
| status = (status == RET_OK ? status_node : status); | |||
| continue; | |||
| @@ -66,8 +66,8 @@ class OnnxModelParser : public ModelParser { | |||
| TensorCache *tensor_cache, int *index); | |||
| STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache, | |||
| const QuantType &quantType, schema::MetaGraphT *dst_graph); | |||
| schema::CNodeT *dst_op, TensorCache *tensor_cache, const QuantType &quantType, | |||
| schema::MetaGraphT *dst_graph); | |||
| void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache, | |||
| @@ -24,7 +24,7 @@ | |||
| #include "include/errorcode.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "ir/dtype/type_id.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnnxNodeParser { | |||
| @@ -14,14 +14,14 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/onnx/onnx_unuseful_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_quantize_parser.h" | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx UnusefulNodeParser"; | |||
| STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx QuantizeDequantizeParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| @@ -32,30 +32,27 @@ STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph, const o | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (onnx_node.op_type() == "Int8Quantize") { | |||
| std::unique_ptr<schema::OnnxInt8QuantizeT> attr = std::make_unique<schema::OnnxInt8QuantizeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; | |||
| op->primitive->value.value = attr.release(); | |||
| attr->srcT = kNumberTypeFloat32; | |||
| attr->dstT = kNumberTypeInt8; | |||
| } else if (onnx_node.op_type() == "Int8Dequantize") { | |||
| std::unique_ptr<schema::OnnxInt8DequantizeT> attr = std::make_unique<schema::OnnxInt8DequantizeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; | |||
| op->primitive->value.value = attr.release(); | |||
| attr->srcT = kNumberTypeInt8; | |||
| attr->dstT = kNumberTypeFloat32; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str(); | |||
| return RET_ERROR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| } | |||
| OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxUnusefulNodeParser()); | |||
| OnnxNodeRegistrar g_onnxInt8DequantizeParser("Int8Dequantize", new OnnxUnusefulNodeParser()); | |||
| OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxQuantizeParser()); | |||
| OnnxNodeRegistrar g_onnxInt8DequantizeParser("Int8Dequantize", new OnnxQuantizeParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -14,21 +14,21 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_QUANTIZE_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_QUANTIZE_PARSER_H | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnnxUnusefulNodeParser : public OnnxNodeParser { | |||
| class OnnxQuantizeParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxUnusefulNodeParser() : OnnxNodeParser("UnusefulNode") {} | |||
| ~OnnxUnusefulNodeParser() override = default; | |||
| OnnxQuantizeParser() : OnnxNodeParser("Quantize") {} | |||
| ~OnnxQuantizeParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_QUANTIZE_PARSER_H | |||
| @@ -79,7 +79,11 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node, | |||
| // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) | |||
| if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D || | |||
| op_type == schema::PrimitiveType_DeConv2D || op_type == schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| if (param_value->format() == schema::Format::Format_NHWC) { | |||
| param_value->set_format(schema::Format::Format_KHWC); | |||
| } else { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||
| << ", node: " << conv_node->fullname_with_scope(); | |||