Browse Source

parallel group

pull/1141/head
陈华 4 years ago
parent
commit
efad22801d
9 changed files with 718 additions and 5 deletions
  1. +2
    -0
      ge/CMakeLists.txt
  2. +43
    -0
      ge/graph/build/logical_stream_allocator.cc
  3. +7
    -0
      ge/graph/build/logical_stream_allocator.h
  4. +2
    -0
      ge/graph/manager/graph_manager.cc
  5. +307
    -0
      ge/graph/passes/parallel_group_pass.cc
  6. +47
    -0
      ge/graph/passes/parallel_group_pass.h
  7. +8
    -5
      tests/ut/ge/CMakeLists.txt
  8. +43
    -0
      tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc
  9. +259
    -0
      tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc

+ 2
- 0
ge/CMakeLists.txt View File

@@ -214,6 +214,7 @@ set(TRAIN_SRC_LIST
"graph/passes/inplace_support_check_pass.cc"
"graph/passes/flow_ctrl_pass.cc"
"graph/passes/global_step_insert_pass.cc"
"graph/passes/parallel_group_pass.cc"
"host_kernels/transpose_kernel.cc"
"host_kernels/add_kernel.cc"
"host_kernels/broadcast_args_kernel.cc"
@@ -606,6 +607,7 @@ set(INFER_SRC_LIST
"graph/passes/hccl_group_pass.cc"
"graph/passes/memcpy_addr_async_pass.cc"
"graph/passes/set_input_output_offset_pass.cc"
"graph/passes/parallel_group_pass.cc"
"graph/manager/model_manager/event_manager.cc"
"graph/manager/util/rt_context_util.cc"
"graph/manager/util/variable_accelerate_ctrl.cc"


+ 43
- 0
ge/graph/build/logical_stream_allocator.cc View File

@@ -366,6 +366,48 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr
return SUCCESS;
}

Status UpdateForParallelGroupPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) {
std::map<int, vector<OpDescPtr>> stream_op_map;
for (const SubgraphPtr &subgraph : subgraphs) {
auto compute_graph = subgraph->subgraph_info.GetSubGraph();
for (NodePtr &node : compute_graph->GetDirectNode()) {
OpDescPtr op_desc = node->GetOpDesc();
if (op_desc->HasAttr(ATTR_NAME_PARALLEL_GROUP)) {
int64_t op_desc_stream_id = op_desc->GetStreamId();
stream_op_map[op_desc_stream_id].push_back(op_desc);
}
}
}
for (const auto &itr : stream_op_map) {
if (itr.first == kInvalidStream) {
continue;
}
std::map<int, vector<OpDescPtr>> group_op;
for (const auto &op_desc : itr.second) {
int group_id;
if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id)) {
GELOGE(GRAPH_FAILED, "Get node %s ATTR_NAME_PARALLEL_GROUP failed.", op_desc->GetName().c_str());
return GRAPH_FAILED;
}
group_op[group_id].emplace_back(op_desc);
}
for (const auto &itr : group_op) {
const auto &op_vec = itr.second;
if (op_vec.empty()) {
continue;
}
int64_t new_stream_id = context.next_stream++;
for (const auto &op : op_vec) {
int64_t old_stream_id = op->GetStreamId();
op->SetStreamId(new_stream_id);
GELOGI("chenhua Node %s assigned stream %ld from stream %ld.", op->GetName().c_str(), new_stream_id,
old_stream_id);
}
}
}
return SUCCESS;
}

int64_t UpdateForSkippedEnginePass::GetSingleInoutStream(const NodePtr &node) const {
set<int64_t> stream_ids;

@@ -655,6 +697,7 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec
passes.emplace_back(MakeShared<IndependentStreamPass>());
passes.emplace_back(MakeShared<AssignByDependencyPass>());
passes.emplace_back(MakeShared<NodeStreamUpdatePass>());
passes.emplace_back(MakeShared<UpdateForParallelGroupPass>());
passes.emplace_back(MakeShared<AllReduceParallelPass>());
passes.emplace_back(MakeShared<UpdateForSkippedEnginePass>());
}


+ 7
- 0
ge/graph/build/logical_stream_allocator.h View File

@@ -149,6 +149,13 @@ class NodeStreamUpdatePass : public LogicalStreamPass {
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
};

// assign stream by parallel group
class UpdateForParallelGroupPass : public LogicalStreamPass {
public:
STREAM_PASS_DEFAULT_FUNC(UpdateForParallelGroupPass);
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
};

// Update the stream of subgraphs to nodes.
class UpdateForSkippedEnginePass : public LogicalStreamPass {
public:


+ 2
- 0
ge/graph/manager/graph_manager.cc View File

@@ -92,6 +92,7 @@
#include "graph/passes/global_step_insert_pass.h"
#include "graph/passes/memcpy_addr_async_pass.h"
#include "graph/passes/hccl_continuous_memcpy_pass.h"
#include "graph/passes/parallel_group_pass.h"
#include "graph/build/label_allocator.h"
#include "graph/utils/tensor_adapter.h"
#include "inc/pass_manager.h"
@@ -2203,6 +2204,7 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
GE_DUMP(compute_graph, "OptimizeStage1_2");
PassManager graph_pass;
// the prune pass should between SwitchPass and SwitchToStreamSwitchPass
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ParallelGroupPass", new (std::nothrow) ParallelGroupPass));
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass));
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass));
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass))


+ 307
- 0
ge/graph/passes/parallel_group_pass.cc View File

@@ -0,0 +1,307 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/parallel_group_pass.h"

#include "framework/common/debug/ge_log.h"
#include "common/ge/ge_util.h"
#include "framework/common/ge_inner_error_codes.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"

namespace ge {

Status ParallelGroupPass::Run(ComputeGraphPtr graph) {
GELOGI("ParallelGroupPass running");
if (graph == nullptr) {
GELOGE(PARAM_INVALID, "Input param graph is null, skip optimize parallel group graph.");
return PARAM_INVALID;
}

if (graph->GetParentGraph() != nullptr) {
GELOGW("Current graph %s is a subgraph, this pass only support root graph.",
graph->GetName().c_str());
return GRAPH_SUCCESS;
}

std::unordered_set<int> parallel_group;
if (ProcessAllGraph(graph, parallel_group) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Process graph %s failed.", graph->GetName().c_str());
return INTERNAL_ERROR;
}

if (graph->TopologicalSorting() != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Graph topological sort failed.");
return GRAPH_FAILED;
}

return GRAPH_SUCCESS;
}

Status ParallelGroupPass::ProcessAllGraph(ComputeGraphPtr graph, std::unordered_set<int> &parallel_group) {
if (graph == nullptr) {
GELOGE(PARAM_INVALID, "Input param graph is null, skip optimize parallel group subgraph.");
return PARAM_INVALID;
}

std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> node_2_switch_merge;
if (ProcessSwitch(graph, node_2_switch_merge) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Input param graph is null, skip optimize parallel group subgraph.");
return GRAPH_FAILED;
}

std::map<int, vector<NodePtr>> group_node;
const auto &candidates = graph->GetDirectNode();
for (const auto &node : candidates) {
OpDescPtr op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
continue;
}
int group_id = -1;
if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id)) {
bool is_unknown_shape = false;
auto ret = ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape);
if (ret != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Get node[%s] shape status failed!", node->GetName().c_str());
return GRAPH_FAILED;
}
// only handle known shape node
if (!is_unknown_shape) {
group_node[group_id].push_back(node);
parallel_group.insert(group_id);
GELOGI("Find hccl node:%s, group_id=%d", op_desc->GetName().c_str(), group_id);
}
}

const auto &subgraph_name = op_desc->GetSubgraphInstanceNames();
for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) {
const auto sub_graph = graph->GetSubgraph(*name_iter);
if (sub_graph != nullptr) {
std::unordered_set<int> sub_parallel_group;
auto ret = ProcessAllGraph(sub_graph, sub_parallel_group);
if (ret != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Process sub graph %s failed.", sub_graph->GetName().c_str());
return GRAPH_FAILED;
}
for (const auto &it : sub_parallel_group) {
parallel_group.insert(it);
group_node[it].emplace_back(node);
}
}
}
}

for (const auto &itr : group_node) {
const auto &node_vec = itr.second;
if (node_vec.empty()) {
continue;
}
NodePtr pre_node = node_vec[0];
NodePtr cur_node = nullptr;
for (int i = 1; i < node_vec.size(); i++) {
cur_node = node_vec[i];
auto tmp_pre_node = pre_node;
auto tmp_cur_node = cur_node;
GELOGI("original we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
ReplaceSwitchAndMerge(tmp_pre_node, tmp_cur_node, node_2_switch_merge);
pre_node = cur_node;
}
}

return GRAPH_SUCCESS;
}

void ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) {
if (pre_node == cur_node) {
GELOGI("--- pr_node == cur_node");
return;
}
const auto &in_node = cur_node->GetInAllNodes();
for (const auto &node : in_node) {
if (pre_node == node) {
GELOGI("--- pr_node and cur_node have linked");
return;
}
}
auto pre_out_ctrl_anchor = pre_node->GetOutControlAnchor();
auto cur_in_ctrl_anchor = cur_node->GetInControlAnchor();
pre_out_ctrl_anchor->LinkTo(cur_in_ctrl_anchor);
}

Status ParallelGroupPass::ProcessSwitch(ComputeGraphPtr graph,
std::map<NodePtr,
std::pair<std::set<NodePtr>,
NodePtr>> &node_2_switch_merge) {
std::string type;
const auto &direct_node = graph->GetDirectNode();
for (const auto &node : direct_node) {
auto ret = GetOriginalType(node, type);
if (ret != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Get node %s type failed.", node->GetName().c_str());
return GRAPH_FAILED;
}
if ((type != SWITCH) && (type != REFSWITCH)) {
continue;
}
InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
if (in_cond_anchor == nullptr) {
GELOGE(INTERNAL_ERROR, "in_cond_anchor is nullptr, node: %s.", node->GetName().c_str());
return INTERNAL_ERROR;
}
OutDataAnchorPtr pred_cond_anchor = in_cond_anchor->GetPeerOutAnchor();
if (pred_cond_anchor == nullptr) {
GELOGE(INTERNAL_ERROR, "pred_cond_anchor is nullptr, node: %s.", node->GetName().c_str());
return INTERNAL_ERROR;
}
// ignore while op
if (pred_cond_anchor->GetOwnerNode()->GetType() == LOOPCOND) {
continue;
}
std::deque<NodePtr> candidates;
std::vector<NodePtr> merge_vec;
std::set<NodePtr> group_set;
std::set<NodePtr> visited;
candidates.emplace_back(node);
while (!candidates.empty()) {
NodePtr tmp_node = candidates.front();
candidates.pop_front();
for (const auto &out_node : tmp_node->GetOutAllNodes()) {
ret = GetOriginalType(out_node, type);
if (ret != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Get node %s type failed.", node->GetName().c_str());
return GRAPH_FAILED;
}
if ((type == MERGE) || (type == REFMERGE)) {
merge_vec.emplace_back(out_node);
continue;
}
const auto op = out_node->GetOpDesc();
if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) {
group_set.emplace(out_node);
}
if (visited.count(out_node) > 0) {
continue;
}
candidates.push_back(out_node);
visited.insert(out_node);
}
}
std::sort(merge_vec.begin(), merge_vec.end(),
[&direct_node] (NodePtr a, NodePtr b) -> bool {
return std::find(direct_node.begin(), direct_node.end(), a) <
std::find(direct_node.begin(), direct_node.end(), b);
});
for (const auto &group_node : group_set) {
auto it = node_2_switch_merge.find(group_node);
if (it != node_2_switch_merge.end()) {
auto &tmp = it->second;
auto &switch_vec = tmp.first;
const auto &merge_node = tmp.second;
GELOGI(" --- hccl node: %s, switch node %s, merge node :%s.",
group_node->GetName().c_str(), node->GetName().c_str(), merge_node->GetName().c_str());
if (merge_node != merge_vec.back()) {
GELOGE(GRAPH_FAILED, "error: has two merge node: %s and %s.",
merge_node->GetName().c_str(), merge_vec.back()->GetName().c_str());
return GRAPH_FAILED;
}
switch_vec.insert(node);
} else {
node_2_switch_merge.emplace(group_node, std::make_pair(std::set<NodePtr>{node}, merge_vec.back()));
}
}
}

return GRAPH_SUCCESS;
}

Status ParallelGroupPass::GetOriginalType(const NodePtr &node, string &type) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "node is nullptr.");
return GRAPH_FAILED;
}
type = node->GetType();
GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS);
if (node->GetOpDesc() == nullptr) {
GELOGE(GRAPH_FAILED, "op_desc is nullptr.");
return GRAPH_FAILED;
}
bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
if (!ret) {
GELOGE(INTERNAL_ERROR, "node is nullptr.");
return INTERNAL_ERROR;
}
GELOGD("Get FrameWorkOp original type [%s]", type.c_str());
return SUCCESS;
}

void ParallelGroupPass::ReplaceSwitchAndMerge(NodePtr &pre_node,
NodePtr &cur_node,
std::map<NodePtr,
std::pair<std::set<NodePtr>,
NodePtr>> &node_2_switch_merge) {
auto pre_itr = node_2_switch_merge.find(pre_node);
auto cur_itr = node_2_switch_merge.find(cur_node);
if (pre_itr != node_2_switch_merge.end()) {
if (cur_itr != node_2_switch_merge.end()) {
const auto &pre_set = pre_itr->second.first;
const auto &cur_set = cur_itr->second.first;
if (!HasSameSwitch(pre_set, cur_set)) {
pre_node = pre_itr->second.second;
for (const auto &switch_node : cur_itr->second.first) {
AddCtrlEdge(pre_node, switch_node);
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), switch_node->GetName().c_str());
}
} else {
GELOGI("--- no need add ctrl edge");
}
} else {
pre_node = pre_itr->second.second;
AddCtrlEdge(pre_node, cur_node);
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
}
} else {
if (cur_itr != node_2_switch_merge.end()) {
for (const auto &switch_node : cur_itr->second.first) {
int64_t pre_id = pre_node->GetOpDesc()->GetId();
int64_t switch_id = switch_node->GetOpDesc()->GetId();
if (pre_id > switch_id) { // special handle for merge and group node
auto merge_node = cur_itr->second.second;
AddCtrlEdge(merge_node, pre_node);
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", merge_node->GetName().c_str(), pre_node->GetName().c_str());
} else {
AddCtrlEdge(pre_node, switch_node);
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), switch_node->GetName().c_str());
}
}
} else {
AddCtrlEdge(pre_node, cur_node);
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
}
}
}

bool ParallelGroupPass::HasSameSwitch(const std::set<NodePtr> &switch_set1, const std::set<NodePtr> &switch_set2) {
for (const auto &node1 : switch_set1) {
for (const auto &node2 : switch_set2) {
if (node1 == node2) {
return true;
}
}
}
return false;
}
} // namespace ge

+ 47
- 0
ge/graph/passes/parallel_group_pass.h View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H
#define GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H


#include <map>
#include <map>
#include <unordered_set>
#include "graph/graph.h"
#include "inc/graph_pass.h"

namespace ge {
class ParallelGroupPass : public GraphPass {
public:
Status Run(ComputeGraphPtr graph) override;
private:
Status ProcessAllGraph(ComputeGraphPtr graph, std::unordered_set<int> &parallel_group);
void AddCtrlEdge(NodePtr pre_node, NodePtr cur_node);
Status GetOriginalType(const ge::NodePtr &node, std::string &type);
void ReplaceSwitchAndMerge(NodePtr &pre_node,
NodePtr &cur_node,
std::map<NodePtr,
std::pair<std::set<NodePtr>,
NodePtr>> &node_2_switch_merge);
bool HasSameSwitch(const std::set<NodePtr> &a, const std::set<NodePtr> &b);
Status ProcessSwitch(ComputeGraphPtr graph,
std::map<NodePtr,
std::pair<std::set<NodePtr>,
NodePtr>> &node_2_switch_merge);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H

+ 8
- 5
tests/ut/ge/CMakeLists.txt View File

@@ -266,8 +266,9 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc"
"${GE_CODE_DIR}/ge/model/ge_model.cc"
"${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/model_utils.cc"
@@ -503,14 +504,15 @@ set(GRAPH_PASS_COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/same_transdata_breadth_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/compile_nodes_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc"
"${GE_CODE_DIR}/ge/graph/common/transop_util.cc"
"${GE_CODE_DIR}/ge/graph/passes/flow_ctrl_pass.cc"
#"${GE_CODE_DIR}/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc"
@@ -671,7 +673,7 @@ set(PASS_TEST_FILES
"graph/passes/trans_op_depth_fusion_pass_unittest.cc"
"graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc"
"graph/passes/constant_folding_pass_unittest.cc"
"graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc"
"graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc"
"graph/passes/stop_gradient_pass_unittest.cc"
"graph/passes/prevent_gradient_pass_unittest.cc"
"graph/passes/identity_pass_unittest.cc"
@@ -688,6 +690,7 @@ set(PASS_TEST_FILES
"graph/passes/no_use_reshape_remove_pass_unittest.cc"
"graph/passes/infershape_pass_unittest.cc"
"graph/passes/multi_batch_clone_pass_unittest.cc"
"graph/passes/parallel_group_pass_unittest.cc"
)

set(KERNEL_TEST_FILES


+ 43
- 0
tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc View File

@@ -32,6 +32,7 @@
#include "graph/compute_graph.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/debug/ge_attr_define.h"

using namespace std;

@@ -148,6 +149,22 @@ class UtestLogicalStreamAllocator : public testing::Test {
return subgraph;
}

SubGraphInfoPtr CreateParallelGroupSubgraphWithName(const string &name, const string &engine,
const string &stream_label = "",
int group_id = 1) {
ComputeGraphPtr compute_graph = make_shared<ComputeGraph>(name);
OpDescPtr op_desc = std::make_shared<OpDesc>("relu", "Relu");
op_desc->AddInputDesc(GeTensorDesc());
op_desc->AddOutputDesc(GeTensorDesc());
AttrUtils::SetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id);
compute_graph->AddNode(op_desc);

SubGraphInfoPtr subgraph = BuildSubGraph(compute_graph, engine, stream_label);
AddPlaceHolderAndEnd(subgraph, 1, 1);

return subgraph;
}

SubGraphInfoPtr CreateSubgraph(const string &engine, const string &stream_label = "", int in_num = 1,
int out_num = 1) {
return CreateSubgraphWithName("graph", engine, stream_label, in_num, out_num);
@@ -878,4 +895,30 @@ TEST_F(UtestLogicalStreamAllocator, test_all_reduce_parallel_pass) {
EXPECT_EQ(ret, NOT_CHANGED);
}

TEST_F(UtestLogicalStreamAllocator, test_parallel_group) {
SubGraphInfoPtr data = CreateDataSubgraph();
SubGraphInfoPtr subgraph1 = CreateParallelGroupSubgraphWithName("graph1", "engine1", "");
SubGraphInfoPtr subgraph2 = CreateParallelGroupSubgraphWithName("graph2", "engine2", "", 2);
SubGraphInfoPtr subgraph3 = CreateParallelGroupSubgraphWithName("graph3", "engine3", "", 3);
SubGraphInfoPtr subgraph4 = CreateParallelGroupSubgraphWithName("graph4", "engine4", "", 4);
LinkSubGraph(data, "end", subgraph1, "placeholder");
LinkSubGraph(subgraph1, "end", subgraph2, "placeholder");
LinkSubGraph(subgraph2, "end", subgraph3, "placeholder");
LinkSubGraph(subgraph3, "end", subgraph4, "placeholder");

EngineConfPtr conf1 = make_shared<EngineConf>();
conf1->id = subgraph1->GetEngineName();
EngineConfPtr conf2 = make_shared<EngineConf>();
conf2->id = subgraph2->GetEngineName();
conf2->attach = false;
EngineConfPtr conf3 = make_shared<EngineConf>();
conf3->id = subgraph3->GetEngineName();
conf3->attach = false;
EngineConfPtr conf4 = make_shared<EngineConf>();
conf4->id = subgraph4->GetEngineName();

Status status = AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4}, {conf1, conf2, conf3, conf4});
EXPECT_EQ(status, ge::SUCCESS);
}

} // namespace ge

+ 259
- 0
tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc View File

@@ -0,0 +1,259 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <cstdint>
#include <string>

#define private public

#include "common/ge_inner_error_codes.h"
#include "inc/pass_manager.h"
#include "utils/graph_utils.h"
#include "graph/passes/parallel_group_pass.h"
#undef private

namespace ge {
namespace {

class UtestGraphPassesParallelGgroupPass : public testing::Test {
protected:
UtestGraphPassesParallelGgroupPass() {
graph_ = std::make_shared<ComputeGraph>("test");
sub_graph_ = std::make_shared<ComputeGraph>("test_subgraph");
vector<int64_t> shape_vec{1, 1, 1, 1};
GeShape shape = GeShape(shape_vec);
default_tensor_desc_ = std::make_shared<GeTensorDesc>();
default_tensor_desc_->SetShape(shape);
default_tensor_desc_->SetFormat(FORMAT_NCHW);
default_tensor_desc_->SetDataType(DT_FLOAT);
}

NodePtr NewNode(const std::string &name, const std::string &type,
int input_cnt, int output_cnt, bool isSubgraph = false) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
for (int i = 0; i < input_cnt; ++i) {
op_desc->AddInputDesc(default_tensor_desc_->Clone());
}

for (int i = 0; i < output_cnt; ++i) {
op_desc->AddOutputDesc(default_tensor_desc_->Clone());
}
NodePtr node = nullptr;
if (isSubgraph) {
node = sub_graph_->AddNode(op_desc);
(void)node->SetOwnerComputeGraph(sub_graph_);
} else {
node = graph_->AddNode(op_desc);
(void)node->SetOwnerComputeGraph(graph_);
}

return node;
}

void BuildDefaultGraph() {
/// input
/// \
/// sqrt1 pred
/// \ /
/// Switch
/// | |
/// F T
/// | |
/// Merge
/// |
/// relu
/// |
/// sqrt2
pred_node_ = NewNode("pred", GREATER, 2, 1);

input_node_ = NewNode("input", RELU, 0, 1);
AttrUtils::SetInt(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1);
sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1);
switch_node_ = NewNode("switch", SWITCH, 2, 2);
output_false_node_ = NewNode("false_output", RELU, 1, 1);
AttrUtils::SetInt(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1);
output_true_node_ = NewNode("true_output", RELU, 1, 1);
AttrUtils::SetInt(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1);
merge_node_ = NewNode("merge", MERGE, 2, 1);
relu_node_ = NewNode("relu", RELU, 1, 1);
sqrt_node2_ = NewNode("sqrt2", SQRT, 1, 1);
AttrUtils::SetInt(sqrt_node2_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1);

switch_node_->GetOpDesc()->SetIsInputConst({false, false});

GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0));
GraphUtils::AddEdge(sqrt_node1_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node2_->GetInDataAnchor(0));

output_false_node_->GetOpDesc()->SetIsInputConst({false});
output_true_node_->GetOpDesc()->SetIsInputConst({false});
}

void BuildDefaultGraph1() {
/// input pred
/// \ /
/// Switch
/// | |
/// ----F T----
/// \ | / \
/// \ Merge1 Merge2
/// \_________|
pred_node_ = NewNode("pred", GREATER, 2, 1);

input_node_ = NewNode("input", RELU, 0, 1);
AttrUtils::SetInt(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1);
switch_node_ = NewNode("switch", SWITCH, 2, 2);
output_false_node_ = NewNode("false_output", RELU, 1, 2);
AttrUtils::SetInt(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1);
output_true_node_ = NewNode("true_output", RELU, 1, 2);
AttrUtils::SetInt(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1);
merge_node_ = NewNode("merge", MERGE, 2, 1);
merge_node1_ = NewNode("merge1", MERGE, 2, 1);

switch_node_->GetOpDesc()->SetIsInputConst({false, false});

GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1));

output_false_node_->GetOpDesc()->SetIsInputConst({false});
output_true_node_->GetOpDesc()->SetIsInputConst({false});
}


void BuildDefaultGraph2() {
/// input pred input pred1
/// \ / \ /
/// Switch Switch1
/// | | _______|
/// | | /
/// ____F T____
/// \ | / \
/// \ Merge1 Merge2
/// \__________|
pred_node_ = NewNode("pred", GREATER, 2, 1);
pred_node1_ = NewNode("pred", LESS, 2, 1);
input_node_ = NewNode("input", RELU, 0, 2);
AttrUtils::SetInt(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1);
switch_node_ = NewNode("switch", SWITCH, 2, 2);
switch_node1_ = NewNode("switch1", SWITCH, 2, 2);
output_false_node_ = NewNode("false_output", RELU, 2, 2);
AttrUtils::SetInt(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1);
output_true_node_ = NewNode("true_output", RELU, 2, 2);
AttrUtils::SetInt(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 2);
merge_node_ = NewNode("merge", MERGE, 2, 1);
merge_node1_ = NewNode("merge1", MERGE, 2, 1);

switch_node_->GetOpDesc()->SetIsInputConst({false, false});

GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(input_node_->GetOutDataAnchor(1), switch_node1_->GetInDataAnchor(0));
GraphUtils::AddEdge(pred_node1_->GetOutDataAnchor(0), switch_node1_->GetInDataAnchor(1));
GraphUtils::AddEdge(switch_node1_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(switch_node1_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1));

output_false_node_->GetOpDesc()->SetIsInputConst({false});
output_true_node_->GetOpDesc()->SetIsInputConst({false});
}

ComputeGraphPtr graph_;
ComputeGraphPtr sub_graph_;
GeTensorDescPtr default_tensor_desc_;
ParallelGroupPass pass_;
NodePtr pred_node_;
NodePtr pred_node1_;
NodePtr sqrt_node1_;
NodePtr sqrt_node2_;
NodePtr input_node_;
NodePtr switch_node_;
NodePtr switch_node1_;
NodePtr output_false_node_;
NodePtr output_true_node_;
NodePtr merge_node_;
NodePtr merge_node1_;
NodePtr relu_node_;
};

TEST_F(UtestGraphPassesParallelGgroupPass, null_graph) {
ComputeGraphPtr graph = nullptr;
auto ret = pass_.Run(graph);
EXPECT_EQ(ret, PARAM_INVALID);
}

TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph) {
BuildDefaultGraph();
auto ret = pass_.Run(graph_);
EXPECT_EQ(ret, GRAPH_SUCCESS);
EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(switch_node_->GetInControlAnchor()));
EXPECT_EQ(true, merge_node_->GetOutControlAnchor()->IsLinkedWith(sqrt_node2_->GetInControlAnchor()));
EXPECT_EQ(false, output_false_node_->GetOutControlAnchor()->IsLinkedWith(output_true_node_->GetInControlAnchor()));
}

TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph1) {
BuildDefaultGraph1();
auto ret = pass_.Run(graph_);
EXPECT_EQ(ret, GRAPH_SUCCESS);
}

TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) {
BuildDefaultGraph2();
auto ret = pass_.Run(graph_);
EXPECT_EQ(ret, GRAPH_SUCCESS);
}

TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) {
BuildDefaultGraph1();
NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true);
NodePtr input_node2 = NewNode("input2", RELU, 0, 1, true);
NodePtr add = NewNode("add", ADD, 2, 1, true);
AttrUtils::SetInt(input_node1->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1);
AttrUtils::SetInt(input_node2->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, 1);

sub_graph_->SetParentNode(input_node_);
sub_graph_->SetParentGraph(graph_);
auto ret = graph_->AddSubgraph(sub_graph_->GetName(), sub_graph_);
EXPECT_EQ(ret, GRAPH_SUCCESS);
ret = input_node_->GetOpDesc()->AddSubgraphName(sub_graph_->GetName());
EXPECT_EQ(ret, GRAPH_SUCCESS);
ret = input_node_->GetOpDesc()->SetSubgraphInstanceName(0, sub_graph_->GetName());
EXPECT_EQ(ret, GRAPH_SUCCESS);
ret = pass_.Run(sub_graph_);
EXPECT_EQ(ret, GRAPH_SUCCESS);
ret = pass_.Run(graph_);
EXPECT_EQ(ret, GRAPH_SUCCESS);
}

} // namespace
} // namespace ge

Loading…
Cancel
Save