|
|
|
@@ -136,7 +136,12 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_mod); |
|
|
|
kernel_mod->set_kernel_name(anf_node_ptr->fullname_with_scope()); |
|
|
|
auto op_name = AnfAlgo::GetCNodeName(anf_node_ptr); |
|
|
|
if (AnfAlgo::GetCNodeName(anf_node_ptr) != kAtomicAddrCleanOpName) { |
|
|
|
if (op_name == kSplitOpName && AnfAlgo::HasNodeAttr(kAttrNonTask, anf_node_ptr)) { |
|
|
|
MS_LOG(INFO) << "Skip task generation for NnTask op " << anf_node_ptr->fullname_with_scope(); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
if (op_name != kAtomicAddrCleanOpName) { |
|
|
|
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node_ptr); ++i) { |
|
|
|
if (op_name == kDynamicRNNOpName && i == 3) { |
|
|
|
continue; |
|
|
|
@@ -153,6 +158,21 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i |
|
|
|
AddressPtr input = std::make_shared<Address>(); |
|
|
|
input->addr = device_address->ptr_; |
|
|
|
input->size = device_address->size_; |
|
|
|
|
|
|
|
auto prenode_with_index = AnfAlgo::GetPrevNodeOutput(anf_node_ptr, i); |
|
|
|
if (AnfAlgo::GetCNodeName(prenode_with_index.first) == kSplitOpName && |
|
|
|
AnfAlgo::HasNodeAttr(kAttrNonTask, prenode_with_index.first->cast<CNodePtr>())) { |
|
|
|
// use memory offset to implement NonTask Type Split op |
|
|
|
// when op A -> split(NonTask) -> op B, op B's input addr is split's input0's addr + offset |
|
|
|
// offset is split's output index * split's output size |
|
|
|
auto split_input0_device_address = AnfAlgo::GetPrevNodeOutputAddr(prenode_with_index.first, 0); |
|
|
|
input->addr = |
|
|
|
static_cast<uint8_t *>(split_input0_device_address->ptr_) + (prenode_with_index.second * input->size); |
|
|
|
MS_LOG(INFO) << "Change " << anf_node_ptr->fullname_with_scope() << "'s input " << i << " address to " |
|
|
|
<< split_input0_device_address->ptr_ << " + " |
|
|
|
<< "prenode_with_index.second * input->size"; |
|
|
|
} |
|
|
|
|
|
|
|
kernel_inputs.push_back(input); |
|
|
|
} |
|
|
|
|
|
|
|
|