|
|
|
@@ -29,11 +29,12 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
enum OpMergeMode { |
|
|
|
OP_MERGE_UNDEFINED = 0, // undefined behavior |
|
|
|
OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list |
|
|
|
OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv` |
|
|
|
OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm` |
|
|
|
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization` |
|
|
|
OP_MERGE_UNDEFINED = 0, // undefined behavior |
|
|
|
OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list |
|
|
|
OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv` |
|
|
|
OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm` |
|
|
|
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization` |
|
|
|
OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool` |
|
|
|
}; |
|
|
|
|
|
|
|
struct OpMergedInfo { |
|
|
|
@@ -233,6 +234,13 @@ OPERATOR_ONNX_CONVERT_DEFINE( |
|
|
|
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) |
|
|
|
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) |
|
|
|
|
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE( |
|
|
|
MaxPoolWithArgmax, MaxPool, |
|
|
|
OpNameInfo() |
|
|
|
.Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) |
|
|
|
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) |
|
|
|
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) |
|
|
|
|
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE( |
|
|
|
AvgPool, AveragePool, |
|
|
|
OpNameInfo() |
|
|
|
@@ -254,6 +262,7 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) { |
|
|
|
|
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(Flatten)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(MaxPool)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(AvgPool)()); |
|
|
|
|
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(Squeeze)()); |
|
|
|
@@ -328,6 +337,8 @@ class OnnxExporter { |
|
|
|
onnx::GraphProto *graph_proto); |
|
|
|
void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
void ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
|
|
|
|
void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *graph_proto); |
|
|
|
@@ -516,6 +527,12 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vecto |
|
|
|
op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM; |
|
|
|
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; |
|
|
|
op_merged_infos[cnode->input(1)].referred_count -= 1; |
|
|
|
} else if (cnode->IsApply(prim::kPrimTupleGetItem) && |
|
|
|
IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("MaxPoolWithArgmax")) && |
|
|
|
GetInt32Value(cnode->input(2)) == 0) { |
|
|
|
op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX; |
|
|
|
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; |
|
|
|
op_merged_infos[cnode->input(1)].referred_count -= 1; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -563,6 +580,9 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodeP |
|
|
|
case OP_MERGE_BATCH_NORM: |
|
|
|
ExportMergeBatchNorm(func_graph, cnode, node_map_ptr, graph_proto); |
|
|
|
break; |
|
|
|
case OP_MERGE_MAXPOOL_WITH_ARGMAX: |
|
|
|
ExportMergeMaxPoolWithArgmax(func_graph, cnode, node_map_ptr, graph_proto); |
|
|
|
break; |
|
|
|
default: |
|
|
|
ExportCNode(func_graph, cnode, node_map_ptr, graph_proto); |
|
|
|
break; |
|
|
|
@@ -811,6 +831,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CN |
|
|
|
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *const graph_proto) { |
|
|
|
auto maxpool_with_argmax_node = dyn_cast<CNode>(node->input(1)); |
|
|
|
|
|
|
|
PrimitivePtr prim_maxpool_with_argmax = |
|
|
|
dyn_cast<Primitive>((dyn_cast<ValueNode>(maxpool_with_argmax_node->input(0)))->value()); |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
for (size_t i = 1; i < maxpool_with_argmax_node->inputs().size(); i++) { |
|
|
|
inputs.push_back(maxpool_with_argmax_node->input(i)); |
|
|
|
} |
|
|
|
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_maxpool_with_argmax, inputs, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { |
|
|
|
if (node->inputs().size() != 2) { |
|
|
|
|