Browse Source

st

pull/544/head
xueteng 3 years ago
parent
commit
29633d5683
1 changed files with 59 additions and 0 deletions
  1. +59
    -0
      tests/st/testcase/test_tensorflow_parser.cc

+ 59
- 0
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -2663,6 +2663,13 @@ TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test)
Status ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
EXPECT_EQ(ret, INTERNAL_ERROR);


// op_node_context for fusion op
ge::OpNodeContext op_node_context;
op_node_context.input_map["pre_node_a"].push_back({0, 0});
op_node_context.input_map["pre_node_b"].push_back({0, 1});
tensorflow_parser.op_node_context_map_[fusion_op_name] = op_node_context;

GraphDef graph;
curr_node_name = "pre_node_a";
nodedef_map.emplace("pre_node_a", node_def);
@@ -2673,6 +2680,57 @@ TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test)
delete node_def;
}

TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test1)
{
TensorFlowModelParser model_parser;
NodeDef *node_def = new NodeDef();
node_def->set_name("Placeholder");
node_def->set_op("Placeholder_0");
std::map<string, NodeDef *> nodedef_map;

curr_node_name = "pre_node_a";
nodedef_map.emplace("pre_node_b", node_def);
node_def->set_op("pre_node_a");
GenOriginContext(&model_parser, curr_node_name);
ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
EXPECT_EQ(ret, INTERNAL_ERROR);
delete node_def;
}

TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test1)
{
TensorFlowModelParser model_parser;
NodeDef *node_def = new NodeDef();
node_def->set_name("Placeholder");
node_def->set_op("Placeholder_0");
std::map<string, NodeDef *> nodedef_map;

curr_node_name = "pre_node_a";
nodedef_map.emplace("pre_node_a", node_def);
node_def->set_op("pre_node_a");
GenOriginContext(&model_parser, curr_node_name);
ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
EXPECT_EQ(ret, INTERNAL_ERROR);
delete node_def;
}

TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test1)
{
TensorFlowModelParser model_parser;
NodeDef *node_def = new NodeDef();
node_def->set_name("Retval_1");
node_def->set_op("_Retval");
std::map<string, NodeDef *> nodedef_map;

curr_node_name = "pre_node_a";
nodedef_map.emplace("pre_node_b", node_def);
node_def->set_op("pre_node_a");
GenOriginContext(&model_parser, curr_node_name);
ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
EXPECT_EQ(ret, SUCCESS);
delete node_def;
}

TEST_F(STestTensorflowParser, tensorflow_OptimizeSnapShot_test)
{
TensorFlowModelParser model_parser;
@@ -2863,6 +2921,7 @@ TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test)
Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize);
EXPECT_EQ(ret, ge::PARAM_INVALID);
}

TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) {
std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/");


Loading…
Cancel
Save