diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc index 8241b72..3fa0261 100644 --- a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc +++ b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc @@ -25,12 +25,12 @@ #include "ut/parser/parser_ut_utils.h" #include "external/ge/ge_api_types.h" #include "tests/depends/ops_stub/ops_stub.h" -#include "parser/onnx/onnx_parser.h" #define protected public #define private public #include "parser/onnx/onnx_constant_parser.h" #include "parser/onnx/onnx_util.h" +#include "parser/onnx/onnx_parser.h" #undef protected #undef private @@ -315,4 +315,62 @@ TEST_F(UtestOnnxParser, OnnxConstantParser_ParseConvertDataType_test) EXPECT_EQ(ret, FAILED); } +TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test) +{ + OnnxModelParser model_parser; + ge::onnx::ModelProto model_proto; + ge::onnx::GraphProto graph = model_proto.graph(); + std::map initializer_name_tensor; + bool is_subgraph = false; + + Status ret = model_parser.ParseInput(initializer_name_tensor, is_subgraph, graph); + EXPECT_EQ(ret, domi::FAILED); + + ret = model_parser.ParseOutput(graph); + EXPECT_EQ(ret, domi::FAILED); +} + +TEST_F(UtestOnnxParser, onnx_test_ConstructOriType) +{ + ge::onnx::ModelProto model_proto; + ge::onnx::GraphProto* graph = model_proto.mutable_graph(); + ge::onnx::NodeProto* add_node = graph->add_node(); + add_node->set_op_type("Add"); + add_node->set_domain("ai.onnx"); + + OnnxModelParser onnx_parser ; + onnx_parser.domain_verseion_["ai.onnx"] = 11; + string ori_type; + Status ret = onnx_parser.ConstructOriType(add_node, ori_type); + EXPECT_EQ(ret, domi::SUCCESS); + + ge::onnx::NodeProto* add_node1 = graph->add_node(); + add_node1->set_op_type("Add1"); + add_node1->set_domain("add.onnx"); + string op_type; + ret = onnx_parser.AdapterOpType(add_node1, ori_type, op_type); + EXPECT_EQ(ret, ge::PARAM_INVALID); + + add_node->set_op_type("Add1"); + ret = onnx_parser.AdapterOpType(add_node, ori_type, op_type); + EXPECT_EQ(ret, PARAM_INVALID); +} + +TEST_F(UtestOnnxParser, onnx_test_TransNodeToOperator) +{ + ge::onnx::ModelProto model_proto; + ge::onnx::GraphProto* graph = model_proto.mutable_graph(); + ge::onnx::NodeProto *node_proto = graph->add_node(); + node_proto->set_op_type("Add1"); + node_proto->set_domain("add.onnx"); + node_proto->set_name("Conv2D"); + ge::OpDescPtr op_desc_src = std::make_shared("Add", "add.onnx"); + ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); + std::string op_type = "Add"; + + OnnxModelParser onnx_parser; + Status ret = onnx_parser.TransNodeToOperator(node_proto, op, op_type); + EXPECT_EQ(ret, SUCCESS); +} + } // namespace ge diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc index fa7c547..55cdfe2 100644 --- a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc @@ -161,6 +161,12 @@ void UtestTensorflowParser::RegisterCustomOp() { domi::OpRegistry::Instance()->registrationDatas.clear(); } +struct DelTransposeInfo { + domi::tensorflow::NodeDef *node_def; // transpose + domi::tensorflow::NodeDef *nextNodeDef; // transpose --> [next] + int inputIdx; +}; + namespace { NodeDef* AddNode(GraphDef& graph, string type, string name) { NodeDef* nodeDef = graph.add_node(); @@ -4341,4 +4347,209 @@ TEST_F(UtestTensorflowParser, tensorflow_IsFusionOpChild) ASSERT_EQ(ret, true); } +TEST_F(UtestTensorflowParser, tensorflow_UpdateAllNodeOpContext) +{ + TensorFlowModelParser modelParser; + Status ret; + auto scope_graph = ge::parser::MakeShared(); + if (scope_graph == nullptr) { + GELOGE(FAILED, "Scope graph make shared failed."); + return; + } + if (scope_graph->Init() != SUCCESS) { + GELOGE(FAILED, "Scope graph init failed."); + return; + } + + ge::ScopeFusionOpInfo info; + info.node_name = "node_name"; + info.fusion_node_name = "fusion_node_name"; + info.fusion_op_type = "fusion_op_type"; + info.description = "description"; + info.scope_pass = "scope_pass"; + modelParser.fusion_op_children_["Const"] = info; + ge::OpNodeContext op_node_context; + op_node_context.input_map["pre_node_a"].push_back({0, 0}); + op_node_context.input_map["pre_node_ctrl_in"].push_back({-1, -1}); // ctrl edges + op_node_context.output_map["post_node_b"].push_back({0, 0}); + op_node_context.output_map["post_node_c"].push_back({1, 0}); + op_node_context.output_map["post_node_d"].push_back({-1, -1}); + op_node_context.output_map["_Retval"].push_back({0, 1}); + modelParser.op_node_context_map_["Const"] = op_node_context; + NodeDef *node = initNodeDef(); + node->set_op("NULL"); + modelParser.nodedef_map_["Const"] = node; + std::vector op_node_name_list = {"Const"}; + + ret = modelParser.UpdateAllNodeOpContext(scope_graph, op_node_name_list); + EXPECT_EQ(ret, SUCCESS); + + delete node; +} + +TEST_F(UtestTensorflowParser, tensorflow_UppdateInputMap) +{ + TensorFlowModelParser modelParser; + Status ret; + auto scope_graph = ge::parser::MakeShared(); + if (scope_graph == nullptr) { + GELOGE(FAILED, "Scope graph make shared failed."); + return; + } + if (scope_graph->Init() != SUCCESS) { + GELOGE(FAILED, "Scope graph init failed."); + return; + } + + ge::ScopeFusionOpInfo info; + ge::OpNodeContext fusion_op_node_context, normal_op_node_context; + info.node_name = "node_name"; + info.fusion_node_name = "fusion_node_name"; + info.fusion_op_type = "fusion_op_type"; + info.description = "description"; + info.scope_pass = "scope_pass"; + modelParser.fusion_op_children_["Const"] = info; + normal_op_node_context.input_map["Const"].push_back({0, 1}); + normal_op_node_context.output_map["Const"].push_back({0, 1}); + fusion_op_node_context.output_map["Const"].push_back({0, 1}); + + NodeDef *node = initNodeDef(); + node->set_op("NULL"); + modelParser.nodedef_map_["Const"] = node; + ret = modelParser.UppdateInputMap(scope_graph, info, fusion_op_node_context, normal_op_node_context); + EXPECT_EQ(ret, SUCCESS); + + ge::ScopeFusionOpInfo info1; + info1.fusion_node_name = "no fusion_node_name"; + ret = modelParser.UppdateInputMap(scope_graph, info1, fusion_op_node_context, normal_op_node_context); + EXPECT_EQ(ret, SUCCESS); + + delete node; +} + +TEST_F(UtestTensorflowParser, tensorflow_UppdateOutputMap) +{ + TensorFlowModelParser modelParser; + Status ret; + auto scope_graph = ge::parser::MakeShared(); + if (scope_graph == nullptr) { + GELOGE(FAILED, "Scope graph make shared failed."); + return; + } + if (scope_graph->Init() != SUCCESS) { + GELOGE(FAILED, "Scope graph init failed."); + return; + } + + ge::ScopeFusionOpInfo info; + ge::OpNodeContext fusion_op_node_context, normal_op_node_context; + info.node_name = "node_name"; + info.fusion_node_name = "fusion_node_name"; + info.fusion_op_type = "fusion_op_type"; + info.description = "description"; + info.scope_pass = "scope_pass"; + modelParser.fusion_op_children_["Const"] = info; + normal_op_node_context.output_map["Const"].push_back({0, 1}); + ret = modelParser.UppdateOutputMap(scope_graph, info, fusion_op_node_context, normal_op_node_context); + EXPECT_EQ(ret, SUCCESS); + + ge::ScopeFusionOpInfo info1; + info1.fusion_node_name = "no fusion_node_name"; + ret = modelParser.UppdateOutputMap(scope_graph, info1, fusion_op_node_context, normal_op_node_context); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestTensorflowParser, tensorflow_EraseNormalOpOutputIfChild) +{ + Status ret; + TensorFlowModelParser modelParser; + auto scope_graph = ge::parser::MakeShared(); + if (scope_graph == nullptr) { + GELOGE(FAILED, "Scope graph make shared failed."); + return; + } + if (scope_graph->Init() != SUCCESS) { + GELOGE(FAILED, "Scope graph init failed."); + return; + } + + const string op_node_name = "Const"; + OpNodeContext normal_op_node_context; + normal_op_node_context.input_map["pre_node_a"].push_back({0, 0}); + normal_op_node_context.output_map[op_node_name].push_back({0, 0}); + + ge::ScopeFusionOpInfo info; + info.node_name = "node_name"; + info.fusion_node_name = "fusion_node_name"; + info.fusion_op_type = "fusion_op_type"; + info.description = "description"; + info.scope_pass = "scope_pass"; + modelParser.fusion_op_children_["Const"] = info; + + NodeDef *node = initNodeDef(); + node->set_op("NULL"); + modelParser.nodedef_map_["Const"] = node; + + ret = modelParser.EraseNormalOpOutputIfChild(scope_graph, op_node_name, normal_op_node_context); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestTensorflowParser, tensorflow_UpdateNormalOpContext) +{ + Status ret; + TensorFlowModelParser modelParser; + auto scope_graph = ge::parser::MakeShared(); + if (scope_graph == nullptr) { + GELOGE(FAILED, "Scope graph make shared failed."); + return; + } + if (scope_graph->Init() != SUCCESS) { + GELOGE(FAILED, "Scope graph init failed."); + return; + } + + const string op_node_name = "Const"; + OpNodeContext normal_op_node_context; + normal_op_node_context.input_map[op_node_name].push_back({0, 0}); + + ge::ScopeFusionOpInfo info; + info.node_name = "node_name"; + info.fusion_node_name = "fusion_node_name"; + info.fusion_op_type = "fusion_op_type"; + info.description = "description"; + info.scope_pass = "scope_pass"; + modelParser.fusion_op_children_["Const"] = info; + + NodeDef *node = initNodeDef(); + node->set_op("NULL"); + modelParser.nodedef_map_["Const"] = node; + + ret = modelParser.UpdateNormalOpContext(scope_graph, op_node_name, normal_op_node_context); + EXPECT_EQ(ret, SUCCESS); + + node->set_op("Const"); + modelParser.nodedef_map_["Const"] = node; + + ret = modelParser.UpdateNormalOpContext(scope_graph, op_node_name, normal_op_node_context); + EXPECT_EQ(ret, SUCCESS); + + delete node; +} + +TEST_F(UtestTensorflowParser, tensorflow_OptimizeTranspose) +{ + TensorFlowModelParser modelParser; + DelTransposeInfo info; + info.node_def = new NodeDef(); + info.nextNodeDef = new NodeDef(); + info.node_def->add_input("ge"); + info.nextNodeDef->add_input("ge"); + info.inputIdx = 0; + std::map transposeInfo = {{"ge", info}}; + modelParser.OptimizeTranspose(transposeInfo); + + delete info.node_def; + delete info.nextNodeDef; +} + } // namespace ge