Signed-off-by: zhoufeng <zhoufeng54@huawei.com>tags/v1.2.0-rc1
| @@ -1 +1 @@ | |||
| Subproject commit f65be61197ed36dfc9dc10b91b58bf93835fa27b | |||
| Subproject commit 40e5c42a12c4daa1530e8db9d006d5b3be5b378f | |||
| @@ -46,7 +46,7 @@ std::string MsOpNameToHcomOpType(const std::string &ms_op_type) { | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) { | |||
| void HcclKernelFactory::Register(const std::string &name, HcclKernelCreater &&fun) { | |||
| hcclKernelMap_.emplace(name, std::move(fun)); | |||
| } | |||
| @@ -99,7 +99,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) { | |||
| if (op_name_ == kReceive) { | |||
| auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_); | |||
| if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { | |||
| MS_LOG(ERROR) << "HcomDataType cann't support Current Ascend Data Type : " << receive_type_; | |||
| MS_LOG(ERROR) << "HcomDataType cannot support Current Ascend Data Type : " << receive_type_; | |||
| return false; | |||
| } | |||
| hccl_data_type_list_.emplace_back(iter->second); | |||
| @@ -180,9 +180,17 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const { | |||
| return output_size_list_; | |||
| } | |||
| const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { | |||
| if (!workspace_size_list_.empty() || hccl_data_type_list_.empty()) { | |||
| return workspace_size_list_; | |||
| } | |||
| workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_, hccl_data_type_list_[0])); | |||
| return workspace_size_list_; | |||
| } | |||
| std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uint32_t stream_id) { | |||
| std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); | |||
| if (hccl_type == kReceive) { | |||
| @@ -221,10 +229,19 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu | |||
| MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret; | |||
| } | |||
| void *workspace_addr = nullptr; | |||
| if (task.workspace_size != 0) { | |||
| if (workspace.empty()) { | |||
| MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node_->DebugString() << " is empty"; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(workspace.at(0)); | |||
| workspace_addr = workspace.at(0)->addr; | |||
| } | |||
| results.emplace_back(std::make_shared<HcclTaskInfo>( | |||
| kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, task.workspace_size, | |||
| task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type, | |||
| group_, NeedDump())); | |||
| kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, workspace_addr, | |||
| task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, | |||
| op_type_, data_type, group_, NeedDump())); | |||
| } | |||
| return results; | |||
| @@ -68,7 +68,7 @@ class HcclKernelFactory { | |||
| public: | |||
| static HcclKernelFactory &Get(); | |||
| void Registe(const string &name, HcclKernelCreater &&fun); | |||
| void Register(const string &name, HcclKernelCreater &&fun); | |||
| static std::shared_ptr<HcclKernel> Get(const string &name); | |||
| private: | |||
| @@ -78,7 +78,7 @@ class HcclKernelFactory { | |||
| class _HcclKernelRegister { | |||
| public: | |||
| _HcclKernelRegister(const string &name, HcclKernelCreater &&fun) { | |||
| HcclKernelFactory::Get().Registe(name, std::move(fun)); | |||
| HcclKernelFactory::Get().Register(name, std::move(fun)); | |||
| } | |||
| ~_HcclKernelRegister() = default; | |||
| }; | |||
| @@ -433,6 +433,7 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { | |||
| void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) { | |||
| AssignCommunicationNodeInputMem(type, node); | |||
| AssignCommunicationNodeOutputMem(type, node); | |||
| AssignWorkSpaceMem(type, node); | |||
| } | |||
| void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) { | |||
| @@ -99,7 +99,7 @@ bool FinalizeHccl() { | |||
| if (ops_kernel_info_store != nullptr) { | |||
| auto ret = ops_kernel_info_store->Finalize(); | |||
| if (ret != ge::SUCCESS) { | |||
| MS_LOG(ERROR) << "Destory info store failed, ret = " << ret; | |||
| MS_LOG(ERROR) << "Destroy info store failed, ret = " << ret; | |||
| return false; | |||
| } | |||
| } | |||
| @@ -107,7 +107,7 @@ bool FinalizeHccl() { | |||
| if (ops_kernel_builder != nullptr) { | |||
| auto ret = ops_kernel_builder->Finalize(); | |||
| if (ret != ge::SUCCESS) { | |||
| MS_LOG(ERROR) << "Destory builder failed, ret = " << ret; | |||
| MS_LOG(ERROR) << "Destroy builder failed, ret = " << ret; | |||
| return false; | |||
| } | |||
| } | |||
| @@ -151,7 +151,30 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTask | |||
| return true; | |||
| } | |||
| bool CalcOpRunningParam(const AnfNodePtr &node) { return true; } | |||
| int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) { | |||
| MS_EXCEPTION_IF_NULL(ops_kernel_builder); | |||
| MS_LOG(INFO) << "Start calc workspace size for hccl node " << node->DebugString() << " ,dtype is " << datatype; | |||
| auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype); | |||
| MS_EXCEPTION_IF_NULL(ge_node); | |||
| auto op = ge_node->GetOpDesc(); | |||
| MS_EXCEPTION_IF_NULL(op); | |||
| MS_LOG(INFO) << "Start to call CalcOpRunningParam"; | |||
| ge::Status ret = ops_kernel_builder->CalcOpRunningParam(*ge_node); | |||
| if (ret != ge::SUCCESS) { | |||
| MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret; | |||
| return false; | |||
| } | |||
| auto workspace_sizes = op->GetWorkspaceBytes(); | |||
| if (workspace_sizes.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Unexpected workspace size " << workspace_sizes.size(); | |||
| } | |||
| int64_t workspace_size = workspace_sizes[0]; | |||
| MS_LOG(INFO) << "Node " << node->DebugString() << " workspace size is " << workspace_size; | |||
| ge_graph.reset(); | |||
| return workspace_size; | |||
| } | |||
| void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); } | |||
| @@ -23,21 +23,18 @@ | |||
| #include "mindspore/core/ir/anf.h" | |||
| #include "hccl/hccl_types.h" | |||
| #define MS_API __attribute__((visibility("default"))) | |||
| namespace mindspore::hccl { | |||
| struct MS_API HcclTaskInfo { | |||
| struct HcclTaskInfo { | |||
| std::string private_def; | |||
| int64_t workspace_size; | |||
| int64_t stream_num; | |||
| }; | |||
| MS_API bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file); | |||
| MS_API bool FinalizeHccl(); | |||
| MS_API bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists); | |||
| MS_API bool CalcOpRunningParam(const AnfNodePtr &node); | |||
| MS_API void *GetHcclOpsKernelInfoStore(); | |||
| MS_API std::string GetHcclType(const AnfNodePtr &node); | |||
| bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file); | |||
| bool FinalizeHccl(); | |||
| bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists); | |||
| int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype); | |||
| void *GetHcclOpsKernelInfoStore(); | |||
| std::string GetHcclType(const AnfNodePtr &node); | |||
| } // namespace mindspore::hccl | |||
| #undef MS_API | |||
| #endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H | |||
| @@ -64,7 +64,7 @@ namespace hccl { | |||
| bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; } | |||
| bool FinalizeHccl() { return true; } | |||
| bool GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) { return true; } | |||
| bool CalcOpRunningParam(const AnfNodePtr &) { return true; } | |||
| int64_t CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) { return 0; } | |||
| void *GetHcclOpsKernelInfoStore() { return nullptr; } | |||
| std::string GetHcclType(const AnfNodePtr &) { return ""; } | |||
| } // namespace hccl | |||