|
|
|
@@ -33,7 +33,7 @@ enum OpMergeMode { |
|
|
|
OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list |
|
|
|
OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv` |
|
|
|
OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm` |
|
|
|
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX Batch Normalization` |
|
|
|
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization` |
|
|
|
OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool` |
|
|
|
}; |
|
|
|
|
|
|
|
@@ -261,10 +261,16 @@ OPERATOR_ONNX_CONVERT_DEFINE( |
|
|
|
|
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(Gather, Gather, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(MakeTuple, SequenceConstruct, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(Maximum, Max, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(Transpose, Transpose, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(StridedSlice, Slice, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(Exp, Exp, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(ResizeNearestNeighbor, Resize, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(Softplus, Softplus, OpNameInfo()) |
|
|
|
OPERATOR_ONNX_CONVERT_DEFINE(Tanh, Tanh, OpNameInfo()) |
|
|
|
|
|
|
|
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name |
|
|
|
|
|
|
|
@@ -288,10 +294,14 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) { |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(MatMul)()); |
|
|
|
|
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(MakeTuple)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(Concat)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(RealDiv)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(Sub)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(Maximum)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(Exp)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(ResizeNearestNeighbor)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(Softplus)()); |
|
|
|
fn(OP_CONVERT_FUNCTION_NAME(Tanh)()); |
|
|
|
} |
|
|
|
|
|
|
|
class OpConvertRegistry { |
|
|
|
@@ -350,6 +360,15 @@ class OnnxExporter { |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
|
|
|
|
void ExportPrimTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimResizeNearestNeighbor(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimConcat(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
@@ -380,6 +399,7 @@ class OnnxExporter { |
|
|
|
onnx::GraphProto *const graph_proto); |
|
|
|
|
|
|
|
void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto); |
|
|
|
|
|
|
|
void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto); |
|
|
|
|
|
|
|
size_t AllocateNodeIndex() { return ++onnx_node_index_; } |
|
|
|
@@ -446,14 +466,12 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphP |
|
|
|
MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; |
|
|
|
} |
|
|
|
|
|
|
|
// set onnx input. |
|
|
|
if (!param_ptr->has_default()) { |
|
|
|
onnx::ValueInfoProto *input_proto = graph_proto->add_input(); |
|
|
|
input_proto->set_name(param_ptr->ToString()); |
|
|
|
SetValueInfoType(param_ptr, input_proto); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
// parameter with default value is an ONNX initializer |
|
|
|
onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); |
|
|
|
initializer_proto->set_name(param_ptr->ToString()); |
|
|
|
@@ -597,6 +615,7 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodeP |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
|
|
|
|
auto iter = op_merged_infos.find(cnode); |
|
|
|
// the node is not referenced by any other nodes, skip it |
|
|
|
if (iter == op_merged_infos.end()) { |
|
|
|
@@ -699,6 +718,237 @@ void OnnxExporter::ExportPrimReduce(const FuncGraphPtr & /*func_graph*/, const C |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportPrimTranspose(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *const graph_proto) { |
|
|
|
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); |
|
|
|
auto input_perm = node->input(2); |
|
|
|
|
|
|
|
auto node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[node] = node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
auto name = prim::kPrimTranspose->name(); |
|
|
|
node_proto->set_op_type(name); |
|
|
|
node_proto->add_output(std::to_string(node_idx)); |
|
|
|
node_proto->add_input(input_data); |
|
|
|
|
|
|
|
if (input_perm->isa<ValueNode>()) { |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_name("perm"); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); |
|
|
|
auto perm_value = dyn_cast<ValueNode>(input_perm)->value(); |
|
|
|
auto int_ptr = dyn_cast<Int32Imm>(perm_value); |
|
|
|
if (int_ptr == nullptr) { |
|
|
|
auto tuple_ptr = dyn_cast<ValueTuple>(perm_value); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_ptr); |
|
|
|
for (size_t i = 0; i < tuple_ptr->size(); ++i) { |
|
|
|
attr_proto->add_ints(GetValue<int64_t>((*tuple_ptr)[i])); |
|
|
|
} |
|
|
|
} else { |
|
|
|
attr_proto->add_ints(int_ptr->value()); |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportPrimStridedSlice(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *const graph_proto) { |
|
|
|
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); |
|
|
|
auto begin = node->input(2); |
|
|
|
std::string name_begin; |
|
|
|
if (begin->isa<ValueNode>()) { |
|
|
|
auto const_node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[begin] = const_node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
name_begin = std::to_string(const_node_idx); |
|
|
|
node_proto->add_output(name_begin); |
|
|
|
|
|
|
|
node_proto->set_op_type("Constant"); |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_name("starts"); |
|
|
|
|
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); |
|
|
|
ConvertTupleToTensor(dyn_cast<ValueNode>(begin)->value(), attr_proto->mutable_t()); |
|
|
|
} else { |
|
|
|
name_begin = GetNodeInputName(begin, node_map_ptr, graph_proto); |
|
|
|
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for StridedSlice."; |
|
|
|
} |
|
|
|
|
|
|
|
auto end = node->input(3); |
|
|
|
std::string name_end; |
|
|
|
if (end->isa<ValueNode>()) { |
|
|
|
auto const_node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[end] = const_node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
name_end = std::to_string(const_node_idx); |
|
|
|
node_proto->add_output(name_end); |
|
|
|
|
|
|
|
node_proto->set_op_type("Constant"); |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_name("ends"); |
|
|
|
|
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); |
|
|
|
ConvertTupleToTensor(dyn_cast<ValueNode>(end)->value(), attr_proto->mutable_t()); |
|
|
|
} else { |
|
|
|
name_begin = GetNodeInputName(end, node_map_ptr, graph_proto); |
|
|
|
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for StridedSlice."; |
|
|
|
} |
|
|
|
|
|
|
|
auto x_shape = dyn_cast<abstract::Shape>(node->input(1)->Shape()); |
|
|
|
int size = x_shape->shape().size(); |
|
|
|
std::vector<int32_t> axes_value; |
|
|
|
ValuePtr axes_value_ptr = nullptr; |
|
|
|
for (int i = 0; i < size; ++i) { |
|
|
|
axes_value.push_back(i); |
|
|
|
} |
|
|
|
axes_value_ptr = MakeValue<std::vector<int32_t>>(axes_value); |
|
|
|
auto axes = NewValueNode(axes_value_ptr)->cast<AnfNodePtr>(); |
|
|
|
std::string name_axes; |
|
|
|
if (axes->isa<ValueNode>()) { |
|
|
|
auto const_node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[axes] = const_node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
name_axes = std::to_string(const_node_idx); |
|
|
|
node_proto->add_output(name_axes); |
|
|
|
|
|
|
|
node_proto->set_op_type("Constant"); |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_name("axes"); |
|
|
|
|
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); |
|
|
|
ConvertTupleToTensor(dyn_cast<ValueNode>(axes)->value(), attr_proto->mutable_t()); |
|
|
|
} else { |
|
|
|
name_begin = GetNodeInputName(axes, node_map_ptr, graph_proto); |
|
|
|
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for StridedSlice."; |
|
|
|
} |
|
|
|
|
|
|
|
auto strides = node->input(4); |
|
|
|
std::string name_strides; |
|
|
|
if (strides->isa<ValueNode>()) { |
|
|
|
auto const_node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[strides] = const_node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
name_strides = std::to_string(const_node_idx); |
|
|
|
node_proto->add_output(name_strides); |
|
|
|
|
|
|
|
node_proto->set_op_type("Constant"); |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_name("steps"); |
|
|
|
|
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); |
|
|
|
ConvertTupleToTensor(dyn_cast<ValueNode>(strides)->value(), attr_proto->mutable_t()); |
|
|
|
} else { |
|
|
|
name_begin = GetNodeInputName(strides, node_map_ptr, graph_proto); |
|
|
|
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for StridedSlice."; |
|
|
|
} |
|
|
|
|
|
|
|
auto node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[node] = node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
node_proto->set_op_type("Slice"); |
|
|
|
node_proto->add_output(std::to_string(node_idx)); |
|
|
|
node_proto->add_input(input_data); |
|
|
|
node_proto->add_input(name_begin); |
|
|
|
node_proto->add_input(name_end); |
|
|
|
node_proto->add_input(name_axes); |
|
|
|
node_proto->add_input(name_strides); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportPrimResizeNearestNeighbor(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *const graph_proto) { |
|
|
|
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); |
|
|
|
auto x_shape = dyn_cast<abstract::Shape>(node->input(1)->Shape()); |
|
|
|
|
|
|
|
AnfNodePtr op = node->input(0); |
|
|
|
auto op_value = dyn_cast<ValueNode>(op); |
|
|
|
auto prim = dyn_cast<Primitive>(op_value->value()); |
|
|
|
std::vector<int64_t> resize_size; |
|
|
|
|
|
|
|
auto tuple_ptr = dyn_cast<ValueTuple>(prim->GetAttr("size")); |
|
|
|
|
|
|
|
for (size_t i = 0; i < x_shape->shape().size() - 2; i++) { |
|
|
|
resize_size.push_back(x_shape->shape()[i]); |
|
|
|
} |
|
|
|
for (size_t i = 0; i < tuple_ptr->size(); i++) { |
|
|
|
ValuePtr elem = (*tuple_ptr)[i]; |
|
|
|
resize_size.push_back(dyn_cast<Int64Imm>(elem)->value()); |
|
|
|
} |
|
|
|
auto resize_size_ptr = MakeValue<std::vector<int64_t>>(resize_size); |
|
|
|
auto size = NewValueNode(resize_size_ptr)->cast<AnfNodePtr>(); |
|
|
|
std::string name_size; |
|
|
|
|
|
|
|
if (size->isa<ValueNode>()) { |
|
|
|
auto const_node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[size] = const_node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
name_size = std::to_string(const_node_idx); |
|
|
|
node_proto->add_output(name_size); |
|
|
|
|
|
|
|
node_proto->set_op_type("Constant"); |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_name("sizes"); |
|
|
|
|
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); |
|
|
|
ConvertTupleToTensor(resize_size_ptr, attr_proto->mutable_t()); |
|
|
|
} else { |
|
|
|
name_size = GetNodeInputName(size, node_map_ptr, graph_proto); |
|
|
|
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for ResizeNearestNeighbor."; |
|
|
|
} |
|
|
|
|
|
|
|
auto node_idx = AllocateNodeIndex(); |
|
|
|
|
|
|
|
onnx::TensorProto *roi_initializer_proto = graph_proto->add_initializer(); |
|
|
|
auto roi_name = std::to_string(node_idx) + "roi_initializer"; |
|
|
|
roi_initializer_proto->set_name(roi_name); |
|
|
|
roi_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32)); |
|
|
|
roi_initializer_proto->add_dims(0); |
|
|
|
|
|
|
|
onnx::TensorProto *scales_initializer_proto = graph_proto->add_initializer(); |
|
|
|
auto scales_name = std::to_string(node_idx) + "scales_initializer"; |
|
|
|
scales_initializer_proto->set_name(scales_name); |
|
|
|
scales_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32)); |
|
|
|
scales_initializer_proto->add_dims(0); |
|
|
|
|
|
|
|
(*node_map_ptr)[node] = node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
|
|
|
|
node_proto->set_op_type("Resize"); |
|
|
|
node_proto->add_output(std::to_string(node_idx)); |
|
|
|
node_proto->add_input(input_data); |
|
|
|
node_proto->add_input(roi_name); |
|
|
|
node_proto->add_input(scales_name); |
|
|
|
node_proto->add_input(name_size); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportPrimConcat(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { |
|
|
|
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); |
|
|
|
auto node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[node] = node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
|
|
|
|
AnfNodePtr op = node->input(0); |
|
|
|
auto op_value = dyn_cast<ValueNode>(op); |
|
|
|
auto prim = dyn_cast<Primitive>(op_value->value()); |
|
|
|
auto input_node = node->input(1)->cast<CNodePtr>(); |
|
|
|
|
|
|
|
if (input_node->IsApply(prim::kPrimMakeTuple)) { |
|
|
|
node_proto->set_op_type("ConcatFromSequence"); |
|
|
|
} else { |
|
|
|
node_proto->set_op_type("Concat"); |
|
|
|
} |
|
|
|
|
|
|
|
// set attr axis |
|
|
|
onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); |
|
|
|
onnx_attr_proto->set_name("axis"); |
|
|
|
SetAttrValueToProto<Int64Imm>(prim->GetAttr("axis"), onnx::AttributeProto_AttributeType_INT, onnx_attr_proto, prim); |
|
|
|
node_proto->add_output(std::to_string(node_idx)); |
|
|
|
node_proto->add_input(input_data); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { |
|
|
|
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); |
|
|
|
@@ -746,7 +996,6 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CN |
|
|
|
attr_proto->set_name("axes"); |
|
|
|
attr_proto->add_ints(1); |
|
|
|
attr_proto->add_ints(2); |
|
|
|
|
|
|
|
node_proto->add_input(input_slope); |
|
|
|
input_slope = std::to_string(node_idx); |
|
|
|
} |
|
|
|
@@ -958,6 +1207,19 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n |
|
|
|
return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
if (node->IsApply(prim::kPrimTranspose)) { |
|
|
|
return ExportPrimTranspose(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
if (node->IsApply(prim::kPrimStridedSlice)) { |
|
|
|
return ExportPrimStridedSlice(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
if (node->IsApply(prim::kPrimResizeNearestNeighbor)) { |
|
|
|
return ExportPrimResizeNearestNeighbor(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
if (node->IsApply(prim::kPrimConcat)) { |
|
|
|
return ExportPrimConcat(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
// MindSpore Cast(x, T) --> ONNX Cast[to=T](x) |
|
|
|
if (node->IsApply(prim::kPrimCast)) { |
|
|
|
return ExportPrimCast(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
@@ -997,7 +1259,6 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n |
|
|
|
if (inputs.size() < 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr op = inputs[0]; |
|
|
|
std::vector<AnfNodePtr> op_inputs; |
|
|
|
// first process node input 1,2,..., since when node input is a ValueNode, here need to create a Constant Operator |
|
|
|
@@ -1007,6 +1268,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n |
|
|
|
} |
|
|
|
} |
|
|
|
auto op_value = dyn_cast<ValueNode>(op); |
|
|
|
|
|
|
|
if (op_value == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Need to support node op type " << op->type_name(); |
|
|
|
} |
|
|
|
@@ -1015,9 +1277,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n |
|
|
|
MS_LOG(EXCEPTION) << "Need to support node op type " << op_value->value()->type_name(); |
|
|
|
} |
|
|
|
|
|
|
|
if (!IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) { |
|
|
|
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); |
|
|
|
} |
|
|
|
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
@@ -1124,6 +1384,7 @@ void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNode |
|
|
|
std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &orig_node, std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *const graph_proto) { |
|
|
|
auto node = GetRealInput(orig_node); |
|
|
|
|
|
|
|
if (node->isa<CNode>()) { |
|
|
|
auto iter = node_map_ptr->find(node); |
|
|
|
if (iter == node_map_ptr->end()) { |
|
|
|
|