From b75e14d2149e08520abcfd527f854ef3085007a2 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 3 Mar 2021 17:02:23 +0800 Subject: [PATCH 01/10] Change check_supported interface. --- ge/engine_manager/dnnengine_manager.cc | 2 +- ge/generator/ge_generator.cc | 27 ++++++++++++--------- ge/graph/passes/cast_translate_pass.cc | 7 +++--- ge/graph/passes/cast_translate_pass.h | 2 +- ge/graph/passes/compile_nodes_pass.cc | 13 +++------- ge/graph/passes/compile_nodes_pass.h | 2 +- ge/graph/passes/transpose_transdata_pass.cc | 12 +++++---- ge/graph/passes/transpose_transdata_pass.h | 4 +-- 8 files changed, 36 insertions(+), 33 deletions(-) diff --git a/ge/engine_manager/dnnengine_manager.cc b/ge/engine_manager/dnnengine_manager.cc index b23993b6..7ff5ed42 100644 --- a/ge/engine_manager/dnnengine_manager.cc +++ b/ge/engine_manager/dnnengine_manager.cc @@ -217,7 +217,7 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) { std::string unsupported_reason; // It will be replaced by engine' checksupport uint64_t start_time = GetCurrentTimestamp(); - if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { + if (kernel_info_store->second->CheckSupported(node_ptr, unsupported_reason)) { checksupport_cost_[kernel_name] += GetCurrentTimestamp() - start_time; op_desc->SetOpEngineName(it.engine); op_desc->SetOpKernelLibName(kernel_name); diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 32d9e5a1..6023b1fb 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -66,7 +66,8 @@ bool ContainsDynamicInpus(const ge::OpDesc &op_desc) { } // namespace namespace ge { -static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engine_type) { +static Status CheckEngineTypeSupport(const Nodeptr &node, OpEngineType engine_type) { + const OpDescPtr &op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); if (engine_type == ENGINE_SYS) { GELOGI("CheckEngineType: use default engine."); @@ -123,7 +124,7 @@ static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engi auto kernel_info_store = kernel_map.find(kernel_name); if (kernel_info_store != kernel_map.end()) { std::string unsupported_reason; - if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { + if (kernel_info_store->second->CheckSupported(node, unsupported_reason)) { op_desc->SetOpEngineName(op_engine_name); op_desc->SetOpKernelLibName(kernel_name); GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), @@ -692,22 +693,26 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc); GE_CHECK_NOTNULL(op_desc_tmp); - // 1. check engine type when compile online + // 1. Create ComputeGraph. + string name = ge::CurrentTimeInStr() + "_" + model_file_name; + Graph graph; + if (BuildSingleOpGraph(op_desc, inputs, outputs, name, graph) != ge::SUCCESS) { + GELOGE(GRAPH_FAILED, "make graph fail."); + return GRAPH_FAILED; + } + + // 2. check engine type when compile online if (model_file_name == kFileNameSuffix) { - Status ret = CheckEngineTypeSupport(op_desc, engine_type); + auto comp_graph = GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(comp_graph); + auto node = comp_graph->FindNode(op_desc->GetName()); + Status ret = CheckEngineTypeSupport(node, engine_type); if (ret != SUCCESS) { GELOGE(ret, "check engine type failed."); return ret; } } - // 2. Create ComputeGraph. - string name = ge::CurrentTimeInStr() + "_" + model_file_name; - Graph graph; - if (BuildSingleOpGraph(op_desc, inputs, outputs, name, graph) != ge::SUCCESS) { - GELOGE(GRAPH_FAILED, "make graph fail."); - return GRAPH_FAILED; - } GELOGI("ATC parser success in single op build."); GeRootModelPtr ge_root_model = nullptr; diff --git a/ge/graph/passes/cast_translate_pass.cc b/ge/graph/passes/cast_translate_pass.cc index 01b5c96b..2e95c19f 100644 --- a/ge/graph/passes/cast_translate_pass.cc +++ b/ge/graph/passes/cast_translate_pass.cc @@ -167,7 +167,7 @@ bool CastTranslatePass::IsOpSupportedOptimize(NodePtr &cast_node, NodePtr &trans trans_op_outdesc->SetDataType(cast_out_datatype); } - if (!TranslateCheckAccuracySupported(trans_op_desc)) { + if (!TranslateCheckAccuracySupported(trans_node)) { if (is_src_cast) { trans_op_desc->MutableInputDesc(0)->SetDataType(trans_in_datatype); } else { @@ -271,7 +271,8 @@ Status CastTranslatePass::FuseDstNTranslates(NodePtr &node) { return SUCCESS; } -bool CastTranslatePass::TranslateCheckAccuracySupported(const OpDescPtr &op_desc) { +bool CastTranslatePass::TranslateCheckAccuracySupported(NodePtr &node) { + const OpDescPtr &op_desc = node->GetOpDesc(); std::shared_ptr instance_ptr = ge::GELib::GetInstance(); if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { GELOGW("GE is not initialized or is finalized."); @@ -293,7 +294,7 @@ bool CastTranslatePass::TranslateCheckAccuracySupported(const OpDescPtr &op_desc auto kernel_info_store = kernel_map.find(kernel_name); if (kernel_info_store != kernel_map.end()) { if (kernel_info_store->second != nullptr && - kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason)) { + kernel_info_store->second->CheckAccuracySupported(node, unsupported_reason)) { return true; } } diff --git a/ge/graph/passes/cast_translate_pass.h b/ge/graph/passes/cast_translate_pass.h index 04c03d42..5c1dcd9a 100755 --- a/ge/graph/passes/cast_translate_pass.h +++ b/ge/graph/passes/cast_translate_pass.h @@ -35,7 +35,7 @@ class CastTranslatePass : public BaseNodePass { bool IsOpSupportedOptimize(NodePtr &cast_node, NodePtr &trans_node, bool &is_src_cast); bool CheckOpSupportOptimize(NodePtr &node, bool &is_src_cast); Status FuseDstNTranslates(NodePtr &node); - bool TranslateCheckAccuracySupported(const OpDescPtr &op_desc); + bool TranslateCheckAccuracySupported(NodePtr &node); }; } // namespace ge #endif // GE_GRAPH_PASSES_CAST_TRANSLATE_PASS_H_ diff --git a/ge/graph/passes/compile_nodes_pass.cc b/ge/graph/passes/compile_nodes_pass.cc index 1ed9caf0..7de7fd48 100755 --- a/ge/graph/passes/compile_nodes_pass.cc +++ b/ge/graph/passes/compile_nodes_pass.cc @@ -110,7 +110,7 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: return ge::GE_GRAPH_PARAM_NULLPTR; } // begin accuracy supported check - if (!CheckAccuracySupport(kernel_info, instance, op_desc)) { + if (!CheckAccuracySupport(kernel_info, instance, node)) { // if check accuracy support failed , try to go to other engine. GELOGD("Check Accuracy Supported return not support, node name is %s. Try to go to other engine.", op_desc->GetName().c_str()); @@ -123,7 +123,7 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: continue; } OpsKernelInfoStorePtr tmp_kernel_info = it->second; - if (CheckAccuracySupport(tmp_kernel_info, instance, op_desc)) { + if (CheckAccuracySupport(tmp_kernel_info, instance, node)) { kernel_lib_name = tmp_kernel_name; GELOGD("Find kernel lib %s support node:%s, type:%s , get kernel lib success.", tmp_kernel_name.c_str(), node->GetName().c_str(), op_desc->GetType().c_str()); @@ -138,14 +138,9 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: } bool CompileNodesPass::CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_info, - const std::shared_ptr instance, OpDescPtr &op_desc) { - auto ge_desc = MakeShared(op_desc); - if (ge_desc == nullptr) { - GELOGE(GE_GRAPH_MEMORY_ALLOC_FAILED, "Fail to malloc op desc."); - return false; - } + const std::shared_ptr instance, const NodePtr &node) { string reason; - if (!(kernel_info->CheckAccuracySupported(*ge_desc, reason, true))) { + if (!(kernel_info->CheckAccuracySupported(node, reason, true))) { return false; } return true; diff --git a/ge/graph/passes/compile_nodes_pass.h b/ge/graph/passes/compile_nodes_pass.h index e2fb59c2..e9a77e07 100644 --- a/ge/graph/passes/compile_nodes_pass.h +++ b/ge/graph/passes/compile_nodes_pass.h @@ -39,7 +39,7 @@ class CompileNodesPass : public GraphPass { private: graphStatus GetSupportedKernel(const NodePtr &node, const std::shared_ptr instance, string &kernel_lib_name); bool CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_info, const std::shared_ptr instance, - OpDescPtr &op_desc); + const NodePtr &node); graphStatus CompileNodes(const std::shared_ptr instance, std::unordered_map> &kernel_to_compile_nodes); }; diff --git a/ge/graph/passes/transpose_transdata_pass.cc b/ge/graph/passes/transpose_transdata_pass.cc index 2178eac7..0f3e7e70 100644 --- a/ge/graph/passes/transpose_transdata_pass.cc +++ b/ge/graph/passes/transpose_transdata_pass.cc @@ -86,7 +86,7 @@ Status TransposeTransDataPass::Run(NodePtr &node) { if (CheckOneInAndOneOutDataAnchor(out_node)) { return FAILED; } - if (!FusionIfNeed(op_desc, out_op_desc)) { + if (!FusionIfNeed(node, out_op_desc)) { continue; } CopyInputEdges(node, out_node); @@ -152,7 +152,8 @@ Status TransposeTransDataPass::RemoveTranspose(NodePtr &node) { return SUCCESS; } -bool TransposeTransDataPass::FusionIfNeed(OpDescPtr &op_desc, OpDescPtr &transdata_op_desc) { +bool TransposeTransDataPass::FusionIfNeed(NodePtr &node, OpDescPtr &transdata_op_desc) { + auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(transdata_op_desc); auto out_input_desc = transdata_op_desc->MutableInputDesc(0); @@ -187,7 +188,7 @@ bool TransposeTransDataPass::FusionIfNeed(OpDescPtr &op_desc, OpDescPtr &transda out_input_desc->SetFormat(src_format); out_input_desc->SetShape(src_shape); - if (!TransDataCheckAccuracySupported(transdata_op_desc)) { + if (!TransDataCheckAccuracySupported(node)) { out_input_desc->SetFormat(out_input_format); out_input_desc->SetShape(out_input_shape); return false; @@ -224,7 +225,8 @@ void TransposeTransDataPass::CopyInputEdges(NodePtr &origin_node, NodePtr &new_n GraphUtils::CopyInCtrlEdges(origin_node, new_node) != GRAPH_SUCCESS, GELOGW("Copy in ctrl edges failed"); return); } -bool TransposeTransDataPass::TransDataCheckAccuracySupported(const OpDescPtr &op_desc) { +bool TransposeTransDataPass::TransDataCheckAccuracySupported(NodePtr &node) { + const OpDescPtr &op_desc = node->GetOpDesc(); std::shared_ptr instance_ptr = ge::GELib::GetInstance(); if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { GELOGW("GELib not initialized"); @@ -244,7 +246,7 @@ bool TransposeTransDataPass::TransDataCheckAccuracySupported(const OpDescPtr &op auto &kernel_name = it.opKernelLib; auto kernel_info_store = kernel_map.find(kernel_name); if (kernel_info_store != kernel_map.end()) { - if (kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason, true)) { + if (kernel_info_store->second->CheckAccuracySupported(node, unsupported_reason, true)) { return true; } } diff --git a/ge/graph/passes/transpose_transdata_pass.h b/ge/graph/passes/transpose_transdata_pass.h index a72893f6..ce42ba39 100644 --- a/ge/graph/passes/transpose_transdata_pass.h +++ b/ge/graph/passes/transpose_transdata_pass.h @@ -26,9 +26,9 @@ class TransposeTransDataPass : public BaseNodePass { private: Status CheckOneInAndOneOutDataAnchor(NodePtr &node) const; Status RemoveTranspose(NodePtr &node); - bool FusionIfNeed(OpDescPtr &op_desc, OpDescPtr &transdata_op_desc); + bool FusionIfNeed(NodePtr &node, OpDescPtr &transdata_op_desc); void CopyInputEdges(NodePtr &origin_node, NodePtr &new_node); - bool TransDataCheckAccuracySupported(const OpDescPtr &op_desc); + bool TransDataCheckAccuracySupported(NodePtr &node); }; } // namespace ge #endif // GE_GRAPH_PASSES_TRANSPOSE_TRANSDATA_PASS_H_ From 9cbba68a7fe80fc3ad08df485adf93b4190a9b9f Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 3 Mar 2021 17:07:51 +0800 Subject: [PATCH 02/10] Change check_supported interface. --- ge/generator/ge_generator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 6023b1fb..975ae7cd 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -66,7 +66,7 @@ bool ContainsDynamicInpus(const ge::OpDesc &op_desc) { } // namespace namespace ge { -static Status CheckEngineTypeSupport(const Nodeptr &node, OpEngineType engine_type) { +static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_type) { const OpDescPtr &op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); if (engine_type == ENGINE_SYS) { From 01238f06215aabf4b00ad855f651dd1cbc5fafc4 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 3 Mar 2021 18:54:56 +0800 Subject: [PATCH 03/10] Fit ut. --- ge/generator/ge_generator.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 975ae7cd..16233ef8 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -696,10 +696,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in // 1. Create ComputeGraph. string name = ge::CurrentTimeInStr() + "_" + model_file_name; Graph graph; - if (BuildSingleOpGraph(op_desc, inputs, outputs, name, graph) != ge::SUCCESS) { - GELOGE(GRAPH_FAILED, "make graph fail."); - return GRAPH_FAILED; - } + GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), "make graph fail."); // 2. check engine type when compile online if (model_file_name == kFileNameSuffix) { From ad108e2c3bc606909047a25664ab4b1d850d5b93 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 3 Mar 2021 20:33:04 +0800 Subject: [PATCH 04/10] Add ut. --- tests/ut/ge/generator/ge_generator_unittest.cc | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/ut/ge/generator/ge_generator_unittest.cc b/tests/ut/ge/generator/ge_generator_unittest.cc index 09ddf2ec..8fca22fa 100644 --- a/tests/ut/ge/generator/ge_generator_unittest.cc +++ b/tests/ut/ge/generator/ge_generator_unittest.cc @@ -53,26 +53,21 @@ TEST_F(UtestGeGenerator, test_build_single_op_offline) { EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, "offline_"), GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); } -/* TEST_F(UtestGeGenerator, test_build_single_op_online) { - GeTensorDesc tensor_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); - TensorUtils::SetSize(tensor_desc, 512); + GeTensorDesc tensor_desc; shared_ptr op_desc = make_shared("Add", "add"); - EXPECT_EQ(op_desc->AddInputDesc(tensor_desc), GRAPH_SUCCESS); - EXPECT_EQ(op_desc->AddInputDesc(tensor_desc), GRAPH_SUCCESS); - EXPECT_EQ(op_desc->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); + op_desc->AddInputDesc(tensor_desc); + op_desc->AddInputDesc(tensor_desc); + op_desc->AddOutputDesc(tensor_desc); GeTensor tensor(tensor_desc); const vector inputs = { tensor, tensor }; const vector outputs = { tensor }; - // not Initialize, impl is null. GeGenerator generator; generator.Initialize({}); ModelBufferData model_buffer; - EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, ENGINE_SYS, model_buffer), GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); + EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, ENGINE_AIVECTOR, model_buffer), FAILED); } -*/ - } // namespace ge From 155198f5deabf5d0e7a26f1479646944ce45abb6 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 3 Mar 2021 20:34:39 +0800 Subject: [PATCH 05/10] Add ut. --- tests/ut/ge/generator/ge_generator_unittest.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/ut/ge/generator/ge_generator_unittest.cc b/tests/ut/ge/generator/ge_generator_unittest.cc index 8fca22fa..e66cab14 100644 --- a/tests/ut/ge/generator/ge_generator_unittest.cc +++ b/tests/ut/ge/generator/ge_generator_unittest.cc @@ -55,7 +55,6 @@ TEST_F(UtestGeGenerator, test_build_single_op_offline) { TEST_F(UtestGeGenerator, test_build_single_op_online) { GeTensorDesc tensor_desc; - shared_ptr op_desc = make_shared("Add", "add"); op_desc->AddInputDesc(tensor_desc); op_desc->AddInputDesc(tensor_desc); From 9bf5113af653812a0fbe71ed4d89bc402c91a32d Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 4 Mar 2021 11:41:00 +0800 Subject: [PATCH 06/10] Fix bug. --- ge/graph/passes/transpose_transdata_pass.cc | 6 +++--- ge/graph/passes/transpose_transdata_pass.h | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ge/graph/passes/transpose_transdata_pass.cc b/ge/graph/passes/transpose_transdata_pass.cc index 0f3e7e70..810f5639 100644 --- a/ge/graph/passes/transpose_transdata_pass.cc +++ b/ge/graph/passes/transpose_transdata_pass.cc @@ -86,7 +86,7 @@ Status TransposeTransDataPass::Run(NodePtr &node) { if (CheckOneInAndOneOutDataAnchor(out_node)) { return FAILED; } - if (!FusionIfNeed(node, out_op_desc)) { + if (!FusionIfNeed(op_desc, out_node)) { continue; } CopyInputEdges(node, out_node); @@ -152,8 +152,8 @@ Status TransposeTransDataPass::RemoveTranspose(NodePtr &node) { return SUCCESS; } -bool TransposeTransDataPass::FusionIfNeed(NodePtr &node, OpDescPtr &transdata_op_desc) { - auto op_desc = node->GetOpDesc(); +bool TransposeTransDataPass::FusionIfNeed(OpDescPtr &op_desc, NodePtr &node) { + auto transdata_op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(transdata_op_desc); auto out_input_desc = transdata_op_desc->MutableInputDesc(0); diff --git a/ge/graph/passes/transpose_transdata_pass.h b/ge/graph/passes/transpose_transdata_pass.h index ce42ba39..c6ef0b36 100644 --- a/ge/graph/passes/transpose_transdata_pass.h +++ b/ge/graph/passes/transpose_transdata_pass.h @@ -26,7 +26,7 @@ class TransposeTransDataPass : public BaseNodePass { private: Status CheckOneInAndOneOutDataAnchor(NodePtr &node) const; Status RemoveTranspose(NodePtr &node); - bool FusionIfNeed(NodePtr &node, OpDescPtr &transdata_op_desc); + bool FusionIfNeed(OpDescPtr &op_desc, NodePtr &node); void CopyInputEdges(NodePtr &origin_node, NodePtr &new_node); bool TransDataCheckAccuracySupported(NodePtr &node); }; From 44fc429ef5ef1e65dcb767301ef624929adb4be7 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 4 Mar 2021 17:16:11 +0800 Subject: [PATCH 07/10] Add ut. --- tests/ut/ge/CMakeLists.txt | 1 + .../transpose_transdata_pass_unittest.cc | 67 +++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 1df848d5..fc5383c3 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -689,6 +689,7 @@ set(PASS_TEST_FILES "graph/passes/no_use_reshape_remove_pass_unittest.cc" "graph/passes/infershape_pass_unittest.cc" "graph/passes/multi_batch_clone_pass_unittest.cc" + "graph/passes/transpose_transdata_pass_unittest.cc" ) set(KERNEL_TEST_FILES diff --git a/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc b/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc new file mode 100644 index 00000000..08fdca57 --- /dev/null +++ b/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#define protected public +#define private public +#include "graph/passes/transpose_transdata_pass.h" +#include "graph_builder_utils.h" +#undef private +#undef protected + +#include "graph/graph.h" +#include "common/ge_inner_error_codes.h" +#include "common/types.h" +#include "graph/debug/ge_attr_define.h" + +namespace ge { +class UtestGraphTransposeTransDataPass : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +static ComputeGraphPtr BuildGraphTranposeD() { + auto builder = ut::GraphBuilder("g1"); + auto transdata1 = builder.AddNode("transdata1", "TransData", 1, 1, FORMAT_NC1HWC0, DT_FLOAT, std::vector({1, 1, 224, 224, 16})); + transdata1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); + transdata1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector({1, 1, 224, 224, 3}))); + + auto transpose1 = builder.AddNode("transpose1", "TransposeD", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector({1, 3, 224, 224})); + transpose1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); + transpose1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector({1, 1, 224, 224, 3}))); + + auto transdata2 = builder.AddNode("transdata2", "TransData", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector({1, 3, 224, 224})); + transdata2->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NC1HWC0); + transdata2->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector({1, 1, 224, 224, 16}))); + + builder.AddDataEdge(transdata1, 0, transpose1, 0); + builder.AddDataEdge(transpose1, 0, transdata2, 0); + + return builder.GetGraph(); +} + +TEST_F(UtestGraphTransposeTransDataPass, test_run) { + auto compute_graph = BuildGraphTranposeD(); + compute_graph->SetSessionID(0); + + auto transpose = compute_graph->FindNode("transpose1"); + TransposeTransdataPass pass; + EXPECT_EQ(pass.Run(transpose), SUCCESS); +} +} // namespace ge From 1c6bceb7d4d2402904c07ccfd881b23721eb5944 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 4 Mar 2021 17:18:27 +0800 Subject: [PATCH 08/10] Add ut. --- tests/ut/ge/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index fc5383c3..169ad0b8 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -671,6 +671,7 @@ set(PASS_TEST_FILES "graph/passes/trans_op_breadth_fusion_pass_unittest.cc" "graph/passes/trans_op_depth_fusion_pass_unittest.cc" "graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc" + "graph/passes/transpose_transdata_pass_unittest.cc" "graph/passes/constant_folding_pass_unittest.cc" "graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc" "graph/passes/stop_gradient_pass_unittest.cc" @@ -689,7 +690,6 @@ set(PASS_TEST_FILES "graph/passes/no_use_reshape_remove_pass_unittest.cc" "graph/passes/infershape_pass_unittest.cc" "graph/passes/multi_batch_clone_pass_unittest.cc" - "graph/passes/transpose_transdata_pass_unittest.cc" ) set(KERNEL_TEST_FILES From 7d1ea4564c43e1ecf91bc6ecb5f195e0a11d5935 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 4 Mar 2021 17:24:47 +0800 Subject: [PATCH 09/10] Add ut. --- .../passes/transpose_transdata_pass_unittest.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc b/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc index 08fdca57..7fc5973a 100644 --- a/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc @@ -30,21 +30,21 @@ #include "graph/debug/ge_attr_define.h" namespace ge { -class UtestGraphTransposeTransDataPass : public testing::Test { +class UtestGraphPassesTransposeTransDataPass : public testing::Test { protected: void SetUp() {} void TearDown() {} }; -static ComputeGraphPtr BuildGraphTranposeD() { +static ComputeGraphPtr BuildGraphTransposeD() { auto builder = ut::GraphBuilder("g1"); auto transdata1 = builder.AddNode("transdata1", "TransData", 1, 1, FORMAT_NC1HWC0, DT_FLOAT, std::vector({1, 1, 224, 224, 16})); transdata1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); transdata1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector({1, 1, 224, 224, 3}))); auto transpose1 = builder.AddNode("transpose1", "TransposeD", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector({1, 3, 224, 224})); - transpose1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); - transpose1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector({1, 1, 224, 224, 3}))); + transpose1->GetOpDesc()->MutableIntputDesc(0)->SetFormat(FORMAT_NHWC); + transpose1->GetOpDesc()->MutableIntputDesc(0)->SetShape(GeShape(std::vector({1, 1, 224, 224, 3}))); auto transdata2 = builder.AddNode("transdata2", "TransData", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector({1, 3, 224, 224})); transdata2->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NC1HWC0); @@ -56,12 +56,12 @@ static ComputeGraphPtr BuildGraphTranposeD() { return builder.GetGraph(); } -TEST_F(UtestGraphTransposeTransDataPass, test_run) { - auto compute_graph = BuildGraphTranposeD(); +TEST_F(UtestGraphPassesTransposeTransDataPass, test_run) { + auto compute_graph = BuildGraphTransposeD(); compute_graph->SetSessionID(0); auto transpose = compute_graph->FindNode("transpose1"); - TransposeTransdataPass pass; + TransposeTransDataPass pass; EXPECT_EQ(pass.Run(transpose), SUCCESS); } } // namespace ge From f6ba21ed1d2cc8c5ad49a22c9b094f6f30c70c41 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 4 Mar 2021 17:28:24 +0800 Subject: [PATCH 10/10] Add ut. --- .../graph/passes/transpose_transdata_pass_unittest.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc b/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc index 7fc5973a..07919dc6 100644 --- a/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/transpose_transdata_pass_unittest.cc @@ -30,7 +30,7 @@ #include "graph/debug/ge_attr_define.h" namespace ge { -class UtestGraphPassesTransposeTransDataPass : public testing::Test { +class UtestGraphPassesTransposeTransdataPass : public testing::Test { protected: void SetUp() {} void TearDown() {} @@ -40,11 +40,11 @@ static ComputeGraphPtr BuildGraphTransposeD() { auto builder = ut::GraphBuilder("g1"); auto transdata1 = builder.AddNode("transdata1", "TransData", 1, 1, FORMAT_NC1HWC0, DT_FLOAT, std::vector({1, 1, 224, 224, 16})); transdata1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); - transdata1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector({1, 1, 224, 224, 3}))); + transdata1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector({1, 224, 224, 3}))); auto transpose1 = builder.AddNode("transpose1", "TransposeD", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector({1, 3, 224, 224})); - transpose1->GetOpDesc()->MutableIntputDesc(0)->SetFormat(FORMAT_NHWC); - transpose1->GetOpDesc()->MutableIntputDesc(0)->SetShape(GeShape(std::vector({1, 1, 224, 224, 3}))); + transpose1->GetOpDesc()->MutableInputDesc(0)->SetFormat(FORMAT_NHWC); + transpose1->GetOpDesc()->MutableInputDesc(0)->SetShape(GeShape(std::vector({1, 224, 224, 3}))); auto transdata2 = builder.AddNode("transdata2", "TransData", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector({1, 3, 224, 224})); transdata2->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NC1HWC0); @@ -56,7 +56,7 @@ static ComputeGraphPtr BuildGraphTransposeD() { return builder.GetGraph(); } -TEST_F(UtestGraphPassesTransposeTransDataPass, test_run) { +TEST_F(UtestGraphPassesTransposeTransdataPass, test_run) { auto compute_graph = BuildGraphTransposeD(); compute_graph->SetSessionID(0);