| @@ -214,6 +214,7 @@ set(TRAIN_SRC_LIST | |||||
| "graph/passes/inplace_support_check_pass.cc" | "graph/passes/inplace_support_check_pass.cc" | ||||
| "graph/passes/flow_ctrl_pass.cc" | "graph/passes/flow_ctrl_pass.cc" | ||||
| "graph/passes/global_step_insert_pass.cc" | "graph/passes/global_step_insert_pass.cc" | ||||
| "graph/passes/parallel_group_pass.cc" | |||||
| "host_kernels/transpose_kernel.cc" | "host_kernels/transpose_kernel.cc" | ||||
| "host_kernels/add_kernel.cc" | "host_kernels/add_kernel.cc" | ||||
| "host_kernels/broadcast_args_kernel.cc" | "host_kernels/broadcast_args_kernel.cc" | ||||
| @@ -606,6 +607,7 @@ set(INFER_SRC_LIST | |||||
| "graph/passes/hccl_group_pass.cc" | "graph/passes/hccl_group_pass.cc" | ||||
| "graph/passes/memcpy_addr_async_pass.cc" | "graph/passes/memcpy_addr_async_pass.cc" | ||||
| "graph/passes/set_input_output_offset_pass.cc" | "graph/passes/set_input_output_offset_pass.cc" | ||||
| "graph/passes/parallel_group_pass.cc" | |||||
| "graph/manager/model_manager/event_manager.cc" | "graph/manager/model_manager/event_manager.cc" | ||||
| "graph/manager/util/rt_context_util.cc" | "graph/manager/util/rt_context_util.cc" | ||||
| "graph/manager/util/variable_accelerate_ctrl.cc" | "graph/manager/util/variable_accelerate_ctrl.cc" | ||||
| @@ -366,6 +366,48 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status UpdateForParallelGroupPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
| std::map<int, vector<OpDescPtr>> stream_op_map; | |||||
| for (const SubgraphPtr &subgraph : subgraphs) { | |||||
| auto compute_graph = subgraph->subgraph_info.GetSubGraph(); | |||||
| for (NodePtr &node : compute_graph->GetDirectNode()) { | |||||
| OpDescPtr op_desc = node->GetOpDesc(); | |||||
| if (op_desc->HasAttr(ATTR_NAME_PARALLEL_GROUP)) { | |||||
| int64_t op_desc_stream_id = op_desc->GetStreamId(); | |||||
| stream_op_map[op_desc_stream_id].push_back(op_desc); | |||||
| } | |||||
| } | |||||
| } | |||||
| for (const auto &itr : stream_op_map) { | |||||
| if (itr.first == kInvalidStream) { | |||||
| continue; | |||||
| } | |||||
| std::map<int, vector<OpDescPtr>> group_op; | |||||
| for (const auto &op_desc : itr.second) { | |||||
| int group_id; | |||||
| if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id)) { | |||||
| GELOGE(GRAPH_FAILED, "Get node %s ATTR_NAME_PARALLEL_GROUP failed.", op_desc->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| group_op[group_id].emplace_back(op_desc); | |||||
| } | |||||
| for (const auto &itr : group_op) { | |||||
| const auto &op_vec = itr.second; | |||||
| if (op_vec.empty()) { | |||||
| continue; | |||||
| } | |||||
| int64_t new_stream_id = context.next_stream++; | |||||
| for (const auto &op : op_vec) { | |||||
| int64_t old_stream_id = op->GetStreamId(); | |||||
| op->SetStreamId(new_stream_id); | |||||
| GELOGI("chenhua Node %s assigned stream %ld from stream %ld.", op->GetName().c_str(), new_stream_id, | |||||
| old_stream_id); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| int64_t UpdateForSkippedEnginePass::GetSingleInoutStream(const NodePtr &node) const { | int64_t UpdateForSkippedEnginePass::GetSingleInoutStream(const NodePtr &node) const { | ||||
| set<int64_t> stream_ids; | set<int64_t> stream_ids; | ||||
| @@ -655,6 +697,7 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec | |||||
| passes.emplace_back(MakeShared<IndependentStreamPass>()); | passes.emplace_back(MakeShared<IndependentStreamPass>()); | ||||
| passes.emplace_back(MakeShared<AssignByDependencyPass>()); | passes.emplace_back(MakeShared<AssignByDependencyPass>()); | ||||
| passes.emplace_back(MakeShared<NodeStreamUpdatePass>()); | passes.emplace_back(MakeShared<NodeStreamUpdatePass>()); | ||||
| passes.emplace_back(MakeShared<UpdateForParallelGroupPass>()); | |||||
| passes.emplace_back(MakeShared<AllReduceParallelPass>()); | passes.emplace_back(MakeShared<AllReduceParallelPass>()); | ||||
| passes.emplace_back(MakeShared<UpdateForSkippedEnginePass>()); | passes.emplace_back(MakeShared<UpdateForSkippedEnginePass>()); | ||||
| } | } | ||||
| @@ -149,6 +149,13 @@ class NodeStreamUpdatePass : public LogicalStreamPass { | |||||
| Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | ||||
| }; | }; | ||||
| // assign stream by parallel group | |||||
| class UpdateForParallelGroupPass : public LogicalStreamPass { | |||||
| public: | |||||
| STREAM_PASS_DEFAULT_FUNC(UpdateForParallelGroupPass); | |||||
| Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
| }; | |||||
| // Update the stream of subgraphs to nodes. | // Update the stream of subgraphs to nodes. | ||||
| class UpdateForSkippedEnginePass : public LogicalStreamPass { | class UpdateForSkippedEnginePass : public LogicalStreamPass { | ||||
| public: | public: | ||||
| @@ -92,6 +92,7 @@ | |||||
| #include "graph/passes/global_step_insert_pass.h" | #include "graph/passes/global_step_insert_pass.h" | ||||
| #include "graph/passes/memcpy_addr_async_pass.h" | #include "graph/passes/memcpy_addr_async_pass.h" | ||||
| #include "graph/passes/hccl_continuous_memcpy_pass.h" | #include "graph/passes/hccl_continuous_memcpy_pass.h" | ||||
| #include "graph/passes/parallel_group_pass.h" | |||||
| #include "graph/build/label_allocator.h" | #include "graph/build/label_allocator.h" | ||||
| #include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
| #include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
| @@ -2203,6 +2204,7 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
| GE_DUMP(compute_graph, "OptimizeStage1_2"); | GE_DUMP(compute_graph, "OptimizeStage1_2"); | ||||
| PassManager graph_pass; | PassManager graph_pass; | ||||
| // the prune pass should between SwitchPass and SwitchToStreamSwitchPass | // the prune pass should between SwitchPass and SwitchToStreamSwitchPass | ||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ParallelGroupPass", new (std::nothrow) ParallelGroupPass)); | |||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); | ||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); | ||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)) | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)) | ||||
| @@ -0,0 +1,307 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/passes/parallel_group_pass.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| namespace ge { | |||||
| Status ParallelGroupPass::Run(ComputeGraphPtr graph) { | |||||
| GELOGI("ParallelGroupPass running"); | |||||
| if (graph == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "Input param graph is null, skip optimize parallel group graph."); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| if (graph->GetParentGraph() != nullptr) { | |||||
| GELOGW("Current graph %s is a subgraph, this pass only support root graph.", | |||||
| graph->GetName().c_str()); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| std::unordered_set<int> parallel_group; | |||||
| if (ProcessAllGraph(graph, parallel_group) != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Process graph %s failed.", graph->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (graph->TopologicalSorting() != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Graph topological sort failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| Status ParallelGroupPass::ProcessAllGraph(ComputeGraphPtr graph, std::unordered_set<int> ¶llel_group) { | |||||
| if (graph == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "Input param graph is null, skip optimize parallel group subgraph."); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> node_2_switch_merge; | |||||
| if (ProcessSwitch(graph, node_2_switch_merge) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Input param graph is null, skip optimize parallel group subgraph."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::map<int, vector<NodePtr>> group_node; | |||||
| const auto &candidates = graph->GetDirectNode(); | |||||
| for (const auto &node : candidates) { | |||||
| OpDescPtr op_desc = node->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| int group_id = -1; | |||||
| if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id)) { | |||||
| bool is_unknown_shape = false; | |||||
| auto ret = ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Get node[%s] shape status failed!", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| // only handle known shape node | |||||
| if (!is_unknown_shape) { | |||||
| group_node[group_id].push_back(node); | |||||
| parallel_group.insert(group_id); | |||||
| GELOGI("Find hccl node:%s, group_id=%d", op_desc->GetName().c_str(), group_id); | |||||
| } | |||||
| } | |||||
| const auto &subgraph_name = op_desc->GetSubgraphInstanceNames(); | |||||
| for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) { | |||||
| const auto sub_graph = graph->GetSubgraph(*name_iter); | |||||
| if (sub_graph != nullptr) { | |||||
| std::unordered_set<int> sub_parallel_group; | |||||
| auto ret = ProcessAllGraph(sub_graph, sub_parallel_group); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Process sub graph %s failed.", sub_graph->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (const auto &it : sub_parallel_group) { | |||||
| parallel_group.insert(it); | |||||
| group_node[it].emplace_back(node); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| for (const auto &itr : group_node) { | |||||
| const auto &node_vec = itr.second; | |||||
| if (node_vec.empty()) { | |||||
| continue; | |||||
| } | |||||
| NodePtr pre_node = node_vec[0]; | |||||
| NodePtr cur_node = nullptr; | |||||
| for (int i = 1; i < node_vec.size(); i++) { | |||||
| cur_node = node_vec[i]; | |||||
| auto tmp_pre_node = pre_node; | |||||
| auto tmp_cur_node = cur_node; | |||||
| GELOGI("original we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); | |||||
| ReplaceSwitchAndMerge(tmp_pre_node, tmp_cur_node, node_2_switch_merge); | |||||
| pre_node = cur_node; | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| void ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) { | |||||
| if (pre_node == cur_node) { | |||||
| GELOGI("--- pr_node == cur_node"); | |||||
| return; | |||||
| } | |||||
| const auto &in_node = cur_node->GetInAllNodes(); | |||||
| for (const auto &node : in_node) { | |||||
| if (pre_node == node) { | |||||
| GELOGI("--- pr_node and cur_node have linked"); | |||||
| return; | |||||
| } | |||||
| } | |||||
| auto pre_out_ctrl_anchor = pre_node->GetOutControlAnchor(); | |||||
| auto cur_in_ctrl_anchor = cur_node->GetInControlAnchor(); | |||||
| pre_out_ctrl_anchor->LinkTo(cur_in_ctrl_anchor); | |||||
| } | |||||
| Status ParallelGroupPass::ProcessSwitch(ComputeGraphPtr graph, | |||||
| std::map<NodePtr, | |||||
| std::pair<std::set<NodePtr>, | |||||
| NodePtr>> &node_2_switch_merge) { | |||||
| std::string type; | |||||
| const auto &direct_node = graph->GetDirectNode(); | |||||
| for (const auto &node : direct_node) { | |||||
| auto ret = GetOriginalType(node, type); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Get node %s type failed.", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if ((type != SWITCH) && (type != REFSWITCH)) { | |||||
| continue; | |||||
| } | |||||
| InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||||
| if (in_cond_anchor == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "in_cond_anchor is nullptr, node: %s.", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| OutDataAnchorPtr pred_cond_anchor = in_cond_anchor->GetPeerOutAnchor(); | |||||
| if (pred_cond_anchor == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "pred_cond_anchor is nullptr, node: %s.", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| // ignore while op | |||||
| if (pred_cond_anchor->GetOwnerNode()->GetType() == LOOPCOND) { | |||||
| continue; | |||||
| } | |||||
| std::deque<NodePtr> candidates; | |||||
| std::vector<NodePtr> merge_vec; | |||||
| std::set<NodePtr> group_set; | |||||
| std::set<NodePtr> visited; | |||||
| candidates.emplace_back(node); | |||||
| while (!candidates.empty()) { | |||||
| NodePtr tmp_node = candidates.front(); | |||||
| candidates.pop_front(); | |||||
| for (const auto &out_node : tmp_node->GetOutAllNodes()) { | |||||
| ret = GetOriginalType(out_node, type); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Get node %s type failed.", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if ((type == MERGE) || (type == REFMERGE)) { | |||||
| merge_vec.emplace_back(out_node); | |||||
| continue; | |||||
| } | |||||
| const auto op = out_node->GetOpDesc(); | |||||
| if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) { | |||||
| group_set.emplace(out_node); | |||||
| } | |||||
| if (visited.count(out_node) > 0) { | |||||
| continue; | |||||
| } | |||||
| candidates.push_back(out_node); | |||||
| visited.insert(out_node); | |||||
| } | |||||
| } | |||||
| std::sort(merge_vec.begin(), merge_vec.end(), | |||||
| [&direct_node] (NodePtr a, NodePtr b) -> bool { | |||||
| return std::find(direct_node.begin(), direct_node.end(), a) < | |||||
| std::find(direct_node.begin(), direct_node.end(), b); | |||||
| }); | |||||
| for (const auto &group_node : group_set) { | |||||
| auto it = node_2_switch_merge.find(group_node); | |||||
| if (it != node_2_switch_merge.end()) { | |||||
| auto &tmp = it->second; | |||||
| auto &switch_vec = tmp.first; | |||||
| const auto &merge_node = tmp.second; | |||||
| GELOGI(" --- hccl node: %s, switch node %s, merge node :%s.", | |||||
| group_node->GetName().c_str(), node->GetName().c_str(), merge_node->GetName().c_str()); | |||||
| if (merge_node != merge_vec.back()) { | |||||
| GELOGE(GRAPH_FAILED, "error: has two merge node: %s and %s.", | |||||
| merge_node->GetName().c_str(), merge_vec.back()->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| switch_vec.insert(node); | |||||
| } else { | |||||
| node_2_switch_merge.emplace(group_node, std::make_pair(std::set<NodePtr>{node}, merge_vec.back())); | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| Status ParallelGroupPass::GetOriginalType(const NodePtr &node, string &type) { | |||||
| if (node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "node is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| type = node->GetType(); | |||||
| GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS); | |||||
| if (node->GetOpDesc() == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "op_desc is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | |||||
| if (!ret) { | |||||
| GELOGE(INTERNAL_ERROR, "node is nullptr."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("Get FrameWorkOp original type [%s]", type.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| void ParallelGroupPass::ReplaceSwitchAndMerge(NodePtr &pre_node, | |||||
| NodePtr &cur_node, | |||||
| std::map<NodePtr, | |||||
| std::pair<std::set<NodePtr>, | |||||
| NodePtr>> &node_2_switch_merge) { | |||||
| auto pre_itr = node_2_switch_merge.find(pre_node); | |||||
| auto cur_itr = node_2_switch_merge.find(cur_node); | |||||
| if (pre_itr != node_2_switch_merge.end()) { | |||||
| if (cur_itr != node_2_switch_merge.end()) { | |||||
| const auto &pre_set = pre_itr->second.first; | |||||
| const auto &cur_set = cur_itr->second.first; | |||||
| if (!HasSameSwitch(pre_set, cur_set)) { | |||||
| pre_node = pre_itr->second.second; | |||||
| for (const auto &switch_node : cur_itr->second.first) { | |||||
| AddCtrlEdge(pre_node, switch_node); | |||||
| GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||||
| } | |||||
| } else { | |||||
| GELOGI("--- no need add ctrl edge"); | |||||
| } | |||||
| } else { | |||||
| pre_node = pre_itr->second.second; | |||||
| AddCtrlEdge(pre_node, cur_node); | |||||
| GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); | |||||
| } | |||||
| } else { | |||||
| if (cur_itr != node_2_switch_merge.end()) { | |||||
| for (const auto &switch_node : cur_itr->second.first) { | |||||
| int64_t pre_id = pre_node->GetOpDesc()->GetId(); | |||||
| int64_t switch_id = switch_node->GetOpDesc()->GetId(); | |||||
| if (pre_id > switch_id) { // special handle for merge and group node | |||||
| auto merge_node = cur_itr->second.second; | |||||
| AddCtrlEdge(merge_node, pre_node); | |||||
| GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", merge_node->GetName().c_str(), pre_node->GetName().c_str()); | |||||
| } else { | |||||
| AddCtrlEdge(pre_node, switch_node); | |||||
| GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| AddCtrlEdge(pre_node, cur_node); | |||||
| GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| bool ParallelGroupPass::HasSameSwitch(const std::set<NodePtr> &switch_set1, const std::set<NodePtr> &switch_set2) { | |||||
| for (const auto &node1 : switch_set1) { | |||||
| for (const auto &node2 : switch_set2) { | |||||
| if (node1 == node2) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H | |||||
| #define GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H | |||||
| #include <map> | |||||
| #include <map> | |||||
| #include <unordered_set> | |||||
| #include "graph/graph.h" | |||||
| #include "inc/graph_pass.h" | |||||
| namespace ge { | |||||
| class ParallelGroupPass : public GraphPass { | |||||
| public: | |||||
| Status Run(ComputeGraphPtr graph) override; | |||||
| private: | |||||
| Status ProcessAllGraph(ComputeGraphPtr graph, std::unordered_set<int> ¶llel_group); | |||||
| void AddCtrlEdge(NodePtr pre_node, NodePtr cur_node); | |||||
| Status GetOriginalType(const ge::NodePtr &node, std::string &type); | |||||
| void ReplaceSwitchAndMerge(NodePtr &pre_node, | |||||
| NodePtr &cur_node, | |||||
| std::map<NodePtr, | |||||
| std::pair<std::set<NodePtr>, | |||||
| NodePtr>> &node_2_switch_merge); | |||||
| bool HasSameSwitch(const std::set<NodePtr> &a, const std::set<NodePtr> &b); | |||||
| Status ProcessSwitch(ComputeGraphPtr graph, | |||||
| std::map<NodePtr, | |||||
| std::pair<std::set<NodePtr>, | |||||
| NodePtr>> &node_2_switch_merge); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H | |||||
| @@ -266,8 +266,9 @@ set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/model/ge_model.cc" | "${GE_CODE_DIR}/ge/model/ge_model.cc" | ||||
| "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" | "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/load/model_manager/model_utils.cc" | "${GE_CODE_DIR}/ge/graph/load/model_manager/model_utils.cc" | ||||
| @@ -503,14 +504,15 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/same_transdata_breadth_fusion_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/same_transdata_breadth_fusion_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/common/transop_util.cc" | "${GE_CODE_DIR}/ge/graph/common/transop_util.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc" | ||||
| #"${GE_CODE_DIR}/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc" | #"${GE_CODE_DIR}/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc" | ||||
| @@ -671,7 +673,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/trans_op_depth_fusion_pass_unittest.cc" | "graph/passes/trans_op_depth_fusion_pass_unittest.cc" | ||||
| "graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc" | "graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc" | ||||
| "graph/passes/constant_folding_pass_unittest.cc" | "graph/passes/constant_folding_pass_unittest.cc" | ||||
| "graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc" | |||||
| "graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc" | |||||
| "graph/passes/stop_gradient_pass_unittest.cc" | "graph/passes/stop_gradient_pass_unittest.cc" | ||||
| "graph/passes/prevent_gradient_pass_unittest.cc" | "graph/passes/prevent_gradient_pass_unittest.cc" | ||||
| "graph/passes/identity_pass_unittest.cc" | "graph/passes/identity_pass_unittest.cc" | ||||
| @@ -688,6 +690,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/no_use_reshape_remove_pass_unittest.cc" | "graph/passes/no_use_reshape_remove_pass_unittest.cc" | ||||
| "graph/passes/infershape_pass_unittest.cc" | "graph/passes/infershape_pass_unittest.cc" | ||||
| "graph/passes/multi_batch_clone_pass_unittest.cc" | "graph/passes/multi_batch_clone_pass_unittest.cc" | ||||
| "graph/passes/parallel_group_pass_unittest.cc" | |||||
| ) | ) | ||||
| set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
| @@ -32,6 +32,7 @@ | |||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| using namespace std; | using namespace std; | ||||
| @@ -148,6 +149,22 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||||
| return subgraph; | return subgraph; | ||||
| } | } | ||||
| SubGraphInfoPtr CreateParallelGroupSubgraphWithName(const string &name, const string &engine, | |||||
| const string &stream_label = "", | |||||
| int group_id = 1) { | |||||
| ComputeGraphPtr compute_graph = make_shared<ComputeGraph>(name); | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>("relu", "Relu"); | |||||
| op_desc->AddInputDesc(GeTensorDesc()); | |||||
| op_desc->AddOutputDesc(GeTensorDesc()); | |||||
| AttrUtils::SetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id); | |||||
| compute_graph->AddNode(op_desc); | |||||
| SubGraphInfoPtr subgraph = BuildSubGraph(compute_graph, engine, stream_label); | |||||
| AddPlaceHolderAndEnd(subgraph, 1, 1); | |||||
| return subgraph; | |||||
| } | |||||
| SubGraphInfoPtr CreateSubgraph(const string &engine, const string &stream_label = "", int in_num = 1, | SubGraphInfoPtr CreateSubgraph(const string &engine, const string &stream_label = "", int in_num = 1, | ||||
| int out_num = 1) { | int out_num = 1) { | ||||
| return CreateSubgraphWithName("graph", engine, stream_label, in_num, out_num); | return CreateSubgraphWithName("graph", engine, stream_label, in_num, out_num); | ||||
| @@ -878,4 +895,30 @@ TEST_F(UtestLogicalStreamAllocator, test_all_reduce_parallel_pass) { | |||||
| EXPECT_EQ(ret, NOT_CHANGED); | EXPECT_EQ(ret, NOT_CHANGED); | ||||
| } | } | ||||
| TEST_F(UtestLogicalStreamAllocator, test_parallel_group) { | |||||
| SubGraphInfoPtr data = CreateDataSubgraph(); | |||||
| SubGraphInfoPtr subgraph1 = CreateParallelGroupSubgraphWithName("graph1", "engine1", ""); | |||||
| SubGraphInfoPtr subgraph2 = CreateParallelGroupSubgraphWithName("graph2", "engine2", "", 2); | |||||
| SubGraphInfoPtr subgraph3 = CreateParallelGroupSubgraphWithName("graph3", "engine3", "", 3); | |||||
| SubGraphInfoPtr subgraph4 = CreateParallelGroupSubgraphWithName("graph4", "engine4", "", 4); | |||||
| LinkSubGraph(data, "end", subgraph1, "placeholder"); | |||||
| LinkSubGraph(subgraph1, "end", subgraph2, "placeholder"); | |||||
| LinkSubGraph(subgraph2, "end", subgraph3, "placeholder"); | |||||
| LinkSubGraph(subgraph3, "end", subgraph4, "placeholder"); | |||||
| EngineConfPtr conf1 = make_shared<EngineConf>(); | |||||
| conf1->id = subgraph1->GetEngineName(); | |||||
| EngineConfPtr conf2 = make_shared<EngineConf>(); | |||||
| conf2->id = subgraph2->GetEngineName(); | |||||
| conf2->attach = false; | |||||
| EngineConfPtr conf3 = make_shared<EngineConf>(); | |||||
| conf3->id = subgraph3->GetEngineName(); | |||||
| conf3->attach = false; | |||||
| EngineConfPtr conf4 = make_shared<EngineConf>(); | |||||
| conf4->id = subgraph4->GetEngineName(); | |||||
| Status status = AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4}, {conf1, conf2, conf3, conf4}); | |||||
| EXPECT_EQ(status, ge::SUCCESS); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -0,0 +1,259 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <cstdint> | |||||
| #include <string> | |||||
| #define private public | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "inc/pass_manager.h" | |||||
| #include "utils/graph_utils.h" | |||||
| #include "graph/passes/parallel_group_pass.h" | |||||
| #undef private | |||||
| namespace ge { | |||||
| namespace { | |||||
| class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||||
| protected: | |||||
| UtestGraphPassesParallelGgroupPass() { | |||||
| graph_ = std::make_shared<ComputeGraph>("test"); | |||||
| sub_graph_ = std::make_shared<ComputeGraph>("test_subgraph"); | |||||
| vector<int64_t> shape_vec{1, 1, 1, 1}; | |||||
| GeShape shape = GeShape(shape_vec); | |||||
| default_tensor_desc_ = std::make_shared<GeTensorDesc>(); | |||||
| default_tensor_desc_->SetShape(shape); | |||||
| default_tensor_desc_->SetFormat(FORMAT_NCHW); | |||||
| default_tensor_desc_->SetDataType(DT_FLOAT); | |||||
| } | |||||
| NodePtr NewNode(const std::string &name, const std::string &type, | |||||
| int input_cnt, int output_cnt, bool isSubgraph = false) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
| for (int i = 0; i < input_cnt; ++i) { | |||||
| op_desc->AddInputDesc(default_tensor_desc_->Clone()); | |||||
| } | |||||
| for (int i = 0; i < output_cnt; ++i) { | |||||
| op_desc->AddOutputDesc(default_tensor_desc_->Clone()); | |||||
| } | |||||
| NodePtr node = nullptr; | |||||
| if (isSubgraph) { | |||||
| node = sub_graph_->AddNode(op_desc); | |||||
| (void)node->SetOwnerComputeGraph(sub_graph_); | |||||
| } else { | |||||
| node = graph_->AddNode(op_desc); | |||||
| (void)node->SetOwnerComputeGraph(graph_); | |||||
| } | |||||
| return node; | |||||
| } | |||||
| void BuildDefaultGraph() { | |||||
| /// input | |||||
| /// \ | |||||
| /// sqrt1 pred | |||||
| /// \ / | |||||
| /// Switch | |||||
| /// | | | |||||
| /// F T | |||||
| /// | | | |||||
| /// Merge | |||||
| /// | | |||||
| /// relu | |||||
| /// | | |||||
| /// sqrt2 | |||||
| pred_node_ = NewNode("pred", GREATER, 2, 1); | |||||
| input_node_ = NewNode("input", RELU, 0, 1); | |||||
| AttrUtils::SetInt(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); | |||||
| sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1); | |||||
| switch_node_ = NewNode("switch", SWITCH, 2, 2); | |||||
| output_false_node_ = NewNode("false_output", RELU, 1, 1); | |||||
| AttrUtils::SetInt(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); | |||||
| output_true_node_ = NewNode("true_output", RELU, 1, 1); | |||||
| AttrUtils::SetInt(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); | |||||
| merge_node_ = NewNode("merge", MERGE, 2, 1); | |||||
| relu_node_ = NewNode("relu", RELU, 1, 1); | |||||
| sqrt_node2_ = NewNode("sqrt2", SQRT, 1, 1); | |||||
| AttrUtils::SetInt(sqrt_node2_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); | |||||
| switch_node_->GetOpDesc()->SetIsInputConst({false, false}); | |||||
| GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(sqrt_node1_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node2_->GetInDataAnchor(0)); | |||||
| output_false_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
| output_true_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
| } | |||||
| void BuildDefaultGraph1() { | |||||
| /// input pred | |||||
| /// \ / | |||||
| /// Switch | |||||
| /// | | | |||||
| /// ----F T---- | |||||
| /// \ | / \ | |||||
| /// \ Merge1 Merge2 | |||||
| /// \_________| | |||||
| pred_node_ = NewNode("pred", GREATER, 2, 1); | |||||
| input_node_ = NewNode("input", RELU, 0, 1); | |||||
| AttrUtils::SetInt(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); | |||||
| switch_node_ = NewNode("switch", SWITCH, 2, 2); | |||||
| output_false_node_ = NewNode("false_output", RELU, 1, 2); | |||||
| AttrUtils::SetInt(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); | |||||
| output_true_node_ = NewNode("true_output", RELU, 1, 2); | |||||
| AttrUtils::SetInt(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); | |||||
| merge_node_ = NewNode("merge", MERGE, 2, 1); | |||||
| merge_node1_ = NewNode("merge1", MERGE, 2, 1); | |||||
| switch_node_->GetOpDesc()->SetIsInputConst({false, false}); | |||||
| GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1)); | |||||
| output_false_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
| output_true_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
| } | |||||
| void BuildDefaultGraph2() { | |||||
| /// input pred input pred1 | |||||
| /// \ / \ / | |||||
| /// Switch Switch1 | |||||
| /// | | _______| | |||||
| /// | | / | |||||
| /// ____F T____ | |||||
| /// \ | / \ | |||||
| /// \ Merge1 Merge2 | |||||
| /// \__________| | |||||
| pred_node_ = NewNode("pred", GREATER, 2, 1); | |||||
| pred_node1_ = NewNode("pred", LESS, 2, 1); | |||||
| input_node_ = NewNode("input", RELU, 0, 2); | |||||
| AttrUtils::SetInt(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); | |||||
| switch_node_ = NewNode("switch", SWITCH, 2, 2); | |||||
| switch_node1_ = NewNode("switch1", SWITCH, 2, 2); | |||||
| output_false_node_ = NewNode("false_output", RELU, 2, 2); | |||||
| AttrUtils::SetInt(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); | |||||
| output_true_node_ = NewNode("true_output", RELU, 2, 2); | |||||
| AttrUtils::SetInt(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 2); | |||||
| merge_node_ = NewNode("merge", MERGE, 2, 1); | |||||
| merge_node1_ = NewNode("merge1", MERGE, 2, 1); | |||||
| switch_node_->GetOpDesc()->SetIsInputConst({false, false}); | |||||
| GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(input_node_->GetOutDataAnchor(1), switch_node1_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(pred_node1_->GetOutDataAnchor(0), switch_node1_->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(switch_node1_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(switch_node1_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1)); | |||||
| output_false_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
| output_true_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
| } | |||||
| ComputeGraphPtr graph_; | |||||
| ComputeGraphPtr sub_graph_; | |||||
| GeTensorDescPtr default_tensor_desc_; | |||||
| ParallelGroupPass pass_; | |||||
| NodePtr pred_node_; | |||||
| NodePtr pred_node1_; | |||||
| NodePtr sqrt_node1_; | |||||
| NodePtr sqrt_node2_; | |||||
| NodePtr input_node_; | |||||
| NodePtr switch_node_; | |||||
| NodePtr switch_node1_; | |||||
| NodePtr output_false_node_; | |||||
| NodePtr output_true_node_; | |||||
| NodePtr merge_node_; | |||||
| NodePtr merge_node1_; | |||||
| NodePtr relu_node_; | |||||
| }; | |||||
| TEST_F(UtestGraphPassesParallelGgroupPass, null_graph) { | |||||
| ComputeGraphPtr graph = nullptr; | |||||
| auto ret = pass_.Run(graph); | |||||
| EXPECT_EQ(ret, PARAM_INVALID); | |||||
| } | |||||
| TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph) { | |||||
| BuildDefaultGraph(); | |||||
| auto ret = pass_.Run(graph_); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(switch_node_->GetInControlAnchor())); | |||||
| EXPECT_EQ(true, merge_node_->GetOutControlAnchor()->IsLinkedWith(sqrt_node2_->GetInControlAnchor())); | |||||
| EXPECT_EQ(false, output_false_node_->GetOutControlAnchor()->IsLinkedWith(output_true_node_->GetInControlAnchor())); | |||||
| } | |||||
| TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph1) { | |||||
| BuildDefaultGraph1(); | |||||
| auto ret = pass_.Run(graph_); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| } | |||||
| TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) { | |||||
| BuildDefaultGraph2(); | |||||
| auto ret = pass_.Run(graph_); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| } | |||||
| TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) { | |||||
| BuildDefaultGraph1(); | |||||
| NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true); | |||||
| NodePtr input_node2 = NewNode("input2", RELU, 0, 1, true); | |||||
| NodePtr add = NewNode("add", ADD, 2, 1, true); | |||||
| AttrUtils::SetInt(input_node1->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); | |||||
| AttrUtils::SetInt(input_node2->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1); | |||||
| sub_graph_->SetParentNode(input_node_); | |||||
| sub_graph_->SetParentGraph(graph_); | |||||
| auto ret = graph_->AddSubgraph(sub_graph_->GetName(), sub_graph_); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| ret = input_node_->GetOpDesc()->AddSubgraphName(sub_graph_->GetName()); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| ret = input_node_->GetOpDesc()->SetSubgraphInstanceName(0, sub_graph_->GetName()); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| ret = pass_.Run(sub_graph_); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| ret = pass_.Run(graph_); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| } | |||||
| } // namespace | |||||
| } // namespace ge | |||||