Browse Source

ONNX converter: implement CRNN export

feature/build-system-rewrite
Alexander Malyshev 4 years ago
parent
commit
e9d771dab7
1 changed files with 218 additions and 10 deletions
  1. +218
    -10
      mindspore/ccsrc/transform/express_ir/onnx_exporter.cc

+ 218
- 10
mindspore/ccsrc/transform/express_ir/onnx_exporter.cc View File

@@ -28,6 +28,7 @@
#include "base/core_ops.h"
#include "proto/onnx.pb.h"
#include "utils/check_convert_utils.h"
#include "utils/ms_context.h"

namespace mindspore {
const int ONNX_VERSION = 11;
@@ -309,18 +310,22 @@ void AddClipOp(const std::string &input, const std::string &output, float min, f
AddOp("Clip", {input, min_input_name, max_input_name}, {output}, graph_proto);
}

void AddSliceOp(const std::string &input, const std::string &output, int64_t start, int64_t end, int64_t axis,
void AddSliceOp(const std::string &input, const std::string &output, const std::vector<int64_t> &start,
const std::vector<int64_t> &end, const std::vector<int64_t> &axis, const std::vector<int64_t> &step,
onnx::GraphProto *graph_proto) {
auto starts_name = output + "__starts_initializer";
AddInt64Tensor1DInitializer(starts_name, {start}, graph_proto);
AddInt64Tensor1DInitializer(starts_name, start, graph_proto);

auto ends_name = output + "__ends_initializer";
AddInt64Tensor1DInitializer(ends_name, {end}, graph_proto);
AddInt64Tensor1DInitializer(ends_name, end, graph_proto);

auto axes_name = output + "__axes_initializer";
AddInt64Tensor1DInitializer(axes_name, {axis}, graph_proto);
AddInt64Tensor1DInitializer(axes_name, axis, graph_proto);

AddOp("Slice", {input, starts_name, ends_name, axes_name}, {output}, graph_proto);
auto steps_name = output + "__steps_initializer";
AddInt64Tensor1DInitializer(steps_name, step, graph_proto);

AddOp("Slice", {input, starts_name, ends_name, axes_name, steps_name}, {output}, graph_proto);
}

void AddSplitOp(const std::string &input, const std::vector<std::string> &outputs, const std::vector<int64_t> &split,
@@ -510,7 +515,7 @@ void ConvertBoxesToXyxy(const std::string &centerpoints, const std::string &dime
void ClipPointsComponent(const std::string &points, const std::string &clipped, float max, int64_t component_idx,
onnx::TensorProto_DataType type, onnx::GraphProto *graph_proto) {
auto res_to_clip_name = clipped + "__clip";
AddSliceOp(points, res_to_clip_name, component_idx, component_idx + 1, 1, graph_proto);
AddSliceOp(points, res_to_clip_name, {component_idx}, {component_idx + 1}, {1}, {1}, graph_proto);
AddClipOp(res_to_clip_name, clipped, 0.0f, max, type, graph_proto);
}

@@ -747,6 +752,13 @@ OPERATOR_ONNX_CONVERT_DEFINE(Select, Where, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Log, Log, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Greater, Greater, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(LogicalAnd, And, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(ReverseSequence, ReverseSequence,
OpNameInfo()
.Attr("seq_dim", "time_axis", onnx::AttributeProto_AttributeType_INT,
SetAttrValueToProto<Int64Imm>)
.Attr("batch_dim", "batch_axis", onnx::AttributeProto_AttributeType_INT,
SetAttrValueToProto<Int64Imm>)
.CastInput(1, onnx::TensorProto_DataType_INT32, onnx::TensorProto_DataType_INT64))

#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name

@@ -791,6 +803,7 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
fn(OP_CONVERT_FUNCTION_NAME(Log)());
fn(OP_CONVERT_FUNCTION_NAME(Greater)());
fn(OP_CONVERT_FUNCTION_NAME(LogicalAnd)());
fn(OP_CONVERT_FUNCTION_NAME(ReverseSequence)());
}

class OpConvertRegistry {
@@ -912,6 +925,10 @@ class OnnxExporter {
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimSqueeze(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimLSTM(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportPrimReverseV2(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
@@ -2111,7 +2128,7 @@ void OnnxExporter::ExportPrimNMSWithMask(const FuncGraphPtr &, const CNodePtr &n
auto max_output_boxes_to_squeeze_name = boxes_count_name + "_to_reshape";
auto input_shape_name = node_name + "input_shape";
AddOp("Shape", {bboxes_input_name}, {input_shape_name}, graph_proto);
AddSliceOp(input_shape_name, max_output_boxes_to_squeeze_name, 0, 1, 0, graph_proto);
AddSliceOp(input_shape_name, max_output_boxes_to_squeeze_name, {0}, {1}, {0}, {1}, graph_proto);
AddReshapeOp(max_output_boxes_to_squeeze_name, boxes_count_name, {}, graph_proto);

auto scores_name = node_name + "scores";
@@ -2120,7 +2137,7 @@ void OnnxExporter::ExportPrimNMSWithMask(const FuncGraphPtr &, const CNodePtr &n
auto scores_to_flatten_name = scores_name + "_to_reshape";
auto descending_order_name = node_name + "descending_indices";
const int BBOX_NUM_EL = 4;
AddSliceOp(bboxes_input_name, scores_to_flatten_name, BBOX_NUM_EL, BBOX_NUM_EL + 1, 1, graph_proto);
AddSliceOp(bboxes_input_name, scores_to_flatten_name, {BBOX_NUM_EL}, {BBOX_NUM_EL + 1}, {1}, {1}, graph_proto);
AddReshapeOp(scores_to_flatten_name, flat_scores_name, {-1}, graph_proto);
AddOp("TopK", {flat_scores_name, max_output_boxes_to_squeeze_name}, {sorted_scores_name, descending_order_name},
graph_proto);
@@ -2132,7 +2149,7 @@ void OnnxExporter::ExportPrimNMSWithMask(const FuncGraphPtr &, const CNodePtr &n
graph_proto); // Output 0: boxes
auto boxes_name = node_name + "boxes";
auto boxes_to_reshape_name = boxes_name + "_to_reshape";
AddSliceOp(selected_boxes_output_name, boxes_to_reshape_name, 0, BBOX_NUM_EL, 1, graph_proto);
AddSliceOp(selected_boxes_output_name, boxes_to_reshape_name, {0}, {BBOX_NUM_EL}, {1}, {1}, graph_proto);
AddReshapeOp(boxes_to_reshape_name, boxes_name, {1, -1, BBOX_NUM_EL}, graph_proto);

if (onnx_input_type == onnx::TensorProto_DataType_FLOAT16) {
@@ -2156,7 +2173,8 @@ void OnnxExporter::ExportPrimNMSWithMask(const FuncGraphPtr &, const CNodePtr &n
auto flat_indices_name = node_name + "flat_indices";
auto flat_indices_to_squeeze_name = flat_indices_name + "__reshape";
const int BOX_INDEX_POS = 2;
AddSliceOp(selected_indices_name, flat_indices_to_squeeze_name, BOX_INDEX_POS, BOX_INDEX_POS + 1, 1, graph_proto);
AddSliceOp(selected_indices_name, flat_indices_to_squeeze_name, {BOX_INDEX_POS}, {BOX_INDEX_POS + 1}, {1}, {1},
graph_proto);
AddReshapeOp(flat_indices_to_squeeze_name, flat_indices_name, {-1}, graph_proto);

auto zero_name = node_name + "zero_initializer";
@@ -2530,6 +2548,194 @@ void OnnxExporter::ExportPrimSqueeze(const FuncGraphPtr &, const CNodePtr &node,
}
}

void MakeLSTMWeight(const std::string &input, const std::string &output, const std::vector<int64_t> &output_shape,
onnx::GraphProto *graph_proto) {
auto reshaped_name = output + "__split";
AddReshapeOp(input, reshaped_name, output_shape, graph_proto);

auto split_i_name = output + "__concat_i";
auto split_o_name = output + "__concat_o";
auto split_f_name = output + "__concat_f";
auto split_c_name = output + "__concat_c";
int64_t hidden_size = output_shape[1] / 4;
AddSplitOp(reshaped_name, {split_i_name, split_f_name, split_c_name, split_o_name},
{hidden_size, hidden_size, hidden_size, hidden_size}, 1, graph_proto);

AddConcatOp({split_i_name, split_o_name, split_f_name, split_c_name}, output, 1, graph_proto);
}

void ExportLSTMWeights(const CNodePtr &node, const std::string &node_name, const std::string &weights_name,
onnx::TensorProto_DataType dtype, const std::string &onnx_input_weights_name,
const std::string &onnx_hidden_weights_name, const std::string &onnx_bias_name,
onnx::GraphProto *graph_proto) {
auto input_size = GetOpAttribute<int64_t>(node, "input_size");
auto hidden_size = GetOpAttribute<int64_t>(node, "hidden_size");
auto num_layers = GetOpAttribute<int64_t>(node, "num_layers");
auto has_bias = GetOpAttribute<bool>(node, "has_bias");
auto bidirectional = GetOpAttribute<bool>(node, "bidirectional");
auto num_dir = 1 + bidirectional;
auto num_gates = 4;
auto gate_size = num_gates * hidden_size;

if (num_layers != 1) {
MS_LOG(EXCEPTION) << "Converter for multilayer LSTM is not implemented";
}
if (bidirectional) {
MS_LOG(EXCEPTION) << "Bidirectional mode for P.LSTM is not implemented";
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto target_device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (target_device != "CPU" && target_device != "GPU") {
MS_LOG(EXCEPTION) << "Unsupported target device: " << target_device;
}

auto input_weights_name = node_name + "_input_weights";
auto hidden_weights_name = node_name + "_hidden_weights";
auto input_bias_name = node_name + "_input_bias";
auto hidden_bias_name = node_name + "_hidden_bias";

std::vector<int64_t> split_sizes = {input_size * gate_size, hidden_size * gate_size};
std::vector<std::string> split_outputs = {input_weights_name, hidden_weights_name};
if (has_bias) {
if (target_device == "GPU") {
split_sizes.insert(split_sizes.end(), {gate_size, gate_size});
split_outputs.insert(split_outputs.end(), {input_bias_name, hidden_bias_name});
} else if (target_device == "CPU") {
split_sizes.push_back(gate_size);
split_outputs.push_back(input_bias_name);
} else {
MS_LOG(EXCEPTION) << "Impossible branch";
}
}
AddSplitOp(weights_name, split_outputs, split_sizes, 0, graph_proto);

MakeLSTMWeight(input_weights_name, onnx_input_weights_name, {num_dir, gate_size, input_size}, graph_proto);
MakeLSTMWeight(hidden_weights_name, onnx_hidden_weights_name, {num_dir, gate_size, hidden_size}, graph_proto);
if (has_bias) {
auto onnx_input_bias_name = node_name + "_onnx_input_bias";
auto onnx_hidden_bias_name = node_name + "_onnx_hidden_bias";
if (target_device == "GPU") {
MakeLSTMWeight(input_bias_name, onnx_input_bias_name, {num_dir, gate_size}, graph_proto);
MakeLSTMWeight(hidden_bias_name, onnx_hidden_bias_name, {num_dir, gate_size}, graph_proto);
} else if (target_device == "CPU") {
MakeLSTMWeight(input_bias_name, onnx_input_bias_name, {num_dir, gate_size}, graph_proto);
auto bias_shape_name = node_name + "_bias_shape";
AddOp("Shape", {onnx_input_bias_name}, {bias_shape_name}, graph_proto);
onnx::TensorProto *zero_padding = AddConstantOfShapeOp(bias_shape_name, onnx_hidden_bias_name, graph_proto);
zero_padding->set_data_type(dtype);
if (dtype == onnx::TensorProto_DataType_FLOAT16) {
zero_padding->add_int32_data(0); // float 0 and int 0 have identical representations
} else if (dtype == onnx::TensorProto_DataType_FLOAT) {
zero_padding->add_float_data(0.0f);
} else {
MS_LOG(EXCEPTION) << "Unsupported type: " << dtype;
}
} else {
MS_LOG(EXCEPTION) << "Impossible branch";
}
AddConcatOp({onnx_input_bias_name, onnx_hidden_bias_name}, onnx_bias_name, 1, graph_proto);
}
}

void OnnxExporter::ExportPrimLSTM(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto node_idx = AllocateNodeIndex();
auto node_name = std::to_string(node_idx);
(*node_map_ptr)[node] = node_idx;

// MS inputs
auto x_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
auto init_h_input_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
auto init_c_input_name = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);

auto hidden_size = GetOpAttribute<int64_t>(node, "hidden_size");
auto has_bias = GetOpAttribute<bool>(node, "has_bias");
auto bidirectional = GetOpAttribute<bool>(node, "bidirectional");
std::string direction = bidirectional ? "bidirectional" : "forward";
auto x_input_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
auto seq_len = x_input_shape[0];
auto batch_size = x_input_shape[1];
auto num_dir = 1 + bidirectional;

auto weights_name = GetNodeInputName(node->input(kFourNum), node_map_ptr, graph_proto);
auto dtype = GetOutputType(node->input(kOneNum));
auto onnx_input_weights_name = node_name + "_onnx_input_weights";
auto onnx_hidden_weights_name = node_name + "_onnx_hidden_weights";
auto onnx_bias_name = node_name + "_onnx_bias";

ExportLSTMWeights(node, node_name, weights_name, dtype, onnx_input_weights_name, onnx_hidden_weights_name,
onnx_bias_name, graph_proto);

// Create LSTM node
onnx::NodeProto *lstm_node_proto = graph_proto->add_node();
lstm_node_proto->set_op_type("LSTM");
lstm_node_proto->add_input(x_input_name);
lstm_node_proto->add_input(onnx_input_weights_name);
lstm_node_proto->add_input(onnx_hidden_weights_name);
lstm_node_proto->add_input(has_bias ? onnx_bias_name : "");
lstm_node_proto->add_input(""); // seqlens
lstm_node_proto->add_input(init_h_input_name);
lstm_node_proto->add_input(init_c_input_name);

auto Y_output_name = node_name + "_Y";
lstm_node_proto->add_output(Y_output_name);
lstm_node_proto->add_output(MakeOutputName(node_name, kOneNum));
lstm_node_proto->add_output(MakeOutputName(node_name, kTwoNum));

onnx::AttributeProto *hidden_size_proto = lstm_node_proto->add_attribute();
hidden_size_proto->set_name("hidden_size");
hidden_size_proto->set_type(onnx::AttributeProto_AttributeType_INT);
hidden_size_proto->set_i(hidden_size);

onnx::AttributeProto *direction_proto = lstm_node_proto->add_attribute();
direction_proto->set_name("direction");
direction_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
direction_proto->set_s(direction);

// Transpose 1st output of the LSTM node
onnx::NodeProto *transpose_node_proto = graph_proto->add_node();
auto transpose_node_name = node_name + "_Y_transposed";
transpose_node_proto->set_name(transpose_node_name);
transpose_node_proto->set_op_type("Transpose");
transpose_node_proto->add_input(Y_output_name);
transpose_node_proto->add_output(transpose_node_name);

onnx::AttributeProto *perm_proto = transpose_node_proto->add_attribute();
perm_proto->set_name("perm");
perm_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
perm_proto->add_ints(kZeroNum);
perm_proto->add_ints(kTwoNum);
perm_proto->add_ints(kOneNum);
perm_proto->add_ints(kThreeNum);

// Reshape
auto output_name = MakeOutputName(node_name, kZeroNum);
AddReshapeOp(transpose_node_name, output_name, {seq_len, batch_size, num_dir * hidden_size}, graph_proto);
}

void OnnxExporter::ExportPrimReverseV2(const FuncGraphPtr &, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
auto input = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
auto output = std::to_string(node_idx);

auto axes_ptr = GetOpAttributePtr<ValueSequeue>(node, "axis");
auto axes_vec = GetValue<std::vector<int64_t>>(axes_ptr);
size_t n_axes = axes_vec.size();
auto shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();

std::vector<int64_t> starts_vec(n_axes, -1);
std::vector<int64_t> ends_vec(n_axes);
std::transform(axes_vec.begin(), axes_vec.end(), ends_vec.begin(),
[&shape](int64_t ax) { return -shape.at(ax) - 1; });
std::vector<int64_t> steps_vec(n_axes, -1);

AddSliceOp(input, output, starts_vec, ends_vec, axes_vec, steps_vec, graph_proto);
}

void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
using ExportFunc = std::function<void(OnnxExporter *, const FuncGraphPtr &, const CNodePtr &,
@@ -2566,6 +2772,8 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
{prim::kPrimExpandDims, &OnnxExporter::ExportPrimExpandDims},
{prim::kPrimBatchMatMul, &OnnxExporter::ExportPrimBatchMatMul},
{prim::kPrimGeLU, &OnnxExporter::ExportPrimGeLU},
{prim::kPrimLstm, &OnnxExporter::ExportPrimLSTM},
{prim::kPrimReverseV2, &OnnxExporter::ExportPrimReverseV2},
};

auto iter = std::find_if(export_table.begin(), export_table.end(),


Loading…
Cancel
Save