|
|
@@ -30,6 +30,7 @@ |
|
|
#include "kernel/common_utils.h" |
|
|
#include "kernel/common_utils.h" |
|
|
#include "kernel/oplib/oplib.h" |
|
|
#include "kernel/oplib/oplib.h" |
|
|
#include "ir/value.h" |
|
|
#include "ir/value.h" |
|
|
|
|
|
#include "pre_activate/common/helper.h" |
|
|
using mindspore::kernel::Address; |
|
|
using mindspore::kernel::Address; |
|
|
using mindspore::kernel::AddressPtr; |
|
|
using mindspore::kernel::AddressPtr; |
|
|
|
|
|
|
|
|
@@ -632,7 +633,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, |
|
|
|
|
|
|
|
|
void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const mindspore::AnfNodePtr &kernel, |
|
|
AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces, |
|
|
AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces, |
|
|
AddressPtrList *kernel_outputs) { |
|
|
AddressPtrList *kernel_outputs) { |
|
|
MS_EXCEPTION_IF_NULL(kernel); |
|
|
MS_EXCEPTION_IF_NULL(kernel); |
|
|
@@ -644,9 +645,15 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod |
|
|
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) { |
|
|
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) { |
|
|
return GenAddrCleanLaunchArgs(cnode, kernel_inputs); |
|
|
return GenAddrCleanLaunchArgs(cnode, kernel_inputs); |
|
|
} |
|
|
} |
|
|
|
|
|
auto is_all_nop_node = opt::IsAllNopNode(&graph); |
|
|
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { |
|
|
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { |
|
|
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i); |
|
|
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i); |
|
|
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input); |
|
|
|
|
|
|
|
|
DeviceAddressPtr device_address; |
|
|
|
|
|
if (is_all_nop_node) { |
|
|
|
|
|
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, false); |
|
|
|
|
|
} else { |
|
|
|
|
|
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, true); |
|
|
|
|
|
} |
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
kernel::AddressPtr input = std::make_shared<kernel::Address>(); |
|
|
kernel::AddressPtr input = std::make_shared<kernel::Address>(); |
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
@@ -656,8 +663,16 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod |
|
|
kernel_inputs->emplace_back(input); |
|
|
kernel_inputs->emplace_back(input); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) { |
|
|
|
|
|
auto device_address = AnfAlgo::GetOutputAddr(kernel, i); |
|
|
|
|
|
|
|
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernel); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_mod); |
|
|
|
|
|
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) { |
|
|
|
|
|
DeviceAddressPtr device_address; |
|
|
|
|
|
if (is_all_nop_node) { |
|
|
|
|
|
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); |
|
|
|
|
|
} else { |
|
|
|
|
|
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, true); |
|
|
|
|
|
} |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
kernel::AddressPtr output = std::make_shared<kernel::Address>(); |
|
|
kernel::AddressPtr output = std::make_shared<kernel::Address>(); |
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
output->addr = device_address->ptr_; |
|
|
output->addr = device_address->ptr_; |
|
|
@@ -666,7 +681,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod |
|
|
kernel_outputs->emplace_back(output); |
|
|
kernel_outputs->emplace_back(output); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) { |
|
|
|
|
|
|
|
|
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { |
|
|
auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); |
|
|
auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); |
|
|
kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); |
|
|
kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); |
|
|
MS_EXCEPTION_IF_NULL(workspace); |
|
|
MS_EXCEPTION_IF_NULL(workspace); |
|
|
@@ -721,7 +736,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { |
|
|
AddressPtrList kernel_inputs; |
|
|
AddressPtrList kernel_inputs; |
|
|
AddressPtrList kernel_workspaces; |
|
|
AddressPtrList kernel_workspaces; |
|
|
AddressPtrList kernel_outputs; |
|
|
AddressPtrList kernel_outputs; |
|
|
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); |
|
|
|
|
|
|
|
|
GenLaunchArgs(graph, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); |
|
|
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); |
|
|
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); |
|
|
if (!ret) { |
|
|
if (!ret) { |
|
|
MS_LOG(ERROR) << "Launch kernel failed."; |
|
|
MS_LOG(ERROR) << "Launch kernel failed."; |
|
|
|