diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 3ae51590..cafc0cfd 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -215,6 +215,7 @@ set(TRAIN_SRC_LIST "graph/passes/inplace_support_check_pass.cc" "graph/passes/flow_ctrl_pass.cc" "graph/passes/global_step_insert_pass.cc" + "graph/passes/parallel_group_pass.cc" "host_kernels/transpose_kernel.cc" "host_kernels/add_kernel.cc" "host_kernels/broadcast_args_kernel.cc" @@ -604,6 +605,7 @@ set(INFER_SRC_LIST "graph/passes/hccl_group_pass.cc" "graph/passes/memcpy_addr_async_pass.cc" "graph/passes/set_input_output_offset_pass.cc" + "graph/passes/parallel_group_pass.cc" "graph/manager/model_manager/event_manager.cc" "graph/manager/util/rt_context_util.cc" "graph/manager/util/variable_accelerate_ctrl.cc" diff --git a/ge/graph/build/logical_stream_allocator.cc b/ge/graph/build/logical_stream_allocator.cc index 8ea7fe71..a1f640a6 100644 --- a/ge/graph/build/logical_stream_allocator.cc +++ b/ge/graph/build/logical_stream_allocator.cc @@ -362,7 +362,48 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { + std::map> stream_op_map; + for (const SubgraphPtr &subgraph : subgraphs) { + auto compute_graph = subgraph->subgraph_info.GetSubGraph(); + for (NodePtr &node : compute_graph->GetDirectNode()) { + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc->HasAttr(ATTR_NAME_PARALLEL_GROUP)) { + int64_t op_desc_stream_id = op_desc->GetStreamId(); + stream_op_map[op_desc_stream_id].push_back(op_desc); + } + } + } + for (const auto &itr : stream_op_map) { + if (itr.first == kInvalidStream) { + continue; + } + std::map> group_op; + for (const auto &op_desc : itr.second) { + int group_id; + if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id)) { + GELOGD("Get node %s ATTR_NAME_PARALLEL_GROUP failed.", op_desc->GetName().c_str()); + return GRAPH_FAILED; + } + group_op[group_id].emplace_back(op_desc); + } + for (const auto &itr : group_op) { + const auto &op_vec = itr.second; + if (op_vec.empty()) { + continue; + } + int64_t new_stream_id = context.next_stream++; + for (const auto &op : op_vec) { + int64_t old_stream_id = op->GetStreamId(); + op->SetStreamId(new_stream_id); + GELOGI("chenhua Node %s assigned stream %ld from stream %ld.", op->GetName().c_str(), new_stream_id, + old_stream_id); + } + } + } return SUCCESS; } @@ -655,6 +696,7 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); + passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); } diff --git a/ge/graph/build/logical_stream_allocator.h b/ge/graph/build/logical_stream_allocator.h index b9aec611..2a94c254 100644 --- a/ge/graph/build/logical_stream_allocator.h +++ b/ge/graph/build/logical_stream_allocator.h @@ -149,6 +149,13 @@ class NodeStreamUpdatePass : public LogicalStreamPass { Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; }; +// assign stream by parallel group +class UpdateForParallelGroupPass : public LogicalStreamPass { + public: + STREAM_PASS_DEFAULT_FUNC(UpdateForParallelGroupPass); + Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; +}; + // Update the stream of subgraphs to nodes. class UpdateForSkippedEnginePass : public LogicalStreamPass { public: diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 8b57858d..2bf678fa 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -93,6 +93,7 @@ #include "graph/passes/global_step_insert_pass.h" #include "graph/passes/memcpy_addr_async_pass.h" #include "graph/passes/hccl_continuous_memcpy_pass.h" +#include "graph/passes/parallel_group_pass.h" #include "graph/build/label_allocator.h" #include "graph/utils/tensor_adapter.h" #include "inc/pass_manager.h" @@ -2215,6 +2216,7 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { GE_DUMP(compute_graph, "OptimizeStage1_2"); PassManager graph_pass; // the prune pass should between SwitchPass and SwitchToStreamSwitchPass + GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ParallelGroupPass", new (std::nothrow) ParallelGroupPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)) diff --git a/ge/graph/passes/parallel_group_pass.cc b/ge/graph/passes/parallel_group_pass.cc new file mode 100644 index 00000000..7aa607d7 --- /dev/null +++ b/ge/graph/passes/parallel_group_pass.cc @@ -0,0 +1,298 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/parallel_group_pass.h" + +#include "framework/common/debug/ge_log.h" +#include "common/ge/ge_util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" + +namespace ge { + +Status ParallelGroupPass::Run(ComputeGraphPtr graph) { + GELOGI("ParallelGroupPass running"); + if (graph == nullptr) { + GELOGE(PARAM_INVALID, "Input param graph is null, skip optimize parallel group graph."); + return PARAM_INVALID; + } + + if (graph->GetParentGraph() != nullptr) { + GELOGW("Current graph %s is a subgraph, this pass only support root graph.", + graph->GetName().c_str()); + return GRAPH_SUCCESS; + } + + std::unordered_set parallel_group; + if (ProcessAllGraph(graph, parallel_group) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Process graph %s failed.", graph->GetName().c_str()); + return INTERNAL_ERROR; + } + + if (graph->TopologicalSorting() != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Graph topological sort failed."); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +Status ParallelGroupPass::ProcessAllGraph(ComputeGraphPtr graph, std::unordered_set ¶llel_group) { + if (graph == nullptr) { + GELOGE(PARAM_INVALID, "Input param graph is null, skip optimize parallel group subgraph."); + return PARAM_INVALID; + } + + std::map, NodePtr>> node_2_switch_merge; + if (ProcessSwitch(graph, node_2_switch_merge) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Input param graph is null, skip optimize parallel group subgraph."); + return GRAPH_FAILED; + } + + std::map> group_node; + const auto &candidates = graph->GetDirectNode(); + for (const auto &node : candidates) { + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + int group_id = -1; + if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id)) { + group_node[group_id].push_back(node); + parallel_group.insert(group_id); + GELOGI("Find hccl node:%s, group_id=%d", op_desc->GetName().c_str(), group_id); + } + + const auto &subgraph_name = op_desc->GetSubgraphInstanceNames(); + for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) { + const auto sub_graph = graph->GetSubgraph(*name_iter); + if (sub_graph != nullptr) { + std::unordered_set sub_parallel_group; + auto ret = ProcessAllGraph(sub_graph, sub_parallel_group); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Process sub graph %s failed.", sub_graph->GetName().c_str()); + return GRAPH_FAILED; + } + for (const auto &it : sub_parallel_group) { + parallel_group.insert(it); + group_node[it].emplace_back(node); + } + } + } + } + + for (const auto &itr : group_node) { + const auto &node_vec = itr.second; + if (node_vec.empty()) { + continue; + } + NodePtr pre_node = node_vec[0]; + NodePtr cur_node = nullptr; + for (int i = 1; i < node_vec.size(); i++) { + cur_node = node_vec[i]; + auto tmp_pre_node = pre_node; + auto tmp_cur_node = cur_node; + GELOGI("original we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); + ReplaceSwitchAndMerge(tmp_pre_node, tmp_cur_node, node_2_switch_merge); + pre_node = cur_node; + } + } + + return GRAPH_SUCCESS; +} + +void ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) { + if (pre_node == cur_node) { + GELOGI("--- pr_node == cur_node"); + return; + } + const auto &in_node = cur_node->GetInAllNodes(); + for (const auto &node : in_node) { + if (pre_node == node) { + GELOGI("--- pr_node and cur_node have linked"); + return; + } + } + auto pre_out_ctrl_anchor = pre_node->GetOutControlAnchor(); + auto cur_in_ctrl_anchor = cur_node->GetInControlAnchor(); + pre_out_ctrl_anchor->LinkTo(cur_in_ctrl_anchor); +} + +Status ParallelGroupPass::ProcessSwitch(ComputeGraphPtr graph, + std::map, + NodePtr>> &node_2_switch_merge) { + std::string type; + const auto &direct_node = graph->GetDirectNode(); + for (const auto &node : direct_node) { + auto ret = GetOriginalType(node, type); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get node %s type failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + if ((type != SWITCH) && (type != REFSWITCH)) { + continue; + } + InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); + if (in_cond_anchor == nullptr) { + GELOGE(INTERNAL_ERROR, "in_cond_anchor is nullptr, node: %s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + OutDataAnchorPtr pred_cond_anchor = in_cond_anchor->GetPeerOutAnchor(); + if (pred_cond_anchor == nullptr) { + GELOGE(INTERNAL_ERROR, "pred_cond_anchor is nullptr, node: %s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + // ignore while op + if (pred_cond_anchor->GetOwnerNode()->GetType() == LOOPCOND) { + continue; + } + std::deque candidates; + std::vector merge_vec; + std::set group_set; + std::set visited; + candidates.emplace_back(node); + while (!candidates.empty()) { + NodePtr tmp_node = candidates.front(); + candidates.pop_front(); + for (const auto &out_node : tmp_node->GetOutAllNodes()) { + ret = GetOriginalType(out_node, type); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get node %s type failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + if ((type == MERGE) || (type == REFMERGE)) { + merge_vec.emplace_back(out_node); + continue; + } + const auto op = out_node->GetOpDesc(); + if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) { + group_set.emplace(out_node); + } + if (visited.count(out_node) > 0) { + continue; + } + candidates.push_back(out_node); + visited.insert(out_node); + } + } + std::sort(merge_vec.begin(), merge_vec.end(), + [&direct_node] (NodePtr a, NodePtr b) -> bool { + return std::find(direct_node.begin(), direct_node.end(), a) < + std::find(direct_node.begin(), direct_node.end(), b); + }); + for (const auto &group_node : group_set) { + auto it = node_2_switch_merge.find(group_node); + if (it != node_2_switch_merge.end()) { + auto &tmp = it->second; + auto &switch_vec = tmp.first; + const auto &merge_node = tmp.second; + GELOGI(" --- hccl node: %s, switch node %s, merge node :%s.", + group_node->GetName().c_str(), node->GetName().c_str(), merge_node->GetName().c_str()); + if (merge_node != merge_vec.back()) { + GELOGE(GRAPH_FAILED, "error: has two merge node: %s and %s.", + merge_node->GetName().c_str(), merge_vec.back()->GetName().c_str()); + return GRAPH_FAILED; + } + switch_vec.insert(node); + } else { + node_2_switch_merge.emplace(group_node, std::make_pair(std::set{node}, merge_vec.back())); + } + } + } + + return GRAPH_SUCCESS; +} + +Status ParallelGroupPass::GetOriginalType(const NodePtr &node, string &type) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "node is nullptr."); + return GRAPH_FAILED; + } + type = node->GetType(); + GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS); + if (node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "op_desc is nullptr."); + return GRAPH_FAILED; + } + bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); + if (!ret) { + GELOGE(INTERNAL_ERROR, "node is nullptr."); + return INTERNAL_ERROR; + } + GELOGD("Get FrameWorkOp original type [%s]", type.c_str()); + return SUCCESS; +} + +void ParallelGroupPass::ReplaceSwitchAndMerge(NodePtr &pre_node, + NodePtr &cur_node, + std::map, + NodePtr>> &node_2_switch_merge) { + auto pre_itr = node_2_switch_merge.find(pre_node); + auto cur_itr = node_2_switch_merge.find(cur_node); + if (pre_itr != node_2_switch_merge.end()) { + if (cur_itr != node_2_switch_merge.end()) { + const auto &pre_set = pre_itr->second.first; + const auto &cur_set = cur_itr->second.first; + if (!HasSameSwitch(pre_set, cur_set)) { + pre_node = pre_itr->second.second; + for (const auto &switch_node : cur_itr->second.first) { + AddCtrlEdge(pre_node, switch_node); + GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), switch_node->GetName().c_str()); + } + } else { + GELOGI("--- no need add ctrl edge"); + } + } else { + pre_node = pre_itr->second.second; + AddCtrlEdge(pre_node, cur_node); + GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); + } + } else { + if (cur_itr != node_2_switch_merge.end()) { + for (const auto &switch_node : cur_itr->second.first) { + int64_t pre_id = pre_node->GetOpDesc()->GetId(); + int64_t switch_id = switch_node->GetOpDesc()->GetId(); + if (pre_id > switch_id) { // special handle for merge and group node + auto merge_node = cur_itr->second.second; + AddCtrlEdge(merge_node, pre_node); + GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", merge_node->GetName().c_str(), pre_node->GetName().c_str()); + } else { + AddCtrlEdge(pre_node, switch_node); + GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), switch_node->GetName().c_str()); + } + } + } else { + AddCtrlEdge(pre_node, cur_node); + GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); + } + } +} + +bool ParallelGroupPass::HasSameSwitch(const std::set &a, const std::set &b) { + for (const auto &node1 : a) { + for (const auto &node2 : b) { + if (node1 == node2) { + return true; + } + } + } + return false; +} +} // namespace ge diff --git a/ge/graph/passes/parallel_group_pass.h b/ge/graph/passes/parallel_group_pass.h new file mode 100644 index 00000000..52ff0307 --- /dev/null +++ b/ge/graph/passes/parallel_group_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H +#define GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H + + +#include +#include +#include +#include "graph/graph.h" +#include "inc/graph_pass.h" + +namespace ge { +class ParallelGroupPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph) override; + private: + Status ProcessAllGraph(ComputeGraphPtr graph, std::unordered_set ¶llel_group); + void AddCtrlEdge(NodePtr pre_node, NodePtr cur_node); + Status GetOriginalType(const ge::NodePtr &node, std::string &type); + void ReplaceSwitchAndMerge(NodePtr &pre_node, + NodePtr &cur_node, + std::map, + NodePtr>> &node_2_switch_merge); + bool HasSameSwitch(const std::set &a, const std::set &b); + Status ProcessSwitch(ComputeGraphPtr graph, + std::map, + NodePtr>> &node_2_switch_merge); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H diff --git a/ge/hybrid/executor/hybrid_model_async_executor.cc b/ge/hybrid/executor/hybrid_model_async_executor.cc index 67c85460..e4e63333 100644 --- a/ge/hybrid/executor/hybrid_model_async_executor.cc +++ b/ge/hybrid/executor/hybrid_model_async_executor.cc @@ -91,7 +91,7 @@ Status HybridModelAsyncExecutor::Init() { data_inputer_ = std::unique_ptr(new(std::nothrow) DataInputer()); GE_CHECK_NOTNULL(data_inputer_); GE_CHK_RT_RET(rtStreamCreate(&stream_, RT_STREAM_PRIORITY_DEFAULT)); - + GELOGI("HybridModelExecutor creat xtream successfully"); executor_ = std::unique_ptr(new(std::nothrow) HybridModelExecutor(model_, device_id_, stream_)); GE_CHECK_NOTNULL(executor_); GE_CHK_STATUS_RET(executor_->Init(), "Failed to init hybrid engine");