Browse Source

Merge remote-tracking branch 'refs/remotes/origin/master'

pull/1146/head^2
陈华 4 years ago
parent
commit
ff5df9a358
7 changed files with 46 additions and 23 deletions
  1. +8
    -7
      build.sh
  2. +1
    -1
      ge/CMakeLists.txt
  3. +18
    -13
      ge/graph/passes/parallel_group_pass.cc
  4. +1
    -1
      metadef
  5. +1
    -1
      parser
  6. +1
    -0
      tests/ut/ge/CMakeLists.txt
  7. +16
    -0
      tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc

+ 8
- 7
build.sh View File

@@ -76,8 +76,8 @@ checkopts()
ENABLE_GE_ST="on" ENABLE_GE_ST="on"
;; ;;
t) t)
ENABLE_GE_UT="on"
;;
ENABLE_GE_UT="on"
;;
c) c)
ENABLE_GE_COV="on" ENABLE_GE_COV="on"
;; ;;
@@ -214,13 +214,14 @@ if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then
cp ${BUILD_PATH}/tests/ut/ge/ut_libge_others_utest ${OUTPUT_PATH} cp ${BUILD_PATH}/tests/ut/ge/ut_libge_others_utest ${OUTPUT_PATH}
cp ${BUILD_PATH}/tests/ut/ge/ut_libge_kernel_utest ${OUTPUT_PATH} cp ${BUILD_PATH}/tests/ut/ge/ut_libge_kernel_utest ${OUTPUT_PATH}


${OUTPUT_PATH}/ut_libgraph &&
${OUTPUT_PATH}/ut_libge_multiparts_utest &&
${OUTPUT_PATH}/ut_libge_distinct_load_utest &&
${OUTPUT_PATH}/ut_libge_others_utest &&
${OUTPUT_PATH}/ut_libge_kernel_utest
RUN_TEST_CASE=${OUTPUT_PATH}/ut_libgraph && ${RUN_TEST_CASE} &&
RUN_TEST_CASE=${OUTPUT_PATH}/ut_libge_multiparts_utest && ${RUN_TEST_CASE} &&
RUN_TEST_CASE=${OUTPUT_PATH}/ut_libge_distinct_load_utest && ${RUN_TEST_CASE} &&
RUN_TEST_CASE=${OUTPUT_PATH}/ut_libge_others_utest && ${RUN_TEST_CASE} &&
RUN_TEST_CASE=${OUTPUT_PATH}/ut_libge_kernel_utest && ${RUN_TEST_CASE}
if [[ "$?" -ne 0 ]]; then if [[ "$?" -ne 0 ]]; then
echo "!!! UT FAILED, PLEASE CHECK YOUR CHANGES !!!" echo "!!! UT FAILED, PLEASE CHECK YOUR CHANGES !!!"
echo -e "\033[31m${RUN_TEST_CASE}\033[0m"
exit 1; exit 1;
fi fi
echo "Generating coverage statistics, please wait..." echo "Generating coverage statistics, please wait..."


+ 1
- 1
ge/CMakeLists.txt View File

@@ -607,8 +607,8 @@ set(INFER_SRC_LIST
"graph/passes/hccl_group_pass.cc" "graph/passes/hccl_group_pass.cc"
"graph/passes/memcpy_addr_async_pass.cc" "graph/passes/memcpy_addr_async_pass.cc"
"graph/passes/set_input_output_offset_pass.cc" "graph/passes/set_input_output_offset_pass.cc"
"graph/passes/parallel_group_pass.cc"
"graph/manager/model_manager/event_manager.cc" "graph/manager/model_manager/event_manager.cc"
"graph/passes/parallel_group_pass.cc"
"graph/manager/util/rt_context_util.cc" "graph/manager/util/rt_context_util.cc"
"graph/manager/util/variable_accelerate_ctrl.cc" "graph/manager/util/variable_accelerate_ctrl.cc"
"graph/manager/util/debug.cc" "graph/manager/util/debug.cc"


+ 18
- 13
ge/graph/passes/parallel_group_pass.cc View File

@@ -83,7 +83,7 @@ Status ParallelGroupPass::ProcessAllGraph(ComputeGraphPtr graph, std::unordered_
if (!is_unknown_shape) { if (!is_unknown_shape) {
group_node[group_id].push_back(node); group_node[group_id].push_back(node);
parallel_group.insert(group_id); parallel_group.insert(group_id);
GELOGI("Find hccl node:%s, group_id=%d", op_desc->GetName().c_str(), group_id);
GELOGD("Find group node:%s, group_id=%d", node->GetName().c_str(), group_id);
} }
} }


@@ -116,7 +116,8 @@ Status ParallelGroupPass::ProcessAllGraph(ComputeGraphPtr graph, std::unordered_
cur_node = node_vec[i]; cur_node = node_vec[i];
auto tmp_pre_node = pre_node; auto tmp_pre_node = pre_node;
auto tmp_cur_node = cur_node; auto tmp_cur_node = cur_node;
GELOGI("original we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
GELOGD("original add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(),
cur_node->GetName().c_str());
ReplaceSwitchAndMerge(tmp_pre_node, tmp_cur_node, node_2_switch_merge); ReplaceSwitchAndMerge(tmp_pre_node, tmp_cur_node, node_2_switch_merge);
pre_node = cur_node; pre_node = cur_node;
} }
@@ -127,13 +128,13 @@ Status ParallelGroupPass::ProcessAllGraph(ComputeGraphPtr graph, std::unordered_


void ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) { void ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) {
if (pre_node == cur_node) { if (pre_node == cur_node) {
GELOGI("--- pr_node == cur_node");
return; return;
} }
const auto &in_node = cur_node->GetInAllNodes(); const auto &in_node = cur_node->GetInAllNodes();
for (const auto &node : in_node) { for (const auto &node : in_node) {
if (pre_node == node) { if (pre_node == node) {
GELOGI("--- pr_node and cur_node have linked");
GELOGD("node:%s and node:%s has linked", pre_node->GetName().c_str(),
cur_node->GetName().c_str());
return; return;
} }
} }
@@ -211,7 +212,7 @@ Status ParallelGroupPass::ProcessSwitch(ComputeGraphPtr graph,
auto &tmp = it->second; auto &tmp = it->second;
auto &switch_vec = tmp.first; auto &switch_vec = tmp.first;
const auto &merge_node = tmp.second; const auto &merge_node = tmp.second;
GELOGI(" --- hccl node: %s, switch node %s, merge node :%s.",
GELOGD("Find group node: %s in switch node %s and merge node :%s.",
group_node->GetName().c_str(), node->GetName().c_str(), merge_node->GetName().c_str()); group_node->GetName().c_str(), node->GetName().c_str(), merge_node->GetName().c_str());
if (merge_node != merge_vec.back()) { if (merge_node != merge_vec.back()) {
GELOGE(GRAPH_FAILED, "error: has two merge node: %s and %s.", GELOGE(GRAPH_FAILED, "error: has two merge node: %s and %s.",
@@ -263,15 +264,15 @@ void ParallelGroupPass::ReplaceSwitchAndMerge(NodePtr &pre_node,
pre_node = pre_itr->second.second; pre_node = pre_itr->second.second;
for (const auto &switch_node : cur_itr->second.first) { for (const auto &switch_node : cur_itr->second.first) {
AddCtrlEdge(pre_node, switch_node); AddCtrlEdge(pre_node, switch_node);
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), switch_node->GetName().c_str());
GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(),
switch_node->GetName().c_str());
} }
} else {
GELOGI("--- no need add ctrl edge");
} }
} else { } else {
pre_node = pre_itr->second.second; pre_node = pre_itr->second.second;
AddCtrlEdge(pre_node, cur_node); AddCtrlEdge(pre_node, cur_node);
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(),
cur_node->GetName().c_str());
} }
} else { } else {
if (cur_itr != node_2_switch_merge.end()) { if (cur_itr != node_2_switch_merge.end()) {
@@ -281,20 +282,24 @@ void ParallelGroupPass::ReplaceSwitchAndMerge(NodePtr &pre_node,
if (pre_id > switch_id) { // special handle for merge and group node if (pre_id > switch_id) { // special handle for merge and group node
auto merge_node = cur_itr->second.second; auto merge_node = cur_itr->second.second;
AddCtrlEdge(merge_node, pre_node); AddCtrlEdge(merge_node, pre_node);
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", merge_node->GetName().c_str(), pre_node->GetName().c_str());
GELOGD("finally add ctrl anchor for node:%s-->node:%s", merge_node->GetName().c_str(),
pre_node->GetName().c_str());
} else { } else {
AddCtrlEdge(pre_node, switch_node); AddCtrlEdge(pre_node, switch_node);
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), switch_node->GetName().c_str());
GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(),
switch_node->GetName().c_str());
} }
} }
} else { } else {
AddCtrlEdge(pre_node, cur_node); AddCtrlEdge(pre_node, cur_node);
GELOGI("changed we should add ctrl anchor for node1:%s------>node2:%s", pre_node->GetName().c_str(), cur_node->GetName().c_str());
GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(),
cur_node->GetName().c_str());
} }
} }
} }


bool ParallelGroupPass::HasSameSwitch(const std::set<NodePtr> &switch_set1, const std::set<NodePtr> &switch_set2) {
bool ParallelGroupPass::HasSameSwitch(const std::set<NodePtr> &switch_set1,
const std::set<NodePtr> &switch_set2) {
for (const auto &node1 : switch_set1) { for (const auto &node1 : switch_set1) {
for (const auto &node2 : switch_set2) { for (const auto &node2 : switch_set2) {
if (node1 == node2) { if (node1 == node2) {


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit a2b80cb22a62a6757c7dd31e684ca632e0b79268
Subproject commit b6de68fdf0f131fd5f8aa3a84245ad7779b348f5

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit cfabf622b803d5957563a73652a0ce5086aab99d
Subproject commit 7a6311351f8294eb11033b10e9f7b2b993cc3c2a

+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -522,6 +522,7 @@ set(GRAPH_PASS_COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/hccl_memcpy_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/hccl_memcpy_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/infershape_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/parallel_group_pass.cc"
"${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" "${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc"
"${GE_CODE_DIR}/ge/analyzer/analyzer.cc" "${GE_CODE_DIR}/ge/analyzer/analyzer.cc"
"${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc"


+ 16
- 0
tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc View File

@@ -170,6 +170,22 @@ class UtestLogicalStreamAllocator : public testing::Test {
return CreateSubgraphWithName("graph", engine, stream_label, in_num, out_num); return CreateSubgraphWithName("graph", engine, stream_label, in_num, out_num);
} }


SubGraphInfoPtr CreateParallelGroupSubgraphWithName(const string &name, const string &engine,
const string &stream_label = "",
int group_id = 1) {
ComputeGraphPtr compute_graph = make_shared<ComputeGraph>(name);
OpDescPtr op_desc = std::make_shared<OpDesc>("relu", "Relu");
op_desc->AddInputDesc(GeTensorDesc());
op_desc->AddOutputDesc(GeTensorDesc());
AttrUtils::SetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id);
compute_graph->AddNode(op_desc);

SubGraphInfoPtr subgraph = BuildSubGraph(compute_graph, engine, stream_label);
AddPlaceHolderAndEnd(subgraph, 1, 1);

return subgraph;
}

void LinkSubGraph(SubGraphInfoPtr subgraph1, const string &end_name, SubGraphInfoPtr subgraph2, void LinkSubGraph(SubGraphInfoPtr subgraph1, const string &end_name, SubGraphInfoPtr subgraph2,
const string &placeholder_name) { const string &placeholder_name) {
NodePtr end_node = subgraph1->GetSubGraph()->FindNode(end_name); NodePtr end_node = subgraph1->GetSubGraph()->FindNode(end_name);


Loading…
Cancel
Save