Browse Source

!25070 fix single mode bug

Merge pull request !25070 from baihuawei/single_mode_bug
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
ee70517a5b
5 changed files with 71 additions and 47 deletions
  1. +20
    -45
      mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc
  2. +0
    -1
      mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h
  3. +46
    -1
      mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc
  4. +2
    -0
      mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h
  5. +3
    -0
      mindspore/ccsrc/utils/utils.h

+ 20
- 45
mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc View File

@@ -732,53 +732,27 @@ bool AscendKernelRuntime::Run(const session::KernelGraph &graph, bool is_task_si

void AscendKernelRuntime::SetKernelModStream(const std::vector<CNodePtr> &kernels,
std::vector<size_t> *last_stream_nodes) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
std::map<void *, size_t> last_kernel;
for (size_t i = 0; i < kernels.size(); ++i) {
auto &node = kernels[i];
auto kernel_mod = AnfAlgo::GetKernelMod(node);
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
if (AnfAlgo::IsCommunicationOp(node)) {
auto group = AnfAlgo::GetNodeAttr<std::string>(node, kAttrGroup);
auto iter = group_stream_id_map_.find(group);
if (iter == group_stream_id_map_.end()) {
void *stream = nullptr;
auto ret = rtStreamCreate(&stream, 0);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
}
auto id = SizeToUint(stream_id_map_.size());
group_stream_id_map_[group] = id;
stream_id_map_[id] = stream;
AnfAlgo::SetStreamId(id, node.get());
ascend_kernel_mod->SetStream(stream);
last_kernel[stream] = i;
} else {
auto id = iter->second;
AnfAlgo::SetStreamId(id, node.get());
ascend_kernel_mod->SetStream(stream_id_map_[id]);
last_kernel[stream_id_map_[id]] = i;
auto stream_id = AnfAlgo::GetStreamId(kernels[i]);
auto iter = stream_id_map_.find(stream_id);
if (iter == stream_id_map_.end()) {
void *stream = nullptr;
auto ret = rtStreamCreate(&stream, 0);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
}
} else if (AnfAlgo::IsIndependentNode(node) && mode != kPynativeMode) {
AnfAlgo::SetStreamId(1, node.get());
ascend_kernel_mod->SetStream(independent_stream_);
last_kernel[independent_stream_] = i;
stream_id_map_[stream_id] = stream;
ascend_kernel_mod->SetStream(stream);
} else {
AnfAlgo::SetStreamId(0, node.get());
ascend_kernel_mod->SetStream(stream_);
ascend_kernel_mod->SetStream(iter->second);
}
}
for (size_t i = 1; i < kernels.size(); ++i) {
if (AnfAlgo::GetCNodeName(kernels[i - 1]) == kAtomicAddrCleanOpName) {
auto stream_id = AnfAlgo::GetStreamId(kernels[i]);
AnfAlgo::SetStreamId(stream_id, kernels[i - 1].get());
auto kernel_mod = AnfAlgo::GetKernelMod(kernels[i - 1]);
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
ascend_kernel_mod->SetStream(stream_id_map_[stream_id]);
if (stream_id > 0) {
last_kernel[stream_id_map_[stream_id]] = i;
}
}
(void)std::transform(last_kernel.begin(), last_kernel.end(), std::back_inserter(*last_stream_nodes),
@@ -798,6 +772,8 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
auto &kernel_post_run_events = kernel_events.second;
kernel_pre_run_events.resize(kernels.size());
kernel_post_run_events.resize(kernels.size());
auto stream_num = stream_id_map_.size();
std::vector<std::vector<bool>> kernel_hit(kernels.size(), std::vector<bool>(stream_num, false));
for (size_t i = 0; i < kernels.size(); ++i) {
auto &kernel = kernels[i];
auto curr_stream_id = AnfAlgo::GetStreamId(kernel);
@@ -805,7 +781,6 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
MS_LOG(EXCEPTION) << "Stream " << curr_stream_id << "has not been created";
}
auto wait_stream = stream_id_map_[curr_stream_id];
auto stream_num = stream_id_map_.size();
std::vector<bool> stream_hit(stream_num, false);
std::vector<AnfNodePtr> used_kernels;
std::set<AnfNodePtr> visited_kernels;
@@ -819,8 +794,9 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
continue;
}
for (auto &visited : used_kernels) {
if (visited == pre_cnode && !stream_hit[pre_cnode_stream_id]) {
if (visited == pre_cnode && !stream_hit[pre_cnode_stream_id] && !kernel_hit[IntToSize(k)][curr_stream_id]) {
stream_hit[pre_cnode_stream_id] = true;
kernel_hit[IntToSize(k)][curr_stream_id] = true;
found_depend = true;
auto record_stream = stream_id_map_[pre_cnode_stream_id];
auto event = CreateDeviceEvent();
@@ -1045,11 +1021,10 @@ bool AscendKernelRuntime::InitDevice() {
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
}
const int kCommunicationStreamID = 2;
stream_id_map_[0] = stream_;
stream_id_map_[1] = independent_stream_;
stream_id_map_[kCommunicationStreamID] = communication_stream_;
group_stream_id_map_[kHcclWorldGroup] = kCommunicationStreamID;

stream_id_map_[kDefaultStreamIndex] = stream_;
stream_id_map_[kIndependentStreamIndex] = independent_stream_;
stream_id_map_[kWorldGroupStreamIndex] = communication_stream_;
return true;
}



+ 0
- 1
mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h View File

@@ -119,7 +119,6 @@ class AscendKernelRuntime : public KernelRuntime {
static std::map<std::string, uint32_t> overflow_tasks_;
static std::vector<rtExceptionInfo> task_fail_infoes_;
std::map<uint32_t, void *> stream_id_map_;
std::map<std::string, uint32_t> group_stream_id_map_;
};

MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime);


+ 46
- 1
mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc View File

@@ -205,6 +205,7 @@ StreamActiveKind GetStreamKind(uint32_t cur_stream_id, uint32_t pre_stream_id, u

return kInvalid;
}

void SetNodeStreamIDAttr(const NotNull<KernelGraphPtr> &graph_ptr) {
auto exec_orders = graph_ptr->execution_order();
for (auto node : exec_orders) {
@@ -213,8 +214,51 @@ void SetNodeStreamIDAttr(const NotNull<KernelGraphPtr> &graph_ptr) {
}
} // namespace

void AscendStreamAssign::AssignStreamForNonTaskSink(const NotNull<KernelGraphPtr> &graph_ptr) {
auto &kernels = graph_ptr->execution_order();
if (kernels.empty()) {
return;
}
if (stream_groups_.empty()) {
stream_groups_.emplace_back(std::vector<uint32_t>{kDefaultStreamIndex});
stream_groups_.emplace_back(std::vector<uint32_t>{kIndependentStreamIndex});
stream_groups_.emplace_back(std::vector<uint32_t>{kWorldGroupStreamIndex});
}
group_stream_id_map_[kHcclWorldGroup] = kWorldGroupStreamIndex;
for (size_t i = 0; i < kernels.size(); ++i) {
auto &node = kernels[i];
if (AnfAlgo::IsCommunicationOp(node)) {
auto group = AnfAlgo::GetNodeAttr<std::string>(node, kAttrGroup);
auto iter = group_stream_id_map_.find(group);
if (iter == group_stream_id_map_.end()) {
auto id = SizeToUint(group_stream_id_map_.size()) + kWorldGroupStreamIndex;
group_stream_id_map_[group] = id;
AnfAlgo::SetStreamId(id, node.get());
stream_groups_.emplace_back(std::vector<uint32_t>{id});
} else {
auto id = iter->second;
AnfAlgo::SetStreamId(id, node.get());
}
} else if (AnfAlgo::IsIndependentNode(node)) {
AnfAlgo::SetStreamId(kIndependentStreamIndex, node.get());
} else {
AnfAlgo::SetStreamId(kDefaultStreamIndex, node.get());
}
}
for (size_t i = 1; i < kernels.size(); ++i) {
if (AnfAlgo::GetCNodeName(kernels[i - 1]) == kAtomicAddrCleanOpName) {
auto stream_id = AnfAlgo::GetStreamId(kernels[i]);
AnfAlgo::SetStreamId(stream_id, kernels[i - 1].get());
}
}
}

void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) {
if (!IsTaskSink()) {
AssignStreamForNonTaskSink(graph_ptr);
return;
}
if (!graph_ptr->is_dynamic_shape()) {
MS_LOG(INFO) << "Communication parallel mode: " << parallel::ParallelContext::GetInstance()->communi_parallel_mode()
<< ".";

@@ -1009,6 +1053,7 @@ void AscendStreamAssign::ActiveRootGraphIndependent(const NotNull<KernelGraphPtr
independent_stream_activated_ = true;
graph_ptr->set_execution_order(update_cnode_list);
}

void AscendStreamAssign::InsertStreamActiveForCommon(const NotNull<KernelGraphPtr> &graph_ptr) {
MS_LOG(INFO) << "Start";
GetProcessedStream(graph_ptr);


+ 2
- 0
mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h View File

@@ -68,6 +68,7 @@ class AscendStreamAssign {
void Reset();
CNodePtr CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
CNodePtr CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
void AssignStreamForNonTaskSink(const NotNull<KernelGraphPtr> &graph_ptr);
void CheckResourceAssign(const NotNull<KernelGraphPtr> &graph_ptr);
void CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr);
void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr);
@@ -164,6 +165,7 @@ class AscendStreamAssign {
std::set<uint32_t> processed_streams_{};
std::vector<uint32_t> need_first_active_streams_{};
std::set<CNodeKey> independent_targets_;
std::map<std::string, uint32_t> group_stream_id_map_;

// key:group name, value:key1:graph id, value1:stream id
std::map<std::string, std::map<uint32_t, std::set<uint32_t>>> group_hcom_graph_map_;


+ 3
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -505,6 +505,9 @@ const int kValueNodeTensorMask = 2;
constexpr auto kNCHWShapeSize = 4;

// define special index in special node
constexpr auto kDefaultStreamIndex = 0;
constexpr auto kIndependentStreamIndex = 1;
constexpr auto kWorldGroupStreamIndex = 2;
constexpr auto kAnfPrimitiveIndex = 0;
constexpr auto kFirstDataInputIndex = 1;
constexpr auto kRealInputNodeIndexInTupleGetItem = 1;


Loading…
Cancel
Save