|
|
|
@@ -98,9 +98,10 @@ bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_ |
|
|
|
void KernelRuntime::AssignMemory(const session::KernelGraph &graph) { |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
if (use_mem_scheduler()) { |
|
|
|
if (UseMemScheduler()) { |
|
|
|
AssignStaticMemoryValueNode(graph); |
|
|
|
ResetNodeAddress(graph); |
|
|
|
AssignCommunicationMem(graph); |
|
|
|
} else { |
|
|
|
MS_EXCEPTION_IF_NULL(mem_manager_); |
|
|
|
mem_manager_->ResetDynamicMemory(); |
|
|
|
@@ -110,9 +111,9 @@ void KernelRuntime::AssignMemory(const session::KernelGraph &graph) { |
|
|
|
UpdateRefNodeOutputMem(graph); |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::RunOpGetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size, |
|
|
|
std::vector<DeviceAddressPtr> *address_list, |
|
|
|
std::vector<size_t> *align_size_list) const { |
|
|
|
void KernelRuntime::GetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size, |
|
|
|
DeviceAddressPtrList *address_list, |
|
|
|
std::vector<size_t> *align_size_list) const { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(total_size); |
|
|
|
MS_EXCEPTION_IF_NULL(address_list); |
|
|
|
@@ -140,24 +141,19 @@ void KernelRuntime::RunOpGetCommunicationInputInfo(const AnfNodePtr &node, size_ |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::RunOpAssignCommunicationInput(const AnfNodePtr &node) const { |
|
|
|
void KernelRuntime::AssignCommunicationInputFromMemoryPool(const AnfNodePtr &node) const { |
|
|
|
if (!AnfAlgo::IsCommunicationOp(node)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(mem_manager_); |
|
|
|
|
|
|
|
size_t total_size = 0; |
|
|
|
std::vector<DeviceAddressPtr> address_list; |
|
|
|
DeviceAddressPtrList address_list; |
|
|
|
std::vector<size_t> align_size_list; |
|
|
|
RunOpGetCommunicationInputInfo(node, &total_size, &address_list, &align_size_list); |
|
|
|
if (address_list.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (cnode->inputs().size() < kMinInputSize) { |
|
|
|
MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope(); |
|
|
|
GetCommunicationInputInfo(node, &total_size, &address_list, &align_size_list); |
|
|
|
if (align_size_list.empty()) { |
|
|
|
MS_LOG(WARNING) << "No inputs for " << node->fullname_with_scope(); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -166,52 +162,53 @@ void KernelRuntime::RunOpAssignCommunicationInput(const AnfNodePtr &node) const |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::RunOpGetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, |
|
|
|
std::vector<size_t> *align_size_list, |
|
|
|
std::vector<DeviceAddressPtr> *device_address_list) const { |
|
|
|
void KernelRuntime::GetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, |
|
|
|
DeviceAddressPtrList *address_list, |
|
|
|
std::vector<size_t> *align_size_list) const { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(total_size); |
|
|
|
MS_EXCEPTION_IF_NULL(align_size_list); |
|
|
|
MS_EXCEPTION_IF_NULL(device_address_list); |
|
|
|
auto runtime_info = node->user_data<session::OpRuntimeInfo>(); |
|
|
|
auto output_num = AnfAlgo::GetOutputTensorNum(node); |
|
|
|
for (size_t i = 0; i < output_num; ++i) { |
|
|
|
MS_EXCEPTION_IF_NULL(runtime_info); |
|
|
|
MS_EXCEPTION_IF_NULL(address_list); |
|
|
|
|
|
|
|
const auto kernel_mod = AnfAlgo::GetKernelMod(node); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_mod); |
|
|
|
const auto output_size_list = kernel_mod->GetOutputSizeList(); |
|
|
|
for (size_t i = 0; i < output_size_list.size(); ++i) { |
|
|
|
DeviceAddressPtr address = nullptr; |
|
|
|
if (AnfAlgo::OutputAddrExist(node, i)) { |
|
|
|
address = AnfAlgo::GetMutableOutputAddr(node, i); |
|
|
|
} else { |
|
|
|
std::string output_format = runtime_info->output_format(i); |
|
|
|
auto output_type = runtime_info->output_type(i); |
|
|
|
address = |
|
|
|
CreateDeviceAddress(nullptr, runtime_info->output_tensor_size(i), output_format, output_type, {node, i}); |
|
|
|
const std::string output_format = AnfAlgo::GetOutputFormat(node, i); |
|
|
|
const auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); |
|
|
|
const auto tensor_size = AnfAlgo::GetOutputTensorMemSize(node, i); |
|
|
|
address = CreateDeviceAddress(nullptr, tensor_size, output_format, output_type, {node, i}); |
|
|
|
AnfAlgo::SetOutputAddr(address, i, node.get()); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(address); |
|
|
|
auto align_size = MemoryManager::GetCommonAlignSize(address->size()); |
|
|
|
*total_size += align_size; |
|
|
|
align_size_list->emplace_back(align_size); |
|
|
|
device_address_list->emplace_back(address); |
|
|
|
address_list->emplace_back(address); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::RunOpAssignCommunicationOutput(const AnfNodePtr &node) const { |
|
|
|
void KernelRuntime::AssignCommunicationOutputFromMemoryPool(const AnfNodePtr &node) const { |
|
|
|
if (!AnfAlgo::IsCommunicationOp(node)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(mem_manager_); |
|
|
|
|
|
|
|
size_t total_size = 0; |
|
|
|
std::vector<size_t> align_size_list; |
|
|
|
std::vector<DeviceAddressPtr> device_address_list; |
|
|
|
RunOpGetCommunicationOutputInfo(node, &total_size, &align_size_list, &device_address_list); |
|
|
|
|
|
|
|
std::vector<DeviceAddressPtr> address_list; |
|
|
|
GetCommunicationOutputInfo(node, &total_size, &address_list, &align_size_list); |
|
|
|
if (align_size_list.empty()) { |
|
|
|
MS_LOG(WARNING) << "No output for " << node->fullname_with_scope(); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
if (!mem_manager_->MallocContinuousMemFromMemPool(device_address_list, total_size, align_size_list)) { |
|
|
|
if (!mem_manager_->MallocContinuousMemFromMemPool(address_list, total_size, align_size_list)) { |
|
|
|
MS_LOG(EXCEPTION) << "Allocate continuous memory failed, totol_size:" << total_size; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -318,8 +315,8 @@ void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &inpu |
|
|
|
mem_manager_->ResetDynamicMemory(); |
|
|
|
|
|
|
|
for (const auto &node : graph.execution_order()) { |
|
|
|
RunOpAssignCommunicationOutput(node); |
|
|
|
RunOpAssignCommunicationInput(node); |
|
|
|
AssignCommunicationOutputFromMemoryPool(node); |
|
|
|
AssignCommunicationInputFromMemoryPool(node); |
|
|
|
} |
|
|
|
|
|
|
|
RunOpAssignInputMemory(input_tensors, graph); |
|
|
|
@@ -688,62 +685,6 @@ void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &n |
|
|
|
AssignWorkSpaceMem(type, node); |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { |
|
|
|
auto &kernels = graph.execution_order(); |
|
|
|
if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto kernel_events = |
|
|
|
std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>(); |
|
|
|
auto &kernel_pre_run_events = kernel_events.first; |
|
|
|
auto &kernel_post_run_events = kernel_events.second; |
|
|
|
kernel_pre_run_events.resize(kernels.size()); |
|
|
|
kernel_post_run_events.resize(kernels.size()); |
|
|
|
for (size_t i = 0; i < kernels.size(); ++i) { |
|
|
|
auto &kernel = kernels[i]; |
|
|
|
if (!AnfAlgo::IsCommunicationOp(kernel)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto pre_event = CreateDeviceEvent(); |
|
|
|
auto post_event = CreateDeviceEvent(); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_event); |
|
|
|
MS_EXCEPTION_IF_NULL(post_event); |
|
|
|
pre_event->set_wait_stream(communication_stream_); |
|
|
|
pre_event->set_record_stream(stream_); |
|
|
|
post_event->set_wait_stream(stream_); |
|
|
|
post_event->set_record_stream(communication_stream_); |
|
|
|
kernel_pre_run_events[i].emplace_back([pre_event]() { |
|
|
|
pre_event->RecordEvent(); |
|
|
|
pre_event->WaitEvent(); |
|
|
|
}); |
|
|
|
kernel_post_run_events[i].emplace_back([post_event]() { post_event->RecordEvent(); }); |
|
|
|
bool found_nearest_child = false; |
|
|
|
for (size_t j = i + 1; j < kernels.size(); ++j) { |
|
|
|
auto &child = kernels[j]; |
|
|
|
MS_EXCEPTION_IF_NULL(child); |
|
|
|
if (AnfAlgo::IsCommunicationOp(child)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto input_size = child->inputs().size() - 1; |
|
|
|
for (size_t k = 0; k < input_size; ++k) { |
|
|
|
auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true); |
|
|
|
if (kernel_index.first == kernel) { |
|
|
|
found_nearest_child = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (found_nearest_child) { |
|
|
|
kernel_pre_run_events[j].emplace_back([post_event]() { post_event->WaitEvent(); }); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!found_nearest_child) { |
|
|
|
kernel_post_run_events[i].emplace_back([post_event]() { post_event->WaitEvent(); }); |
|
|
|
} |
|
|
|
} |
|
|
|
graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events); |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(mem_manager_); |
|
|
|
@@ -1174,7 +1115,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool KernelRuntime::use_mem_scheduler() { |
|
|
|
bool KernelRuntime::UseMemScheduler() { |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER)) { |
|
|
|
@@ -1185,6 +1126,62 @@ bool KernelRuntime::use_mem_scheduler() { |
|
|
|
(context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode)); |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) { |
|
|
|
auto &kernels = graph.execution_order(); |
|
|
|
if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto kernel_events = |
|
|
|
std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>(); |
|
|
|
auto &kernel_pre_run_events = kernel_events.first; |
|
|
|
auto &kernel_post_run_events = kernel_events.second; |
|
|
|
kernel_pre_run_events.resize(kernels.size()); |
|
|
|
kernel_post_run_events.resize(kernels.size()); |
|
|
|
for (size_t i = 0; i < kernels.size(); ++i) { |
|
|
|
auto &kernel = kernels[i]; |
|
|
|
if (!AnfAlgo::IsCommunicationOp(kernel)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto pre_event = CreateDeviceEvent(); |
|
|
|
auto post_event = CreateDeviceEvent(); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_event); |
|
|
|
MS_EXCEPTION_IF_NULL(post_event); |
|
|
|
pre_event->set_wait_stream(communication_stream_); |
|
|
|
pre_event->set_record_stream(stream_); |
|
|
|
post_event->set_wait_stream(stream_); |
|
|
|
post_event->set_record_stream(communication_stream_); |
|
|
|
kernel_pre_run_events[i].emplace_back([pre_event]() { |
|
|
|
pre_event->RecordEvent(); |
|
|
|
pre_event->WaitEvent(); |
|
|
|
}); |
|
|
|
kernel_post_run_events[i].emplace_back([post_event]() { post_event->RecordEvent(); }); |
|
|
|
bool found_nearest_child = false; |
|
|
|
for (size_t j = i + 1; j < kernels.size(); ++j) { |
|
|
|
auto &child = kernels[j]; |
|
|
|
MS_EXCEPTION_IF_NULL(child); |
|
|
|
if (AnfAlgo::IsCommunicationOp(child)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto input_size = child->inputs().size() - 1; |
|
|
|
for (size_t k = 0; k < input_size; ++k) { |
|
|
|
auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true); |
|
|
|
if (kernel_index.first == kernel) { |
|
|
|
found_nearest_child = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (found_nearest_child) { |
|
|
|
kernel_pre_run_events[j].emplace_back([post_event]() { post_event->WaitEvent(); }); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!found_nearest_child) { |
|
|
|
kernel_post_run_events[i].emplace_back([post_event]() { post_event->WaitEvent(); }); |
|
|
|
} |
|
|
|
} |
|
|
|
graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events); |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs, |
|
|
|
const std::shared_ptr<MemScheduler> &mem_scheduler) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
@@ -1416,6 +1413,16 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::AssignCommunicationMem(const session::KernelGraph &graph) { |
|
|
|
for (const auto &kernel : graph.execution_order()) { |
|
|
|
if (!AnfAlgo::IsCommunicationOp(kernel)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
AssignCommunicationInputFromMemoryPool(kernel); |
|
|
|
AssignCommunicationOutputFromMemoryPool(kernel); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNodePtr &kernel, |
|
|
|
const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel); |
|
|
|
@@ -1465,7 +1472,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
std::shared_ptr<MemScheduler> mem_scheduler = nullptr; |
|
|
|
if (use_mem_scheduler()) { |
|
|
|
if (UseMemScheduler()) { |
|
|
|
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id()); |
|
|
|
MS_EXCEPTION_IF_NULL(mem_scheduler); |
|
|
|
mem_scheduler->SetMemHandler(mem_manager_); |
|
|
|
@@ -1533,7 +1540,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock |
|
|
|
void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) { |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
if (!use_mem_scheduler()) { |
|
|
|
if (!UseMemScheduler()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id()); |
|
|
|
|