| @@ -76,6 +76,7 @@ | |||
| #include "parser/tensorflow/graph_optimizer.h" | |||
| #include "metadef/inc/register/scope/scope_pass_registry_impl.h" | |||
| #include "register/scope/scope_fusion_pass_register.h" | |||
| #include "common/op_map.h" | |||
| #undef protected | |||
| #undef private | |||
| @@ -90,6 +91,15 @@ using namespace google::protobuf; | |||
| static const string GRAPH_DEFAULT_NAME = "default"; | |||
| namespace ge { | |||
| struct DelTransposeInfo { | |||
| domi::tensorflow::NodeDef *node_def; // transpose | |||
| domi::tensorflow::NodeDef *nextNodeDef; // transpose --> [next] | |||
| int inputIdx; | |||
| }; | |||
| Status GetTransposeInfo(GraphDef *graph_def, std::map<std::string, std::string> &softmaxInfo, | |||
| std::map<std::string, DelTransposeInfo> &transposeInfo); | |||
| class UtestTensorflowParser : public testing::Test { | |||
| protected: | |||
| void SetUp() { | |||
| @@ -161,12 +171,6 @@ void UtestTensorflowParser::RegisterCustomOp() { | |||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | |||
| } | |||
| struct DelTransposeInfo { | |||
| domi::tensorflow::NodeDef *node_def; // transpose | |||
| domi::tensorflow::NodeDef *nextNodeDef; // transpose --> [next] | |||
| int inputIdx; | |||
| }; | |||
| namespace { | |||
| NodeDef* AddNode(GraphDef& graph, string type, string name) { | |||
| NodeDef* nodeDef = graph.add_node(); | |||
| @@ -4492,6 +4496,8 @@ TEST_F(UtestTensorflowParser, tensorflow_EraseNormalOpOutputIfChild) | |||
| ret = modelParser.EraseNormalOpOutputIfChild(scope_graph, op_node_name, normal_op_node_context); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| delete node; | |||
| } | |||
| TEST_F(UtestTensorflowParser, tensorflow_UpdateNormalOpContext) | |||
| @@ -4552,4 +4558,64 @@ TEST_F(UtestTensorflowParser, tensorflow_OptimizeTranspose) | |||
| delete info.nextNodeDef; | |||
| } | |||
| TEST_F(UtestTensorflowParser, tensorflow_SoftmaxAddAttr) | |||
| { | |||
| TensorFlowModelParser modelParser; | |||
| domi::tensorflow::GraphDef graph_def; | |||
| graph_def.add_node(); | |||
| modelParser.SoftmaxAddAttr(&graph_def); | |||
| } | |||
| TEST_F(UtestTensorflowParser, tensorflow_InferInputFormats) | |||
| { | |||
| domiTensorFormat_t ret; | |||
| TensorFlowModelParser modelParser; | |||
| GetParserContext().format = DOMI_TENSOR_RESERVED; | |||
| NodeDef *node = MallocNodeDef("node", "DATA"); | |||
| modelParser.nodedef_map_["node"] = node; | |||
| tensorflow_op_map["DATA"] = "node"; | |||
| ret = modelParser.InferInputFormats(); | |||
| EXPECT_EQ(ret, domi::DOMI_TENSOR_NHWC); | |||
| delete node; | |||
| NodeDef* node1 = nullptr; | |||
| modelParser.nodedef_map_["node"] = node1; | |||
| ret = modelParser.InferInputFormats(); | |||
| EXPECT_EQ(ret, domi::DOMI_TENSOR_RESERVED); | |||
| } | |||
| TEST_F(UtestTensorflowParser, tensorflow_GetTransposeInfo) | |||
| { | |||
| Status ret; | |||
| DelTransposeInfo info; | |||
| tensorflow::GraphDef *graph = new tensorflow::GraphDef(); | |||
| std::map<std::string, std::string> softmaxInfo = {{"ge", "ge"}}; | |||
| info.node_def = new NodeDef(); | |||
| info.nextNodeDef = new NodeDef(); | |||
| info.node_def->add_input("ge"); | |||
| info.nextNodeDef->add_input("ge"); | |||
| info.inputIdx = 0; | |||
| NodeDef *node = graph->add_node(); | |||
| node->set_op("Transpose"); | |||
| std::map<std::string, DelTransposeInfo> transposeInfo = {{"Softmax", info}}; | |||
| ret = ge::GetTransposeInfo(graph, softmaxInfo, transposeInfo); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| node->set_op("Softmax"); | |||
| node->set_name("Softmax"); | |||
| node->add_input("Softmax"); | |||
| ret = ge::GetTransposeInfo(graph, softmaxInfo, transposeInfo); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| delete info.node_def; | |||
| delete info.nextNodeDef; | |||
| delete graph; | |||
| } | |||
| } // namespace ge | |||