|
|
|
@@ -30,21 +30,21 @@ |
|
|
|
#include "graph/debug/ge_attr_define.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
class UtestGraphTransposeTransDataPass : public testing::Test { |
|
|
|
class UtestGraphPassesTransposeTransDataPass : public testing::Test { |
|
|
|
protected: |
|
|
|
void SetUp() {} |
|
|
|
void TearDown() {} |
|
|
|
}; |
|
|
|
|
|
|
|
static ComputeGraphPtr BuildGraphTranposeD() { |
|
|
|
static ComputeGraphPtr BuildGraphTransposeD() { |
|
|
|
auto builder = ut::GraphBuilder("g1"); |
|
|
|
auto transdata1 = builder.AddNode("transdata1", "TransData", 1, 1, FORMAT_NC1HWC0, DT_FLOAT, std::vector<int64_t>({1, 1, 224, 224, 16})); |
|
|
|
transdata1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); |
|
|
|
transdata1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 1, 224, 224, 3}))); |
|
|
|
|
|
|
|
auto transpose1 = builder.AddNode("transpose1", "TransposeD", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({1, 3, 224, 224})); |
|
|
|
transpose1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); |
|
|
|
transpose1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 1, 224, 224, 3}))); |
|
|
|
transpose1->GetOpDesc()->MutableIntputDesc(0)->SetFormat(FORMAT_NHWC); |
|
|
|
transpose1->GetOpDesc()->MutableIntputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 1, 224, 224, 3}))); |
|
|
|
|
|
|
|
auto transdata2 = builder.AddNode("transdata2", "TransData", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({1, 3, 224, 224})); |
|
|
|
transdata2->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NC1HWC0); |
|
|
|
@@ -56,12 +56,12 @@ static ComputeGraphPtr BuildGraphTranposeD() { |
|
|
|
return builder.GetGraph(); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(UtestGraphTransposeTransDataPass, test_run) { |
|
|
|
auto compute_graph = BuildGraphTranposeD(); |
|
|
|
TEST_F(UtestGraphPassesTransposeTransDataPass, test_run) { |
|
|
|
auto compute_graph = BuildGraphTransposeD(); |
|
|
|
compute_graph->SetSessionID(0); |
|
|
|
|
|
|
|
auto transpose = compute_graph->FindNode("transpose1"); |
|
|
|
TransposeTransdataPass pass; |
|
|
|
TransposeTransDataPass pass; |
|
|
|
EXPECT_EQ(pass.Run(transpose), SUCCESS); |
|
|
|
} |
|
|
|
} // namespace ge |