diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index db316ffa..a7fd2794 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -214,6 +214,7 @@ set(TRAIN_SRC_LIST "graph/passes/inplace_support_check_pass.cc" "graph/passes/flow_ctrl_pass.cc" "graph/passes/global_step_insert_pass.cc" + "graph/passes/parallel_group_pass.cc" "host_kernels/transpose_kernel.cc" "host_kernels/add_kernel.cc" "host_kernels/broadcast_args_kernel.cc" @@ -606,6 +607,7 @@ set(INFER_SRC_LIST "graph/passes/hccl_group_pass.cc" "graph/passes/memcpy_addr_async_pass.cc" "graph/passes/set_input_output_offset_pass.cc" + "graph/passes/parallel_group_pass.cc" "graph/manager/model_manager/event_manager.cc" "graph/manager/util/rt_context_util.cc" "graph/manager/util/variable_accelerate_ctrl.cc" diff --git a/ge/graph/build/logical_stream_allocator.cc b/ge/graph/build/logical_stream_allocator.cc index 8ea7fe71..7dc2f6c7 100644 --- a/ge/graph/build/logical_stream_allocator.cc +++ b/ge/graph/build/logical_stream_allocator.cc @@ -366,6 +366,48 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { + std::map> stream_op_map; + for (const SubgraphPtr &subgraph : subgraphs) { + auto compute_graph = subgraph->subgraph_info.GetSubGraph(); + for (NodePtr &node : compute_graph->GetDirectNode()) { + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc->HasAttr(ATTR_NAME_PARALLEL_GROUP)) { + int64_t op_desc_stream_id = op_desc->GetStreamId(); + stream_op_map[op_desc_stream_id].push_back(op_desc); + } + } + } + for (const auto &itr : stream_op_map) { + if (itr.first == kInvalidStream) { + continue; + } + std::map> group_op; + for (const auto &op_desc : itr.second) { + int group_id; + if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id)) { + GELOGE(GRAPH_FAILED, "Get node %s ATTR_NAME_PARALLEL_GROUP failed.", op_desc->GetName().c_str()); + return GRAPH_FAILED; + } + group_op[group_id].emplace_back(op_desc); + } + for (const auto &itr : group_op) { + const auto &op_vec = itr.second; + if (op_vec.empty()) { + continue; + } + int64_t new_stream_id = context.next_stream++; + for (const auto &op : op_vec) { + int64_t old_stream_id = op->GetStreamId(); + op->SetStreamId(new_stream_id); + GELOGI("chenhua Node %s assigned stream %ld from stream %ld.", op->GetName().c_str(), new_stream_id, + old_stream_id); + } + } + } + return SUCCESS; +} + int64_t UpdateForSkippedEnginePass::GetSingleInoutStream(const NodePtr &node) const { set stream_ids; @@ -655,6 +697,7 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); + passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); } diff --git a/ge/graph/build/logical_stream_allocator.h b/ge/graph/build/logical_stream_allocator.h index b9aec611..2a94c254 100644 --- a/ge/graph/build/logical_stream_allocator.h +++ b/ge/graph/build/logical_stream_allocator.h @@ -149,6 +149,13 @@ class NodeStreamUpdatePass : public LogicalStreamPass { Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; }; +// assign stream by parallel group +class UpdateForParallelGroupPass : public LogicalStreamPass { + public: + STREAM_PASS_DEFAULT_FUNC(UpdateForParallelGroupPass); + Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; +}; + // Update the stream of subgraphs to nodes. class UpdateForSkippedEnginePass : public LogicalStreamPass { public: diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index a57f0e61..1dfc307b 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -92,6 +92,7 @@ #include "graph/passes/global_step_insert_pass.h" #include "graph/passes/memcpy_addr_async_pass.h" #include "graph/passes/hccl_continuous_memcpy_pass.h" +#include "graph/passes/parallel_group_pass.h" #include "graph/build/label_allocator.h" #include "graph/utils/tensor_adapter.h" #include "inc/pass_manager.h" @@ -2203,6 +2204,7 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { GE_DUMP(compute_graph, "OptimizeStage1_2"); PassManager graph_pass; // the prune pass should between SwitchPass and SwitchToStreamSwitchPass + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ParallelGroupPass", new (std::nothrow) ParallelGroupPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)) diff --git a/ge/graph/passes/parallel_group_pass.cc b/ge/graph/passes/parallel_group_pass.cc new file mode 100644 index 00000000..b3225dd0 --- /dev/null +++ b/ge/graph/passes/parallel_group_pass.cc @@ -0,0 +1,307 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/parallel_group_pass.h" + +#include "framework/common/debug/ge_log.h" +#include "common/ge/ge_util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" + +namespace ge { + +Status ParallelGroupPass::Run(ComputeGraphPtr graph) { + GELOGI("ParallelGroupPass running"); + if (graph == nullptr) { + GELOGE(PARAM_INVALID, "Input param graph is null, skip optimize parallel group graph."); + return PARAM_INVALID; + } + + if (graph->GetParentGraph() != nullptr) { + GELOGW("Current graph %s is a subgraph, this pass only support root graph.", + graph->GetName().c_str()); + return GRAPH_SUCCESS; + } + + std::unordered_set parallel_group; + if (ProcessAllGraph(graph, parallel_group) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Process graph %s failed.", graph->GetName().c_str()); + return INTERNAL_ERROR; + } + + if (graph->TopologicalSorting() != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Graph topological sort failed."); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +Status ParallelGroupPass::ProcessAllGraph(ComputeGraphPtr graph, std::unordered_set ¶llel_group) { + if (graph == nullptr) { + GELOGE(PARAM_INVALID, "Input param graph is null, skip optimize parallel group subgraph."); + return PARAM_INVALID; + } + + std::map, NodePtr>> node_2_switch_merge; + if (ProcessSwitch(graph, node_2_switch_merge) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Input param graph is null, skip optimize parallel group subgraph."); + return GRAPH_FAILED; + } + + std::map> group_node; + const auto &candidates = graph->GetDirectNode(); + for (const auto &node : candidates) { + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + int group_id = -1; + if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id)) { + bool is_unknown_shape = false; + auto ret = ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get node[%s] shape status failed!", node->GetName().c_str()); + return GRAPH_FAILED; + } + // only handle known shape node + if (!is_unknown_shape) { + group_node[group_id].push_back(node); + parallel_group.insert(group_id); + GELOGI("Find hccl node:%s, group_id=%d", op_desc->GetName().c_str(), group_id); + } + } + + const auto &subgraph_name = op_desc->GetSubgraphInstanceNames(); + for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) { + const auto sub_graph = graph->GetSubgraph(*name_iter); + if (sub_graph != nullptr) { + std::unordered_set sub_parallel_group; + auto ret = ProcessAllGraph(sub_graph, sub_parallel_group); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Process sub graph %s failed.", sub_graph->GetName().c_str()); + return GRAPH_FAILED; + } + for (const auto &it : sub_parallel_group) { + parallel_group.insert(it); + group_node[it].emplace_back(node); + } + } + } + } + + for (const auto &itr : group_node) { + const auto &node_vec = itr.second; + if (node_vec.empty()) { + continue; + } + NodePtr pre_node = node_vec[0]; + NodePtr cur_node = nullptr; + for (int i = 1; i < node_vec.size(); i++) { + cur_node = node_vec[i]; + auto tmp_pre_node = pre_node; + auto tmp_cur_node = cur_node; + GELOGI("original we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); + ReplaceSwitchAndMerge(tmp_pre_node, tmp_cur_node, node_2_switch_merge); + pre_node = cur_node; + } + } + + return GRAPH_SUCCESS; +} + +void ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) { + if (pre_node == cur_node) { + GELOGI("--- pr_node == cur_node"); + return; + } + const auto &in_node = cur_node->GetInAllNodes(); + for (const auto &node : in_node) { + if (pre_node == node) { + GELOGI("--- pr_node and cur_node have linked"); + return; + } + } + auto pre_out_ctrl_anchor = pre_node->GetOutControlAnchor(); + auto cur_in_ctrl_anchor = cur_node->GetInControlAnchor(); + pre_out_ctrl_anchor->LinkTo(cur_in_ctrl_anchor); +} + +Status ParallelGroupPass::ProcessSwitch(ComputeGraphPtr graph, + std::map, + NodePtr>> &node_2_switch_merge) { + std::string type; + const auto &direct_node = graph->GetDirectNode(); + for (const auto &node : direct_node) { + auto ret = GetOriginalType(node, type); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get node %s type failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + if ((type != SWITCH) && (type != REFSWITCH)) { + continue; + } + InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); + if (in_cond_anchor == nullptr) { + GELOGE(INTERNAL_ERROR, "in_cond_anchor is nullptr, node: %s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + OutDataAnchorPtr pred_cond_anchor = in_cond_anchor->GetPeerOutAnchor(); + if (pred_cond_anchor == nullptr) { + GELOGE(INTERNAL_ERROR, "pred_cond_anchor is nullptr, node: %s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + // ignore while op + if (pred_cond_anchor->GetOwnerNode()->GetType() == LOOPCOND) { + continue; + } + std::deque candidates; + std::vector merge_vec; + std::set group_set; + std::set visited; + candidates.emplace_back(node); + while (!candidates.empty()) { + NodePtr tmp_node = candidates.front(); + candidates.pop_front(); + for (const auto &out_node : tmp_node->GetOutAllNodes()) { + ret = GetOriginalType(out_node, type); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get node %s type failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + if ((type == MERGE) || (type == REFMERGE)) { + merge_vec.emplace_back(out_node); + continue; + } + const auto op = out_node->GetOpDesc(); + if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) { + group_set.emplace(out_node); + } + if (visited.count(out_node) > 0) { + continue; + } + candidates.push_back(out_node); + visited.insert(out_node); + } + } + std::sort(merge_vec.begin(), merge_vec.end(), + [&direct_node] (NodePtr a, NodePtr b) -> bool { + return std::find(direct_node.begin(), direct_node.end(), a) < + std::find(direct_node.begin(), direct_node.end(), b); + }); + for (const auto &group_node : group_set) { + auto it = node_2_switch_merge.find(group_node); + if (it != node_2_switch_merge.end()) { + auto &tmp = it->second; + auto &switch_vec = tmp.first; + const auto &merge_node = tmp.second; + GELOGI(" --- hccl node: %s, switch node %s, merge node :%s.", + group_node->GetName().c_str(), node->GetName().c_str(), merge_node->GetName().c_str()); + if (merge_node != merge_vec.back()) { + GELOGE(GRAPH_FAILED, "error: has two merge node: %s and %s.", + merge_node->GetName().c_str(), merge_vec.back()->GetName().c_str()); + return GRAPH_FAILED; + } + switch_vec.insert(node); + } else { + node_2_switch_merge.emplace(group_node, std::make_pair(std::set{node}, merge_vec.back())); + } + } + } + + return GRAPH_SUCCESS; +} + +Status ParallelGroupPass::GetOriginalType(const NodePtr &node, string &type) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "node is nullptr."); + return GRAPH_FAILED; + } + type = node->GetType(); + GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS); + if (node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "op_desc is nullptr."); + return GRAPH_FAILED; + } + bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); + if (!ret) { + GELOGE(INTERNAL_ERROR, "node is nullptr."); + return INTERNAL_ERROR; + } + GELOGD("Get FrameWorkOp original type [%s]", type.c_str()); + return SUCCESS; +} + +void ParallelGroupPass::ReplaceSwitchAndMerge(NodePtr &pre_node, + NodePtr &cur_node, + std::map, + NodePtr>> &node_2_switch_merge) { + auto pre_itr = node_2_switch_merge.find(pre_node); + auto cur_itr = node_2_switch_merge.find(cur_node); + if (pre_itr != node_2_switch_merge.end()) { + if (cur_itr != node_2_switch_merge.end()) { + const auto &pre_set = pre_itr->second.first; + const auto &cur_set = cur_itr->second.first; + if (!HasSameSwitch(pre_set, cur_set)) { + pre_node = pre_itr->second.second; + for (const auto &switch_node : cur_itr->second.first) { + AddCtrlEdge(pre_node, switch_node); + GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), switch_node->GetName().c_str()); + } + } else { + GELOGI("--- no need add ctrl edge"); + } + } else { + pre_node = pre_itr->second.second; + AddCtrlEdge(pre_node, cur_node); + GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); + } + } else { + if (cur_itr != node_2_switch_merge.end()) { + for (const auto &switch_node : cur_itr->second.first) { + int64_t pre_id = pre_node->GetOpDesc()->GetId(); + int64_t switch_id = switch_node->GetOpDesc()->GetId(); + if (pre_id > switch_id) { // special handle for merge and group node + auto merge_node = cur_itr->second.second; + AddCtrlEdge(merge_node, pre_node); + GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", merge_node->GetName().c_str(), pre_node->GetName().c_str()); + } else { + AddCtrlEdge(pre_node, switch_node); + GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), switch_node->GetName().c_str()); + } + } + } else { + AddCtrlEdge(pre_node, cur_node); + GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); + } + } +} + +bool ParallelGroupPass::HasSameSwitch(const std::set &switch_set1, const std::set &switch_set2) { + for (const auto &node1 : switch_set1) { + for (const auto &node2 : switch_set2) { + if (node1 == node2) { + return true; + } + } + } + return false; +} +} // namespace ge diff --git a/ge/graph/passes/parallel_group_pass.h b/ge/graph/passes/parallel_group_pass.h new file mode 100644 index 00000000..52ff0307 --- /dev/null +++ b/ge/graph/passes/parallel_group_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H +#define GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H + + +#include +#include +#include +#include "graph/graph.h" +#include "inc/graph_pass.h" + +namespace ge { +class ParallelGroupPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph) override; + private: + Status ProcessAllGraph(ComputeGraphPtr graph, std::unordered_set ¶llel_group); + void AddCtrlEdge(NodePtr pre_node, NodePtr cur_node); + Status GetOriginalType(const ge::NodePtr &node, std::string &type); + void ReplaceSwitchAndMerge(NodePtr &pre_node, + NodePtr &cur_node, + std::map, + NodePtr>> &node_2_switch_merge); + bool HasSameSwitch(const std::set &a, const std::set &b); + Status ProcessSwitch(ComputeGraphPtr graph, + std::map, + NodePtr>> &node_2_switch_merge); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index a09d5789..85d530bd 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -266,8 +266,9 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" "${GE_CODE_DIR}/ge/model/ge_model.cc" "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/model_utils.cc" @@ -503,14 +504,15 @@ set(GRAPH_PASS_COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" - "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/same_transdata_breadth_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" "${GE_CODE_DIR}/ge/graph/common/transop_util.cc" "${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc" #"${GE_CODE_DIR}/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc" @@ -671,7 +673,7 @@ set(PASS_TEST_FILES "graph/passes/trans_op_depth_fusion_pass_unittest.cc" "graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc" "graph/passes/constant_folding_pass_unittest.cc" - "graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc" + "graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc" "graph/passes/stop_gradient_pass_unittest.cc" "graph/passes/prevent_gradient_pass_unittest.cc" "graph/passes/identity_pass_unittest.cc" @@ -688,6 +690,7 @@ set(PASS_TEST_FILES "graph/passes/no_use_reshape_remove_pass_unittest.cc" "graph/passes/infershape_pass_unittest.cc" "graph/passes/multi_batch_clone_pass_unittest.cc" + "graph/passes/parallel_group_pass_unittest.cc" ) set(KERNEL_TEST_FILES diff --git a/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc b/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc index 5b87939f..20e005c5 100644 --- a/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc +++ b/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc @@ -32,6 +32,7 @@ #include "graph/compute_graph.h" #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" +#include "graph/debug/ge_attr_define.h" using namespace std; @@ -148,6 +149,22 @@ class UtestLogicalStreamAllocator : public testing::Test { return subgraph; } + SubGraphInfoPtr CreateParallelGroupSubgraphWithName(const string &name, const string &engine, + const string &stream_label = "", + int group_id = 1) { + ComputeGraphPtr compute_graph = make_shared(name); + OpDescPtr op_desc = std::make_shared("relu", "Relu"); + op_desc->AddInputDesc(GeTensorDesc()); + op_desc->AddOutputDesc(GeTensorDesc()); + AttrUtils::SetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id); + compute_graph->AddNode(op_desc); + + SubGraphInfoPtr subgraph = BuildSubGraph(compute_graph, engine, stream_label); + AddPlaceHolderAndEnd(subgraph, 1, 1); + + return subgraph; + } + SubGraphInfoPtr CreateSubgraph(const string &engine, const string &stream_label = "", int in_num = 1, int out_num = 1) { return CreateSubgraphWithName("graph", engine, stream_label, in_num, out_num); @@ -878,4 +895,30 @@ TEST_F(UtestLogicalStreamAllocator, test_all_reduce_parallel_pass) { EXPECT_EQ(ret, NOT_CHANGED); } +TEST_F(UtestLogicalStreamAllocator, test_parallel_group) { + SubGraphInfoPtr data = CreateDataSubgraph(); + SubGraphInfoPtr subgraph1 = CreateParallelGroupSubgraphWithName("graph1", "engine1", ""); + SubGraphInfoPtr subgraph2 = CreateParallelGroupSubgraphWithName("graph2", "engine2", "", 2); + SubGraphInfoPtr subgraph3 = CreateParallelGroupSubgraphWithName("graph3", "engine3", "", 3); + SubGraphInfoPtr subgraph4 = CreateParallelGroupSubgraphWithName("graph4", "engine4", "", 4); + LinkSubGraph(data, "end", subgraph1, "placeholder"); + LinkSubGraph(subgraph1, "end", subgraph2, "placeholder"); + LinkSubGraph(subgraph2, "end", subgraph3, "placeholder"); + LinkSubGraph(subgraph3, "end", subgraph4, "placeholder"); + + EngineConfPtr conf1 = make_shared(); + conf1->id = subgraph1->GetEngineName(); + EngineConfPtr conf2 = make_shared(); + conf2->id = subgraph2->GetEngineName(); + conf2->attach = false; + EngineConfPtr conf3 = make_shared(); + conf3->id = subgraph3->GetEngineName(); + conf3->attach = false; + EngineConfPtr conf4 = make_shared(); + conf4->id = subgraph4->GetEngineName(); + + Status status = AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4}, {conf1, conf2, conf3, conf4}); + EXPECT_EQ(status, ge::SUCCESS); +} + } // namespace ge diff --git a/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc b/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc new file mode 100644 index 00000000..3b020dfe --- /dev/null +++ b/tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc @@ -0,0 +1,259 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#define private public + +#include "common/ge_inner_error_codes.h" +#include "inc/pass_manager.h" +#include "utils/graph_utils.h" +#include "graph/passes/parallel_group_pass.h" +#undef private + +namespace ge { +namespace { + +class UtestGraphPassesParallelGgroupPass : public testing::Test { + protected: + UtestGraphPassesParallelGgroupPass() { + graph_ = std::make_shared("test"); + sub_graph_ = std::make_shared("test_subgraph"); + vector shape_vec{1, 1, 1, 1}; + GeShape shape = GeShape(shape_vec); + default_tensor_desc_ = std::make_shared(); + default_tensor_desc_->SetShape(shape); + default_tensor_desc_->SetFormat(FORMAT_NCHW); + default_tensor_desc_->SetDataType(DT_FLOAT); + } + + NodePtr NewNode(const std::string &name, const std::string &type, + int input_cnt, int output_cnt, bool isSubgraph = false) { + OpDescPtr op_desc = std::make_shared(name, type); + for (int i = 0; i < input_cnt; ++i) { + op_desc->AddInputDesc(default_tensor_desc_->Clone()); + } + + for (int i = 0; i < output_cnt; ++i) { + op_desc->AddOutputDesc(default_tensor_desc_->Clone()); + } + NodePtr node = nullptr; + if (isSubgraph) { + node = sub_graph_->AddNode(op_desc); + (void)node->SetOwnerComputeGraph(sub_graph_); + } else { + node = graph_->AddNode(op_desc); + (void)node->SetOwnerComputeGraph(graph_); + } + + return node; + } + + void BuildDefaultGraph() { + /// input + /// \ + /// sqrt1 pred + /// \ / + /// Switch + /// | | + /// F T + /// | | + /// Merge + /// | + /// relu + /// | + /// sqrt2 + pred_node_ = NewNode("pred", GREATER, 2, 1); + + input_node_ = NewNode("input", RELU, 0, 1); + AttrUtils::SetInt(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); + sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1); + switch_node_ = NewNode("switch", SWITCH, 2, 2); + output_false_node_ = NewNode("false_output", RELU, 1, 1); + AttrUtils::SetInt(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); + output_true_node_ = NewNode("true_output", RELU, 1, 1); + AttrUtils::SetInt(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); + merge_node_ = NewNode("merge", MERGE, 2, 1); + relu_node_ = NewNode("relu", RELU, 1, 1); + sqrt_node2_ = NewNode("sqrt2", SQRT, 1, 1); + AttrUtils::SetInt(sqrt_node2_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); + + switch_node_->GetOpDesc()->SetIsInputConst({false, false}); + + GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(sqrt_node1_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node2_->GetInDataAnchor(0)); + + output_false_node_->GetOpDesc()->SetIsInputConst({false}); + output_true_node_->GetOpDesc()->SetIsInputConst({false}); + } + + void BuildDefaultGraph1() { + /// input pred + /// \ / + /// Switch + /// | | + /// ----F T---- + /// \ | / \ + /// \ Merge1 Merge2 + /// \_________| + pred_node_ = NewNode("pred", GREATER, 2, 1); + + input_node_ = NewNode("input", RELU, 0, 1); + AttrUtils::SetInt(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); + switch_node_ = NewNode("switch", SWITCH, 2, 2); + output_false_node_ = NewNode("false_output", RELU, 1, 2); + AttrUtils::SetInt(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); + output_true_node_ = NewNode("true_output", RELU, 1, 2); + AttrUtils::SetInt(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); + merge_node_ = NewNode("merge", MERGE, 2, 1); + merge_node1_ = NewNode("merge1", MERGE, 2, 1); + + switch_node_->GetOpDesc()->SetIsInputConst({false, false}); + + GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1)); + + output_false_node_->GetOpDesc()->SetIsInputConst({false}); + output_true_node_->GetOpDesc()->SetIsInputConst({false}); + } + + + void BuildDefaultGraph2() { + /// input pred input pred1 + /// \ / \ / + /// Switch Switch1 + /// | | _______| + /// | | / + /// ____F T____ + /// \ | / \ + /// \ Merge1 Merge2 + /// \__________| + pred_node_ = NewNode("pred", GREATER, 2, 1); + pred_node1_ = NewNode("pred", LESS, 2, 1); + input_node_ = NewNode("input", RELU, 0, 2); + AttrUtils::SetInt(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); + switch_node_ = NewNode("switch", SWITCH, 2, 2); + switch_node1_ = NewNode("switch1", SWITCH, 2, 2); + output_false_node_ = NewNode("false_output", RELU, 2, 2); + AttrUtils::SetInt(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); + output_true_node_ = NewNode("true_output", RELU, 2, 2); + AttrUtils::SetInt(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 2); + merge_node_ = NewNode("merge", MERGE, 2, 1); + merge_node1_ = NewNode("merge1", MERGE, 2, 1); + + switch_node_->GetOpDesc()->SetIsInputConst({false, false}); + + GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(input_node_->GetOutDataAnchor(1), switch_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(pred_node1_->GetOutDataAnchor(0), switch_node1_->GetInDataAnchor(1)); + GraphUtils::AddEdge(switch_node1_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(switch_node1_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); + GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0)); + GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1)); + + output_false_node_->GetOpDesc()->SetIsInputConst({false}); + output_true_node_->GetOpDesc()->SetIsInputConst({false}); + } + + ComputeGraphPtr graph_; + ComputeGraphPtr sub_graph_; + GeTensorDescPtr default_tensor_desc_; + ParallelGroupPass pass_; + NodePtr pred_node_; + NodePtr pred_node1_; + NodePtr sqrt_node1_; + NodePtr sqrt_node2_; + NodePtr input_node_; + NodePtr switch_node_; + NodePtr switch_node1_; + NodePtr output_false_node_; + NodePtr output_true_node_; + NodePtr merge_node_; + NodePtr merge_node1_; + NodePtr relu_node_; +}; + +TEST_F(UtestGraphPassesParallelGgroupPass, null_graph) { + ComputeGraphPtr graph = nullptr; + auto ret = pass_.Run(graph); + EXPECT_EQ(ret, PARAM_INVALID); +} + +TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph) { + BuildDefaultGraph(); + auto ret = pass_.Run(graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); + EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(switch_node_->GetInControlAnchor())); + EXPECT_EQ(true, merge_node_->GetOutControlAnchor()->IsLinkedWith(sqrt_node2_->GetInControlAnchor())); + EXPECT_EQ(false, output_false_node_->GetOutControlAnchor()->IsLinkedWith(output_true_node_->GetInControlAnchor())); +} + +TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph1) { + BuildDefaultGraph1(); + auto ret = pass_.Run(graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); +} + +TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) { + BuildDefaultGraph2(); + auto ret = pass_.Run(graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); +} + +TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) { + BuildDefaultGraph1(); + NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true); + NodePtr input_node2 = NewNode("input2", RELU, 0, 1, true); + NodePtr add = NewNode("add", ADD, 2, 1, true); + AttrUtils::SetInt(input_node1->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); + AttrUtils::SetInt(input_node2->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); + + sub_graph_->SetParentNode(input_node_); + sub_graph_->SetParentGraph(graph_); + auto ret = graph_->AddSubgraph(sub_graph_->GetName(), sub_graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); + ret = input_node_->GetOpDesc()->AddSubgraphName(sub_graph_->GetName()); + EXPECT_EQ(ret, GRAPH_SUCCESS); + ret = input_node_->GetOpDesc()->SetSubgraphInstanceName(0, sub_graph_->GetName()); + EXPECT_EQ(ret, GRAPH_SUCCESS); + ret = pass_.Run(sub_graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); + ret = pass_.Run(graph_); + EXPECT_EQ(ret, GRAPH_SUCCESS); +} + +} // namespace +} // namespace ge