Browse Source

Ascend swap support communication op

tags/v1.6.0
tanghuikang 4 years ago
parent
commit
b6035cf1a1
6 changed files with 118 additions and 112 deletions
  1. +4
    -4
      mindspore/ccsrc/backend/session/ascend_session.cc
  2. +1
    -1
      mindspore/ccsrc/backend/session/executor.cc
  3. +1
    -3
      mindspore/ccsrc/backend/session/session_basic.cc
  4. +1
    -1
      mindspore/ccsrc/runtime/device/ascend/distribute/ascend_collective.cc
  5. +102
    -95
      mindspore/ccsrc/runtime/device/kernel_runtime.cc
  6. +9
    -8
      mindspore/ccsrc/runtime/device/kernel_runtime.h

+ 4
- 4
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -294,7 +294,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
MS_EXCEPTION_IF_NULL(kernel_graph);
device::KernelAdjust::GetInstance().LoadDeviceLoopCtrlParameters(kernel_graph);
auto &input_nodes = kernel_graph->input_nodes();
if (device::KernelRuntime::use_mem_scheduler()) {
if (device::KernelRuntime::UseMemScheduler()) {
kernel_graph->SetInputTensors(inputs);
return;
}
@@ -539,7 +539,7 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) {
} else {
// alloc memory, including static memory and dynamic memory
MemoryAlloc(graph.get());
if (!device::KernelRuntime::use_mem_scheduler()) {
if (!device::KernelRuntime::UseMemScheduler()) {
AnfAlgo::CacheAddrForGraph(graph);
}
// generate and load task info to device if it is sink mode
@@ -576,7 +576,7 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
// optimize graph
HardwareOptimize(child_graph);
// assign static memory of parameters
if (!device::KernelRuntime::use_mem_scheduler()) {
if (!device::KernelRuntime::UseMemScheduler()) {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->AssignStaticMemoryInput(*child_graph);
@@ -1800,7 +1800,7 @@ void AscendSession::ExecuteAllTaskInQueue() {
void AscendSession::UpdateOutputTensors(const VectorRef *outputs,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
std::map<DeviceAddressPtr, DeviceAddressPtr> *) {
if (device::KernelRuntime::use_mem_scheduler()) {
if (device::KernelRuntime::UseMemScheduler()) {
return;
}
MS_EXCEPTION_IF_NULL(outputs);


+ 1
- 1
mindspore/ccsrc/backend/session/executor.cc View File

@@ -130,7 +130,7 @@ void RunGraphTask::Run() {
return;
}
graph->ResetGraphRunningStatus();
if (device::KernelRuntime::use_mem_scheduler()) {
if (device::KernelRuntime::UseMemScheduler()) {
graph->SetOutputNodeToTensor(node_to_tensor_);
}
try {


+ 1
- 3
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -1691,8 +1691,6 @@ void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vecto
MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, node_to_tensor));
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
}

void SessionBasic::UpdateOutputTensors(const VectorRef *outputs,
@@ -1700,7 +1698,7 @@ void SessionBasic::UpdateOutputTensors(const VectorRef *outputs,
std::map<DeviceAddressPtr, DeviceAddressPtr> *) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (device::KernelRuntime::use_mem_scheduler()) {
if (device::KernelRuntime::UseMemScheduler()) {
return;
}
MS_EXCEPTION_IF_NULL(outputs);


+ 1
- 1
mindspore/ccsrc/runtime/device/ascend/distribute/ascend_collective.cc View File

@@ -49,7 +49,7 @@ bool HcclCollectiveGroup::InitCollective() {
<< "Loading libascend_collective.so failed. Many reasons could cause this:\n1.libascend_collective.so is not "
"installed.\n2.hccl is not "
"installed or found.\n3.mpi is not installed or found, please check if lib files of OpenMPI is added to "
"LD_LIBRATY_PATH.";
"LD_LIBRARY_PATH.";
}
init_mpi_ = DlsymFuncObj(InitMPI, collective_handle_);
finalize_mpi_ = DlsymFuncObj(FinalizeMPI, collective_handle_);


+ 102
- 95
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -98,9 +98,10 @@ bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_
void KernelRuntime::AssignMemory(const session::KernelGraph &graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (use_mem_scheduler()) {
if (UseMemScheduler()) {
AssignStaticMemoryValueNode(graph);
ResetNodeAddress(graph);
AssignCommunicationMem(graph);
} else {
MS_EXCEPTION_IF_NULL(mem_manager_);
mem_manager_->ResetDynamicMemory();
@@ -110,9 +111,9 @@ void KernelRuntime::AssignMemory(const session::KernelGraph &graph) {
UpdateRefNodeOutputMem(graph);
}

void KernelRuntime::RunOpGetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size,
std::vector<DeviceAddressPtr> *address_list,
std::vector<size_t> *align_size_list) const {
void KernelRuntime::GetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size,
DeviceAddressPtrList *address_list,
std::vector<size_t> *align_size_list) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(total_size);
MS_EXCEPTION_IF_NULL(address_list);
@@ -140,24 +141,19 @@ void KernelRuntime::RunOpGetCommunicationInputInfo(const AnfNodePtr &node, size_
}
}

void KernelRuntime::RunOpAssignCommunicationInput(const AnfNodePtr &node) const {
void KernelRuntime::AssignCommunicationInputFromMemoryPool(const AnfNodePtr &node) const {
if (!AnfAlgo::IsCommunicationOp(node)) {
return;
}
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(mem_manager_);

size_t total_size = 0;
std::vector<DeviceAddressPtr> address_list;
DeviceAddressPtrList address_list;
std::vector<size_t> align_size_list;
RunOpGetCommunicationInputInfo(node, &total_size, &address_list, &align_size_list);
if (address_list.empty()) {
return;
}

auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() < kMinInputSize) {
MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope();
GetCommunicationInputInfo(node, &total_size, &address_list, &align_size_list);
if (align_size_list.empty()) {
MS_LOG(WARNING) << "No inputs for " << node->fullname_with_scope();
return;
}

@@ -166,52 +162,53 @@ void KernelRuntime::RunOpAssignCommunicationInput(const AnfNodePtr &node) const
}
}

void KernelRuntime::RunOpGetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size,
std::vector<size_t> *align_size_list,
std::vector<DeviceAddressPtr> *device_address_list) const {
void KernelRuntime::GetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size,
DeviceAddressPtrList *address_list,
std::vector<size_t> *align_size_list) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(total_size);
MS_EXCEPTION_IF_NULL(align_size_list);
MS_EXCEPTION_IF_NULL(device_address_list);
auto runtime_info = node->user_data<session::OpRuntimeInfo>();
auto output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < output_num; ++i) {
MS_EXCEPTION_IF_NULL(runtime_info);
MS_EXCEPTION_IF_NULL(address_list);

const auto kernel_mod = AnfAlgo::GetKernelMod(node);
MS_EXCEPTION_IF_NULL(kernel_mod);
const auto output_size_list = kernel_mod->GetOutputSizeList();
for (size_t i = 0; i < output_size_list.size(); ++i) {
DeviceAddressPtr address = nullptr;
if (AnfAlgo::OutputAddrExist(node, i)) {
address = AnfAlgo::GetMutableOutputAddr(node, i);
} else {
std::string output_format = runtime_info->output_format(i);
auto output_type = runtime_info->output_type(i);
address =
CreateDeviceAddress(nullptr, runtime_info->output_tensor_size(i), output_format, output_type, {node, i});
const std::string output_format = AnfAlgo::GetOutputFormat(node, i);
const auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
const auto tensor_size = AnfAlgo::GetOutputTensorMemSize(node, i);
address = CreateDeviceAddress(nullptr, tensor_size, output_format, output_type, {node, i});
AnfAlgo::SetOutputAddr(address, i, node.get());
}
MS_EXCEPTION_IF_NULL(address);
auto align_size = MemoryManager::GetCommonAlignSize(address->size());
*total_size += align_size;
align_size_list->emplace_back(align_size);
device_address_list->emplace_back(address);
address_list->emplace_back(address);
}
}

void KernelRuntime::RunOpAssignCommunicationOutput(const AnfNodePtr &node) const {
void KernelRuntime::AssignCommunicationOutputFromMemoryPool(const AnfNodePtr &node) const {
if (!AnfAlgo::IsCommunicationOp(node)) {
return;
}

MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(mem_manager_);

size_t total_size = 0;
std::vector<size_t> align_size_list;
std::vector<DeviceAddressPtr> device_address_list;
RunOpGetCommunicationOutputInfo(node, &total_size, &align_size_list, &device_address_list);

std::vector<DeviceAddressPtr> address_list;
GetCommunicationOutputInfo(node, &total_size, &address_list, &align_size_list);
if (align_size_list.empty()) {
MS_LOG(WARNING) << "No output for " << node->fullname_with_scope();
return;
}

if (!mem_manager_->MallocContinuousMemFromMemPool(device_address_list, total_size, align_size_list)) {
if (!mem_manager_->MallocContinuousMemFromMemPool(address_list, total_size, align_size_list)) {
MS_LOG(EXCEPTION) << "Allocate continuous memory failed, totol_size:" << total_size;
}
}
@@ -318,8 +315,8 @@ void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &inpu
mem_manager_->ResetDynamicMemory();

for (const auto &node : graph.execution_order()) {
RunOpAssignCommunicationOutput(node);
RunOpAssignCommunicationInput(node);
AssignCommunicationOutputFromMemoryPool(node);
AssignCommunicationInputFromMemoryPool(node);
}

RunOpAssignInputMemory(input_tensors, graph);
@@ -688,62 +685,6 @@ void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &n
AssignWorkSpaceMem(type, node);
}

void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
auto &kernels = graph.execution_order();
if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) {
return;
}
auto kernel_events =
std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>();
auto &kernel_pre_run_events = kernel_events.first;
auto &kernel_post_run_events = kernel_events.second;
kernel_pre_run_events.resize(kernels.size());
kernel_post_run_events.resize(kernels.size());
for (size_t i = 0; i < kernels.size(); ++i) {
auto &kernel = kernels[i];
if (!AnfAlgo::IsCommunicationOp(kernel)) {
continue;
}
auto pre_event = CreateDeviceEvent();
auto post_event = CreateDeviceEvent();
MS_EXCEPTION_IF_NULL(pre_event);
MS_EXCEPTION_IF_NULL(post_event);
pre_event->set_wait_stream(communication_stream_);
pre_event->set_record_stream(stream_);
post_event->set_wait_stream(stream_);
post_event->set_record_stream(communication_stream_);
kernel_pre_run_events[i].emplace_back([pre_event]() {
pre_event->RecordEvent();
pre_event->WaitEvent();
});
kernel_post_run_events[i].emplace_back([post_event]() { post_event->RecordEvent(); });
bool found_nearest_child = false;
for (size_t j = i + 1; j < kernels.size(); ++j) {
auto &child = kernels[j];
MS_EXCEPTION_IF_NULL(child);
if (AnfAlgo::IsCommunicationOp(child)) {
continue;
}
auto input_size = child->inputs().size() - 1;
for (size_t k = 0; k < input_size; ++k) {
auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true);
if (kernel_index.first == kernel) {
found_nearest_child = true;
break;
}
}
if (found_nearest_child) {
kernel_pre_run_events[j].emplace_back([post_event]() { post_event->WaitEvent(); });
break;
}
}
if (!found_nearest_child) {
kernel_post_run_events[i].emplace_back([post_event]() { post_event->WaitEvent(); });
}
}
graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
}

void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(mem_manager_);
@@ -1174,7 +1115,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
}
}

bool KernelRuntime::use_mem_scheduler() {
bool KernelRuntime::UseMemScheduler() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER)) {
@@ -1185,6 +1126,62 @@ bool KernelRuntime::use_mem_scheduler() {
(context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode));
}

void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
auto &kernels = graph.execution_order();
if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) {
return;
}
auto kernel_events =
std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>();
auto &kernel_pre_run_events = kernel_events.first;
auto &kernel_post_run_events = kernel_events.second;
kernel_pre_run_events.resize(kernels.size());
kernel_post_run_events.resize(kernels.size());
for (size_t i = 0; i < kernels.size(); ++i) {
auto &kernel = kernels[i];
if (!AnfAlgo::IsCommunicationOp(kernel)) {
continue;
}
auto pre_event = CreateDeviceEvent();
auto post_event = CreateDeviceEvent();
MS_EXCEPTION_IF_NULL(pre_event);
MS_EXCEPTION_IF_NULL(post_event);
pre_event->set_wait_stream(communication_stream_);
pre_event->set_record_stream(stream_);
post_event->set_wait_stream(stream_);
post_event->set_record_stream(communication_stream_);
kernel_pre_run_events[i].emplace_back([pre_event]() {
pre_event->RecordEvent();
pre_event->WaitEvent();
});
kernel_post_run_events[i].emplace_back([post_event]() { post_event->RecordEvent(); });
bool found_nearest_child = false;
for (size_t j = i + 1; j < kernels.size(); ++j) {
auto &child = kernels[j];
MS_EXCEPTION_IF_NULL(child);
if (AnfAlgo::IsCommunicationOp(child)) {
continue;
}
auto input_size = child->inputs().size() - 1;
for (size_t k = 0; k < input_size; ++k) {
auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true);
if (kernel_index.first == kernel) {
found_nearest_child = true;
break;
}
}
if (found_nearest_child) {
kernel_pre_run_events[j].emplace_back([post_event]() { post_event->WaitEvent(); });
break;
}
}
if (!found_nearest_child) {
kernel_post_run_events[i].emplace_back([post_event]() { post_event->WaitEvent(); });
}
}
graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
}

void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs,
const std::shared_ptr<MemScheduler> &mem_scheduler) {
MS_EXCEPTION_IF_NULL(cnode);
@@ -1416,6 +1413,16 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
}
}

void KernelRuntime::AssignCommunicationMem(const session::KernelGraph &graph) {
for (const auto &kernel : graph.execution_order()) {
if (!AnfAlgo::IsCommunicationOp(kernel)) {
continue;
}
AssignCommunicationInputFromMemoryPool(kernel);
AssignCommunicationOutputFromMemoryPool(kernel);
}
}

bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNodePtr &kernel,
const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock) {
MS_EXCEPTION_IF_NULL(kernel);
@@ -1465,7 +1472,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::shared_ptr<MemScheduler> mem_scheduler = nullptr;
if (use_mem_scheduler()) {
if (UseMemScheduler()) {
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
MS_EXCEPTION_IF_NULL(mem_scheduler);
mem_scheduler->SetMemHandler(mem_manager_);
@@ -1533,7 +1540,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!use_mem_scheduler()) {
if (!UseMemScheduler()) {
return;
}
auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());


+ 9
- 8
mindspore/ccsrc/runtime/device/kernel_runtime.h View File

@@ -57,8 +57,8 @@ class KernelRuntime {
virtual void AssignMemory(const session::KernelGraph &graph);
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node = {});
void RunOpAssignCommunicationOutput(const AnfNodePtr &node) const;
void RunOpAssignCommunicationInput(const AnfNodePtr &node) const;
void AssignCommunicationOutputFromMemoryPool(const AnfNodePtr &node) const;
void AssignCommunicationInputFromMemoryPool(const AnfNodePtr &node) const;
void RunOpClearMemory(const session::KernelGraph &graph) const;
void RunOpMallocPre(const session::KernelGraph &graph, const std::vector<tensor::TensorPtr> &input_tensors);
#ifdef ENABLE_DEBUGGER
@@ -94,7 +94,7 @@ class KernelRuntime {
virtual void ReleaseDeviceRes() {}
void set_device_id(uint32_t device_id) { device_id_ = device_id; }
uint32_t device_id() { return device_id_; }
static bool use_mem_scheduler();
static bool UseMemScheduler();

#ifdef ENABLE_DEBUGGER
// set debugger
@@ -152,6 +152,8 @@ class KernelRuntime {
void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph);
void SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph,
const AnfNodePtr &kernel, bool mock);

void AssignCommunicationMem(const session::KernelGraph &graph);
void AssignStaticMemoryOutput(const session::KernelGraph &graph);
bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false);
void LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &run_events, size_t index) const;
@@ -171,11 +173,10 @@ class KernelRuntime {
void CheckIfSupportPSEmbeddingCache(const session::KernelGraph &graph);
void CheckSparsePSEmbeddingCache(const CNodePtr &node);
#endif
void RunOpGetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size,
std::vector<DeviceAddressPtr> *address_list,
std::vector<size_t> *align_size_list) const;
void RunOpGetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, std::vector<size_t> *align_size_list,
std::vector<DeviceAddressPtr> *device_address_list) const;
void GetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size, DeviceAddressPtrList *address_list,
std::vector<size_t> *align_size_list) const;
void GetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, DeviceAddressPtrList *address_list,
std::vector<size_t> *align_size_list) const;

protected:
uint32_t device_id_{0};


Loading…
Cancel
Save