| @@ -76,6 +76,7 @@ | |||||
| #include "parser/tensorflow/graph_optimizer.h" | #include "parser/tensorflow/graph_optimizer.h" | ||||
| #include "metadef/inc/register/scope/scope_pass_registry_impl.h" | #include "metadef/inc/register/scope/scope_pass_registry_impl.h" | ||||
| #include "register/scope/scope_fusion_pass_register.h" | #include "register/scope/scope_fusion_pass_register.h" | ||||
| #include "common/op_map.h" | |||||
| #undef protected | #undef protected | ||||
| #undef private | #undef private | ||||
| @@ -90,6 +91,15 @@ using namespace google::protobuf; | |||||
| static const string GRAPH_DEFAULT_NAME = "default"; | static const string GRAPH_DEFAULT_NAME = "default"; | ||||
| namespace ge { | namespace ge { | ||||
| struct DelTransposeInfo { | |||||
| domi::tensorflow::NodeDef *node_def; // transpose | |||||
| domi::tensorflow::NodeDef *nextNodeDef; // transpose --> [next] | |||||
| int inputIdx; | |||||
| }; | |||||
| Status GetTransposeInfo(GraphDef *graph_def, std::map<std::string, std::string> &softmaxInfo, | |||||
| std::map<std::string, DelTransposeInfo> &transposeInfo); | |||||
| class UtestTensorflowParser : public testing::Test { | class UtestTensorflowParser : public testing::Test { | ||||
| protected: | protected: | ||||
| void SetUp() { | void SetUp() { | ||||
| @@ -161,12 +171,6 @@ void UtestTensorflowParser::RegisterCustomOp() { | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
| } | } | ||||
| struct DelTransposeInfo { | |||||
| domi::tensorflow::NodeDef *node_def; // transpose | |||||
| domi::tensorflow::NodeDef *nextNodeDef; // transpose --> [next] | |||||
| int inputIdx; | |||||
| }; | |||||
| namespace { | namespace { | ||||
| NodeDef* AddNode(GraphDef& graph, string type, string name) { | NodeDef* AddNode(GraphDef& graph, string type, string name) { | ||||
| NodeDef* nodeDef = graph.add_node(); | NodeDef* nodeDef = graph.add_node(); | ||||
| @@ -4492,6 +4496,8 @@ TEST_F(UtestTensorflowParser, tensorflow_EraseNormalOpOutputIfChild) | |||||
| ret = modelParser.EraseNormalOpOutputIfChild(scope_graph, op_node_name, normal_op_node_context); | ret = modelParser.EraseNormalOpOutputIfChild(scope_graph, op_node_name, normal_op_node_context); | ||||
| EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
| delete node; | |||||
| } | } | ||||
| TEST_F(UtestTensorflowParser, tensorflow_UpdateNormalOpContext) | TEST_F(UtestTensorflowParser, tensorflow_UpdateNormalOpContext) | ||||
| @@ -4552,4 +4558,64 @@ TEST_F(UtestTensorflowParser, tensorflow_OptimizeTranspose) | |||||
| delete info.nextNodeDef; | delete info.nextNodeDef; | ||||
| } | } | ||||
| TEST_F(UtestTensorflowParser, tensorflow_SoftmaxAddAttr) | |||||
| { | |||||
| TensorFlowModelParser modelParser; | |||||
| domi::tensorflow::GraphDef graph_def; | |||||
| graph_def.add_node(); | |||||
| modelParser.SoftmaxAddAttr(&graph_def); | |||||
| } | |||||
| TEST_F(UtestTensorflowParser, tensorflow_InferInputFormats) | |||||
| { | |||||
| domiTensorFormat_t ret; | |||||
| TensorFlowModelParser modelParser; | |||||
| GetParserContext().format = DOMI_TENSOR_RESERVED; | |||||
| NodeDef *node = MallocNodeDef("node", "DATA"); | |||||
| modelParser.nodedef_map_["node"] = node; | |||||
| tensorflow_op_map["DATA"] = "node"; | |||||
| ret = modelParser.InferInputFormats(); | |||||
| EXPECT_EQ(ret, domi::DOMI_TENSOR_NHWC); | |||||
| delete node; | |||||
| NodeDef* node1 = nullptr; | |||||
| modelParser.nodedef_map_["node"] = node1; | |||||
| ret = modelParser.InferInputFormats(); | |||||
| EXPECT_EQ(ret, domi::DOMI_TENSOR_RESERVED); | |||||
| } | |||||
| TEST_F(UtestTensorflowParser, tensorflow_GetTransposeInfo) | |||||
| { | |||||
| Status ret; | |||||
| DelTransposeInfo info; | |||||
| tensorflow::GraphDef *graph = new tensorflow::GraphDef(); | |||||
| std::map<std::string, std::string> softmaxInfo = {{"ge", "ge"}}; | |||||
| info.node_def = new NodeDef(); | |||||
| info.nextNodeDef = new NodeDef(); | |||||
| info.node_def->add_input("ge"); | |||||
| info.nextNodeDef->add_input("ge"); | |||||
| info.inputIdx = 0; | |||||
| NodeDef *node = graph->add_node(); | |||||
| node->set_op("Transpose"); | |||||
| std::map<std::string, DelTransposeInfo> transposeInfo = {{"Softmax", info}}; | |||||
| ret = ge::GetTransposeInfo(graph, softmaxInfo, transposeInfo); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| node->set_op("Softmax"); | |||||
| node->set_name("Softmax"); | |||||
| node->add_input("Softmax"); | |||||
| ret = ge::GetTransposeInfo(graph, softmaxInfo, transposeInfo); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| delete info.node_def; | |||||
| delete info.nextNodeDef; | |||||
| delete graph; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||