diff --git a/tests/st/testcase/test_tensorflow_parser.cc b/tests/st/testcase/test_tensorflow_parser.cc index 27199b1..549e85d 100644 --- a/tests/st/testcase/test_tensorflow_parser.cc +++ b/tests/st/testcase/test_tensorflow_parser.cc @@ -2663,6 +2663,13 @@ TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test) Status ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); EXPECT_EQ(ret, INTERNAL_ERROR); + + // op_node_context for fusion op + ge::OpNodeContext op_node_context; + op_node_context.input_map["pre_node_a"].push_back({0, 0}); + op_node_context.input_map["pre_node_b"].push_back({0, 1}); + tensorflow_parser.op_node_context_map_[fusion_op_name] = op_node_context; + GraphDef graph; curr_node_name = "pre_node_a"; nodedef_map.emplace("pre_node_a", node_def); @@ -2673,6 +2680,57 @@ TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test) delete node_def; } +TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test1) +{ + TensorFlowModelParser model_parser; + NodeDef *node_def = new NodeDef(); + node_def->set_name("Placeholder"); + node_def->set_op("Placeholder_0"); + std::map nodedef_map; + + curr_node_name = "pre_node_a"; + nodedef_map.emplace("pre_node_b", node_def); + node_def->set_op("pre_node_a"); + GenOriginContext(&model_parser, curr_node_name); + ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); + EXPECT_EQ(ret, INTERNAL_ERROR); + delete node_def; +} + +TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test1) +{ + TensorFlowModelParser model_parser; + NodeDef *node_def = new NodeDef(); + node_def->set_name("Placeholder"); + node_def->set_op("Placeholder_0"); + std::map nodedef_map; + + curr_node_name = "pre_node_a"; + nodedef_map.emplace("pre_node_a", node_def); + node_def->set_op("pre_node_a"); + GenOriginContext(&model_parser, curr_node_name); + ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); + EXPECT_EQ(ret, INTERNAL_ERROR); + delete node_def; +} + +TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test1) +{ + TensorFlowModelParser model_parser; + NodeDef *node_def = new NodeDef(); + node_def->set_name("Retval_1"); + node_def->set_op("_Retval"); + std::map nodedef_map; + + curr_node_name = "pre_node_a"; + nodedef_map.emplace("pre_node_b", node_def); + node_def->set_op("pre_node_a"); + GenOriginContext(&model_parser, curr_node_name); + ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); + EXPECT_EQ(ret, SUCCESS); + delete node_def; +} + TEST_F(STestTensorflowParser, tensorflow_OptimizeSnapShot_test) { TensorFlowModelParser model_parser; @@ -2863,6 +2921,7 @@ TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test) Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize); EXPECT_EQ(ret, ge::PARAM_INVALID); } + TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { std::string caseDir = __FILE__; std::size_t idx = caseDir.find_last_of("/");