| @@ -30,21 +30,21 @@ | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| namespace ge { | namespace ge { | ||||
| class UtestGraphTransposeTransDataPass : public testing::Test { | |||||
| class UtestGraphPassesTransposeTransDataPass : public testing::Test { | |||||
| protected: | protected: | ||||
| void SetUp() {} | void SetUp() {} | ||||
| void TearDown() {} | void TearDown() {} | ||||
| }; | }; | ||||
| static ComputeGraphPtr BuildGraphTranposeD() { | |||||
| static ComputeGraphPtr BuildGraphTransposeD() { | |||||
| auto builder = ut::GraphBuilder("g1"); | auto builder = ut::GraphBuilder("g1"); | ||||
| auto transdata1 = builder.AddNode("transdata1", "TransData", 1, 1, FORMAT_NC1HWC0, DT_FLOAT, std::vector<int64_t>({1, 1, 224, 224, 16})); | auto transdata1 = builder.AddNode("transdata1", "TransData", 1, 1, FORMAT_NC1HWC0, DT_FLOAT, std::vector<int64_t>({1, 1, 224, 224, 16})); | ||||
| transdata1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); | transdata1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); | ||||
| transdata1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 1, 224, 224, 3}))); | transdata1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 1, 224, 224, 3}))); | ||||
| auto transpose1 = builder.AddNode("transpose1", "TransposeD", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({1, 3, 224, 224})); | auto transpose1 = builder.AddNode("transpose1", "TransposeD", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({1, 3, 224, 224})); | ||||
| transpose1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); | |||||
| transpose1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 1, 224, 224, 3}))); | |||||
| transpose1->GetOpDesc()->MutableIntputDesc(0)->SetFormat(FORMAT_NHWC); | |||||
| transpose1->GetOpDesc()->MutableIntputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 1, 224, 224, 3}))); | |||||
| auto transdata2 = builder.AddNode("transdata2", "TransData", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({1, 3, 224, 224})); | auto transdata2 = builder.AddNode("transdata2", "TransData", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({1, 3, 224, 224})); | ||||
| transdata2->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NC1HWC0); | transdata2->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NC1HWC0); | ||||
| @@ -56,12 +56,12 @@ static ComputeGraphPtr BuildGraphTranposeD() { | |||||
| return builder.GetGraph(); | return builder.GetGraph(); | ||||
| } | } | ||||
| TEST_F(UtestGraphTransposeTransDataPass, test_run) { | |||||
| auto compute_graph = BuildGraphTranposeD(); | |||||
| TEST_F(UtestGraphPassesTransposeTransDataPass, test_run) { | |||||
| auto compute_graph = BuildGraphTransposeD(); | |||||
| compute_graph->SetSessionID(0); | compute_graph->SetSessionID(0); | ||||
| auto transpose = compute_graph->FindNode("transpose1"); | auto transpose = compute_graph->FindNode("transpose1"); | ||||
| TransposeTransdataPass pass; | |||||
| TransposeTransDataPass pass; | |||||
| EXPECT_EQ(pass.Run(transpose), SUCCESS); | EXPECT_EQ(pass.Run(transpose), SUCCESS); | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||