From 065ef12ec9413d1984a1c364282821ac62480bc7 Mon Sep 17 00:00:00 2001 From: caifubi Date: Thu, 17 Dec 2020 16:13:57 +0800 Subject: [PATCH] PyNative support Broadcast --- .../hccl/hcom_all_broadcast.cc | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc index 1fb30c5190..6c8eb8c6b5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc @@ -17,13 +17,27 @@ #include "backend/kernel_compiler/hccl/hcom_all_broadcast.h" #include #include "utils/ms_context.h" +#include "backend/kernel_compiler/hccl/hccl_context.h" +#include "external/hccl/hccl.h" namespace mindspore { namespace kernel { -bool HcomAllBroadCastKernel::Launch(const std::vector & /*inputs*/, +bool HcomAllBroadCastKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void * /*stream_ptr*/) { - MS_LOG(INFO) << "HcomAllBroadCast launch"; + const std::vector & /*outputs*/, void *stream_ptr) { + if (inputs.empty() || hccl_data_type_list_.empty()) { + MS_LOG(ERROR) << "BroadCast param is empty"; + return false; + } + MS_EXCEPTION_IF_NULL(inputs[0]); + MS_EXCEPTION_IF_NULL(stream_ptr); + auto hccl_result = HcclBroadcast(inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], root_id_, + HcclContext::GetInstance().hccl_comm(), stream_ptr); + + if (hccl_result != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << hccl_result; + return false; + } return true; } } // namespace kernel