| @@ -29,6 +29,7 @@ | |||||
| #include "tools/common/graph_util.h" | #include "tools/common/graph_util.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -263,7 +264,15 @@ bool FusionPass::MatchTree(schema::MetaGraphT *graph, size_t nodeIdx, const std: | |||||
| return true; | return true; | ||||
| } | } | ||||
| for (auto preNodeIdx : preNodeIdxes) { | for (auto preNodeIdx : preNodeIdxes) { | ||||
| MS_ASSERT(subGraph->nodes.size() > preNodeIdx); | |||||
| MS_ASSERT(graph->nodes.size() > preNodeIdx); | |||||
| // Case of multiple outputs is not supported. | |||||
| if (GetInputNodeIdx(*graph, preNodeIdx).size() > kDoubleNum || | |||||
| GetOutputNodeIdx(*graph, preNodeIdx).size() > kSingleNum) { | |||||
| sinkIdes.erase((sinkIdes.end() - 1)); | |||||
| pathSinkIdes.erase((pathSinkIdes.end() - 1)); | |||||
| target->UnSetPath(); | |||||
| return false; | |||||
| } | |||||
| // match left | // match left | ||||
| if (MatchTree(graph, preNodeIdx, target->left, sinkIdes, pathSinkIdes)) { | if (MatchTree(graph, preNodeIdx, target->left, sinkIdes, pathSinkIdes)) { | ||||
| // match right | // match right | ||||
| @@ -75,6 +75,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| attr->dilateW = 1; | attr->dilateW = 1; | ||||
| attr->group = 1; | attr->group = 1; | ||||
| attr->padMode = schema::PadMode_NOTSET; | attr->padMode = schema::PadMode_NOTSET; | ||||
| attr->format = schema::Format::Format_NCHW; | |||||
| // set opdef each attr params | // set opdef each attr params | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| if (onnx_node_attr.name() == "group") { | if (onnx_node_attr.name() == "group") { | ||||
| @@ -161,7 +162,6 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| attr->channelOut = dims[0]; | attr->channelOut = dims[0]; | ||||
| attr->channelIn = dims[3] * attr->group; | attr->channelIn = dims[3] * attr->group; | ||||
| } | } | ||||
| attr->format = schema::Format::Format_NCHW; | |||||
| attr->hasBias = onnx_node.input().size() == 3; | attr->hasBias = onnx_node.input().size() == 3; | ||||
| if (onnx_node.op_type() == "ConvRelu" || onnx_node.op_type() == "Int8ConvRelu") { | if (onnx_node.op_type() == "ConvRelu" || onnx_node.op_type() == "Int8ConvRelu") { | ||||
| attr->activationType = schema::ActivationType_RELU; | attr->activationType = schema::ActivationType_RELU; | ||||
| @@ -244,6 +244,16 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, | |||||
| MS_LOG(ERROR) << "memcpy_s failed"; | MS_LOG(ERROR) << "memcpy_s failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // set quantParams to Int8GivenTensor. | |||||
| std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<schema::QuantParamT>(); | |||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||||
| if (onnx_node_attr.name() == "Y_scale") { | |||||
| quant_param->scale = onnx_node_attr.f(); | |||||
| } else if (onnx_node_attr.name() == "Y_zero_point") { | |||||
| quant_param->zeroPoint = static_cast<int32_t>(onnx_node_attr.i()); | |||||
| } | |||||
| } | |||||
| tensor->quantParams.emplace_back(std::move(quant_param)); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "unsupported data type " << tensor->dataType; | MS_LOG(ERROR) << "unsupported data type " << tensor->dataType; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -256,9 +266,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, | |||||
| } | } | ||||
| STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | ||||
| schema::CNodeT *dst_op, schema::TensorT *dst_tensor, | |||||
| TensorCache *tensor_cache, const QuantType &quantType, | |||||
| schema::MetaGraphT *dst_graph) { | |||||
| schema::CNodeT *dst_op, TensorCache *tensor_cache, | |||||
| const QuantType &quantType, schema::MetaGraphT *dst_graph) { | |||||
| // change op_type() to name(), that is unique | // change op_type() to name(), that is unique | ||||
| static bool interrupt = false; | static bool interrupt = false; | ||||
| dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); | dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); | ||||
| @@ -267,7 +276,6 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||||
| MS_LOG(DEBUG) << "onnx op name " << onnx_node.op_type() << ", dst op name: " << dst_op->name << ", input size " | MS_LOG(DEBUG) << "onnx op name " << onnx_node.op_type() << ", dst op name: " << dst_op->name << ", input size " | ||||
| << onnx_node.input_size(); | << onnx_node.input_size(); | ||||
| // get the real op type | // get the real op type | ||||
| SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache); | |||||
| if (onnx_node.op_type() == "Loop") { | if (onnx_node.op_type() == "Loop") { | ||||
| NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); | NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); | ||||
| interrupt = true; | interrupt = true; | ||||
| @@ -305,6 +313,13 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||||
| MS_LOG(ERROR) << "SetOpInputIndex failed"; | MS_LOG(ERROR) << "SetOpInputIndex failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (dst_op->primitive->value.type == schema::PrimitiveType_Conv2D) { | |||||
| auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); | |||||
| weight_tensor->format = dst_op->primitive->value.AsConv2D()->format; | |||||
| } else if (dst_op->primitive->value.type == schema::PrimitiveType_DeConv2D) { | |||||
| auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); | |||||
| weight_tensor->format = dst_op->primitive->value.AsDeConv2D()->format; | |||||
| } | |||||
| // set op output index | // set op output index | ||||
| std::vector<string> node_outputs; | std::vector<string> node_outputs; | ||||
| (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); | (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); | ||||
| @@ -314,6 +329,13 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||||
| MS_LOG(ERROR) << "SetOpOutputIndex failed"; | MS_LOG(ERROR) << "SetOpOutputIndex failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto &output_tensor = tensor_cache->GetCachedTensor().at(dst_op->outputIndex.front()); | |||||
| if (output_tensor == nullptr) { | |||||
| interrupt = true; | |||||
| MS_LOG(ERROR) << "Output tensor of node " << onnx_node.op_type() << "is nullptr."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| SetOpQuantParams(onnx_graph, onnx_node, dst_op, output_tensor, tensor_cache); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -572,9 +594,7 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT | |||||
| } | } | ||||
| std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>(); | std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>(); | ||||
| std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>(); | |||||
| status_node = | |||||
| ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType, dst_graph); | |||||
| status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), &tensor_cache, quantType, dst_graph); | |||||
| if (status_node != RET_OK) { | if (status_node != RET_OK) { | ||||
| status = (status == RET_OK ? status_node : status); | status = (status == RET_OK ? status_node : status); | ||||
| continue; | continue; | ||||
| @@ -66,8 +66,8 @@ class OnnxModelParser : public ModelParser { | |||||
| TensorCache *tensor_cache, int *index); | TensorCache *tensor_cache, int *index); | ||||
| STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | ||||
| schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache, | |||||
| const QuantType &quantType, schema::MetaGraphT *dst_graph); | |||||
| schema::CNodeT *dst_op, TensorCache *tensor_cache, const QuantType &quantType, | |||||
| schema::MetaGraphT *dst_graph); | |||||
| void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | ||||
| schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache, | schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache, | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "ir/dtype/type_id.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class OnnxNodeParser { | class OnnxNodeParser { | ||||
| @@ -14,14 +14,14 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "tools/converter/parser/onnx/onnx_unuseful_node_parser.h" | |||||
| #include "tools/converter/parser/onnx/onnx_quantize_parser.h" | |||||
| #include <memory> | #include <memory> | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "onnx UnusefulNodeParser"; | |||||
| STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "onnx QuantizeDequantizeParser"; | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -32,30 +32,27 @@ STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph, const o | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed."; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (onnx_node.op_type() == "Int8Quantize") { | if (onnx_node.op_type() == "Int8Quantize") { | ||||
| std::unique_ptr<schema::OnnxInt8QuantizeT> attr = std::make_unique<schema::OnnxInt8QuantizeT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; | |||||
| op->primitive->value.value = attr.release(); | |||||
| attr->srcT = kNumberTypeFloat32; | |||||
| attr->dstT = kNumberTypeInt8; | |||||
| } else if (onnx_node.op_type() == "Int8Dequantize") { | } else if (onnx_node.op_type() == "Int8Dequantize") { | ||||
| std::unique_ptr<schema::OnnxInt8DequantizeT> attr = std::make_unique<schema::OnnxInt8DequantizeT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; | |||||
| op->primitive->value.value = attr.release(); | |||||
| attr->srcT = kNumberTypeInt8; | |||||
| attr->dstT = kNumberTypeFloat32; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str(); | MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxUnusefulNodeParser()); | |||||
| OnnxNodeRegistrar g_onnxInt8DequantizeParser("Int8Dequantize", new OnnxUnusefulNodeParser()); | |||||
| OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxQuantizeParser()); | |||||
| OnnxNodeRegistrar g_onnxInt8DequantizeParser("Int8Dequantize", new OnnxQuantizeParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,21 +14,21 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_QUANTIZE_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_QUANTIZE_PARSER_H | |||||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | #include "tools/converter/parser/onnx/onnx_node_parser.h" | ||||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class OnnxUnusefulNodeParser : public OnnxNodeParser { | |||||
| class OnnxQuantizeParser : public OnnxNodeParser { | |||||
| public: | public: | ||||
| OnnxUnusefulNodeParser() : OnnxNodeParser("UnusefulNode") {} | |||||
| ~OnnxUnusefulNodeParser() override = default; | |||||
| OnnxQuantizeParser() : OnnxNodeParser("Quantize") {} | |||||
| ~OnnxQuantizeParser() override = default; | |||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | ||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_QUANTIZE_PARSER_H | |||||
| @@ -79,7 +79,11 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node, | |||||
| // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) | // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) | ||||
| if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D || | if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D || | ||||
| op_type == schema::PrimitiveType_DeConv2D || op_type == schema::PrimitiveType_DeDepthwiseConv2D) { | op_type == schema::PrimitiveType_DeConv2D || op_type == schema::PrimitiveType_DeDepthwiseConv2D) { | ||||
| param_value->set_format(schema::Format::Format_KCHW); | |||||
| if (param_value->format() == schema::Format::Format_NHWC) { | |||||
| param_value->set_format(schema::Format::Format_KHWC); | |||||
| } else { | |||||
| param_value->set_format(schema::Format::Format_KCHW); | |||||
| } | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | ||||
| << ", node: " << conv_node->fullname_with_scope(); | << ", node: " << conv_node->fullname_with_scope(); | ||||