diff --git a/mindspore/ccsrc/onnx/onnx_exporter.cc b/mindspore/ccsrc/onnx/onnx_exporter.cc index f6f4ec2f1f..a0c8de75af 100644 --- a/mindspore/ccsrc/onnx/onnx_exporter.cc +++ b/mindspore/ccsrc/onnx/onnx_exporter.cc @@ -249,6 +249,13 @@ OPERATOR_ONNX_CONVERT_DEFINE( .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) +OPERATOR_ONNX_CONVERT_DEFINE(GatherV2, Gather, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(make_tuple, SequenceConstruct, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo()) + #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name void RegisterOpConverters(const std::function &fn) { @@ -269,6 +276,12 @@ void RegisterOpConverters(const std::function &fn) { fn(OP_CONVERT_FUNCTION_NAME(Squeeze)()); fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)()); fn(OP_CONVERT_FUNCTION_NAME(MatMul)()); + + fn(OP_CONVERT_FUNCTION_NAME(make_tuple)()); + fn(OP_CONVERT_FUNCTION_NAME(Concat)()); + fn(OP_CONVERT_FUNCTION_NAME(RealDiv)()); + fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)()); + fn(OP_CONVERT_FUNCTION_NAME(Sub)()); } class OpConvertRegistry { @@ -325,8 +338,8 @@ class OnnxExporter { void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); - void ExportPrimReduceMean(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, @@ -335,6 +348,12 @@ class OnnxExporter { onnx::GraphProto *graph_proto); void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); @@ -628,16 +647,19 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const node_proto->add_input(name_shape); } -void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, - onnx::GraphProto *const graph_proto) { +void OnnxExporter::ExportPrimReduce(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_axis = node->input(2); auto node_idx = AllocateNodeIndex(); (*node_map_ptr)[node] = node_idx; onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type(prim::kPrimReduceMean->name()); + auto name = prim::kPrimReduceMean->name(); + if (node->IsApply(prim::kPrimReduceSum)) { + name = prim::kPrimReduceSum->name(); + } + node_proto->set_op_type(name); node_proto->add_output(std::to_string(node_idx)); node_proto->add_input(input_data); @@ -646,13 +668,18 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, con attr_proto->set_name("axes"); attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); auto axis_value = dyn_cast(input_axis)->value(); - auto tuple_ptr = dyn_cast(axis_value); - MS_EXCEPTION_IF_NULL(tuple_ptr); - for (size_t i = 0; i < tuple_ptr->size(); ++i) { - attr_proto->add_ints(GetValue((*tuple_ptr)[i])); + auto int_ptr = dyn_cast(axis_value); + if (int_ptr == nullptr) { + auto tuple_ptr = dyn_cast(axis_value); + MS_EXCEPTION_IF_NULL(tuple_ptr); + for (size_t i = 0; i < tuple_ptr->size(); ++i) { + attr_proto->add_ints(GetValue((*tuple_ptr)[i])); + } + } else { + attr_proto->add_ints(int_ptr->value()); } } else { - MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for ReduceMean."; + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name; } } @@ -826,6 +853,83 @@ void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/ SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); } +void OnnxExporter::ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto multiples = node->input(2); + std::string name_multiples; + if (multiples->isa()) { + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[multiples] = const_node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + name_multiples = std::to_string(const_node_idx); + node_proto->add_output(name_multiples); + + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("repeat"); + + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(dyn_cast(multiples)->value(), attr_proto->mutable_t()); + } else { + name_multiples = GetNodeInputName(multiples, node_map_ptr, graph_proto); + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Tile."; + } + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Tile"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(name_x); + node_proto->add_input(name_multiples); +} + +void OnnxExporter::ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + std::string name_exponent; + auto const_node_idx = AllocateNodeIndex(); + onnx::NodeProto *node_proto_exp = graph_proto->add_node(); + name_exponent = std::to_string(const_node_idx); + node_proto_exp->add_output(name_exponent); + + node_proto_exp->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto_exp->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + tensor_proto->set_name("exponent"); + tensor_proto->add_dims(static_cast<::google::protobuf::int64>(1)); + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + tensor_proto->add_int64_data(2); + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Pow"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(name_x); + node_proto->add_input(name_exponent); +} + +void OnnxExporter::ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto name_indices = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); + auto axis = node->input(3)->cast()->value(); + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Gather"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(name_x); + node_proto->add_input(name_indices); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + attr_proto->set_i(static_cast<::google::protobuf::int64>(dyn_cast(axis)->value())); +} + void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert @@ -833,8 +937,8 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); } - if (node->IsApply(prim::kPrimReduceMean)) { - return ExportPrimReduceMean(func_graph, node, node_map_ptr, graph_proto); + if (node->IsApply(prim::kPrimReduceMean) || node->IsApply(prim::kPrimReduceSum)) { + return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto); } // MindSpore Cast(x, T) --> ONNX Cast[to=T](x) @@ -857,6 +961,21 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto); } + // MindSpore Tile(x) --> ONNX Tile(x, repeat) + if (node->IsApply(prim::kPrimTile)) { + return ExportPrimTile(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore Square(x) --> ONNX Pow(x, 2) + if (node->IsApply(prim::kPrimSquare)) { + return ExportPrimSquare(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore GatherV2(x, indices, axis) --> ONNX Pow(x, indices) + if (node->IsApply(prim::kPrimGatherV2)) { + return ExportPrimGatherV2(func_graph, node, node_map_ptr, graph_proto); + } + auto inputs = node->inputs(); if (inputs.size() < 1) { MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; @@ -1054,7 +1173,30 @@ void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *cons node_proto->set_op_type("Constant"); onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("value"); - MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; + if (value->isa()) { + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + auto casted_value = dyn_cast(value); + if (casted_value == nullptr) { + MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; + } + auto attr_value = casted_value->value(); + attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value)); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + } else if (value->isa()) { + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + auto data = dyn_cast(value); + tensor_proto->set_raw_data(data->data().request(true).ptr, static_cast(data->data().nbytes())); + auto dtype = data->data_type(); + auto shape = data->shape_c(); + + tensor_proto->set_data_type(GetOnnxDataType(dtype)); + for (const auto &dim : shape) { + tensor_proto->add_dims(dim); + } + } else { + MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; + } } std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { diff --git a/tests/ut/python/onnx/test_onnx.py b/tests/ut/python/onnx/test_onnx.py index ca877548de..e7d04fcfc7 100644 --- a/tests/ut/python/onnx/test_onnx.py +++ b/tests/ut/python/onnx/test_onnx.py @@ -142,6 +142,20 @@ class DepthwiseConv2dAndReLU6(nn.Cell): x = self.relu6(x) return x +class DeepFMOpNet(nn.Cell): + """Net definition with Gatherv2 and Tile and Square.""" + + def __init__(self): + super(DeepFMOpNet, self).__init__() + self.gather = P.GatherV2() + self.square = P.Square() + self.tile = P.Tile() + + def construct(self, x, y): + x = self.tile(x, (1000, 1)) + x = self.square(x) + x = self.gather(x, y, 0) + return x # generate mindspore Tensor by shape and numpy datatype def gen_tensor(shape, dtype=np.float32): @@ -153,6 +167,7 @@ net_cfgs = [ ('lenet', LeNet5(), gen_tensor([1, 1, 32, 32])), ('maxpoolwithargmax', DefinedNet(), gen_tensor([1, 3, 224, 224])), ('depthwiseconv_relu6', DepthwiseConv2dAndReLU6(3, kernel_size=3), gen_tensor([1, 3, 32, 32])), + ('deepfm_ops', DeepFMOpNet(), (gen_tensor([1, 1]), gen_tensor([1000, 1], dtype=np.int32))) ] @@ -164,7 +179,10 @@ def get_id(cfg): @pytest.mark.parametrize('name, net, inp', net_cfgs, ids=get_id(net_cfgs)) def test_onnx_export(name, net, inp): onnx_file = name + ".onnx" - export(net, inp, file_name=onnx_file, file_format='ONNX') + if isinstance(inp, (tuple, list)): + export(net, *inp, file_name=onnx_file, file_format='ONNX') + else: + export(net, inp, file_name=onnx_file, file_format='ONNX') # check existence of exported onnx file and delete it assert os.path.exists(onnx_file)