Browse Source

!11954 modify hccl op number per-stream

From: @zhoufeng54
Reviewed-by: @lilongfei15,@xsmq
Signed-off-by: @xsmq
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
9382fe64d8
1 changed files with 45 additions and 41 deletions
  1. +45
    -41
      mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc

+ 45
- 41
mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc View File

@@ -62,9 +62,31 @@ bool HasRefNodes(const vector<CNodePtr> &moved_backward_cnodes) {
}
return false;
}

StreamActiveKind GetStreamKind(uint32_t cur_stream_id, uint32_t pre_stream_id, uint32_t next_stream_id) {
// pre_stream_id equal to UINT32_MAX means no node active current StreamActive
// next_stream_id equal to UINT32_MAX means current StreamActive active no node
if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) {
return kInvalid;
}

if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) {
return kMiddle;
}

if (cur_stream_id == pre_stream_id) {
return kTail;
}

if (cur_stream_id == next_stream_id) {
return kHead;
}

return kInvalid;
}
} // namespace

const uint32_t kHcomMaxTask = 5;
const uint32_t kHcomMaxTask = 4;
const uint32_t kCommonMaxTask = 350;

void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
@@ -899,7 +921,7 @@ void AscendStreamAssign::InsertStreamActiveForIndependent(const NotNull<KernelGr
std::vector<CNodePtr> update_cnode_list;
auto exe_orders = graph_ptr->execution_order();

// first independent is been actived, active other independent stream
// first independent is been activated, active other independent stream
std::vector<uint32_t> streams;
std::copy(independent_streams.begin(), independent_streams.end(), std::back_inserter(streams));
std::sort(streams.begin(), streams.end());
@@ -999,10 +1021,10 @@ void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph
auto kind = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrStreamSwitchKind);
if (kind == kIndependentStreamSwitch) {
bool independent_empty = independent_stream_map_.empty();
// if indepdent empty: delete independent streamswitch
// if independent empty: delete independent streamswitch
if (!independent_empty) {
for (const auto &item : independent_stream_map_) {
// first independetn stream id is minimum and order by std map;
// first independent stream id is minimum and order by std map;
auto first_independent_stream = item.first;
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(first_independent_stream), switch_ptr);
orders->emplace_back(switch_ptr);
@@ -1028,7 +1050,7 @@ void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph
return;
}
auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrTrueBranchStream);
MS_LOG(INFO) << "Swtich stream id:" << AnfAlgo::GetStreamId(switch_ptr) << "; active stream id:" << true_stream_id;
MS_LOG(INFO) << "Switch stream id:" << AnfAlgo::GetStreamId(switch_ptr) << "; active stream id:" << true_stream_id;
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
AnfAlgo::SetStreamId(true_stream_id, active_ptr.get());
vector<uint32_t> active_ids;
@@ -1328,7 +1350,7 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
}
for (const auto &group : groups) {
auto cnode_ptr_list = graph_ptr->execution_order();
std::vector<std::pair<uint32_t, vector<size_t>>> stream_indexs;
std::vector<std::pair<uint32_t, vector<size_t>>> stream_indices;
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
auto cur_cnode = cnode_ptr_list[i];
if (!IsHcom(cur_cnode)) {
@@ -1346,11 +1368,11 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
continue;
}

if (stream_indexs.empty()) {
stream_indexs.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
if (stream_indices.empty()) {
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
} else {
bool exit = false;
for (auto &item : stream_indexs) {
for (auto &item : stream_indices) {
if (item.first == cur_stream_id) {
item.second.emplace_back(i);
exit = true;
@@ -1358,17 +1380,17 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
}
}
if (!exit) {
stream_indexs.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
}
}
}

if (stream_indexs.size() < 2) {
if (stream_indices.size() < 2) {
MS_LOG(INFO) << "Group:" << group
<< "; different stream hcom size is less than 2, no need insert event between them";
continue;
}
InsertEventBetweenHcom(graph_ptr, stream_indexs);
InsertEventBetweenHcom(graph_ptr, stream_indices);
}
}

@@ -1474,7 +1496,7 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG

auto target = FindTargetOp(it, cnodes.end(), *(it - 1), false);
if (target == cnodes.end()) {
MS_LOG(DEBUG) << "Independ node[" << (*(it - 1))->fullname_with_scope()
MS_LOG(DEBUG) << "Independent node[" << (*(it - 1))->fullname_with_scope()
<< "] can't find target for insert recv op, no insert send/recv";
it = cnodes.erase(it);
continue;
@@ -1558,16 +1580,16 @@ uint32_t AscendStreamAssign::GetMaxIndexTarget(const NotNull<KernelGraphPtr> &gr
return UINT32_MAX;
}

std::set<uint32_t> indexs;
std::set<uint32_t> indices;
for (const auto &key : independent_targets_) {
auto index = GetIndexByKey(graph_ptr, key);
if (index == UINT32_MAX) {
MS_LOG(EXCEPTION) << "graph has no correspond key";
}
indexs.emplace(index);
indices.emplace(index);
}

return *(std::max_element(indexs.begin(), indexs.end()));
return *(std::max_element(indices.begin(), indices.end()));
}

uint32_t AscendStreamAssign::GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr) {
@@ -1623,7 +1645,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
CNodePtr cur_cnode_ptr = nullptr;
auto cnode_ptr_list = graph_ptr->execution_order();

// 1)stream witch kStreamNeedActivedFirst attr should be actived;
// 1)stream witch kStreamNeedActivedFirst attr should be activated;
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
@@ -1634,7 +1656,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
auto need_active = AnfAlgo::GetNodeAttr<bool>(cur_cnode_ptr, kStreamNeedActivedFirst);
if (need_active) {
auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
MS_LOG(INFO) << "Stream id:" << stream_id << " is need actived at first";
MS_LOG(INFO) << "Stream id:" << stream_id << " is need activated at first";
need_first_active_streams_.push_back(stream_id);
}
}
@@ -1659,7 +1681,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
}
}

// 4)first stream 0 should be actived first;
// 4)first stream 0 should be activated first;
auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), 0);
if (it == need_first_active_streams_.end()) {
need_first_active_streams_.emplace_back(0);
@@ -2025,7 +2047,7 @@ void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull<KernelGraph

for (const auto &item : active_list) {
if (item <= active_current_node) {
MS_LOG(WARNING) << "Actived stream is less than activing stream";
MS_LOG(WARNING) << "Activated stream is less than activing stream";
continue;
}
auto it =
@@ -2054,7 +2076,7 @@ void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull<KernelGraph
} else {
for (const auto &stream : active_list) {
if (stream <= cur_stream_id) {
MS_LOG(WARNING) << "Actived stream is less than activing stream";
MS_LOG(WARNING) << "Activated stream is less than activing stream";
continue;
}
auto iter = std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), stream);
@@ -2131,25 +2153,7 @@ StreamActiveKind AscendStreamAssign::GetStreamActiveKind(const NotNull<KernelGra
break;
}

// pre_stream_id = UINT32_MAX:means no node active current StreamActive
// next_stream_id = UINT32_MAX:means current StreamActive active no node
if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) {
return kInvalid;
}

if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) {
return kMiddle;
}

if (cur_stream_id == pre_stream_id) {
return kTail;
}

if (cur_stream_id == next_stream_id) {
return kHead;
}

return kInvalid;
return GetStreamKind(cur_stream_id, pre_stream_id, next_stream_id);
}

uint32_t AscendStreamAssign::GetStreamByActivedStream(uint32_t actived_stream_id) {
@@ -2172,7 +2176,7 @@ void AscendStreamAssign::PrintStreamRelations() {
for (const auto &item : stream_relations_) {
MS_LOG(INFO) << "Stream:" << item.first;
for (const auto &stream : item.second) {
MS_LOG(INFO) << "--actived stream id:" << stream;
MS_LOG(INFO) << "--activated stream id:" << stream;
}
}
}


Loading…
Cancel
Save