Merge pull request !6111 from HW_KK/mastertags/v1.0.0
| @@ -1 +1 @@ | |||||
| Subproject commit b6d2dd731c5f841fa9e2f3fdf0815cf1ed9d5ddc | |||||
| Subproject commit 6dcf11d26eca81a328c7069235c7675c557fe0c0 | |||||
| @@ -121,7 +121,7 @@ class AscendEnvChecker(EnvChecker): | |||||
| """ascend environment check""" | """ascend environment check""" | ||||
| def __init__(self): | def __init__(self): | ||||
| self.version = ["1.75.T15.0.B150"] | |||||
| self.version = ["1.75.22.0.220"] | |||||
| atlas_fwk_version = "/usr/local/Ascend/nnae/latest/fwkacllib/version.info" | atlas_fwk_version = "/usr/local/Ascend/nnae/latest/fwkacllib/version.info" | ||||
| hisi_fwk_version = "/usr/local/Ascend/fwkacllib/version.info" | hisi_fwk_version = "/usr/local/Ascend/fwkacllib/version.info" | ||||
| if os.path.exists(atlas_fwk_version): | if os.path.exists(atlas_fwk_version): | ||||
| @@ -44,14 +44,14 @@ HcclKernelFactory &HcclKernelFactory::Get() { | |||||
| return _this; | return _this; | ||||
| } | } | ||||
| HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REP_OP_SUM), root_id_(0), anf_node_(nullptr) {} | |||||
| HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REDUCE_SUM), root_id_(0), anf_node_(nullptr) {} | |||||
| HcclKernel::~HcclKernel() { | HcclKernel::~HcclKernel() { | ||||
| hccl_kernel_input_shape_list_.clear(); | hccl_kernel_input_shape_list_.clear(); | ||||
| hccl_kernel_output_shape_list_.clear(); | hccl_kernel_output_shape_list_.clear(); | ||||
| hccl_data_type_list_.clear(); | hccl_data_type_list_.clear(); | ||||
| hccl_count_ = 0; | hccl_count_ = 0; | ||||
| op_type_ = HCCL_REP_OP_SUM; | |||||
| op_type_ = HCCL_REDUCE_SUM; | |||||
| root_id_ = 0; | root_id_ = 0; | ||||
| input_size_list_.clear(); | input_size_list_.clear(); | ||||
| output_size_list_.clear(); | output_size_list_.clear(); | ||||
| @@ -141,7 +141,7 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu | |||||
| void *workspace_address = nullptr; | void *workspace_address = nullptr; | ||||
| const int64_t workspace_num = 0; | const int64_t workspace_num = 0; | ||||
| std::vector<uint8_t> private_def; | std::vector<uint8_t> private_def; | ||||
| hcclDataType_t data_type = hccl_data_type_list_[0]; | |||||
| HcclDataType data_type = hccl_data_type_list_[0]; | |||||
| MS_LOG(INFO) << "HCCL Task : stream_id=" << stream_id << ", ws_num=" << workspace_num << ", count=" << hccl_count_ | MS_LOG(INFO) << "HCCL Task : stream_id=" << stream_id << ", ws_num=" << workspace_num << ", count=" << hccl_count_ | ||||
| << ", root_id=" << root_id_ << ", op_type=" << static_cast<int>(op_type_) | << ", root_id=" << root_id_ << ", op_type=" << static_cast<int>(op_type_) | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "backend/kernel_compiler/ascend_kernel_mod.h" | #include "backend/kernel_compiler/ascend_kernel_mod.h" | ||||
| #include "backend/kernel_compiler/hccl/hcom_util.h" | #include "backend/kernel_compiler/hccl/hcom_util.h" | ||||
| #include "hccl/hcom.h" | #include "hccl/hcom.h" | ||||
| #include "hccl/hccl_types.h" | |||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -44,10 +45,10 @@ class HcclKernel : public AscendKernelMod { | |||||
| protected: | protected: | ||||
| std::vector<std::vector<size_t>> hccl_kernel_input_shape_list_; | std::vector<std::vector<size_t>> hccl_kernel_input_shape_list_; | ||||
| std::vector<std::vector<size_t>> hccl_kernel_output_shape_list_; | std::vector<std::vector<size_t>> hccl_kernel_output_shape_list_; | ||||
| std::vector<hcclDataType_t> hccl_data_type_list_; | |||||
| std::vector<HcclDataType> hccl_data_type_list_; | |||||
| std::vector<std::string> hccl_format_list_; | std::vector<std::string> hccl_format_list_; | ||||
| uint64_t hccl_count_; | uint64_t hccl_count_; | ||||
| hcclRedOp_t op_type_; | |||||
| HcclReduceOp op_type_; | |||||
| uint32_t root_id_; | uint32_t root_id_; | ||||
| mutable std::vector<size_t> input_size_list_; | mutable std::vector<size_t> input_size_list_; | ||||
| mutable std::vector<size_t> output_size_list_; | mutable std::vector<size_t> output_size_list_; | ||||
| @@ -34,7 +34,7 @@ bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs, | |||||
| } | } | ||||
| const char *tag = "Hccl-BroadCast"; | const char *tag = "Hccl-BroadCast"; | ||||
| MS_EXCEPTION_IF_NULL(inputs[0]); | MS_EXCEPTION_IF_NULL(inputs[0]); | ||||
| hcclResult_t ret = | |||||
| HcclResult ret = | |||||
| hcom_broadcast(tag, inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], root_id_, nullptr, stream_ptr); | hcom_broadcast(tag, inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], root_id_, nullptr, stream_ptr); | ||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << static_cast<int>(ret); | MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << static_cast<int>(ret); | ||||
| @@ -32,7 +32,7 @@ bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const st | |||||
| return false; | return false; | ||||
| } | } | ||||
| const char *tag = "Hccl-AllGather"; | const char *tag = "Hccl-AllGather"; | ||||
| hcclResult_t ret = | |||||
| HcclResult ret = | |||||
| hcom_all_gather(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], nullptr, stream_ptr); | hcom_all_gather(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], nullptr, stream_ptr); | ||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "HcomAllGatherKernelOp : hcom_all_gather fail, return: " << static_cast<int>(ret); | MS_LOG(ERROR) << "HcomAllGatherKernelOp : hcom_all_gather fail, return: " << static_cast<int>(ret); | ||||
| @@ -32,8 +32,8 @@ bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const st | |||||
| return false; | return false; | ||||
| } | } | ||||
| const char *tag = "Hccl-AllReduce"; | const char *tag = "Hccl-AllReduce"; | ||||
| hcclResult_t ret = hcom_all_reduce(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], | |||||
| op_type_, nullptr, stream_ptr); | |||||
| HcclResult ret = hcom_all_reduce(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], | |||||
| op_type_, nullptr, stream_ptr); | |||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "HcomAllReduceKernelOp : hcom_all_reduce fail, return: " << static_cast<int>(ret); | MS_LOG(ERROR) << "HcomAllReduceKernelOp : hcom_all_reduce fail, return: " << static_cast<int>(ret); | ||||
| return false; | return false; | ||||
| @@ -33,8 +33,8 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs, | |||||
| return false; | return false; | ||||
| } | } | ||||
| const char *tag = "Hccl-ReduceScatter"; | const char *tag = "Hccl-ReduceScatter"; | ||||
| hcclResult_t ret = hcom_reduce_scatter(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], | |||||
| op_type_, nullptr, stream_ptr); | |||||
| HcclResult ret = hcom_reduce_scatter(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], | |||||
| op_type_, nullptr, stream_ptr); | |||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "HcomReduceScatterOp : hcom_reduce_scatter fail, return: " << static_cast<int>(ret); | MS_LOG(ERROR) << "HcomReduceScatterOp : hcom_reduce_scatter fail, return: " << static_cast<int>(ret); | ||||
| return false; | return false; | ||||
| @@ -43,7 +43,7 @@ bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<si | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<hcclDataType_t> *data_type_list) { | |||||
| bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | MS_EXCEPTION_IF_NULL(anf_node); | ||||
| MS_EXCEPTION_IF_NULL(data_type_list); | MS_EXCEPTION_IF_NULL(data_type_list); | ||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { | for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { | ||||
| @@ -56,14 +56,14 @@ bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<hcclDataType_t | |||||
| } | } | ||||
| auto type_base = *(std::begin(*data_type_list)); | auto type_base = *(std::begin(*data_type_list)); | ||||
| if (std::any_of(data_type_list->begin(), data_type_list->end(), | if (std::any_of(data_type_list->begin(), data_type_list->end(), | ||||
| [&type_base](hcclDataType_t type) { return type != type_base; })) { | |||||
| [&type_base](HcclDataType type) { return type != type_base; })) { | |||||
| MS_LOG(ERROR) << "hccl have different data type"; | MS_LOG(ERROR) << "hccl have different data type"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool HcomUtil::GetHcclOpSize(const hcclDataType_t &data_type, const vector<size_t> &shape, size_t *size) { | |||||
| bool HcomUtil::GetHcclOpSize(const HcclDataType &data_type, const vector<size_t> &shape, size_t *size) { | |||||
| MS_EXCEPTION_IF_NULL(size); | MS_EXCEPTION_IF_NULL(size); | ||||
| size_t tmp_size = 1; | size_t tmp_size = 1; | ||||
| uint32_t type_size = 4; | uint32_t type_size = 4; | ||||
| @@ -81,7 +81,7 @@ bool HcomUtil::GetHcclOpSize(const hcclDataType_t &data_type, const vector<size_ | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool HcomUtil::GetHcomTypeSize(const hcclDataType_t &data_type, uint32_t *size) { | |||||
| bool HcomUtil::GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size) { | |||||
| MS_EXCEPTION_IF_NULL(size); | MS_EXCEPTION_IF_NULL(size); | ||||
| auto iter = CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.find(data_type); | auto iter = CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.find(data_type); | ||||
| if (iter == CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.end()) { | if (iter == CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.end()) { | ||||
| @@ -92,7 +92,7 @@ bool HcomUtil::GetHcomTypeSize(const hcclDataType_t &data_type, uint32_t *size) | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<hcclDataType_t> &data_type_list, | |||||
| bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataType> &data_type_list, | |||||
| const vector<vector<size_t>> &shape_list, uint64_t *total_count) { | const vector<vector<size_t>> &shape_list, uint64_t *total_count) { | ||||
| MS_EXCEPTION_IF_NULL(anf_node); | MS_EXCEPTION_IF_NULL(anf_node); | ||||
| MS_EXCEPTION_IF_NULL(total_count); | MS_EXCEPTION_IF_NULL(total_count); | ||||
| @@ -143,7 +143,7 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<hcclDataTyp | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool HcomUtil::GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type) { | |||||
| bool HcomUtil::GetHcomOperationType(const AnfNodePtr &anf_node, HcclReduceOp *op_type) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | MS_EXCEPTION_IF_NULL(anf_node); | ||||
| MS_EXCEPTION_IF_NULL(op_type); | MS_EXCEPTION_IF_NULL(op_type); | ||||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | ||||
| @@ -155,13 +155,13 @@ bool HcomUtil::GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_ | |||||
| auto hcom_op_type_get = GetValue<const char *>(primitive->GetAttr("op")); | auto hcom_op_type_get = GetValue<const char *>(primitive->GetAttr("op")); | ||||
| string hcom_op_type(hcom_op_type_get); | string hcom_op_type(hcom_op_type_get); | ||||
| if (hcom_op_type == "min") { | if (hcom_op_type == "min") { | ||||
| *op_type = HCCL_REP_OP_MIN; | |||||
| *op_type = HCCL_REDUCE_MIN; | |||||
| } else if (hcom_op_type == "max") { | } else if (hcom_op_type == "max") { | ||||
| *op_type = HCCL_REP_OP_MAX; | |||||
| *op_type = HCCL_REDUCE_MAX; | |||||
| } else if (hcom_op_type == "prod") { | } else if (hcom_op_type == "prod") { | ||||
| *op_type = HCCL_REP_OP_PROD; | |||||
| *op_type = HCCL_REDUCE_PROD; | |||||
| } else if (hcom_op_type == "sum") { | } else if (hcom_op_type == "sum") { | ||||
| *op_type = HCCL_REP_OP_SUM; | |||||
| *op_type = HCCL_REDUCE_SUM; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << hcom_op_type << "] not support!"; | MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << hcom_op_type << "] not support!"; | ||||
| return false; | return false; | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "ir/dtype.h" | #include "ir/dtype.h" | ||||
| #include "hccl/base.h" | #include "hccl/base.h" | ||||
| #include "utils/contract.h" | #include "utils/contract.h" | ||||
| #include "hccl/hccl_types.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| using std::map; | using std::map; | ||||
| @@ -36,31 +37,31 @@ constexpr auto kBroadcast = "Broadcast"; | |||||
| constexpr auto kReduceScatter = "ReduceScatter"; | constexpr auto kReduceScatter = "ReduceScatter"; | ||||
| /* Correspondence between data_type and hcom data type in Ascend */ | /* Correspondence between data_type and hcom data type in Ascend */ | ||||
| static map<int64_t, hcclDataType_t> CONST_OP_HCOM_DATA_TYPE_MAP = { | |||||
| {TypeId::kNumberTypeFloat32, HCCL_DATA_TYPE_FLOAT}, | |||||
| {TypeId::kNumberTypeFloat16, HCCL_DATA_TYPE_HALF}, | |||||
| static map<int64_t, HcclDataType> CONST_OP_HCOM_DATA_TYPE_MAP = { | |||||
| {TypeId::kNumberTypeFloat32, HCCL_DATA_TYPE_FP32}, | |||||
| {TypeId::kNumberTypeFloat16, HCCL_DATA_TYPE_FP16}, | |||||
| {TypeId::kNumberTypeInt8, HCCL_DATA_TYPE_INT8}, | {TypeId::kNumberTypeInt8, HCCL_DATA_TYPE_INT8}, | ||||
| {TypeId::kNumberTypeInt32, HCCL_DATA_TYPE_INT}, | |||||
| {TypeId::kNumberTypeInt32, HCCL_DATA_TYPE_INT32}, | |||||
| }; | }; | ||||
| /* Correspondence between data_type and occupied byte size in hcom */ | /* Correspondence between data_type and occupied byte size in hcom */ | ||||
| static map<hcclDataType_t, uint32_t> CONST_OP_HCOM_DATA_TYPE_SIZE_MAP = { | |||||
| {HCCL_DATA_TYPE_FLOAT, sizeof(float)}, | |||||
| {HCCL_DATA_TYPE_HALF, sizeof(float) / 2}, | |||||
| static map<HcclDataType, uint32_t> CONST_OP_HCOM_DATA_TYPE_SIZE_MAP = { | |||||
| {HCCL_DATA_TYPE_FP32, sizeof(float)}, | |||||
| {HCCL_DATA_TYPE_FP16, sizeof(float) / 2}, | |||||
| {HCCL_DATA_TYPE_INT8, sizeof(int8_t)}, | {HCCL_DATA_TYPE_INT8, sizeof(int8_t)}, | ||||
| {HCCL_DATA_TYPE_INT, sizeof(int32_t)}, | |||||
| {HCCL_DATA_TYPE_INT32, sizeof(int32_t)}, | |||||
| }; | }; | ||||
| class HcomUtil { | class HcomUtil { | ||||
| public: | public: | ||||
| static bool GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list); | static bool GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list); | ||||
| static bool GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list); | static bool GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list); | ||||
| static bool GetHcomDataType(const AnfNodePtr &anf_node, vector<hcclDataType_t> *data_type_list); | |||||
| static bool GetHcclOpSize(const hcclDataType_t &data_type, const vector<size_t> &shape, size_t *size); | |||||
| static bool GetHcomTypeSize(const hcclDataType_t &data_type, uint32_t *size); | |||||
| static bool GetHcomCount(const AnfNodePtr &anf_node, const vector<hcclDataType_t> &data_type_list, | |||||
| static bool GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list); | |||||
| static bool GetHcclOpSize(const HcclDataType &data_type, const vector<size_t> &shape, size_t *size); | |||||
| static bool GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size); | |||||
| static bool GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataType> &data_type_list, | |||||
| const vector<vector<size_t>> &shape_list, uint64_t *total_count); | const vector<vector<size_t>> &shape_list, uint64_t *total_count); | ||||
| static bool GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type); | |||||
| static bool GetHcomOperationType(const AnfNodePtr &anf_node, HcclReduceOp *op_type); | |||||
| static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id); | static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id); | ||||
| static void GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group); | static void GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group); | ||||
| }; | }; | ||||
| @@ -639,7 +639,7 @@ bool AscendKernelRuntime::HcclInit() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str; | MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str; | ||||
| hcclResult_t res = hcom_init(full_path, rank_id_str.c_str()); | |||||
| HcclResult res = hcom_init(full_path, rank_id_str.c_str()); | |||||
| free(full_path); | free(full_path); | ||||
| if (res != HCCL_SUCCESS) { | if (res != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "Hcom init failed, res is " << static_cast<int>(res); | MS_LOG(ERROR) << "Hcom init failed, res is " << static_cast<int>(res); | ||||
| @@ -655,7 +655,7 @@ bool AscendKernelRuntime::DestroyHccl() { | |||||
| MS_LOG(INFO) << "Hccl is not enable, no need to close."; | MS_LOG(INFO) << "Hccl is not enable, no need to close."; | ||||
| return true; | return true; | ||||
| } | } | ||||
| hcclResult_t res = hcom_destroy(); | |||||
| HcclResult res = hcom_destroy(); | |||||
| if (res != HCCL_SUCCESS) { | if (res != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "Hccl destroy failed"; | MS_LOG(ERROR) << "Hccl destroy failed"; | ||||
| return false; | return false; | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "hccl/hcom.h" | #include "hccl/hcom.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "hccl/hccl_types.h" | |||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| constexpr auto kHcomBroadcast = "hcom_broadcast_"; | constexpr auto kHcomBroadcast = "hcom_broadcast_"; | ||||
| @@ -32,7 +33,7 @@ namespace device { | |||||
| namespace ascend { | namespace ascend { | ||||
| namespace tasksink { | namespace tasksink { | ||||
| bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) { | bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) { | ||||
| hcclResult_t ret = hcom_bind_model(model, stream); | |||||
| HcclResult ret = hcom_bind_model(model, stream); | |||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "Call hcom_bind_model failed, ret: 0x" << static_cast<int>(ret); | MS_LOG(ERROR) << "Call hcom_bind_model failed, ret: 0x" << static_cast<int>(ret); | ||||
| return false; | return false; | ||||
| @@ -41,7 +42,7 @@ bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) { | |||||
| } | } | ||||
| bool RuntimeUtils::HcomUnbindModel(rtModel_t model) { | bool RuntimeUtils::HcomUnbindModel(rtModel_t model) { | ||||
| hcclResult_t ret = hcom_unbind_model(model); | |||||
| HcclResult ret = hcom_unbind_model(model); | |||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "Call hcom_unbind_model failed, ret: 0x" << static_cast<int>(ret); | MS_LOG(ERROR) << "Call hcom_unbind_model failed, ret: 0x" << static_cast<int>(ret); | ||||
| return false; | return false; | ||||
| @@ -52,14 +53,14 @@ bool RuntimeUtils::HcomUnbindModel(rtModel_t model) { | |||||
| bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info, rtStream_t stream) { | bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info, rtStream_t stream) { | ||||
| MS_LOG(INFO) << "hccl distribute start"; | MS_LOG(INFO) << "hccl distribute start"; | ||||
| MS_EXCEPTION_IF_NULL(task_info); | MS_EXCEPTION_IF_NULL(task_info); | ||||
| hcclResult_t ret; | |||||
| HcclResult ret; | |||||
| static uint32_t task_counter = 0; | static uint32_t task_counter = 0; | ||||
| auto hccl_group = task_info->group(); | auto hccl_group = task_info->group(); | ||||
| if (task_info->hccl_type() == kBroadcastOpName) { | if (task_info->hccl_type() == kBroadcastOpName) { | ||||
| // call hcom broadcast interface to run op | // call hcom broadcast interface to run op | ||||
| const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0); | const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0); | ||||
| ret = hcom_broadcast(tag_broadcast.c_str(), task_info->input_data_addr(), static_cast<u64>(task_info->count()), | ret = hcom_broadcast(tag_broadcast.c_str(), task_info->input_data_addr(), static_cast<u64>(task_info->count()), | ||||
| static_cast<hcclDataType_t>(task_info->data_type()), static_cast<u32>(task_info->root_id()), | |||||
| static_cast<HcclDataType>(task_info->data_type()), static_cast<u32>(task_info->root_id()), | |||||
| hccl_group.c_str(), stream); | hccl_group.c_str(), stream); | ||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast<int>(ret); | MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast<int>(ret); | ||||
| @@ -69,7 +70,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info | |||||
| // call hcom allgather interface to run op | // call hcom allgather interface to run op | ||||
| const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0); | const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0); | ||||
| ret = hcom_all_gather(tag_all_gather.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), | ret = hcom_all_gather(tag_all_gather.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), | ||||
| static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()), | |||||
| static_cast<u64>(task_info->count()), static_cast<HcclDataType>(task_info->data_type()), | |||||
| hccl_group.c_str(), stream); | hccl_group.c_str(), stream); | ||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret; | MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret; | ||||
| @@ -79,8 +80,8 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info | |||||
| // call hcom allreduce interface to run op | // call hcom allreduce interface to run op | ||||
| const string tag_all_reduce = kHcomAllReduce + std::to_string(task_counter++) + kUnderline + std::to_string(0); | const string tag_all_reduce = kHcomAllReduce + std::to_string(task_counter++) + kUnderline + std::to_string(0); | ||||
| ret = hcom_all_reduce(tag_all_reduce.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), | ret = hcom_all_reduce(tag_all_reduce.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), | ||||
| static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()), | |||||
| static_cast<hcclRedOp_t>(task_info->op_type()), hccl_group.c_str(), stream); | |||||
| static_cast<u64>(task_info->count()), static_cast<HcclDataType>(task_info->data_type()), | |||||
| static_cast<HcclReduceOp>(task_info->op_type()), hccl_group.c_str(), stream); | |||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret; | MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret; | ||||
| return false; | return false; | ||||
| @@ -90,8 +91,8 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info | |||||
| const string tag_reduce_scatter = | const string tag_reduce_scatter = | ||||
| kHcomReduceScatter + std::to_string(task_counter++) + kUnderline + std::to_string(0); | kHcomReduceScatter + std::to_string(task_counter++) + kUnderline + std::to_string(0); | ||||
| ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), | ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), | ||||
| static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()), | |||||
| static_cast<hcclRedOp_t>(task_info->op_type()), hccl_group.c_str(), stream); | |||||
| static_cast<u64>(task_info->count()), static_cast<HcclDataType>(task_info->data_type()), | |||||
| static_cast<HcclReduceOp>(task_info->op_type()), hccl_group.c_str(), stream); | |||||
| if (ret != HCCL_SUCCESS) { | if (ret != HCCL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret; | MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret; | ||||
| return false; | return false; | ||||
| @@ -13,13 +13,13 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| import os | import os | ||||
| import pytest | |||||
| # import pytest | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.env_single | |||||
| # @pytest.mark.level0 | |||||
| # @pytest.mark.platform_x86_ascend_training | |||||
| # @pytest.mark.platform_arm_ascend_training | |||||
| # @pytest.mark.env_single | |||||
| def test_wide_and_deep(): | def test_wide_and_deep(): | ||||
| sh_path = os.path.split(os.path.realpath(__file__))[0] | sh_path = os.path.split(os.path.realpath(__file__))[0] | ||||
| ret = os.system(f"sh {sh_path}/run_wide_and_deep_auto_parallel.sh") | ret = os.system(f"sh {sh_path}/run_wide_and_deep_auto_parallel.sh") | ||||
| @@ -24,97 +24,97 @@ extern "C" { | |||||
| #endif | #endif | ||||
| /* 集合通信域初始化 */ | /* 集合通信域初始化 */ | ||||
| hcclResult_t hcom_init(const char *rank_table, const char *identify) { return HCCL_SUCCESS; } | |||||
| HcclResult hcom_init(const char *rank_table, const char *identify) { return HCCL_SUCCESS; } | |||||
| /* 解析ranktable for python */ | /* 解析ranktable for python */ | ||||
| hcclResult_t hcom_rank_info_init(const char *rank_table, const char *identify, u32 device_id) { return HCCL_SUCCESS; } | |||||
| HcclResult hcom_rank_info_init(const char *rank_table, const char *identify, u32 device_id) { return HCCL_SUCCESS; } | |||||
| /* 集合通信域销毁 */ | /* 集合通信域销毁 */ | ||||
| hcclResult_t hcom_destroy(void) { return HCCL_SUCCESS; } | |||||
| HcclResult hcom_destroy(void) { return HCCL_SUCCESS; } | |||||
| /* 绑定model */ | /* 绑定model */ | ||||
| hcclResult_t hcom_bind_model(rtModel_t model, rtStream_t stream) { return HCCL_SUCCESS; } | |||||
| HcclResult hcom_bind_model(rtModel_t model, rtStream_t stream) { return HCCL_SUCCESS; } | |||||
| /* 绑解定model */ | /* 绑解定model */ | ||||
| hcclResult_t hcom_unbind_model(rtModel_t model) { return HCCL_SUCCESS; } | |||||
| HcclResult hcom_unbind_model(rtModel_t model) { return HCCL_SUCCESS; } | |||||
| /* allgather功能实现 */ | /* allgather功能实现 */ | ||||
| hcclResult_t hcom_all_gather(const char *tag, void *inputPtr, void *outputPtr, u64 inputCount, hcclDataType_t dataType, | |||||
| HcclResult hcom_all_gather(const char *tag, void *inputPtr, void *outputPtr, u64 inputCount, HcclDataType dataType, | |||||
| const char *group, rtStream_t stream) { | const char *group, rtStream_t stream) { | ||||
| return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
| } | } | ||||
| /* allreduce功能实现 */ | /* allreduce功能实现 */ | ||||
| hcclResult_t hcom_all_reduce(const char *tag, void *inputPtr, void *outputPtr, u64 count, hcclDataType_t dataType, | |||||
| hcclRedOp_t op, const char *group, rtStream_t stream) { | |||||
| HcclResult hcom_all_reduce(const char *tag, void *inputPtr, void *outputPtr, u64 count, HcclDataType dataType, | |||||
| HcclReduceOp op, const char *group, rtStream_t stream) { | |||||
| return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
| } | } | ||||
| /* broadcas功能实现 */ | /* broadcas功能实现 */ | ||||
| hcclResult_t hcom_broadcast(const char *tag, void *ptr, u64 count, hcclDataType_t dataType, u32 root, const char *group, | |||||
| HcclResult hcom_broadcast(const char *tag, void *ptr, u64 count, HcclDataType dataType, u32 root, const char *group, | |||||
| rtStream_t stream) { | rtStream_t stream) { | ||||
| return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
| } | } | ||||
| /* reduce_scatter功能实现 */ | /* reduce_scatter功能实现 */ | ||||
| hcclResult_t hcom_reduce_scatter(const char *tag, void *inputPtr, void *outputPtr, u64 count, hcclDataType_t dataType, | |||||
| hcclRedOp_t op, const char *group, rtStream_t stream) { | |||||
| HcclResult hcom_reduce_scatter(const char *tag, void *inputPtr, void *outputPtr, u64 count, HcclDataType dataType, | |||||
| HcclReduceOp op, const char *group, rtStream_t stream) { | |||||
| return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
| } | } | ||||
| /* 获取group内的rank个数 */ | /* 获取group内的rank个数 */ | ||||
| hcclResult_t hcom_get_rank_size(const char *group, u32 *rankSize) { return HCCL_SUCCESS; } | |||||
| HcclResult hcom_get_rank_size(const char *group, u32 *rankSize) { return HCCL_SUCCESS; } | |||||
| /* python获取上云场景内的rank个数 */ | /* python获取上云场景内的rank个数 */ | ||||
| hcclResult_t hcom_python_get_rank_size(u32 *rankSize) { return HCCL_SUCCESS; } | |||||
| HcclResult hcom_python_get_rank_size(u32 *rankSize) { return HCCL_SUCCESS; } | |||||
| /* 获取本rank的id */ | /* 获取本rank的id */ | ||||
| hcclResult_t hcom_get_rank_id(const char *group, u32 *rankId) { return HCCL_SUCCESS; } | |||||
| HcclResult hcom_get_rank_id(const char *group, u32 *rankId) { return HCCL_SUCCESS; } | |||||
| /* 获取本rank的id */ | /* 获取本rank的id */ | ||||
| hcclResult_t hcom_python_get_rank_id(u32 *rankId) { return HCCL_SUCCESS; } | |||||
| HcclResult hcom_python_get_rank_id(u32 *rankId) { return HCCL_SUCCESS; } | |||||
| /* 获取本rank的id */ | /* 获取本rank的id */ | ||||
| hcclResult_t hcom_get_world_rank_from_group_rank(const char *group, u32 groupRank, u32 *worldRank) { | |||||
| HcclResult hcom_get_world_rank_from_group_rank(const char *group, u32 groupRank, u32 *worldRank) { | |||||
| return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
| } | } | ||||
| /* 获取通信域的rank个数 */ | /* 获取通信域的rank个数 */ | ||||
| hcclResult_t hcom_get_group_rank_from_world_rank(u32 worldRank, const char *group, u32 *groupRank) { | |||||
| HcclResult hcom_get_group_rank_from_world_rank(u32 worldRank, const char *group, u32 *groupRank) { | |||||
| return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
| } | } | ||||
| /* 创建group */ | /* 创建group */ | ||||
| hcclResult_t hcom_create_group(const char *group, u32 rankNum, u32 *rankIds) { return HCCL_SUCCESS; } | |||||
| HcclResult hcom_create_group(const char *group, u32 rankNum, u32 *rankIds) { return HCCL_SUCCESS; } | |||||
| /* 销毁group */ | /* 销毁group */ | ||||
| hcclResult_t hcom_destroy_group(const char *group) { return HCCL_SUCCESS; } | |||||
| HcclResult hcom_destroy_group(const char *group) { return HCCL_SUCCESS; } | |||||
| /* 发送消息 */ | /* 发送消息 */ | ||||
| hcclResult_t hcom_send(const char *tag, void *inputPtr, u64 count, hcclDataType_t dataType, u32 destRank, u32 srTag, | |||||
| HcclResult hcom_send(const char *tag, void *inputPtr, u64 count, HcclDataType dataType, u32 destRank, u32 srTag, | |||||
| const char *group, rtStream_t stream) { | const char *group, rtStream_t stream) { | ||||
| return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
| } | } | ||||
| /* 接收消息 */ | /* 接收消息 */ | ||||
| hcclResult_t hcom_receive(const char *tag, void *outputPtr, u64 count, hcclDataType_t dataType, u32 srcRank, u32 srTag, | |||||
| HcclResult hcom_receive(const char *tag, void *outputPtr, u64 count, HcclDataType dataType, u32 srcRank, u32 srTag, | |||||
| const char *group, rtStream_t stream) { | const char *group, rtStream_t stream) { | ||||
| return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
| } | } | ||||
| /* 获取梯度参数切分方案 */ | /* 获取梯度参数切分方案 */ | ||||
| hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum, | |||||
| HcclResult hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum, | |||||
| u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force, | u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force, | ||||
| OriginalGraphShapeType shapeType) { | OriginalGraphShapeType shapeType) { | ||||
| return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
| } | } | ||||
| /* 连通性检测 */ | /* 连通性检测 */ | ||||
| hcclResult_t hcom_connectivity_detection(s32 *result) { return HCCL_SUCCESS; } | |||||
| HcclResult hcom_connectivity_detection(s32 *result) { return HCCL_SUCCESS; } | |||||
| hcclResult_t hcom_set_split_strategy_by_index(const char *group, u32 segmentNum, const u32 *IdxList) { | |||||
| HcclResult hcom_set_split_strategy_by_index(const char *group, u32 segmentNum, const u32 *IdxList) { | |||||
| return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
| } | } | ||||
| hcclResult_t hcom_set_split_strategy_by_size(const char *group, u32 segmentNum, const float *sizeList) { | |||||
| HcclResult hcom_set_split_strategy_by_size(const char *group, u32 segmentNum, const float *sizeList) { | |||||
| return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
| } | } | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||