Browse Source

Pre Merge pull request !16118 from liuyang/add_op_export_onnx

pull/16118/MERGE
liuyang Gitee 5 years ago
parent
commit
f74fd2d1b2
2 changed files with 272 additions and 10 deletions
  1. +271
    -10
      mindspore/ccsrc/transform/express_ir/onnx_exporter.cc
  2. +1
    -0
      mindspore/core/base/core_ops.h

+ 271
- 10
mindspore/ccsrc/transform/express_ir/onnx_exporter.cc View File

@@ -33,7 +33,7 @@ enum OpMergeMode {
OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list
OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv`
OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm`
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX Batch Normalization`
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization`
OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool`
};

@@ -261,10 +261,16 @@ OPERATOR_ONNX_CONVERT_DEFINE(

OPERATOR_ONNX_CONVERT_DEFINE(Gather, Gather, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(MakeTuple, SequenceConstruct, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Maximum, Max, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Transpose, Transpose, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(StridedSlice, Slice, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Exp, Exp, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(ResizeNearestNeighbor, Resize, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Softplus, Softplus, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Tanh, Tanh, OpNameInfo())

#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name

@@ -288,10 +294,14 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
fn(OP_CONVERT_FUNCTION_NAME(MatMul)());

fn(OP_CONVERT_FUNCTION_NAME(MakeTuple)());
fn(OP_CONVERT_FUNCTION_NAME(Concat)());
fn(OP_CONVERT_FUNCTION_NAME(RealDiv)());
fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)());
fn(OP_CONVERT_FUNCTION_NAME(Sub)());
fn(OP_CONVERT_FUNCTION_NAME(Maximum)());
fn(OP_CONVERT_FUNCTION_NAME(Exp)());
fn(OP_CONVERT_FUNCTION_NAME(ResizeNearestNeighbor)());
fn(OP_CONVERT_FUNCTION_NAME(Softplus)());
fn(OP_CONVERT_FUNCTION_NAME(Tanh)());
}

class OpConvertRegistry {
@@ -350,6 +360,15 @@ class OnnxExporter {
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);

void ExportPrimTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimResizeNearestNeighbor(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimConcat(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
@@ -380,6 +399,7 @@ class OnnxExporter {
onnx::GraphProto *const graph_proto);

void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto);

void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto);

size_t AllocateNodeIndex() { return ++onnx_node_index_; }
@@ -446,14 +466,12 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphP
MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
}

// set onnx input.
if (!param_ptr->has_default()) {
onnx::ValueInfoProto *input_proto = graph_proto->add_input();
input_proto->set_name(param_ptr->ToString());
SetValueInfoType(param_ptr, input_proto);
continue;
}

// parameter with default value is an ONNX initializer
onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
initializer_proto->set_name(param_ptr->ToString());
@@ -597,6 +615,7 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodeP
continue;
}
auto cnode = node->cast<CNodePtr>();

auto iter = op_merged_infos.find(cnode);
// the node is not referenced by any other nodes, skip it
if (iter == op_merged_infos.end()) {
@@ -699,6 +718,237 @@ void OnnxExporter::ExportPrimReduce(const FuncGraphPtr & /*func_graph*/, const C
}
}

void OnnxExporter::ExportPrimTranspose(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto input_perm = node->input(2);

auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
auto name = prim::kPrimTranspose->name();
node_proto->set_op_type(name);
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_data);

if (input_perm->isa<ValueNode>()) {
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("perm");
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
auto perm_value = dyn_cast<ValueNode>(input_perm)->value();
auto int_ptr = dyn_cast<Int32Imm>(perm_value);
if (int_ptr == nullptr) {
auto tuple_ptr = dyn_cast<ValueTuple>(perm_value);
MS_EXCEPTION_IF_NULL(tuple_ptr);
for (size_t i = 0; i < tuple_ptr->size(); ++i) {
attr_proto->add_ints(GetValue<int64_t>((*tuple_ptr)[i]));
}
} else {
attr_proto->add_ints(int_ptr->value());
}
} else {
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name;
}
}

void OnnxExporter::ExportPrimStridedSlice(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto begin = node->input(2);
std::string name_begin;
if (begin->isa<ValueNode>()) {
auto const_node_idx = AllocateNodeIndex();
(*node_map_ptr)[begin] = const_node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
name_begin = std::to_string(const_node_idx);
node_proto->add_output(name_begin);

node_proto->set_op_type("Constant");
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("starts");

attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
ConvertTupleToTensor(dyn_cast<ValueNode>(begin)->value(), attr_proto->mutable_t());
} else {
name_begin = GetNodeInputName(begin, node_map_ptr, graph_proto);
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for StridedSlice.";
}

auto end = node->input(3);
std::string name_end;
if (end->isa<ValueNode>()) {
auto const_node_idx = AllocateNodeIndex();
(*node_map_ptr)[end] = const_node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
name_end = std::to_string(const_node_idx);
node_proto->add_output(name_end);

node_proto->set_op_type("Constant");
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("ends");

attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
ConvertTupleToTensor(dyn_cast<ValueNode>(end)->value(), attr_proto->mutable_t());
} else {
name_begin = GetNodeInputName(end, node_map_ptr, graph_proto);
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for StridedSlice.";
}

auto x_shape = dyn_cast<abstract::Shape>(node->input(1)->Shape());
int size = x_shape->shape().size();
std::vector<int32_t> axes_value;
ValuePtr axes_value_ptr = nullptr;
for (int i = 0; i < size; ++i) {
axes_value.push_back(i);
}
axes_value_ptr = MakeValue<std::vector<int32_t>>(axes_value);
auto axes = NewValueNode(axes_value_ptr)->cast<AnfNodePtr>();
std::string name_axes;
if (axes->isa<ValueNode>()) {
auto const_node_idx = AllocateNodeIndex();
(*node_map_ptr)[axes] = const_node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
name_axes = std::to_string(const_node_idx);
node_proto->add_output(name_axes);

node_proto->set_op_type("Constant");
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("axes");

attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
ConvertTupleToTensor(dyn_cast<ValueNode>(axes)->value(), attr_proto->mutable_t());
} else {
name_begin = GetNodeInputName(axes, node_map_ptr, graph_proto);
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for StridedSlice.";
}

auto strides = node->input(4);
std::string name_strides;
if (strides->isa<ValueNode>()) {
auto const_node_idx = AllocateNodeIndex();
(*node_map_ptr)[strides] = const_node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
name_strides = std::to_string(const_node_idx);
node_proto->add_output(name_strides);

node_proto->set_op_type("Constant");
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("steps");

attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
ConvertTupleToTensor(dyn_cast<ValueNode>(strides)->value(), attr_proto->mutable_t());
} else {
name_begin = GetNodeInputName(strides, node_map_ptr, graph_proto);
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for StridedSlice.";
}

auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("Slice");
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_data);
node_proto->add_input(name_begin);
node_proto->add_input(name_end);
node_proto->add_input(name_axes);
node_proto->add_input(name_strides);
}

void OnnxExporter::ExportPrimResizeNearestNeighbor(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto x_shape = dyn_cast<abstract::Shape>(node->input(1)->Shape());

AnfNodePtr op = node->input(0);
auto op_value = dyn_cast<ValueNode>(op);
auto prim = dyn_cast<Primitive>(op_value->value());
std::vector<int64_t> resize_size;

auto tuple_ptr = dyn_cast<ValueTuple>(prim->GetAttr("size"));

for (size_t i = 0; i < x_shape->shape().size() - 2; i++) {
resize_size.push_back(x_shape->shape()[i]);
}
for (size_t i = 0; i < tuple_ptr->size(); i++) {
ValuePtr elem = (*tuple_ptr)[i];
resize_size.push_back(dyn_cast<Int64Imm>(elem)->value());
}
auto resize_size_ptr = MakeValue<std::vector<int64_t>>(resize_size);
auto size = NewValueNode(resize_size_ptr)->cast<AnfNodePtr>();
std::string name_size;

if (size->isa<ValueNode>()) {
auto const_node_idx = AllocateNodeIndex();
(*node_map_ptr)[size] = const_node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
name_size = std::to_string(const_node_idx);
node_proto->add_output(name_size);

node_proto->set_op_type("Constant");
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("sizes");

attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
ConvertTupleToTensor(resize_size_ptr, attr_proto->mutable_t());
} else {
name_size = GetNodeInputName(size, node_map_ptr, graph_proto);
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for ResizeNearestNeighbor.";
}

auto node_idx = AllocateNodeIndex();

onnx::TensorProto *roi_initializer_proto = graph_proto->add_initializer();
auto roi_name = std::to_string(node_idx) + "roi_initializer";
roi_initializer_proto->set_name(roi_name);
roi_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
roi_initializer_proto->add_dims(0);

onnx::TensorProto *scales_initializer_proto = graph_proto->add_initializer();
auto scales_name = std::to_string(node_idx) + "scales_initializer";
scales_initializer_proto->set_name(scales_name);
scales_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
scales_initializer_proto->add_dims(0);

(*node_map_ptr)[node] = node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();

node_proto->set_op_type("Resize");
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_data);
node_proto->add_input(roi_name);
node_proto->add_input(scales_name);
node_proto->add_input(name_size);
}

void OnnxExporter::ExportPrimConcat(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();

AnfNodePtr op = node->input(0);
auto op_value = dyn_cast<ValueNode>(op);
auto prim = dyn_cast<Primitive>(op_value->value());
auto input_node = node->input(1)->cast<CNodePtr>();

if (input_node->IsApply(prim::kPrimMakeTuple)) {
node_proto->set_op_type("ConcatFromSequence");
} else {
node_proto->set_op_type("Concat");
}

// set attr axis
onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute();
onnx_attr_proto->set_name("axis");
SetAttrValueToProto<Int64Imm>(prim->GetAttr("axis"), onnx::AttributeProto_AttributeType_INT, onnx_attr_proto, prim);
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_data);
}

void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
@@ -746,7 +996,6 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CN
attr_proto->set_name("axes");
attr_proto->add_ints(1);
attr_proto->add_ints(2);

node_proto->add_input(input_slope);
input_slope = std::to_string(node_idx);
}
@@ -958,6 +1207,19 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto);
}

if (node->IsApply(prim::kPrimTranspose)) {
return ExportPrimTranspose(func_graph, node, node_map_ptr, graph_proto);
}
if (node->IsApply(prim::kPrimStridedSlice)) {
return ExportPrimStridedSlice(func_graph, node, node_map_ptr, graph_proto);
}
if (node->IsApply(prim::kPrimResizeNearestNeighbor)) {
return ExportPrimResizeNearestNeighbor(func_graph, node, node_map_ptr, graph_proto);
}
if (node->IsApply(prim::kPrimConcat)) {
return ExportPrimConcat(func_graph, node, node_map_ptr, graph_proto);
}

// MindSpore Cast(x, T) --> ONNX Cast[to=T](x)
if (node->IsApply(prim::kPrimCast)) {
return ExportPrimCast(func_graph, node, node_map_ptr, graph_proto);
@@ -997,7 +1259,6 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
if (inputs.size() < 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}

AnfNodePtr op = inputs[0];
std::vector<AnfNodePtr> op_inputs;
// first process node input 1,2,..., since when node input is a ValueNode, here need to create a Constant Operator
@@ -1007,6 +1268,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
}
}
auto op_value = dyn_cast<ValueNode>(op);

if (op_value == nullptr) {
MS_LOG(EXCEPTION) << "Need to support node op type " << op->type_name();
}
@@ -1015,9 +1277,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
MS_LOG(EXCEPTION) << "Need to support node op type " << op_value->value()->type_name();
}

if (!IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
}
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
}

size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map<AnfNodePtr, size_t> *node_map_ptr,
@@ -1124,6 +1384,7 @@ void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNode
std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &orig_node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto node = GetRealInput(orig_node);

if (node->isa<CNode>()) {
auto iter = node_map_ptr->find(node);
if (iter == node_map_ptr->end()) {


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -190,6 +190,7 @@ inline const PrimitivePtr kPrimRank = std::make_shared<Primitive>("Rank");
inline const PrimitivePtr kPrimResizeBilinear = std::make_shared<Primitive>("ResizeBilinear");
inline const PrimitivePtr kPrimResizeGrad = std::make_shared<Primitive>("ResizeGrad");
inline const PrimitivePtr kPrimSort = std::make_shared<Primitive>("Sort");
inline const PrimitivePtr kPrimResizeNearestNeighbor = std::make_shared<Primitive>("ResizeNearestNeighbor");

// NN
inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam");


Loading…
Cancel
Save