|
|
|
@@ -249,6 +249,13 @@ OPERATOR_ONNX_CONVERT_DEFINE( |
|
|
|
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) |
|
|
|
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) |
|
|
|
|
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(GatherV2, Gather, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(make_tuple, SequenceConstruct, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo()) |
|
|
|
|
|
|
|
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name |
|
|
|
|
|
|
|
void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) { |
|
|
|
@@ -269,6 +276,12 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) { |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(Squeeze)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(MatMul)()); |
|
|
|
|
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(make_tuple)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(Concat)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(RealDiv)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(Sub)()); |
|
|
|
} |
|
|
|
|
|
|
|
class OpConvertRegistry { |
|
|
|
@@ -325,8 +338,8 @@ class OnnxExporter { |
|
|
|
|
|
|
|
void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimReduceMean(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
@@ -335,6 +348,12 @@ class OnnxExporter { |
|
|
|
onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
|
|
|
|
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *graph_proto); |
|
|
|
@@ -628,16 +647,19 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const |
|
|
|
node_proto->add_input(name_shape); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *const graph_proto) { |
|
|
|
void OnnxExporter::ExportPrimReduce(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { |
|
|
|
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); |
|
|
|
auto input_axis = node->input(2); |
|
|
|
|
|
|
|
auto node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[node] = node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
node_proto->set_op_type(prim::kPrimReduceMean->name()); |
|
|
|
auto name = prim::kPrimReduceMean->name(); |
|
|
|
if (node->IsApply(prim::kPrimReduceSum)) { |
|
|
|
name = prim::kPrimReduceSum->name(); |
|
|
|
} |
|
|
|
node_proto->set_op_type(name); |
|
|
|
node_proto->add_output(std::to_string(node_idx)); |
|
|
|
node_proto->add_input(input_data); |
|
|
|
|
|
|
|
@@ -646,13 +668,18 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, con |
|
|
|
attr_proto->set_name("axes"); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); |
|
|
|
auto axis_value = dyn_cast<ValueNode>(input_axis)->value(); |
|
|
|
auto tuple_ptr = dyn_cast<ValueTuple>(axis_value); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_ptr); |
|
|
|
for (size_t i = 0; i < tuple_ptr->size(); ++i) { |
|
|
|
attr_proto->add_ints(GetValue<int>((*tuple_ptr)[i])); |
|
|
|
auto int_ptr = dyn_cast<Int32Imm>(axis_value); |
|
|
|
if (int_ptr == nullptr) { |
|
|
|
auto tuple_ptr = dyn_cast<ValueTuple>(axis_value); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_ptr); |
|
|
|
for (size_t i = 0; i < tuple_ptr->size(); ++i) { |
|
|
|
attr_proto->add_ints(GetValue<int>((*tuple_ptr)[i])); |
|
|
|
} |
|
|
|
} else { |
|
|
|
attr_proto->add_ints(int_ptr->value()); |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for ReduceMean."; |
|
|
|
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -826,6 +853,83 @@ void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/ |
|
|
|
SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { |
|
|
|
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); |
|
|
|
auto multiples = node->input(2); |
|
|
|
std::string name_multiples; |
|
|
|
if (multiples->isa<ValueNode>()) { |
|
|
|
auto const_node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[multiples] = const_node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
name_multiples = std::to_string(const_node_idx); |
|
|
|
node_proto->add_output(name_multiples); |
|
|
|
|
|
|
|
node_proto->set_op_type("Constant"); |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_name("repeat"); |
|
|
|
|
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); |
|
|
|
ConvertTupleToTensor(dyn_cast<ValueNode>(multiples)->value(), attr_proto->mutable_t()); |
|
|
|
} else { |
|
|
|
name_multiples = GetNodeInputName(multiples, node_map_ptr, graph_proto); |
|
|
|
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Tile."; |
|
|
|
} |
|
|
|
|
|
|
|
auto node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[node] = node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
node_proto->set_op_type("Tile"); |
|
|
|
node_proto->add_output(std::to_string(node_idx)); |
|
|
|
node_proto->add_input(name_x); |
|
|
|
node_proto->add_input(name_multiples); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { |
|
|
|
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); |
|
|
|
std::string name_exponent; |
|
|
|
auto const_node_idx = AllocateNodeIndex(); |
|
|
|
onnx::NodeProto *node_proto_exp = graph_proto->add_node(); |
|
|
|
name_exponent = std::to_string(const_node_idx); |
|
|
|
node_proto_exp->add_output(name_exponent); |
|
|
|
|
|
|
|
node_proto_exp->set_op_type("Constant"); |
|
|
|
onnx::AttributeProto *attr_proto = node_proto_exp->add_attribute(); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); |
|
|
|
tensor_proto->set_name("exponent"); |
|
|
|
tensor_proto->add_dims(static_cast<::google::protobuf::int64>(1)); |
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); |
|
|
|
tensor_proto->add_int64_data(2); |
|
|
|
|
|
|
|
auto node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[node] = node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
node_proto->set_op_type("Pow"); |
|
|
|
node_proto->add_output(std::to_string(node_idx)); |
|
|
|
node_proto->add_input(name_x); |
|
|
|
node_proto->add_input(name_exponent); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { |
|
|
|
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); |
|
|
|
auto name_indices = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); |
|
|
|
auto axis = node->input(3)->cast<ValueNodePtr>()->value(); |
|
|
|
|
|
|
|
auto node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[node] = node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
node_proto->set_op_type("Gather"); |
|
|
|
node_proto->add_output(std::to_string(node_idx)); |
|
|
|
node_proto->add_input(name_x); |
|
|
|
node_proto->add_input(name_indices); |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); |
|
|
|
attr_proto->set_i(static_cast<::google::protobuf::int64>(dyn_cast<Int32Imm>(axis)->value())); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { |
|
|
|
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert |
|
|
|
@@ -833,8 +937,8 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n |
|
|
|
return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
if (node->IsApply(prim::kPrimReduceMean)) { |
|
|
|
return ExportPrimReduceMean(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
if (node->IsApply(prim::kPrimReduceMean) || node->IsApply(prim::kPrimReduceSum)) { |
|
|
|
return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
// MindSpore Cast(x, T) --> ONNX Cast[to=T](x) |
|
|
|
@@ -857,6 +961,21 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n |
|
|
|
return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
// MindSpore Tile(x) --> ONNX Tile(x, repeat) |
|
|
|
if (node->IsApply(prim::kPrimTile)) { |
|
|
|
return ExportPrimTile(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
// MindSpore Square(x) --> ONNX Pow(x, 2) |
|
|
|
if (node->IsApply(prim::kPrimSquare)) { |
|
|
|
return ExportPrimSquare(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
// MindSpore GatherV2(x, indices, axis) --> ONNX Pow(x, indices) |
|
|
|
if (node->IsApply(prim::kPrimGatherV2)) { |
|
|
|
return ExportPrimGatherV2(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
auto inputs = node->inputs(); |
|
|
|
if (inputs.size() < 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; |
|
|
|
@@ -1054,7 +1173,30 @@ void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *cons |
|
|
|
node_proto->set_op_type("Constant"); |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_name("value"); |
|
|
|
MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; |
|
|
|
if (value->isa<Int32Imm>()) { |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); |
|
|
|
auto casted_value = dyn_cast<Int32Imm>(value); |
|
|
|
if (casted_value == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; |
|
|
|
} |
|
|
|
auto attr_value = casted_value->value(); |
|
|
|
attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value)); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); |
|
|
|
} else if (value->isa<tensor::Tensor>()) { |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); |
|
|
|
auto data = dyn_cast<tensor::Tensor>(value); |
|
|
|
tensor_proto->set_raw_data(data->data().request(true).ptr, static_cast<size_t>(data->data().nbytes())); |
|
|
|
auto dtype = data->data_type(); |
|
|
|
auto shape = data->shape_c(); |
|
|
|
|
|
|
|
tensor_proto->set_data_type(GetOnnxDataType(dtype)); |
|
|
|
for (const auto &dim : shape) { |
|
|
|
tensor_proto->add_dims(dim); |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { |
|
|
|
|