| @@ -46,6 +46,7 @@ enum OpMergeMode { | |||
| OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX Batch Normalization` | |||
| OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool` | |||
| OP_MERGE_LAYER_NORM = 6, // indicate `MindSpore LayerNorm(x)[0]` --> `ONNX MeanVarianceNormalization` | |||
| OP_MERGE_CONV2D_TRANSPOSE = 7, // indicate `MindSpore ConvTranspose + BiasAdd` --> `ONNX ConvTranspose` | |||
| }; | |||
| struct OpMergedInfo { | |||
| @@ -408,6 +409,59 @@ void AddReduceOp(const std::string &op_type, const std::string &input, const std | |||
| } | |||
| } | |||
| void AddMeanVarianceNormalizationOp(const std::string &input, const std::string &gamma, const std::string &beta, | |||
| const std::string &output, const std::vector<int64_t> &axes, float epsilon, | |||
| const std::vector<int64_t> &input_shape, onnx::TensorProto_DataType input_type, | |||
| onnx::GraphProto *graph_proto) { | |||
| auto input_name = output + "_input"; | |||
| AddCastOp(input, input_name, onnx::TensorProto_DataType_FLOAT, graph_proto); | |||
| auto gamma_name = output + "_gamma"; | |||
| AddCastOp(gamma, gamma_name, onnx::TensorProto_DataType_FLOAT, graph_proto); | |||
| auto beta_name = output + "_beta"; | |||
| AddCastOp(beta, beta_name, onnx::TensorProto_DataType_FLOAT, graph_proto); | |||
| // MeanVarianceNormalization is replaced with equivalent ops because it is not supported by CUDAExecutionProvider | |||
| auto meanvariancenormal_node_name = output + "_normalized"; | |||
| auto mean_name = output + "_mean"; | |||
| AddReduceOp("ReduceMean", input_name, mean_name, axes, true, graph_proto); | |||
| auto centered_name = output + "_centered"; | |||
| AddOp("Sub", {input_name, mean_name}, {centered_name}, graph_proto); | |||
| auto sqsum_name = output + "_sqsum"; | |||
| AddReduceOp("ReduceSumSquare", centered_name, sqsum_name, axes, true, graph_proto); | |||
| float reduce_size = std::accumulate(axes.begin(), axes.end(), 1.0f, | |||
| [&input_shape](auto acc, auto axis) { return acc * input_shape[axis]; }); | |||
| auto reduce_size_name = output + "_reduce_size"; | |||
| AddFloatScalarInitializer(reduce_size_name, reduce_size, onnx::TensorProto_DataType_FLOAT, graph_proto); | |||
| auto variance_name = output + "_variance"; | |||
| AddOp("Div", {sqsum_name, reduce_size_name}, {variance_name}, graph_proto); | |||
| auto epsilon_name = output + "_epsilon"; | |||
| AddFloatScalarInitializer(epsilon_name, epsilon, onnx::TensorProto_DataType_FLOAT, graph_proto); | |||
| auto variance_with_epsilon_name = output + "_variance_with_epsilon"; | |||
| AddOp("Add", {variance_name, epsilon_name}, {variance_with_epsilon_name}, graph_proto); | |||
| auto std_name = output + "_std"; | |||
| AddOp("Sqrt", {variance_with_epsilon_name}, {std_name}, graph_proto); | |||
| AddOp("Div", {centered_name, std_name}, {meanvariancenormal_node_name}, graph_proto); | |||
| // Add mul and add node | |||
| auto mul_node_name = output + "_rescaled"; | |||
| AddOp("Mul", {meanvariancenormal_node_name, gamma_name}, {mul_node_name}, graph_proto); | |||
| // add beta | |||
| auto add_node_name = output; | |||
| if (input_type == onnx::TensorProto_DataType_FLOAT16) { | |||
| add_node_name += "_shifted"; | |||
| } | |||
| AddOp("Add", {mul_node_name, beta_name}, {add_node_name}, graph_proto); | |||
| if (input_type == onnx::TensorProto_DataType_FLOAT16) { | |||
| AddCastOp(add_node_name, output, onnx::TensorProto_DataType_FLOAT16, graph_proto); | |||
| } | |||
| } | |||
| void AddConcatOp(const std::vector<std::string> &inputs, const std::string &output, int axis, | |||
| onnx::GraphProto *graph_proto) { | |||
| onnx::NodeProto *concat_proto = graph_proto->add_node(); | |||
| @@ -570,9 +624,6 @@ OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo()) | |||
| OPERATOR_ONNX_CONVERT_DEFINE(Sigmoid, Sigmoid, OpNameInfo()) | |||
| OPERATOR_ONNX_CONVERT_DEFINE(Flatten, Flatten, OpNameInfo()) | |||
| OPERATOR_ONNX_CONVERT_DEFINE(Squeeze, Squeeze, | |||
| OpNameInfo().Attr("axis", "axes", onnx::AttributeProto_AttributeType_INTS, | |||
| SetAttrTupleValueToProto<0>)) | |||
| OPERATOR_ONNX_CONVERT_DEFINE( | |||
| Conv2D, Conv, | |||
| @@ -714,7 +765,6 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) { | |||
| fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)()); | |||
| fn(OP_CONVERT_FUNCTION_NAME(AvgPool)()); | |||
| fn(OP_CONVERT_FUNCTION_NAME(Squeeze)()); | |||
| fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)()); | |||
| fn(OP_CONVERT_FUNCTION_NAME(MatMul)()); | |||
| fn(OP_CONVERT_FUNCTION_NAME(MakeTuple)()); | |||
| @@ -854,8 +904,14 @@ class OnnxExporter { | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||
| void ExportPrimOneHot(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||
| void PrimConv2DTransposeExportHelper(const CNodePtr &conv_node, const CNodePtr &bias_add_node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto); | |||
| void ExportPrimConv2DTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||
| void ExportPrimGreaterEqual(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||
| void ExportPrimSqueeze(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||
| void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *graph_proto); | |||
| void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| @@ -866,6 +922,8 @@ class OnnxExporter { | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||
| void ExportMergeLayerNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||
| void ExportMergeConv2DTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||
| void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *graph_proto); | |||
| @@ -1047,6 +1105,7 @@ void OnnxExporter::MatchAndMarkCNode(const FuncGraphPtr &func_graph, const CNode | |||
| const std::vector<MergeRule> first_input_merge_rules = { | |||
| {prim::kPrimBiasAdd, prim::kPrimConv2D, OP_MERGE_CONV}, | |||
| {prim::kPrimBiasAdd, prim::kPrimConv2DTranspose, OP_MERGE_CONV2D_TRANSPOSE}, | |||
| {prim::kPrimBiasAdd, prim::kPrimConv3D, OP_MERGE_CONV}, | |||
| {prim::kPrimBiasAdd, prim::kPrimConv3DTranspose, OP_MERGE_CONV}, | |||
| {prim::kPrimBiasAdd, prim::kPrimMatMul, OP_MERGE_GEMM}, | |||
| @@ -1143,6 +1202,9 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodeP | |||
| case OP_MERGE_LAYER_NORM: | |||
| ExportMergeLayerNorm(func_graph, cnode, node_map_ptr, graph_proto); | |||
| break; | |||
| case OP_MERGE_CONV2D_TRANSPOSE: | |||
| ExportMergeConv2DTranspose(func_graph, cnode, node_map_ptr, graph_proto); | |||
| break; | |||
| default: | |||
| ExportCNode(func_graph, cnode, node_map_ptr, graph_proto); | |||
| break; | |||
| @@ -2362,6 +2424,73 @@ void OnnxExporter::ExportPrimOneHot(const FuncGraphPtr &, const CNodePtr &node, | |||
| one_hot_axis_attr_proto->set_i(axis); | |||
| } | |||
| /* | |||
| Based on nn.Conv2dTranspose | |||
| Warning: `output_shape` is an input in MS and an attribute in ONNX. Hence | |||
| it is not possible to change the output shape in runtime | |||
| */ | |||
| void OnnxExporter::PrimConv2DTransposeExportHelper(const CNodePtr &conv_node, const CNodePtr &bias_add_node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *const graph_proto) { | |||
| auto node_idx = AllocateNodeIndex(); | |||
| std::vector<AnfNodePtr> inputs{conv_node->input(kOneNum), conv_node->input(kTwoNum)}; | |||
| if (bias_add_node != nullptr) { | |||
| inputs.push_back(bias_add_node->input(kTwoNum)); | |||
| (*node_map_ptr)[bias_add_node] = node_idx; | |||
| } else { | |||
| (*node_map_ptr)[conv_node] = node_idx; | |||
| } | |||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||
| node_proto->set_op_type("ConvTranspose"); | |||
| for (const auto &input : inputs) { | |||
| node_proto->add_input(GetNodeInputName(input, node_map_ptr, graph_proto)); | |||
| } | |||
| node_proto->add_output(std::to_string(node_idx)); | |||
| auto prim = GetPrimitive(conv_node); | |||
| auto attrs_convert_info = | |||
| OpNameInfo() | |||
| .Attr("dilation", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>) | |||
| .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int64Imm>) | |||
| .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>) | |||
| .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetConvTransposePadding) | |||
| .Attr("stride", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>); | |||
| for (const auto &attr_info : attrs_convert_info.op_attrs()) { | |||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||
| attr_proto->set_name(attr_info.onnx_attr_name()); | |||
| auto ms_attr = GetOpAttributePtr<Value>(conv_node, attr_info.attr_name()); | |||
| MS_EXCEPTION_IF_NULL(ms_attr); | |||
| attr_info.fn_gen_attr()(ms_attr, attr_info.onnx_attr_type(), attr_proto, prim); | |||
| } | |||
| // Set output shape | |||
| auto input_shape_node = GetRealInput(conv_node->input(kThreeNum)); | |||
| if (!input_shape_node->isa<ValueNode>()) { | |||
| MS_LOG(EXCEPTION) << "For ONNX export third argument must be constant " | |||
| "(Python tuple). Instead got " | |||
| << input_shape_node->ToString(); | |||
| } | |||
| auto input_shape_value_ptr = input_shape_node->cast<ValueNodePtr>()->value(); | |||
| if (!input_shape_value_ptr->isa<ValueTuple>()) { | |||
| MS_LOG(EXCEPTION) << "Expected ValueTuple, got " << input_shape_value_ptr->ToString() << " of type " | |||
| << input_shape_value_ptr->type()->ToString(); | |||
| } | |||
| onnx::AttributeProto *output_shape_attr_proto = node_proto->add_attribute(); | |||
| output_shape_attr_proto->set_name("output_shape"); | |||
| SetAttrTupleValueToProto<0>(input_shape_value_ptr, onnx::AttributeProto_AttributeType_INTS, output_shape_attr_proto, | |||
| prim); | |||
| } | |||
| void OnnxExporter::ExportPrimConv2DTranspose(const FuncGraphPtr &, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *graph_proto) { | |||
| PrimConv2DTransposeExportHelper(node, nullptr, node_map_ptr, graph_proto); | |||
| } | |||
| void OnnxExporter::ExportPrimGreaterEqual(const FuncGraphPtr &, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *const graph_proto) { | |||
| @@ -2377,6 +2506,30 @@ void OnnxExporter::ExportPrimGreaterEqual(const FuncGraphPtr &, const CNodePtr & | |||
| AddOp("Not", {less_name}, {node_name}, graph_proto); | |||
| } | |||
| void OnnxExporter::ExportPrimSqueeze(const FuncGraphPtr &, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||
| auto node_idx = AllocateNodeIndex(); | |||
| (*node_map_ptr)[node] = node_idx; | |||
| auto input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto); | |||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||
| node_proto->set_op_type("Squeeze"); | |||
| node_proto->add_input(input_name); | |||
| node_proto->add_output(std::to_string(node_idx)); | |||
| auto axes = GetOpAttributePtr<ValueSequence>(node, "axis"); | |||
| auto axes_value = GetValue<std::vector<int64_t>>(axes); | |||
| if (!axes_value.empty()) { | |||
| onnx::AttributeProto *axes_proto = node_proto->add_attribute(); | |||
| axes_proto->set_name("axes"); | |||
| axes_proto->set_type(onnx::AttributeProto_AttributeType_INTS); | |||
| for (auto axis : axes_value) { | |||
| axes_proto->add_ints(axis); | |||
| } | |||
| } | |||
| } | |||
| void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||
| using ExportFunc = std::function<void(OnnxExporter *, const FuncGraphPtr &, const CNodePtr &, | |||
| @@ -2407,7 +2560,9 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n | |||
| {prim::kPrimOnesLike, &OnnxExporter::ExportPrimOnesLike}, | |||
| {prim::kPrimArgMaxWithValue, &OnnxExporter::ExportPrimArgMaxWithValue}, | |||
| {prim::kPrimOneHot, &OnnxExporter::ExportPrimOneHot}, | |||
| {prim::kPrimConv2DTranspose, &OnnxExporter::ExportPrimConv2DTranspose}, | |||
| {prim::kPrimGreaterEqual, &OnnxExporter::ExportPrimGreaterEqual}, | |||
| {prim::kPrimSqueeze, &OnnxExporter::ExportPrimSqueeze}, | |||
| {prim::kPrimExpandDims, &OnnxExporter::ExportPrimExpandDims}, | |||
| {prim::kPrimBatchMatMul, &OnnxExporter::ExportPrimBatchMatMul}, | |||
| {prim::kPrimGeLU, &OnnxExporter::ExportPrimGeLU}, | |||
| @@ -2613,12 +2768,44 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CN | |||
| onnx::GraphProto *const graph_proto) { | |||
| auto batch_norm_node = dyn_cast<CNode>(node->input(kOneNum)); | |||
| PrimitivePtr prim_batch_norm = dyn_cast<Primitive>((dyn_cast<ValueNode>(batch_norm_node->input(kZeroNum)))->value()); | |||
| std::vector<AnfNodePtr> inputs; | |||
| for (size_t i = 1; i < batch_norm_node->inputs().size(); i++) { | |||
| inputs.push_back(batch_norm_node->input(i)); | |||
| auto is_training = GetOpAttribute<bool>(batch_norm_node, "is_training"); | |||
| if (is_training) { | |||
| auto input_x_name = GetNodeInputName(batch_norm_node->input(kOneNum), node_map_ptr, graph_proto); | |||
| auto scale_input_name = GetNodeInputName(batch_norm_node->input(kTwoNum), node_map_ptr, graph_proto); | |||
| auto bias_input_name = GetNodeInputName(batch_norm_node->input(kThreeNum), node_map_ptr, graph_proto); | |||
| auto onnx_type = GetOutputType(batch_norm_node->input(kOneNum)); | |||
| auto output_index = AllocateNodeIndex(); | |||
| auto output_name = std::to_string(output_index); | |||
| (*node_map_ptr)[node] = output_index; | |||
| auto input_shape_ptr = batch_norm_node->input(kOneNum)->Shape(); | |||
| auto input_shape = input_shape_ptr->cast<abstract::ShapePtr>()->shape(); | |||
| std::vector<int64_t> normalize_axes = {0}; | |||
| for (size_t i = kTwoNum; i < input_shape.size(); ++i) { | |||
| normalize_axes.push_back(static_cast<int64_t>(i)); | |||
| } | |||
| std::vector<int64_t> scale_bias_shape(input_shape.size(), 1); | |||
| scale_bias_shape[1] = -1; | |||
| auto reshaped_scale_name = output_name + "_reshaped_scale"; | |||
| AddReshapeOp(scale_input_name, reshaped_scale_name, scale_bias_shape, graph_proto); | |||
| auto reshaped_bias_name = output_name + "_reshaped_bias"; | |||
| AddReshapeOp(bias_input_name, reshaped_bias_name, scale_bias_shape, graph_proto); | |||
| auto epsilon = GetOpAttribute<float>(batch_norm_node, "epsilon"); | |||
| AddMeanVarianceNormalizationOp(input_x_name, reshaped_scale_name, reshaped_bias_name, output_name, normalize_axes, | |||
| epsilon, input_shape, onnx_type, graph_proto); | |||
| } else { | |||
| PrimitivePtr prim_batch_norm = GetPrimitive(batch_norm_node); | |||
| std::vector<AnfNodePtr> inputs; | |||
| for (size_t i = 1; i < batch_norm_node->inputs().size(); i++) { | |||
| inputs.push_back(batch_norm_node->input(i)); | |||
| } | |||
| (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); | |||
| } | |||
| (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); | |||
| } | |||
| void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| @@ -2644,109 +2831,28 @@ void OnnxExporter::ExportMergeLayerNorm(const FuncGraphPtr &, const CNodePtr &no | |||
| auto layernorm_input_gamma = GetNodeInputName(LayerNormNode->input(kTwoNum), node_map_ptr, graph_proto); | |||
| auto layernorm_input_beta = GetNodeInputName(LayerNormNode->input(kThreeNum), node_map_ptr, graph_proto); | |||
| auto layernorm_input_x_node = LayerNormNode->input(kOneNum); | |||
| auto dtype = layernorm_input_x_node->Type(); | |||
| auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id(); | |||
| size_t pre_cast_node_idx = 0; | |||
| // if type is float16, add cast node cast type from float16 to float32 | |||
| if (elem_type == kNumberTypeFloat16) { | |||
| pre_cast_node_idx = AllocateNodeIndex(); | |||
| AddCastOp(layernorm_input_x, std::to_string(pre_cast_node_idx), onnx::TensorProto_DataType_FLOAT, graph_proto); | |||
| } | |||
| // reshape before MeanVarianceNormalization | |||
| auto input_shape = dyn_cast<abstract::Shape>(LayerNormNode->input(kOneNum)->Shape()); | |||
| std::vector<int64_t> new_input_shape; | |||
| int64_t n_shape = 1; | |||
| int64_t c_shape = 1; | |||
| int64_t h_shape = 1; | |||
| size_t input_shape_size = input_shape->shape().size(); | |||
| for (size_t i = 0; i < input_shape_size - 1; i++) { | |||
| c_shape = c_shape * input_shape->shape()[i]; | |||
| } | |||
| new_input_shape.push_back(n_shape); | |||
| new_input_shape.push_back(c_shape); | |||
| new_input_shape.push_back(h_shape); | |||
| new_input_shape.push_back(input_shape->shape()[input_shape_size - kOneNum]); | |||
| // Add shape node for reshape(before MeanVarianceNormalization) | |||
| auto new_shape_value = MakeValue<std::vector<int64_t>>(new_input_shape); | |||
| auto shape_node = NewValueNode(new_shape_value)->cast<AnfNodePtr>(); | |||
| auto shape_node_idx = AllocateNodeIndex(); | |||
| onnx::NodeProto *shape_node_proto = graph_proto->add_node(); | |||
| shape_node_proto->add_output(std::to_string(shape_node_idx)); | |||
| shape_node_proto->set_op_type("Constant"); | |||
| onnx::AttributeProto *shape_attr_proto = shape_node_proto->add_attribute(); | |||
| shape_attr_proto->set_name("value"); | |||
| shape_attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | |||
| ConvertTupleToTensor(dyn_cast<ValueNode>(shape_node)->value(), shape_attr_proto->mutable_t()); | |||
| // Add reshape node before MeanVarianceNormalization | |||
| auto pre_reshape_node_idx = AllocateNodeIndex(); | |||
| onnx::NodeProto *pre_reshape_node_proto = graph_proto->add_node(); | |||
| pre_reshape_node_proto->set_op_type("Reshape"); | |||
| pre_reshape_node_proto->add_output(std::to_string(pre_reshape_node_idx)); | |||
| if (elem_type == kNumberTypeFloat16) { | |||
| pre_reshape_node_proto->add_input(std::to_string(pre_cast_node_idx)); | |||
| } else { | |||
| pre_reshape_node_proto->add_input(layernorm_input_x); | |||
| auto begin_norm_axis = GetOpAttribute<int64_t>(LayerNormNode, "begin_norm_axis"); | |||
| auto begin_params_axis = GetOpAttribute<int64_t>(LayerNormNode, "begin_params_axis"); | |||
| if (begin_norm_axis != -1 || begin_params_axis != -1) { | |||
| MS_LOG(EXCEPTION) << "begin_norm_axis != -1 and begin_params_axis != -1 are not implemented"; | |||
| } | |||
| pre_reshape_node_proto->add_input(std::to_string(shape_node_idx)); | |||
| // MeanVarianceNormalization | |||
| auto meanvariancenormal_node_idx = AllocateNodeIndex(); | |||
| onnx::NodeProto *meanvariancenormal_node_proto = graph_proto->add_node(); | |||
| meanvariancenormal_node_proto->set_op_type("MeanVarianceNormalization"); | |||
| meanvariancenormal_node_proto->add_output(std::to_string(meanvariancenormal_node_idx)); | |||
| meanvariancenormal_node_proto->add_input(std::to_string(pre_reshape_node_idx)); | |||
| // if cast type from float16 to float32, add cast node cast type from float32 to float16 | |||
| size_t aft_cast_node_idx = 0; | |||
| if (elem_type == kNumberTypeFloat16) { | |||
| aft_cast_node_idx = AllocateNodeIndex(); | |||
| AddCastOp(std::to_string(meanvariancenormal_node_idx), std::to_string(aft_cast_node_idx), | |||
| onnx::TensorProto_DataType_FLOAT16, graph_proto); | |||
| } | |||
| auto onnx_type = GetOutputType(LayerNormNode->input(kOneNum)); | |||
| auto input_shape = dyn_cast<abstract::Shape>(LayerNormNode->input(kOneNum)->Shape())->shape(); | |||
| auto node_idx = AllocateNodeIndex(); | |||
| (*node_map_ptr)[node] = node_idx; | |||
| auto epsilon = GetOpAttribute<float>(LayerNormNode, "epsilon"); | |||
| std::vector<int64_t> reduce_axes = {static_cast<int64_t>(input_shape.size()) - 1}; | |||
| // Add mul and add node | |||
| auto mul_node_idx = AllocateNodeIndex(); | |||
| onnx::NodeProto *mul_node_proto = graph_proto->add_node(); | |||
| mul_node_proto->set_op_type("Mul"); | |||
| if (elem_type == kNumberTypeFloat16) { | |||
| mul_node_proto->add_input(std::to_string(aft_cast_node_idx)); | |||
| } else { | |||
| mul_node_proto->add_input(std::to_string(meanvariancenormal_node_idx)); | |||
| } | |||
| mul_node_proto->add_input(layernorm_input_gamma); | |||
| mul_node_proto->add_output(std::to_string(mul_node_idx)); | |||
| AddMeanVarianceNormalizationOp(layernorm_input_x, layernorm_input_gamma, layernorm_input_beta, | |||
| std::to_string(node_idx), reduce_axes, epsilon, input_shape, onnx_type, graph_proto); | |||
| } | |||
| // add beta | |||
| auto add_node_idx = AllocateNodeIndex(); | |||
| AddOp("Add", {std::to_string(mul_node_idx), layernorm_input_beta}, {std::to_string(add_node_idx)}, graph_proto); | |||
| // reshape after MeanVarianceNormalization | |||
| // Add shape node for reshape(after MeanVarianceNormalization) | |||
| auto output_shape_value = MakeValue<std::vector<int64_t>>(input_shape->shape()); | |||
| auto output_shape_node = NewValueNode(output_shape_value)->cast<AnfNodePtr>(); | |||
| auto output_shape_node_idx = AllocateNodeIndex(); | |||
| onnx::NodeProto *output_shape_node_proto = graph_proto->add_node(); | |||
| output_shape_node_proto->add_output(std::to_string(output_shape_node_idx)); | |||
| output_shape_node_proto->set_op_type("Constant"); | |||
| onnx::AttributeProto *output_shape_attr_proto = output_shape_node_proto->add_attribute(); | |||
| output_shape_attr_proto->set_name("value"); | |||
| output_shape_attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | |||
| ConvertTupleToTensor(dyn_cast<ValueNode>(output_shape_node)->value(), output_shape_attr_proto->mutable_t()); | |||
| // Add reshape node after MeanVarianceNormalization | |||
| auto aft_reshape_node_idx = AllocateNodeIndex(); | |||
| (*node_map_ptr)[node] = aft_reshape_node_idx; | |||
| onnx::NodeProto *aft_reshape_node_proto = graph_proto->add_node(); | |||
| aft_reshape_node_proto->set_op_type("Reshape"); | |||
| aft_reshape_node_proto->add_output(std::to_string(aft_reshape_node_idx)); | |||
| aft_reshape_node_proto->add_input(std::to_string(add_node_idx)); | |||
| aft_reshape_node_proto->add_input(std::to_string(output_shape_node_idx)); | |||
| void OnnxExporter::ExportMergeConv2DTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *const graph_proto) { | |||
| auto conv_node = dyn_cast<CNode>(node->input(kOneNum)); | |||
| PrimConv2DTransposeExportHelper(conv_node, node, node_map_ptr, graph_proto); | |||
| } | |||
| /* | |||