diff --git a/mindspore/ccsrc/backend/optimizer/somas/somas.cc b/mindspore/ccsrc/backend/optimizer/somas/somas.cc index 3dc798b618..3014d51d6c 100644 --- a/mindspore/ccsrc/backend/optimizer/somas/somas.cc +++ b/mindspore/ccsrc/backend/optimizer/somas/somas.cc @@ -86,6 +86,7 @@ bool Somas::InitSomasTensors(const session::KernelGraph *graph) { IndependentNodeOutputProcess(graph); SummaryInputProcess(graph); RefNodeProcess(graph); + NonTaskSplitProcess(graph); UnReuseNodeProcess(graph); GenContiguousList(graph); GetNextOutputProcess(graph); @@ -535,6 +536,27 @@ void Somas::RefNodeProcess(const session::KernelGraph *graph) { MS_LOG(INFO) << "Special Tensor total size: RefNode: input " << total_input_size << " output " << total_output_size; } +void Somas::NonTaskSplitProcess(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto kernel_cnodes = graph->execution_order(); + for (const auto &kernel : kernel_cnodes) { + auto op_name = AnfAlgo::GetCNodeName(kernel); + if (op_name == kSplitOpName && AnfAlgo::HasNodeAttr(kAttrNonTask, kernel)) { + std::vector refnode_input_output; + auto node = nodes_map_[kernel.get()]; + auto input_tensor = node->input_tensors_[0]; + input_tensor->type_ = kRefNodeInput; + refnode_input_output.push_back(input_tensor->GetId()); + + for (auto &output_tensor : node->output_tensors_) { + output_tensor->type_ = kRefNodeOutput; + refnode_input_output.push_back(output_tensor->GetId()); + } + ref_node_constraints_.push_back(refnode_input_output); + } + } +} + void Somas::UnReuseNodeProcess(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); vector full_name_list = {}; diff --git a/mindspore/ccsrc/backend/optimizer/somas/somas.h b/mindspore/ccsrc/backend/optimizer/somas/somas.h index 74a8a4e980..3bae53cd2a 100644 --- a/mindspore/ccsrc/backend/optimizer/somas/somas.h +++ b/mindspore/ccsrc/backend/optimizer/somas/somas.h @@ -114,6 +114,7 @@ class Somas { void IndependentNodeOutputProcess(const session::KernelGraph *graph); void SummaryInputProcess(const session::KernelGraph *graph); void RefNodeProcess(const session::KernelGraph *graph); + void NonTaskSplitProcess(const session::KernelGraph *graph); void UnReuseNodeProcess(const session::KernelGraph *graph); SomasTensorPtr CreateGapTensor(size_t gap_tensor_id); void GenContiguousList(const session::KernelGraph *graph); diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc index 745397a76d..6e33e432a1 100644 --- a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc @@ -136,7 +136,12 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i MS_EXCEPTION_IF_NULL(kernel_mod); kernel_mod->set_kernel_name(anf_node_ptr->fullname_with_scope()); auto op_name = AnfAlgo::GetCNodeName(anf_node_ptr); - if (AnfAlgo::GetCNodeName(anf_node_ptr) != kAtomicAddrCleanOpName) { + if (op_name == kSplitOpName && AnfAlgo::HasNodeAttr(kAttrNonTask, anf_node_ptr)) { + MS_LOG(INFO) << "Skip task generation for NnTask op " << anf_node_ptr->fullname_with_scope(); + return true; + } + + if (op_name != kAtomicAddrCleanOpName) { for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node_ptr); ++i) { if (op_name == kDynamicRNNOpName && i == 3) { continue; @@ -153,6 +158,21 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i AddressPtr input = std::make_shared
(); input->addr = device_address->ptr_; input->size = device_address->size_; + + auto prenode_with_index = AnfAlgo::GetPrevNodeOutput(anf_node_ptr, i); + if (AnfAlgo::GetCNodeName(prenode_with_index.first) == kSplitOpName && + AnfAlgo::HasNodeAttr(kAttrNonTask, prenode_with_index.first->cast())) { + // use memory offset to implement NonTask Type Split op + // when op A -> split(NonTask) -> op B, op B's input addr is split's input0's addr + offset + // offset is split's output index * split's output size + auto split_input0_device_address = AnfAlgo::GetPrevNodeOutputAddr(prenode_with_index.first, 0); + input->addr = + static_cast(split_input0_device_address->ptr_) + (prenode_with_index.second * input->size); + MS_LOG(INFO) << "Change " << anf_node_ptr->fullname_with_scope() << "'s input " << i << " address to " + << split_input0_device_address->ptr_ << " + " + << "prenode_with_index.second * input->size"; + } + kernel_inputs.push_back(input); }