Browse Source

!14492 fix reallocate memory bug for communication op

From: @laiyongqiang
Reviewed-by: @kisnwang,@jjfeing
Signed-off-by: @jjfeing
pull/14492/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
a1c18ecce1
5 changed files with 32 additions and 2 deletions
  1. +8
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  2. +2
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
  3. +7
    -0
      mindspore/ccsrc/runtime/device/kernel_info.cc
  4. +1
    -0
      mindspore/ccsrc/runtime/device/kernel_info.h
  5. +14
    -2
      mindspore/ccsrc/runtime/device/kernel_runtime.cc

+ 8
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -884,6 +884,14 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_
return kernel_info->OutputAddrExist(output_idx);
}

bool AnfRuntimeAlgorithm::WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
// Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->WorkspaceAddrExist(output_idx);
}

const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
bool visit_nop_node) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);


+ 2
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h View File

@@ -153,6 +153,8 @@ class AnfRuntimeAlgorithm {
static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
// check whether output addr is exist or not
static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx);
// check whether workspace addr is exist or not
static bool WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx);
// get address from prev node,input_index is the input index of current node related to prev node
static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx,
bool visit_nop_node = true);


+ 7
- 0
mindspore/ccsrc/runtime/device/kernel_info.cc View File

@@ -81,6 +81,13 @@ DeviceAddressPtr KernelInfo::GetMutableWorkspaceAddr(size_t index) const {
return workspace_address_list_[index];
}

bool KernelInfo::WorkspaceAddrExist(size_t index) const {
if (index >= workspace_address_list_.size()) {
return false;
}
return workspace_address_list_[index] != nullptr;
}

bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) {
if (workspace_address_list_.empty()) {
// parameter and valuenode


+ 1
- 0
mindspore/ccsrc/runtime/device/kernel_info.h View File

@@ -55,6 +55,7 @@ class KernelInfo : public KernelInfoDevice {
bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index);
DeviceAddress *GetWorkspaceAddr(size_t index) const;
DeviceAddressPtr GetMutableWorkspaceAddr(size_t index) const;
bool WorkspaceAddrExist(size_t index) const;
bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index);
void set_kernel_mod(const kernel::KernelModPtr &kernel_mod);
kernel::KernelMod *MutableKernelMod() const;


+ 14
- 2
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -454,8 +454,8 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
std::vector<size_t> align_size_list;
for (uint64_t mem_size : output_sizes) {
if (AnfAlgo::OutputAddrExist(node, output_index++)) {
MS_LOG(INFO) << "communication op addr exist";
continue;
MS_LOG(INFO) << "Communication op " << node->fullname_with_scope() << " has output device address";
return;
}
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
mem_size = mem_manager_->GetCommonAlignSize(mem_size);
@@ -464,6 +464,10 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
align_size_list.emplace_back(mem_size);
}

if (align_size_list.empty()) {
return;
}

if (type == kReuseDynamicMem) {
// reuse communication op's all outputs' memory
type = kReuseDynamicCommMem;
@@ -533,6 +537,10 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP
for (size_t i = 0; i < input_num; ++i) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
auto input_node = input_node_with_index.first;
if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) {
MS_LOG(INFO) << "Communication op " << input_node->fullname_with_scope() << " has input device address";
return;
}
DeviceAddressPtr address = nullptr;
if (input_node->isa<CNode>()) {
address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
@@ -811,6 +819,10 @@ void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(kernel_mod);
size_t index = 0;
for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
if (AnfAlgo::WorkspaceAddrExist(node, index)) {
MS_LOG(INFO) << "Op " << node->fullname_with_scope() << " has workspace device address";
return;
}
auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size);
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
index++;


Loading…
Cancel
Save