1.Support Broadcast op 2.Support communication op as graph output 3.Optimize Communication op memory alocation 4.support hccl multi-grouptags/v0.3.0-alpha
| @@ -9,11 +9,11 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake) | |||||
| include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) | include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) | ||||
| include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) | include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) | ||||
| include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake) | include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake) | ||||
| include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake) | |||||
| # for CPU/GPU mode, find c_sec and slog from local prebuild | |||||
| # for CPU/GPU mode, find slog from local prebuild | |||||
| if (NOT ENABLE_D) | if (NOT ENABLE_D) | ||||
| set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) | set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) | ||||
| find_library(c_sec libc_sec.so ${GE_PREBUILD_PATH}) | |||||
| find_library(slog libslog.so ${GE_PREBUILD_PATH}) | find_library(slog libslog.so ${GE_PREBUILD_PATH}) | ||||
| elseif (DEFINED ENV{D_LINK_PATH}) | elseif (DEFINED ENV{D_LINK_PATH}) | ||||
| set(GE_LIB_PATH $ENV{D_LINK_PATH}) | set(GE_LIB_PATH $ENV{D_LINK_PATH}) | ||||
| @@ -28,7 +28,6 @@ elseif (DEFINED ENV{D_LINK_PATH}) | |||||
| message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | ||||
| endif() | endif() | ||||
| set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | ||||
| find_library(c_sec libc_sec.so ${GE_LIB_PATH}) | |||||
| find_library(slog libslog.so ${GE_LIB_PATH}) | find_library(slog libslog.so ${GE_LIB_PATH}) | ||||
| find_library(mmpa libmmpa.so ${GE_LIB_PATH}) | find_library(mmpa libmmpa.so ${GE_LIB_PATH}) | ||||
| find_library(runtime libruntime.so ${GE_LIB_PATH}) | find_library(runtime libruntime.so ${GE_LIB_PATH}) | ||||
| @@ -153,7 +153,7 @@ if (NOT ENABLE_GE) | |||||
| FILES | FILES | ||||
| ${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so | ${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so | ||||
| ${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libslog.so | ${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libslog.so | ||||
| ${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libc_sec.so | |||||
| ${CMAKE_SOURCE_DIR}/build/graphengine/libc_sec.so | |||||
| DESTINATION ${INSTALL_LIB_DIR} | DESTINATION ${INSTALL_LIB_DIR} | ||||
| COMPONENT mindspore | COMPONENT mindspore | ||||
| ) | ) | ||||
| @@ -1 +1 @@ | |||||
| Subproject commit 995b6dadc0fbbe4b80a08196886a53a18bffa60e | |||||
| Subproject commit 579dcb75a990b533f9182733a6424f2bd66f0f23 | |||||
| @@ -333,8 +333,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { | |||||
| bool status = ge::model_runner::ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, | bool status = ge::model_runner::ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, | ||||
| model_iter->second, listener); | model_iter->second, listener); | ||||
| if (!status) { | if (!status) { | ||||
| MS_LOG(ERROR) << "load task failed"; | |||||
| return false; | |||||
| MS_LOG(EXCEPTION) << "Load Task Failed"; | |||||
| } | } | ||||
| if (ProfilingManager::GetInstance().IsProfiling()) { | if (ProfilingManager::GetInstance().IsProfiling()) { | ||||
| auto task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first); | auto task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first); | ||||
| @@ -29,6 +29,7 @@ class GraphDescReporter : public DescReporter { | |||||
| public: | public: | ||||
| GraphDescReporter(uint32_t device_id, const std::string &file_name, std::vector<CNodePtr> cnode_list) | GraphDescReporter(uint32_t device_id, const std::string &file_name, std::vector<CNodePtr> cnode_list) | ||||
| : DescReporter(device_id, file_name, std::move(cnode_list)) {} | : DescReporter(device_id, file_name, std::move(cnode_list)) {} | ||||
| ~GraphDescReporter() override = default; | |||||
| void ReportData() override; | void ReportData() override; | ||||
| }; | }; | ||||
| } // namespace ascend | } // namespace ascend | ||||
| @@ -60,7 +60,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info | |||||
| const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0); | const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0); | ||||
| ret = hcom_broadcast(tag_broadcast.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()), | ret = hcom_broadcast(tag_broadcast.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()), | ||||
| static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()), | static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()), | ||||
| static_cast<u32>(task_info->root_id()), nullptr, stream); | |||||
| static_cast<u32>(task_info->root_id()), task_info->group().c_str(), stream); | |||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast<int>(ret); | MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast<int>(ret); | ||||
| return false; | return false; | ||||
| @@ -70,7 +70,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info | |||||
| const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0); | const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0); | ||||
| ret = hcom_all_gather(tag_all_gather.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()), | ret = hcom_all_gather(tag_all_gather.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()), | ||||
| reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()), | reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()), | ||||
| static_cast<hcclDataType_t>(task_info->data_type()), nullptr, stream); | |||||
| static_cast<hcclDataType_t>(task_info->data_type()), task_info->group().c_str(), stream); | |||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret; | MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret; | ||||
| return false; | return false; | ||||
| @@ -81,7 +81,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info | |||||
| ret = hcom_all_reduce(tag_all_reduce.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()), | ret = hcom_all_reduce(tag_all_reduce.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()), | ||||
| reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()), | reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()), | ||||
| static_cast<hcclDataType_t>(task_info->data_type()), | static_cast<hcclDataType_t>(task_info->data_type()), | ||||
| static_cast<hcclRedOp_t>(task_info->op_type()), nullptr, stream); | |||||
| static_cast<hcclRedOp_t>(task_info->op_type()), task_info->group().c_str(), stream); | |||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret; | MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret; | ||||
| return false; | return false; | ||||
| @@ -93,7 +93,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info | |||||
| ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()), | ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()), | ||||
| reinterpret_cast<void *>(task_info->output_data_addr()), | reinterpret_cast<void *>(task_info->output_data_addr()), | ||||
| static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()), | static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()), | ||||
| static_cast<hcclRedOp_t>(task_info->op_type()), nullptr, stream); | |||||
| static_cast<hcclRedOp_t>(task_info->op_type()), task_info->group().c_str(), stream); | |||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret; | MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret; | ||||
| return false; | return false; | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "device/kernel_runtime.h" | #include "device/kernel_runtime.h" | ||||
| #include <vector> | |||||
| #include <utility> | #include <utility> | ||||
| #include <numeric> | #include <numeric> | ||||
| #include <functional> | #include <functional> | ||||
| @@ -130,20 +131,16 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) { | |||||
| mem_manager_->ResetDynamicMemory(); | mem_manager_->ResetDynamicMemory(); | ||||
| AssignStaticMemory(graph); | AssignStaticMemory(graph); | ||||
| AssignDynamicMemory(graph); | AssignDynamicMemory(graph); | ||||
| UpdateRefNodeOutputMem(graph); | UpdateRefNodeOutputMem(graph); | ||||
| } | } | ||||
| void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, | void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| session::KernelGraph *graph) { | session::KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| // assign memory for input nodes | |||||
| RunOpAssignInputMemory(input_tensors, graph); | RunOpAssignInputMemory(input_tensors, graph); | ||||
| AssignStaticMemoryValueNode(graph); | AssignStaticMemoryValueNode(graph); | ||||
| for (const auto &cnode : graph->execution_order()) { | for (const auto &cnode : graph->execution_order()) { | ||||
| // assign memory for output nodes | |||||
| RunOpAssignOutputMemory(cnode); | RunOpAssignOutputMemory(cnode); | ||||
| // assign memory for workspace | |||||
| RunOpAssignWorkSpaceMemory(cnode); | RunOpAssignWorkSpaceMemory(cnode); | ||||
| } | } | ||||
| UpdateRefNodeOutputMem(graph); | UpdateRefNodeOutputMem(graph); | ||||
| @@ -280,12 +277,22 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||||
| void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) { | void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); | auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); | ||||
| std::vector<session::KernelWithIndex> non_communication_op; | |||||
| // Assign Communicate Op Memory firstly. | |||||
| for (const auto &node : nodes) { | for (const auto &node : nodes) { | ||||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); | auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); | ||||
| MS_EXCEPTION_IF_NULL(item_with_index.first); | MS_EXCEPTION_IF_NULL(item_with_index.first); | ||||
| if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) { | if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (AnfAlgo::IsCommunicationOp(item_with_index.first)) { | |||||
| AssignCommunicationNodeMem(kStaticMem, item_with_index.first); | |||||
| } else { | |||||
| non_communication_op.emplace_back(item_with_index); | |||||
| } | |||||
| } | |||||
| for (const auto &item_with_index : non_communication_op) { | |||||
| AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second)); | AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -322,6 +329,11 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { | |||||
| } | } | ||||
| } | } | ||||
| void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) { | |||||
| AssignCommunicationNodeInputMem(node); | |||||
| AssignCommunicationNodeOutputMem(flag, node); | |||||
| } | |||||
| void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) { | void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(mem_manager_); | MS_EXCEPTION_IF_NULL(mem_manager_); | ||||
| @@ -335,8 +347,13 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr | |||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| size_t total_size = 0; | size_t total_size = 0; | ||||
| size_t output_index = 0; | |||||
| std::vector<size_t> align_size_list; | std::vector<size_t> align_size_list; | ||||
| for (uint64_t mem_size : output_sizes) { | for (uint64_t mem_size : output_sizes) { | ||||
| if (AnfAlgo::OutputAddrExist(node, output_index++)) { | |||||
| MS_LOG(INFO) << "communication op addr exist"; | |||||
| continue; | |||||
| } | |||||
| if (context_ptr->enable_hccl()) { | if (context_ptr->enable_hccl()) { | ||||
| mem_size = mem_manager_->GetCommonAlignSize(mem_size); | mem_size = mem_manager_->GetCommonAlignSize(mem_size); | ||||
| } | } | ||||
| @@ -353,7 +370,21 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr | |||||
| } | } | ||||
| } | } | ||||
| void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) { | |||||
| DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(anf_node); | |||||
| auto output_sizes = kernel_mod->GetOutputSizeList(); | |||||
| if (output_sizes.size() <= index) { | |||||
| MS_LOG(EXCEPTION) << "Previous node output size < node index"; | |||||
| } | |||||
| std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index); | |||||
| auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index); | |||||
| auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type); | |||||
| AnfAlgo::SetOutputAddr(address, index, anf_node.get()); | |||||
| return address; | |||||
| } | |||||
| void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) { | |||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| @@ -361,12 +392,16 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) { | |||||
| size_t total_size = 0; | size_t total_size = 0; | ||||
| std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size; | std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size; | ||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { | for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { | ||||
| auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i); | |||||
| MS_EXCEPTION_IF_NULL(address); | |||||
| auto mem_size = address->size(); | |||||
| if (context_ptr->enable_hccl()) { | |||||
| mem_size = mem_manager_->GetCommonAlignSize(mem_size); | |||||
| auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); | |||||
| auto input_node = input_node_with_index.first; | |||||
| DeviceAddressPtr address = nullptr; | |||||
| if (input_node->isa<CNode>()) { | |||||
| address = PreAssignCNodeMemory(input_node, input_node_with_index.second); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Communication node inputs only support CNode"; | |||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(address); | |||||
| auto mem_size = mem_manager_->GetCommonAlignSize(address->size()); | |||||
| total_size += mem_size; | total_size += mem_size; | ||||
| addr_size.emplace_back(address.get(), mem_size); | addr_size.emplace_back(address.get(), mem_size); | ||||
| } | } | ||||
| @@ -381,11 +416,6 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) { | |||||
| void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) { | void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(mem_manager_); | MS_EXCEPTION_IF_NULL(mem_manager_); | ||||
| if (AnfAlgo::IsCommunicationOp(node)) { | |||||
| UpdateCommunicationOpInputMem(node); | |||||
| AssignCommunicationNodeOutputMem(flag, node); | |||||
| return; | |||||
| } | |||||
| if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { | if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { | ||||
| MS_LOG(INFO) << "GetNext disable mem_reuse"; | MS_LOG(INFO) << "GetNext disable mem_reuse"; | ||||
| flag = kDynamicMem; | flag = kDynamicMem; | ||||
| @@ -506,10 +536,22 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { | |||||
| mem_manager_->MallocReusedDynamicMem(graph); | mem_manager_->MallocReusedDynamicMem(graph); | ||||
| mem_flag = kReuseDynamicMem; | mem_flag = kReuseDynamicMem; | ||||
| } | } | ||||
| auto &kernels = graph->execution_order(); | |||||
| for (auto &kernel : kernels) { | |||||
| AssignNodeOutputMem(mem_flag, kernel, kGetAllOuts); | |||||
| AssignWorkSpaceMem(mem_flag, kernel); | |||||
| auto &execution_nodes = graph->execution_order(); | |||||
| std::vector<CNodePtr> compute_nodes; | |||||
| // communication nodes first | |||||
| for (auto &node : execution_nodes) { | |||||
| if (AnfAlgo::IsCommunicationOp(node)) { | |||||
| // skip if the memory is already alocated | |||||
| AssignCommunicationNodeMem(mem_flag, node); | |||||
| } else { | |||||
| compute_nodes.emplace_back(node); | |||||
| } | |||||
| } | |||||
| // then compute nodes | |||||
| for (auto &node : compute_nodes) { | |||||
| AssignNodeOutputMem(mem_flag, node, kGetAllOuts); | |||||
| AssignWorkSpaceMem(mem_flag, node); | |||||
| } | } | ||||
| } | } | ||||
| @@ -73,9 +73,12 @@ class KernelRuntime { | |||||
| void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); | void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); | ||||
| void AssignWorkSpaceMem(int flag, const AnfNodePtr &node); | void AssignWorkSpaceMem(int flag, const AnfNodePtr &node); | ||||
| void AssignReuseWorkSpaceMem(const AnfNodePtr &node); | void AssignReuseWorkSpaceMem(const AnfNodePtr &node); | ||||
| void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); | |||||
| void UpdateRefNodeOutputMem(const session::KernelGraph *graph); | void UpdateRefNodeOutputMem(const session::KernelGraph *graph); | ||||
| void UpdateCommunicationOpInputMem(const AnfNodePtr &node); | |||||
| void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); | |||||
| void AssignCommunicationNodeInputMem(const AnfNodePtr &node); | |||||
| void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node); | |||||
| #ifdef ENABLE_DUMP_E2E | #ifdef ENABLE_DUMP_E2E | ||||
| bool SetDumpConf(); | bool SetDumpConf(); | ||||
| #endif | #endif | ||||
| @@ -91,6 +94,7 @@ class KernelRuntime { | |||||
| void RunOpAssignOutputMemory(const AnfNodePtr &kernel); | void RunOpAssignOutputMemory(const AnfNodePtr &kernel); | ||||
| void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); | void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); | ||||
| void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); | void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); | ||||
| DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); | |||||
| protected: | protected: | ||||
| uint32_t device_id_{0}; | uint32_t device_id_{0}; | ||||
| @@ -90,6 +90,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| HcomUtil::GetHcomGroup(NOT_NULL(anf_node), NOT_NULL(&group_)); | |||||
| anf_node_ = anf_node; | anf_node_ = anf_node; | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -147,7 +148,7 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu | |||||
| HcclTaskInfoPtr task_info_ptr = std::make_shared<HcclTaskInfo>( | HcclTaskInfoPtr task_info_ptr = std::make_shared<HcclTaskInfo>( | ||||
| stream_id, hccl_type, input_data_addr, output_data_addr, workspace_address, workspace_num, 0, private_def, nullptr, | stream_id, hccl_type, input_data_addr, output_data_addr, workspace_address, workspace_num, 0, private_def, nullptr, | ||||
| hccl_count_, root_id_, op_type_, data_type, RuntimeUtils::HcomBindModel, RuntimeUtils::HcomUnbindModel, | |||||
| hccl_count_, root_id_, op_type_, data_type, group_, RuntimeUtils::HcomBindModel, RuntimeUtils::HcomUnbindModel, | |||||
| RuntimeUtils::HcomDistribute); | RuntimeUtils::HcomDistribute); | ||||
| MS_EXCEPTION_IF_NULL(task_info_ptr); | MS_EXCEPTION_IF_NULL(task_info_ptr); | ||||
| return {task_info_ptr}; | return {task_info_ptr}; | ||||
| @@ -54,6 +54,7 @@ class HcclKernel : public AscendKernelMod { | |||||
| mutable std::vector<size_t> workspace_size_list_; | mutable std::vector<size_t> workspace_size_list_; | ||||
| AnfNodePtr anf_node_; | AnfNodePtr anf_node_; | ||||
| std::string op_name_; | std::string op_name_; | ||||
| std::string group_; | |||||
| }; | }; | ||||
| using HcclKernelCreater = std::function<std::shared_ptr<HcclKernel>()>; | using HcclKernelCreater = std::function<std::shared_ptr<HcclKernel>()>; | ||||
| @@ -176,11 +176,22 @@ bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) { | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| if (primitive->GetAttr("root_rank") != nullptr) { | if (primitive->GetAttr("root_rank") != nullptr) { | ||||
| *root_id = GetValue<const vector<uint32_t>>(primitive->GetAttr("root_rank"))[0]; | |||||
| *root_id = (uint32_t)GetValue<int>(primitive->GetAttr("root_rank")); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!"; | MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| void HcomUtil::GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group) { | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto attr = primitive->GetAttr("group"); | |||||
| if (attr != nullptr) { | |||||
| *group = GetValue<std::string>(attr); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Get Hcom Group Attr of Op:" << anf_node->fullname_with_scope() << " failed"; | |||||
| } | |||||
| } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "ir/dtype.h" | #include "ir/dtype.h" | ||||
| #include "hccl/base.h" | #include "hccl/base.h" | ||||
| #include "utils/contract.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| using std::map; | using std::map; | ||||
| @@ -61,6 +62,7 @@ class HcomUtil { | |||||
| const vector<vector<size_t>> &shape_list, uint64_t *total_count); | const vector<vector<size_t>> &shape_list, uint64_t *total_count); | ||||
| static bool GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type); | static bool GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type); | ||||
| static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id); | static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id); | ||||
| static void GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group); | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -66,8 +66,7 @@ const AnfNodePtr AddMemcpyAsync::Process(const FuncGraphPtr &func_graph, const A | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||||
| if (op_name != kAllReduceOpName && op_name != kAllGatherOpName && op_name != kReduceScatterOpName) { | |||||
| if (!AnfAlgo::IsCommunicationOp(node)) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return AddMemcpyAsyncIfInputIsUsedByOthers(func_graph, cnode); | return AddMemcpyAsyncIfInputIsUsedByOthers(func_graph, cnode); | ||||
| @@ -173,6 +173,19 @@ const BaseRef DealRefTransAndCast::DefinePattern() const { | |||||
| return VectorRef({V, Xs}); | return VectorRef({V, Xs}); | ||||
| } | } | ||||
| void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||||
| if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { | |||||
| auto input_size = AnfAlgo::GetInputTensorNum(cnode); | |||||
| for (size_t i = 0; i < input_size; ++i) { | |||||
| auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i); | |||||
| auto input_node = input_node_with_index.first; | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| MS_LOG(INFO) << "origin node:" << input_node->fullname_with_scope(); | |||||
| AddRefPairToKernelGraph(func_graph, cnode, nullptr, cnode, i, input_node_with_index); | |||||
| } | |||||
| } | |||||
| } | |||||
| const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | ||||
| const EquivPtr &) const { | const EquivPtr &) const { | ||||
| if (node == nullptr || !node->isa<CNode>()) { | if (node == nullptr || !node->isa<CNode>()) { | ||||
| @@ -184,6 +197,9 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A | |||||
| if (!AnfAlgo::IsRealCNodeKernel(cnode)) { | if (!AnfAlgo::IsRealCNodeKernel(cnode)) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| DealBroadCastAsRef(graph, cnode); | |||||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | auto op_name = AnfAlgo::GetCNodeName(cnode); | ||||
| auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); | auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); | ||||
| if (op_info == nullptr || !op_info->is_ref()) { | if (op_info == nullptr || !op_info->is_ref()) { | ||||