diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc index 55cdfe2..feff339 100644 --- a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc @@ -76,6 +76,7 @@ #include "parser/tensorflow/graph_optimizer.h" #include "metadef/inc/register/scope/scope_pass_registry_impl.h" #include "register/scope/scope_fusion_pass_register.h" +#include "common/op_map.h" #undef protected #undef private @@ -90,6 +91,15 @@ using namespace google::protobuf; static const string GRAPH_DEFAULT_NAME = "default"; namespace ge { +struct DelTransposeInfo { + domi::tensorflow::NodeDef *node_def; // transpose + domi::tensorflow::NodeDef *nextNodeDef; // transpose --> [next] + int inputIdx; +}; + +Status GetTransposeInfo(GraphDef *graph_def, std::map &softmaxInfo, + std::map &transposeInfo); + class UtestTensorflowParser : public testing::Test { protected: void SetUp() { @@ -161,12 +171,6 @@ void UtestTensorflowParser::RegisterCustomOp() { domi::OpRegistry::Instance()->registrationDatas.clear(); } -struct DelTransposeInfo { - domi::tensorflow::NodeDef *node_def; // transpose - domi::tensorflow::NodeDef *nextNodeDef; // transpose --> [next] - int inputIdx; -}; - namespace { NodeDef* AddNode(GraphDef& graph, string type, string name) { NodeDef* nodeDef = graph.add_node(); @@ -4492,6 +4496,8 @@ TEST_F(UtestTensorflowParser, tensorflow_EraseNormalOpOutputIfChild) ret = modelParser.EraseNormalOpOutputIfChild(scope_graph, op_node_name, normal_op_node_context); EXPECT_EQ(ret, SUCCESS); + + delete node; } TEST_F(UtestTensorflowParser, tensorflow_UpdateNormalOpContext) @@ -4552,4 +4558,64 @@ TEST_F(UtestTensorflowParser, tensorflow_OptimizeTranspose) delete info.nextNodeDef; } +TEST_F(UtestTensorflowParser, tensorflow_SoftmaxAddAttr) +{ + TensorFlowModelParser modelParser; + domi::tensorflow::GraphDef graph_def; + graph_def.add_node(); + modelParser.SoftmaxAddAttr(&graph_def); +} + +TEST_F(UtestTensorflowParser, tensorflow_InferInputFormats) +{ + domiTensorFormat_t ret; + TensorFlowModelParser modelParser; + + GetParserContext().format = DOMI_TENSOR_RESERVED; + + NodeDef *node = MallocNodeDef("node", "DATA"); + modelParser.nodedef_map_["node"] = node; + tensorflow_op_map["DATA"] = "node"; + ret = modelParser.InferInputFormats(); + EXPECT_EQ(ret, domi::DOMI_TENSOR_NHWC); + delete node; + + NodeDef* node1 = nullptr; + modelParser.nodedef_map_["node"] = node1; + + ret = modelParser.InferInputFormats(); + EXPECT_EQ(ret, domi::DOMI_TENSOR_RESERVED); +} + +TEST_F(UtestTensorflowParser, tensorflow_GetTransposeInfo) +{ + Status ret; + DelTransposeInfo info; + tensorflow::GraphDef *graph = new tensorflow::GraphDef(); + std::map softmaxInfo = {{"ge", "ge"}}; + + info.node_def = new NodeDef(); + info.nextNodeDef = new NodeDef(); + info.node_def->add_input("ge"); + info.nextNodeDef->add_input("ge"); + info.inputIdx = 0; + + NodeDef *node = graph->add_node(); + node->set_op("Transpose"); + + std::map transposeInfo = {{"Softmax", info}}; + ret = ge::GetTransposeInfo(graph, softmaxInfo, transposeInfo); + EXPECT_EQ(ret, SUCCESS); + + node->set_op("Softmax"); + node->set_name("Softmax"); + node->add_input("Softmax"); + ret = ge::GetTransposeInfo(graph, softmaxInfo, transposeInfo); + EXPECT_EQ(ret, SUCCESS); + + delete info.node_def; + delete info.nextNodeDef; + delete graph; +} + } // namespace ge