|
|
|
@@ -17,13 +17,27 @@ |
|
|
|
#include "backend/kernel_compiler/hccl/hcom_all_broadcast.h" |
|
|
|
#include <memory> |
|
|
|
#include "utils/ms_context.h" |
|
|
|
#include "backend/kernel_compiler/hccl/hccl_context.h" |
|
|
|
#include "external/hccl/hccl.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> & /*inputs*/, |
|
|
|
bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs, |
|
|
|
const std::vector<AddressPtr> & /*workspace*/, |
|
|
|
const std::vector<AddressPtr> & /*outputs*/, void * /*stream_ptr*/) { |
|
|
|
MS_LOG(INFO) << "HcomAllBroadCast launch"; |
|
|
|
const std::vector<AddressPtr> & /*outputs*/, void *stream_ptr) { |
|
|
|
if (inputs.empty() || hccl_data_type_list_.empty()) { |
|
|
|
MS_LOG(ERROR) << "BroadCast param is empty"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(inputs[0]); |
|
|
|
MS_EXCEPTION_IF_NULL(stream_ptr); |
|
|
|
auto hccl_result = HcclBroadcast(inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], root_id_, |
|
|
|
HcclContext::GetInstance().hccl_comm(), stream_ptr); |
|
|
|
|
|
|
|
if (hccl_result != HCCL_SUCCESS) { |
|
|
|
MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << hccl_result; |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
} // namespace kernel |
|
|
|
|