From 3d327b69ec26f6f83beee0f582489c4032a6b4a2 Mon Sep 17 00:00:00 2001 From: liuyang_655 Date: Sat, 8 May 2021 16:11:55 +0800 Subject: [PATCH] export yolo onnx --- .../transform/express_ir/onnx_exporter.cc | 281 +++++++++++++++++- mindspore/core/base/core_ops.h | 1 + 2 files changed, 272 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc b/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc index 3aa53a78f2..1b68813cc8 100644 --- a/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc @@ -33,7 +33,7 @@ enum OpMergeMode { OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv` OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm` - OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX Batch Normalization` + OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization` OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool` }; @@ -261,10 +261,16 @@ OPERATOR_ONNX_CONVERT_DEFINE( OPERATOR_ONNX_CONVERT_DEFINE(Gather, Gather, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(MakeTuple, SequenceConstruct, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Maximum, Max, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Transpose, Transpose, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(StridedSlice, Slice, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Exp, Exp, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(ResizeNearestNeighbor, Resize, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Softplus, Softplus, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Tanh, Tanh, OpNameInfo()) #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name @@ -288,10 +294,14 @@ void RegisterOpConverters(const std::function &fn) { fn(OP_CONVERT_FUNCTION_NAME(MatMul)()); fn(OP_CONVERT_FUNCTION_NAME(MakeTuple)()); - fn(OP_CONVERT_FUNCTION_NAME(Concat)()); fn(OP_CONVERT_FUNCTION_NAME(RealDiv)()); fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)()); fn(OP_CONVERT_FUNCTION_NAME(Sub)()); + fn(OP_CONVERT_FUNCTION_NAME(Maximum)()); + fn(OP_CONVERT_FUNCTION_NAME(Exp)()); + fn(OP_CONVERT_FUNCTION_NAME(ResizeNearestNeighbor)()); + fn(OP_CONVERT_FUNCTION_NAME(Softplus)()); + fn(OP_CONVERT_FUNCTION_NAME(Tanh)()); } class OpConvertRegistry { @@ -350,6 +360,15 @@ class OnnxExporter { std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); + + void ExportPrimTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimResizeNearestNeighbor(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimConcat(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, @@ -380,6 +399,7 @@ class OnnxExporter { onnx::GraphProto *const graph_proto); void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto); + void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto); size_t AllocateNodeIndex() { return ++onnx_node_index_; } @@ -446,14 +466,12 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphP MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; } - // set onnx input. if (!param_ptr->has_default()) { onnx::ValueInfoProto *input_proto = graph_proto->add_input(); input_proto->set_name(param_ptr->ToString()); SetValueInfoType(param_ptr, input_proto); continue; } - // parameter with default value is an ONNX initializer onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); initializer_proto->set_name(param_ptr->ToString()); @@ -597,6 +615,7 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::mapcast(); + auto iter = op_merged_infos.find(cnode); // the node is not referenced by any other nodes, skip it if (iter == op_merged_infos.end()) { @@ -699,6 +718,237 @@ void OnnxExporter::ExportPrimReduce(const FuncGraphPtr & /*func_graph*/, const C } } +void OnnxExporter::ExportPrimTranspose(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto input_perm = node->input(2); + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + auto name = prim::kPrimTranspose->name(); + node_proto->set_op_type(name); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_data); + + if (input_perm->isa()) { + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("perm"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); + auto perm_value = dyn_cast(input_perm)->value(); + auto int_ptr = dyn_cast(perm_value); + if (int_ptr == nullptr) { + auto tuple_ptr = dyn_cast(perm_value); + MS_EXCEPTION_IF_NULL(tuple_ptr); + for (size_t i = 0; i < tuple_ptr->size(); ++i) { + attr_proto->add_ints(GetValue((*tuple_ptr)[i])); + } + } else { + attr_proto->add_ints(int_ptr->value()); + } + } else { + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name; + } +} + +void OnnxExporter::ExportPrimStridedSlice(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto begin = node->input(2); + std::string name_begin; + if (begin->isa()) { + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[begin] = const_node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + name_begin = std::to_string(const_node_idx); + node_proto->add_output(name_begin); + + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("starts"); + + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(dyn_cast(begin)->value(), attr_proto->mutable_t()); + } else { + name_begin = GetNodeInputName(begin, node_map_ptr, graph_proto); + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for StridedSlice."; + } + + auto end = node->input(3); + std::string name_end; + if (end->isa()) { + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[end] = const_node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + name_end = std::to_string(const_node_idx); + node_proto->add_output(name_end); + + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("ends"); + + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(dyn_cast(end)->value(), attr_proto->mutable_t()); + } else { + name_begin = GetNodeInputName(end, node_map_ptr, graph_proto); + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for StridedSlice."; + } + + auto x_shape = dyn_cast(node->input(1)->Shape()); + int size = x_shape->shape().size(); + std::vector axes_value; + ValuePtr axes_value_ptr = nullptr; + for (int i = 0; i < size; ++i) { + axes_value.push_back(i); + } + axes_value_ptr = MakeValue>(axes_value); + auto axes = NewValueNode(axes_value_ptr)->cast(); + std::string name_axes; + if (axes->isa()) { + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[axes] = const_node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + name_axes = std::to_string(const_node_idx); + node_proto->add_output(name_axes); + + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("axes"); + + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(dyn_cast(axes)->value(), attr_proto->mutable_t()); + } else { + name_begin = GetNodeInputName(axes, node_map_ptr, graph_proto); + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for StridedSlice."; + } + + auto strides = node->input(4); + std::string name_strides; + if (strides->isa()) { + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[strides] = const_node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + name_strides = std::to_string(const_node_idx); + node_proto->add_output(name_strides); + + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("steps"); + + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(dyn_cast(strides)->value(), attr_proto->mutable_t()); + } else { + name_begin = GetNodeInputName(strides, node_map_ptr, graph_proto); + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for StridedSlice."; + } + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Slice"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_data); + node_proto->add_input(name_begin); + node_proto->add_input(name_end); + node_proto->add_input(name_axes); + node_proto->add_input(name_strides); +} + +void OnnxExporter::ExportPrimResizeNearestNeighbor(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto x_shape = dyn_cast(node->input(1)->Shape()); + + AnfNodePtr op = node->input(0); + auto op_value = dyn_cast(op); + auto prim = dyn_cast(op_value->value()); + std::vector resize_size; + + auto tuple_ptr = dyn_cast(prim->GetAttr("size")); + + for (size_t i = 0; i < x_shape->shape().size() - 2; i++) { + resize_size.push_back(x_shape->shape()[i]); + } + for (size_t i = 0; i < tuple_ptr->size(); i++) { + ValuePtr elem = (*tuple_ptr)[i]; + resize_size.push_back(dyn_cast(elem)->value()); + } + auto resize_size_ptr = MakeValue>(resize_size); + auto size = NewValueNode(resize_size_ptr)->cast(); + std::string name_size; + + if (size->isa()) { + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[size] = const_node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + name_size = std::to_string(const_node_idx); + node_proto->add_output(name_size); + + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("sizes"); + + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(resize_size_ptr, attr_proto->mutable_t()); + } else { + name_size = GetNodeInputName(size, node_map_ptr, graph_proto); + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for ResizeNearestNeighbor."; + } + + auto node_idx = AllocateNodeIndex(); + + onnx::TensorProto *roi_initializer_proto = graph_proto->add_initializer(); + auto roi_name = std::to_string(node_idx) + "roi_initializer"; + roi_initializer_proto->set_name(roi_name); + roi_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32)); + roi_initializer_proto->add_dims(0); + + onnx::TensorProto *scales_initializer_proto = graph_proto->add_initializer(); + auto scales_name = std::to_string(node_idx) + "scales_initializer"; + scales_initializer_proto->set_name(scales_name); + scales_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32)); + scales_initializer_proto->add_dims(0); + + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + + node_proto->set_op_type("Resize"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_data); + node_proto->add_input(roi_name); + node_proto->add_input(scales_name); + node_proto->add_input(name_size); +} + +void OnnxExporter::ExportPrimConcat(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + + AnfNodePtr op = node->input(0); + auto op_value = dyn_cast(op); + auto prim = dyn_cast(op_value->value()); + auto input_node = node->input(1)->cast(); + + if (input_node->IsApply(prim::kPrimMakeTuple)) { + node_proto->set_op_type("ConcatFromSequence"); + } else { + node_proto->set_op_type("Concat"); + } + + // set attr axis + onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("axis"); + SetAttrValueToProto(prim->GetAttr("axis"), onnx::AttributeProto_AttributeType_INT, onnx_attr_proto, prim); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_data); +} + void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); @@ -746,7 +996,6 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CN attr_proto->set_name("axes"); attr_proto->add_ints(1); attr_proto->add_ints(2); - node_proto->add_input(input_slope); input_slope = std::to_string(node_idx); } @@ -958,6 +1207,19 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto); } + if (node->IsApply(prim::kPrimTranspose)) { + return ExportPrimTranspose(func_graph, node, node_map_ptr, graph_proto); + } + if (node->IsApply(prim::kPrimStridedSlice)) { + return ExportPrimStridedSlice(func_graph, node, node_map_ptr, graph_proto); + } + if (node->IsApply(prim::kPrimResizeNearestNeighbor)) { + return ExportPrimResizeNearestNeighbor(func_graph, node, node_map_ptr, graph_proto); + } + if (node->IsApply(prim::kPrimConcat)) { + return ExportPrimConcat(func_graph, node, node_map_ptr, graph_proto); + } + // MindSpore Cast(x, T) --> ONNX Cast[to=T](x) if (node->IsApply(prim::kPrimCast)) { return ExportPrimCast(func_graph, node, node_map_ptr, graph_proto); @@ -997,7 +1259,6 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n if (inputs.size() < 1) { MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; } - AnfNodePtr op = inputs[0]; std::vector op_inputs; // first process node input 1,2,..., since when node input is a ValueNode, here need to create a Constant Operator @@ -1007,6 +1268,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n } } auto op_value = dyn_cast(op); + if (op_value == nullptr) { MS_LOG(EXCEPTION) << "Need to support node op type " << op->type_name(); } @@ -1015,9 +1277,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n MS_LOG(EXCEPTION) << "Need to support node op type " << op_value->value()->type_name(); } - if (!IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) { - (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); - } + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); } size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map *node_map_ptr, @@ -1124,6 +1384,7 @@ void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNode std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &orig_node, std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto node = GetRealInput(orig_node); + if (node->isa()) { auto iter = node_map_ptr->find(node); if (iter == node_map_ptr->end()) { diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 3f796b03ad..1bfdbec4af 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -190,6 +190,7 @@ inline const PrimitivePtr kPrimRank = std::make_shared("Rank"); inline const PrimitivePtr kPrimResizeBilinear = std::make_shared("ResizeBilinear"); inline const PrimitivePtr kPrimResizeGrad = std::make_shared("ResizeGrad"); inline const PrimitivePtr kPrimSort = std::make_shared("Sort"); +inline const PrimitivePtr kPrimResizeNearestNeighbor = std::make_shared("ResizeNearestNeighbor"); // NN inline const PrimitivePtr kPrimAdam = std::make_shared("Adam");