From: @laiyongqiang Reviewed-by: @kisnwang,@jjfeing Signed-off-by: @jjfeingpull/14492/MERGE
| @@ -884,6 +884,14 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_ | |||||
| return kernel_info->OutputAddrExist(output_idx); | return kernel_info->OutputAddrExist(output_idx); | ||||
| } | } | ||||
| bool AnfRuntimeAlgorithm::WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice` | |||||
| auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| return kernel_info->WorkspaceAddrExist(output_idx); | |||||
| } | |||||
| const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, | const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, | ||||
| bool visit_nop_node) { | bool visit_nop_node) { | ||||
| KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); | KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); | ||||
| @@ -153,6 +153,8 @@ class AnfRuntimeAlgorithm { | |||||
| static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); | static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); | ||||
| // check whether output addr is exist or not | // check whether output addr is exist or not | ||||
| static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); | static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); | ||||
| // check whether workspace addr is exist or not | |||||
| static bool WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx); | |||||
| // get address from prev node,input_index is the input index of current node related to prev node | // get address from prev node,input_index is the input index of current node related to prev node | ||||
| static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx, | static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx, | ||||
| bool visit_nop_node = true); | bool visit_nop_node = true); | ||||
| @@ -81,6 +81,13 @@ DeviceAddressPtr KernelInfo::GetMutableWorkspaceAddr(size_t index) const { | |||||
| return workspace_address_list_[index]; | return workspace_address_list_[index]; | ||||
| } | } | ||||
| bool KernelInfo::WorkspaceAddrExist(size_t index) const { | |||||
| if (index >= workspace_address_list_.size()) { | |||||
| return false; | |||||
| } | |||||
| return workspace_address_list_[index] != nullptr; | |||||
| } | |||||
| bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) { | bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) { | ||||
| if (workspace_address_list_.empty()) { | if (workspace_address_list_.empty()) { | ||||
| // parameter and valuenode | // parameter and valuenode | ||||
| @@ -55,6 +55,7 @@ class KernelInfo : public KernelInfoDevice { | |||||
| bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index); | bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index); | ||||
| DeviceAddress *GetWorkspaceAddr(size_t index) const; | DeviceAddress *GetWorkspaceAddr(size_t index) const; | ||||
| DeviceAddressPtr GetMutableWorkspaceAddr(size_t index) const; | DeviceAddressPtr GetMutableWorkspaceAddr(size_t index) const; | ||||
| bool WorkspaceAddrExist(size_t index) const; | |||||
| bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index); | bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index); | ||||
| void set_kernel_mod(const kernel::KernelModPtr &kernel_mod); | void set_kernel_mod(const kernel::KernelModPtr &kernel_mod); | ||||
| kernel::KernelMod *MutableKernelMod() const; | kernel::KernelMod *MutableKernelMod() const; | ||||
| @@ -454,8 +454,8 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode | |||||
| std::vector<size_t> align_size_list; | std::vector<size_t> align_size_list; | ||||
| for (uint64_t mem_size : output_sizes) { | for (uint64_t mem_size : output_sizes) { | ||||
| if (AnfAlgo::OutputAddrExist(node, output_index++)) { | if (AnfAlgo::OutputAddrExist(node, output_index++)) { | ||||
| MS_LOG(INFO) << "communication op addr exist"; | |||||
| continue; | |||||
| MS_LOG(INFO) << "Communication op " << node->fullname_with_scope() << " has output device address"; | |||||
| return; | |||||
| } | } | ||||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) { | if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) { | ||||
| mem_size = mem_manager_->GetCommonAlignSize(mem_size); | mem_size = mem_manager_->GetCommonAlignSize(mem_size); | ||||
| @@ -464,6 +464,10 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode | |||||
| align_size_list.emplace_back(mem_size); | align_size_list.emplace_back(mem_size); | ||||
| } | } | ||||
| if (align_size_list.empty()) { | |||||
| return; | |||||
| } | |||||
| if (type == kReuseDynamicMem) { | if (type == kReuseDynamicMem) { | ||||
| // reuse communication op's all outputs' memory | // reuse communication op's all outputs' memory | ||||
| type = kReuseDynamicCommMem; | type = kReuseDynamicCommMem; | ||||
| @@ -533,6 +537,10 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP | |||||
| for (size_t i = 0; i < input_num; ++i) { | for (size_t i = 0; i < input_num; ++i) { | ||||
| auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); | auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); | ||||
| auto input_node = input_node_with_index.first; | auto input_node = input_node_with_index.first; | ||||
| if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) { | |||||
| MS_LOG(INFO) << "Communication op " << input_node->fullname_with_scope() << " has input device address"; | |||||
| return; | |||||
| } | |||||
| DeviceAddressPtr address = nullptr; | DeviceAddressPtr address = nullptr; | ||||
| if (input_node->isa<CNode>()) { | if (input_node->isa<CNode>()) { | ||||
| address = PreAssignCNodeMemory(input_node, input_node_with_index.second); | address = PreAssignCNodeMemory(input_node, input_node_with_index.second); | ||||
| @@ -811,6 +819,10 @@ void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| size_t index = 0; | size_t index = 0; | ||||
| for (auto &size : kernel_mod->GetWorkspaceSizeList()) { | for (auto &size : kernel_mod->GetWorkspaceSizeList()) { | ||||
| if (AnfAlgo::WorkspaceAddrExist(node, index)) { | |||||
| MS_LOG(INFO) << "Op " << node->fullname_with_scope() << " has workspace device address"; | |||||
| return; | |||||
| } | |||||
| auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size); | auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size); | ||||
| AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get()); | AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get()); | ||||
| index++; | index++; | ||||