|
|
|
@@ -732,53 +732,27 @@ bool AscendKernelRuntime::Run(const session::KernelGraph &graph, bool is_task_si |
|
|
|
|
|
|
|
void AscendKernelRuntime::SetKernelModStream(const std::vector<CNodePtr> &kernels, |
|
|
|
std::vector<size_t> *last_stream_nodes) { |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE); |
|
|
|
std::map<void *, size_t> last_kernel; |
|
|
|
for (size_t i = 0; i < kernels.size(); ++i) { |
|
|
|
auto &node = kernels[i]; |
|
|
|
auto kernel_mod = AnfAlgo::GetKernelMod(node); |
|
|
|
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod); |
|
|
|
MS_EXCEPTION_IF_NULL(ascend_kernel_mod); |
|
|
|
if (AnfAlgo::IsCommunicationOp(node)) { |
|
|
|
auto group = AnfAlgo::GetNodeAttr<std::string>(node, kAttrGroup); |
|
|
|
auto iter = group_stream_id_map_.find(group); |
|
|
|
if (iter == group_stream_id_map_.end()) { |
|
|
|
void *stream = nullptr; |
|
|
|
auto ret = rtStreamCreate(&stream, 0); |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret; |
|
|
|
} |
|
|
|
auto id = SizeToUint(stream_id_map_.size()); |
|
|
|
group_stream_id_map_[group] = id; |
|
|
|
stream_id_map_[id] = stream; |
|
|
|
AnfAlgo::SetStreamId(id, node.get()); |
|
|
|
ascend_kernel_mod->SetStream(stream); |
|
|
|
last_kernel[stream] = i; |
|
|
|
} else { |
|
|
|
auto id = iter->second; |
|
|
|
AnfAlgo::SetStreamId(id, node.get()); |
|
|
|
ascend_kernel_mod->SetStream(stream_id_map_[id]); |
|
|
|
last_kernel[stream_id_map_[id]] = i; |
|
|
|
auto stream_id = AnfAlgo::GetStreamId(kernels[i]); |
|
|
|
auto iter = stream_id_map_.find(stream_id); |
|
|
|
if (iter == stream_id_map_.end()) { |
|
|
|
void *stream = nullptr; |
|
|
|
auto ret = rtStreamCreate(&stream, 0); |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret; |
|
|
|
} |
|
|
|
} else if (AnfAlgo::IsIndependentNode(node) && mode != kPynativeMode) { |
|
|
|
AnfAlgo::SetStreamId(1, node.get()); |
|
|
|
ascend_kernel_mod->SetStream(independent_stream_); |
|
|
|
last_kernel[independent_stream_] = i; |
|
|
|
stream_id_map_[stream_id] = stream; |
|
|
|
ascend_kernel_mod->SetStream(stream); |
|
|
|
} else { |
|
|
|
AnfAlgo::SetStreamId(0, node.get()); |
|
|
|
ascend_kernel_mod->SetStream(stream_); |
|
|
|
ascend_kernel_mod->SetStream(iter->second); |
|
|
|
} |
|
|
|
} |
|
|
|
for (size_t i = 1; i < kernels.size(); ++i) { |
|
|
|
if (AnfAlgo::GetCNodeName(kernels[i - 1]) == kAtomicAddrCleanOpName) { |
|
|
|
auto stream_id = AnfAlgo::GetStreamId(kernels[i]); |
|
|
|
AnfAlgo::SetStreamId(stream_id, kernels[i - 1].get()); |
|
|
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernels[i - 1]); |
|
|
|
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod); |
|
|
|
MS_EXCEPTION_IF_NULL(ascend_kernel_mod); |
|
|
|
ascend_kernel_mod->SetStream(stream_id_map_[stream_id]); |
|
|
|
if (stream_id > 0) { |
|
|
|
last_kernel[stream_id_map_[stream_id]] = i; |
|
|
|
} |
|
|
|
} |
|
|
|
(void)std::transform(last_kernel.begin(), last_kernel.end(), std::back_inserter(*last_stream_nodes), |
|
|
|
@@ -798,6 +772,8 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { |
|
|
|
auto &kernel_post_run_events = kernel_events.second; |
|
|
|
kernel_pre_run_events.resize(kernels.size()); |
|
|
|
kernel_post_run_events.resize(kernels.size()); |
|
|
|
auto stream_num = stream_id_map_.size(); |
|
|
|
std::vector<std::vector<bool>> kernel_hit(kernels.size(), std::vector<bool>(stream_num, false)); |
|
|
|
for (size_t i = 0; i < kernels.size(); ++i) { |
|
|
|
auto &kernel = kernels[i]; |
|
|
|
auto curr_stream_id = AnfAlgo::GetStreamId(kernel); |
|
|
|
@@ -805,7 +781,6 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { |
|
|
|
MS_LOG(EXCEPTION) << "Stream " << curr_stream_id << "has not been created"; |
|
|
|
} |
|
|
|
auto wait_stream = stream_id_map_[curr_stream_id]; |
|
|
|
auto stream_num = stream_id_map_.size(); |
|
|
|
std::vector<bool> stream_hit(stream_num, false); |
|
|
|
std::vector<AnfNodePtr> used_kernels; |
|
|
|
std::set<AnfNodePtr> visited_kernels; |
|
|
|
@@ -819,8 +794,9 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
for (auto &visited : used_kernels) { |
|
|
|
if (visited == pre_cnode && !stream_hit[pre_cnode_stream_id]) { |
|
|
|
if (visited == pre_cnode && !stream_hit[pre_cnode_stream_id] && !kernel_hit[IntToSize(k)][curr_stream_id]) { |
|
|
|
stream_hit[pre_cnode_stream_id] = true; |
|
|
|
kernel_hit[IntToSize(k)][curr_stream_id] = true; |
|
|
|
found_depend = true; |
|
|
|
auto record_stream = stream_id_map_[pre_cnode_stream_id]; |
|
|
|
auto event = CreateDeviceEvent(); |
|
|
|
@@ -1045,11 +1021,10 @@ bool AscendKernelRuntime::InitDevice() { |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret; |
|
|
|
} |
|
|
|
const int kCommunicationStreamID = 2; |
|
|
|
stream_id_map_[0] = stream_; |
|
|
|
stream_id_map_[1] = independent_stream_; |
|
|
|
stream_id_map_[kCommunicationStreamID] = communication_stream_; |
|
|
|
group_stream_id_map_[kHcclWorldGroup] = kCommunicationStreamID; |
|
|
|
|
|
|
|
stream_id_map_[kDefaultStreamIndex] = stream_; |
|
|
|
stream_id_map_[kIndependentStreamIndex] = independent_stream_; |
|
|
|
stream_id_map_[kWorldGroupStreamIndex] = communication_stream_; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
|