From 44fc429ef5ef1e65dcb767301ef624929adb4be7 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 4 Mar 2021 17:16:11 +0800 Subject: [PATCH] Add ut. --- tests/ut/ge/CMakeLists.txt | 1 + .../transpose_transdata_pass_unittest.cc | 67 +++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 1df848d5..fc5383c3 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -689,6 +689,7 @@ set(PASS_TEST_FILES "graph/passes/no_use_reshape_remove_pass_unittest.cc" "graph/passes/infershape_pass_unittest.cc" "graph/passes/multi_batch_clone_pass_unittest.cc" + "graph/passes/transpose_transdata_pass_unittest.cc" ) set(KERNEL_TEST_FILES diff --git a/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc b/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc new file mode 100644 index 00000000..08fdca57 --- /dev/null +++ b/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#define protected public +#define private public +#include "graph/passes/transpose_transdata_pass.h" +#include "graph_builder_utils.h" +#undef private +#undef protected + +#include "graph/graph.h" +#include "common/ge_inner_error_codes.h" +#include "common/types.h" +#include "graph/debug/ge_attr_define.h" + +namespace ge { +class UtestGraphTransposeTransDataPass : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +static ComputeGraphPtr BuildGraphTranposeD() { + auto builder = ut::GraphBuilder("g1"); + auto transdata1 = builder.AddNode("transdata1", "TransData", 1, 1, FORMAT_NC1HWC0, DT_FLOAT, std::vector({1, 1, 224, 224, 16})); + transdata1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); + transdata1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector({1, 1, 224, 224, 3}))); + + auto transpose1 = builder.AddNode("transpose1", "TransposeD", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector({1, 3, 224, 224})); + transpose1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); + transpose1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector({1, 1, 224, 224, 3}))); + + auto transdata2 = builder.AddNode("transdata2", "TransData", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector({1, 3, 224, 224})); + transdata2->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NC1HWC0); + transdata2->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector({1, 1, 224, 224, 16}))); + + builder.AddDataEdge(transdata1, 0, transpose1, 0); + builder.AddDataEdge(transpose1, 0, transdata2, 0); + + return builder.GetGraph(); +} + +TEST_F(UtestGraphTransposeTransDataPass, test_run) { + auto compute_graph = BuildGraphTranposeD(); + compute_graph->SetSessionID(0); + + auto transpose = compute_graph->FindNode("transpose1"); + TransposeTransdataPass pass; + EXPECT_EQ(pass.Run(transpose), SUCCESS); +} +} // namespace ge