Signed-off-by: zhoufeng <zhoufeng54@huawei.com>tags/v1.2.0-rc1
| @@ -1 +1 @@ | |||||
| Subproject commit f65be61197ed36dfc9dc10b91b58bf93835fa27b | |||||
| Subproject commit 40e5c42a12c4daa1530e8db9d006d5b3be5b378f | |||||
| @@ -46,7 +46,7 @@ std::string MsOpNameToHcomOpType(const std::string &ms_op_type) { | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) { | |||||
| void HcclKernelFactory::Register(const std::string &name, HcclKernelCreater &&fun) { | |||||
| hcclKernelMap_.emplace(name, std::move(fun)); | hcclKernelMap_.emplace(name, std::move(fun)); | ||||
| } | } | ||||
| @@ -99,7 +99,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) { | |||||
| if (op_name_ == kReceive) { | if (op_name_ == kReceive) { | ||||
| auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_); | auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_); | ||||
| if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { | if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { | ||||
| MS_LOG(ERROR) << "HcomDataType cann't support Current Ascend Data Type : " << receive_type_; | |||||
| MS_LOG(ERROR) << "HcomDataType cannot support Current Ascend Data Type : " << receive_type_; | |||||
| return false; | return false; | ||||
| } | } | ||||
| hccl_data_type_list_.emplace_back(iter->second); | hccl_data_type_list_.emplace_back(iter->second); | ||||
| @@ -180,9 +180,17 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const { | |||||
| return output_size_list_; | return output_size_list_; | ||||
| } | } | ||||
| const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||||
| const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { | |||||
| if (!workspace_size_list_.empty() || hccl_data_type_list_.empty()) { | |||||
| return workspace_size_list_; | |||||
| } | |||||
| workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_, hccl_data_type_list_[0])); | |||||
| return workspace_size_list_; | |||||
| } | |||||
| std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs, | |||||
| const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, uint32_t stream_id) { | const std::vector<AddressPtr> &outputs, uint32_t stream_id) { | ||||
| std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); | std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); | ||||
| if (hccl_type == kReceive) { | if (hccl_type == kReceive) { | ||||
| @@ -221,10 +229,19 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu | |||||
| MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret; | MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret; | ||||
| } | } | ||||
| void *workspace_addr = nullptr; | |||||
| if (task.workspace_size != 0) { | |||||
| if (workspace.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node_->DebugString() << " is empty"; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(workspace.at(0)); | |||||
| workspace_addr = workspace.at(0)->addr; | |||||
| } | |||||
| results.emplace_back(std::make_shared<HcclTaskInfo>( | results.emplace_back(std::make_shared<HcclTaskInfo>( | ||||
| kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, task.workspace_size, | |||||
| task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type, | |||||
| group_, NeedDump())); | |||||
| kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, workspace_addr, | |||||
| task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, | |||||
| op_type_, data_type, group_, NeedDump())); | |||||
| } | } | ||||
| return results; | return results; | ||||
| @@ -68,7 +68,7 @@ class HcclKernelFactory { | |||||
| public: | public: | ||||
| static HcclKernelFactory &Get(); | static HcclKernelFactory &Get(); | ||||
| void Registe(const string &name, HcclKernelCreater &&fun); | |||||
| void Register(const string &name, HcclKernelCreater &&fun); | |||||
| static std::shared_ptr<HcclKernel> Get(const string &name); | static std::shared_ptr<HcclKernel> Get(const string &name); | ||||
| private: | private: | ||||
| @@ -78,7 +78,7 @@ class HcclKernelFactory { | |||||
| class _HcclKernelRegister { | class _HcclKernelRegister { | ||||
| public: | public: | ||||
| _HcclKernelRegister(const string &name, HcclKernelCreater &&fun) { | _HcclKernelRegister(const string &name, HcclKernelCreater &&fun) { | ||||
| HcclKernelFactory::Get().Registe(name, std::move(fun)); | |||||
| HcclKernelFactory::Get().Register(name, std::move(fun)); | |||||
| } | } | ||||
| ~_HcclKernelRegister() = default; | ~_HcclKernelRegister() = default; | ||||
| }; | }; | ||||
| @@ -433,6 +433,7 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { | |||||
| void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) { | void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) { | ||||
| AssignCommunicationNodeInputMem(type, node); | AssignCommunicationNodeInputMem(type, node); | ||||
| AssignCommunicationNodeOutputMem(type, node); | AssignCommunicationNodeOutputMem(type, node); | ||||
| AssignWorkSpaceMem(type, node); | |||||
| } | } | ||||
| void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) { | void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) { | ||||
| @@ -99,7 +99,7 @@ bool FinalizeHccl() { | |||||
| if (ops_kernel_info_store != nullptr) { | if (ops_kernel_info_store != nullptr) { | ||||
| auto ret = ops_kernel_info_store->Finalize(); | auto ret = ops_kernel_info_store->Finalize(); | ||||
| if (ret != ge::SUCCESS) { | if (ret != ge::SUCCESS) { | ||||
| MS_LOG(ERROR) << "Destory info store failed, ret = " << ret; | |||||
| MS_LOG(ERROR) << "Destroy info store failed, ret = " << ret; | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -107,7 +107,7 @@ bool FinalizeHccl() { | |||||
| if (ops_kernel_builder != nullptr) { | if (ops_kernel_builder != nullptr) { | ||||
| auto ret = ops_kernel_builder->Finalize(); | auto ret = ops_kernel_builder->Finalize(); | ||||
| if (ret != ge::SUCCESS) { | if (ret != ge::SUCCESS) { | ||||
| MS_LOG(ERROR) << "Destory builder failed, ret = " << ret; | |||||
| MS_LOG(ERROR) << "Destroy builder failed, ret = " << ret; | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -151,7 +151,30 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTask | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool CalcOpRunningParam(const AnfNodePtr &node) { return true; } | |||||
| int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) { | |||||
| MS_EXCEPTION_IF_NULL(ops_kernel_builder); | |||||
| MS_LOG(INFO) << "Start calc workspace size for hccl node " << node->DebugString() << " ,dtype is " << datatype; | |||||
| auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype); | |||||
| MS_EXCEPTION_IF_NULL(ge_node); | |||||
| auto op = ge_node->GetOpDesc(); | |||||
| MS_EXCEPTION_IF_NULL(op); | |||||
| MS_LOG(INFO) << "Start to call CalcOpRunningParam"; | |||||
| ge::Status ret = ops_kernel_builder->CalcOpRunningParam(*ge_node); | |||||
| if (ret != ge::SUCCESS) { | |||||
| MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret; | |||||
| return false; | |||||
| } | |||||
| auto workspace_sizes = op->GetWorkspaceBytes(); | |||||
| if (workspace_sizes.size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "Unexpected workspace size " << workspace_sizes.size(); | |||||
| } | |||||
| int64_t workspace_size = workspace_sizes[0]; | |||||
| MS_LOG(INFO) << "Node " << node->DebugString() << " workspace size is " << workspace_size; | |||||
| ge_graph.reset(); | |||||
| return workspace_size; | |||||
| } | |||||
| void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); } | void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); } | ||||
| @@ -23,21 +23,18 @@ | |||||
| #include "mindspore/core/ir/anf.h" | #include "mindspore/core/ir/anf.h" | ||||
| #include "hccl/hccl_types.h" | #include "hccl/hccl_types.h" | ||||
| #define MS_API __attribute__((visibility("default"))) | |||||
| namespace mindspore::hccl { | namespace mindspore::hccl { | ||||
| struct MS_API HcclTaskInfo { | |||||
| struct HcclTaskInfo { | |||||
| std::string private_def; | std::string private_def; | ||||
| int64_t workspace_size; | int64_t workspace_size; | ||||
| int64_t stream_num; | int64_t stream_num; | ||||
| }; | }; | ||||
| MS_API bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file); | |||||
| MS_API bool FinalizeHccl(); | |||||
| MS_API bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists); | |||||
| MS_API bool CalcOpRunningParam(const AnfNodePtr &node); | |||||
| MS_API void *GetHcclOpsKernelInfoStore(); | |||||
| MS_API std::string GetHcclType(const AnfNodePtr &node); | |||||
| bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file); | |||||
| bool FinalizeHccl(); | |||||
| bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists); | |||||
| int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype); | |||||
| void *GetHcclOpsKernelInfoStore(); | |||||
| std::string GetHcclType(const AnfNodePtr &node); | |||||
| } // namespace mindspore::hccl | } // namespace mindspore::hccl | ||||
| #undef MS_API | |||||
| #endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H | #endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H | ||||
| @@ -64,7 +64,7 @@ namespace hccl { | |||||
| bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; } | bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; } | ||||
| bool FinalizeHccl() { return true; } | bool FinalizeHccl() { return true; } | ||||
| bool GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) { return true; } | bool GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) { return true; } | ||||
| bool CalcOpRunningParam(const AnfNodePtr &) { return true; } | |||||
| int64_t CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) { return 0; } | |||||
| void *GetHcclOpsKernelInfoStore() { return nullptr; } | void *GetHcclOpsKernelInfoStore() { return nullptr; } | ||||
| std::string GetHcclType(const AnfNodePtr &) { return ""; } | std::string GetHcclType(const AnfNodePtr &) { return ""; } | ||||
| } // namespace hccl | } // namespace hccl | ||||