Merge pull request !22386 from zhoufeng/fix-neighbor-empty-input-baktags/v1.5.0-rc1
| @@ -39,7 +39,7 @@ std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { | |||
| if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) { | |||
| return kOpFormat_DEFAULT; | |||
| } | |||
| if (op_name == kReceive || op_name == kHcomSend) { | |||
| if (op_name == kReceive || op_name == kHcomSend || op_name == kAllToAllv) { | |||
| return kOpFormat_DEFAULT; | |||
| } | |||
| auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); | |||
| @@ -14,6 +14,9 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/hccl/hcom_all_to_all.h" | |||
| #include "runtime/hccl_adapter/hccl_adapter.h" | |||
| #include "runtime/device/ascend/ge_runtime/task_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| namespace mindspore::kernel { | |||
| HcomAllToAllKernel::HcomAllToAllKernel() {} | |||
| @@ -25,5 +28,87 @@ bool HcomAllToAllKernel::Launch(const std::vector<AddressPtr> &, const std::vect | |||
| return true; | |||
| } | |||
| bool HcomAllToAllKernel::Init(const AnfNodePtr &anf_node) { | |||
| bool ret = HcclKernel::Init(anf_node); | |||
| if (!ret) { | |||
| return ret; | |||
| } | |||
| if (hccl_data_type_list_.empty()) { | |||
| auto recv_type = AnfAlgo::GetNodeAttr<TypePtr>(anf_node, kAttrRecvType); | |||
| MS_EXCEPTION_IF_NULL(recv_type); | |||
| data_type_ = HcomUtil::ConvertHcclType(recv_type->type_id()); | |||
| } else { | |||
| data_type_ = hccl_data_type_list_[0]; | |||
| } | |||
| workspace_size_list_ = {LongToSize(hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node, data_type_))}; | |||
| return true; | |||
| } | |||
| const std::vector<size_t> &HcomAllToAllKernel::GetOutputSizeList() const { | |||
| if (!output_size_list_.empty()) { | |||
| return output_size_list_; | |||
| } | |||
| for (size_t i = 0; i < hccl_kernel_output_shape_list_.size(); ++i) { | |||
| size_t size = 0; | |||
| if (!HcomUtil::GetHcclOpSize(data_type_, hccl_kernel_output_shape_list_[i], &size)) { | |||
| MS_LOG(EXCEPTION) << "AllToAllv get output size failed."; | |||
| } | |||
| output_size_list_.push_back(size); | |||
| } | |||
| return output_size_list_; | |||
| } | |||
| std::vector<TaskInfoPtr> HcomAllToAllKernel::GenTask(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uint32_t stream_id) { | |||
| auto anf_node = anf_node_.lock(); | |||
| if (!anf_node) { | |||
| MS_LOG(EXCEPTION) << "anf_node pointer is expired."; | |||
| } | |||
| stream_id_ = stream_id; | |||
| void *input_data_addr = inputs.empty() ? nullptr : inputs.at(0)->addr; | |||
| void *output_data_addr = outputs.empty() ? nullptr : outputs.at(0)->addr; | |||
| std::vector<uint8_t> private_def; | |||
| std::vector<hccl::HcclTaskInfo> task_info; | |||
| bool ret = hccl::HcclAdapter::GetInstance().GenTask(anf_node, data_type_, &task_info); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Gen Task for " << anf_node->DebugString() << " failed."; | |||
| } | |||
| std::vector<TaskInfoPtr> results; | |||
| for (auto &task : task_info) { | |||
| MS_LOG(INFO) << "AlltoAll Task : stream_id=" << stream_id << ", count=" << hccl_count_ << ", root_id=" << root_id_ | |||
| << ", op_type=" << static_cast<int>(op_type_) << ", data_type=" << static_cast<int>(data_type_) | |||
| << ", workspace_size=" << task.workspace_size << ", stream_num=" << task.stream_num | |||
| << ", private_def_size=" << task.private_def.size(); | |||
| private_def.resize(task.private_def.size()); | |||
| auto sec_ret = memcpy_s(private_def.data(), private_def.size(), task.private_def.data(), task.private_def.size()); | |||
| if (sec_ret != 0) { | |||
| MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret; | |||
| } | |||
| void *workspace_addr = nullptr; | |||
| if (task.workspace_size != 0) { | |||
| if (workspace.empty()) { | |||
| MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node->DebugString() << " is empty"; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(workspace.at(0)); | |||
| workspace_addr = workspace.at(0)->addr; | |||
| } | |||
| results.emplace_back(std::make_shared<ge::model_runner::HcclTaskInfo>( | |||
| unique_name_, stream_id, hccl::HcclAdapter::GetHcclType(anf_node), input_data_addr, output_data_addr, | |||
| workspace_addr, task.workspace_size, task.stream_num, private_def, | |||
| hccl::HcclAdapter::GetInstance().GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type_, group_, | |||
| NeedDump())); | |||
| } | |||
| return results; | |||
| } | |||
| MS_HCCL_REG_KERNEL(AllToAllv, HcomAllToAllKernel); | |||
| } // namespace mindspore::kernel | |||
| @@ -26,8 +26,15 @@ class HcomAllToAllKernel : public HcclKernel { | |||
| public: | |||
| HcomAllToAllKernel(); | |||
| ~HcomAllToAllKernel() override; | |||
| bool Init(const AnfNodePtr &anf_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, | |||
| void *) override; | |||
| const std::vector<size_t> &GetOutputSizeList() const override; | |||
| std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uint32_t stream_id) override; | |||
| private: | |||
| HcclDataType data_type_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_TO_ALL_H_ | |||
| @@ -54,6 +54,14 @@ bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<si | |||
| return true; | |||
| } | |||
| ::HcclDataType HcomUtil::ConvertHcclType(TypeId type_id) { | |||
| auto iter = kConstOpHcomDataTypeMap.find(type_id); | |||
| if (iter == kConstOpHcomDataTypeMap.end()) { | |||
| MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_id; | |||
| } | |||
| return iter->second; | |||
| } | |||
| bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(data_type_list); | |||
| @@ -69,17 +77,14 @@ bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> | |||
| } else { | |||
| type_ptr = AnfAlgo::GetInputDeviceDataType(anf_node, i); | |||
| } | |||
| auto iter = kConstOpHcomDataTypeMap.find(type_ptr); | |||
| if (iter == kConstOpHcomDataTypeMap.end()) { | |||
| MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_ptr; | |||
| } | |||
| data_type_list->emplace_back(iter->second); | |||
| data_type_list->emplace_back(ConvertHcclType(type_ptr)); | |||
| } | |||
| auto type_base = *(std::begin(*data_type_list)); | |||
| if (std::any_of(data_type_list->begin(), data_type_list->end(), | |||
| [&type_base](HcclDataType type) { return type != type_base; })) { | |||
| MS_LOG(ERROR) << "hccl have different data type"; | |||
| return false; | |||
| if (!data_type_list->empty()) { | |||
| if (std::any_of(data_type_list->begin(), data_type_list->end(), | |||
| [&data_type_list](HcclDataType type) { return type != *(data_type_list->begin()); })) { | |||
| MS_LOG(ERROR) << "hccl have different data type"; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| @@ -59,6 +59,7 @@ class HcomUtil { | |||
| public: | |||
| static bool GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list); | |||
| static bool GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list); | |||
| static ::HcclDataType ConvertHcclType(TypeId type_id); | |||
| static bool GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list); | |||
| static bool GetHcclOpSize(const HcclDataType &data_type, const vector<size_t> &shape, size_t *size); | |||
| static bool GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size); | |||
| @@ -131,7 +131,7 @@ void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &in | |||
| } | |||
| } | |||
| abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto recv_shapes = primitive->GetAttr(kRecvShapes); | |||
| MS_EXCEPTION_IF_NULL(recv_shapes); | |||
| @@ -147,15 +147,14 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec | |||
| MS_EXCEPTION_IF_NULL(base_shape); | |||
| base_shape_list.push_back(base_shape); | |||
| } | |||
| if (base_shape_list.empty()) { | |||
| return std::make_shared<abstract::Shape>(); | |||
| } | |||
| return std::make_shared<abstract::TupleShape>(base_shape_list); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| TypePtr InferType(const PrimitivePtr &primitive) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| (void)CheckAndConvertUtils::CheckInteger("NeighborExchange infer", SizeToLong(input_args.size()), kEqual, 1, | |||
| prim_name); | |||
| MS_EXCEPTION_IF_NULL(input_args[0]); | |||
| auto recv_shapes = primitive->GetAttr(kRecvShapes); | |||
| MS_EXCEPTION_IF_NULL(recv_shapes); | |||
| auto shapes_seq = recv_shapes->cast<ValueSequeuePtr>(); | |||
| @@ -165,14 +164,17 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP | |||
| auto recv_type = primitive->GetAttr(kRecvType)->cast<TypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(recv_type); | |||
| std::vector<TypePtr> type_vec(out_num, recv_type); | |||
| if (type_vec.empty()) { | |||
| return std::make_shared<TypeNone>(); | |||
| } | |||
| return std::make_shared<Tuple>(type_vec); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| Check(primitive, input_args); | |||
| auto type = InferType(primitive, input_args); | |||
| auto shape = InferShape(primitive, input_args); | |||
| auto type = InferType(primitive); | |||
| auto shape = InferShape(primitive); | |||
| return abstract::MakeAbstract(shape, type); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true); | |||
| @@ -91,7 +91,51 @@ def test_NeighborExchange_single_input_success(): | |||
| compile_net(net) | |||
| def test_NeighborExchage_empty_send_empty_recv_success(): | |||
| def test_NeighborExchange_empty_send_success(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: empty inputs, with valid arguments | |||
| Expectation: success | |||
| """ | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[1], recv_shapes=([1],), | |||
| send_shapes=(), recv_type=ms.float32) | |||
| def construct(self, x1): | |||
| self.alltoallv() | |||
| return x1 | |||
| net = Net() | |||
| _executor.compile(net, _x1) | |||
| def test_NeighborExchange_empty_recv_success(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: empty outputs, with valid arguments | |||
| Expectation: success | |||
| """ | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[], recv_shapes=(), | |||
| send_shapes=([32, 16],), recv_type=ms.float32) | |||
| def construct(self, x1): | |||
| self.alltoallv((x1,)) | |||
| return x1 | |||
| net = Net() | |||
| _executor.compile(net, _x1) | |||
| def test_NeighborExchange_empty_send_empty_recv_success(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: empty inputs and empty outputs, with valid arguments | |||
| @@ -102,20 +146,18 @@ def test_NeighborExchage_empty_send_empty_recv_success(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[], | |||
| recv_shapes=(), | |||
| send_shapes=(), recv_type=ms.float32, group=("str",)) | |||
| self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[], recv_shapes=(), | |||
| send_shapes=(), recv_type=ms.float32) | |||
| def construct(self, x1): | |||
| self.alltoallv() | |||
| return x1 | |||
| net = Net() | |||
| with pytest.raises(TypeError): | |||
| _executor.compile(net, _x1) | |||
| _executor.compile(net, _x1) | |||
| def test_NeighborExchage_recv_shape_num_diff_with_recv_rank_size_failed(): | |||
| def test_NeighborExchange_recv_shape_num_diff_with_recv_rank_size_failed(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: send_rank_ids and send_shapes are set as 1 input, but gives 2 | |||
| @@ -143,7 +185,7 @@ def test_NeighborExchage_recv_shape_num_diff_with_recv_rank_size_failed(): | |||
| compile_net(net) | |||
| def test_NeighborExchage_send_shape_num_diff_with_send_rank_size_failed(): | |||
| def test_NeighborExchange_send_shape_num_diff_with_send_rank_size_failed(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: send_rank_ids is set as 2 inputs, but send_shapes are set as 1 input | |||
| @@ -172,7 +214,7 @@ def test_NeighborExchage_send_shape_num_diff_with_send_rank_size_failed(): | |||
| compile_net(net) | |||
| def test_NeighborExchage_send_shape_num_diff_with_input_num_failed(): | |||
| def test_NeighborExchange_send_shape_num_diff_with_input_num_failed(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: send_rank_ids and send_shapes are set as 2 inputs, but has only 1 input | |||
| @@ -201,7 +243,7 @@ def test_NeighborExchage_send_shape_num_diff_with_input_num_failed(): | |||
| compile_net(net) | |||
| def test_NeighborExchage_send_shape_diff_with_input_shape_failed(): | |||
| def test_NeighborExchange_send_shape_diff_with_input_shape_failed(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: send_shapes is set as [16, 16], but input is [32, 32] | |||
| @@ -229,7 +271,7 @@ def test_NeighborExchage_send_shape_diff_with_input_shape_failed(): | |||
| compile_net(net) | |||
| def test_NeighborExchage_attr_check_send_rank_ids_is_tuple_failed(): | |||
| def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_failed(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: send_rank_ids should be list, but a tuple is given | |||
| @@ -252,7 +294,7 @@ def test_NeighborExchage_attr_check_send_rank_ids_is_tuple_failed(): | |||
| _executor.compile(net, _x1) | |||
| def test_NeighborExchage_attr_check_send_rank_ids_is_float_failed(): | |||
| def test_NeighborExchange_attr_check_send_rank_ids_is_float_failed(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: send_rank_ids should be int, but a float is given | |||
| @@ -276,7 +318,7 @@ def test_NeighborExchage_attr_check_send_rank_ids_is_float_failed(): | |||
| _executor.compile(net, _x1) | |||
| def test_NeighborExchage_attr_check_recv_rank_ids_is_tuple_failed(): | |||
| def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_failed(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: recv_rank_ids should be list, but a tuple is given | |||
| @@ -300,7 +342,7 @@ def test_NeighborExchage_attr_check_recv_rank_ids_is_tuple_failed(): | |||
| _executor.compile(net, _x1) | |||
| def test_NeighborExchage_attr_check_recv_rank_ids_is_float_failed(): | |||
| def test_NeighborExchange_attr_check_recv_rank_ids_is_float_failed(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: recv_rank_ids should be int, but a float is given | |||
| @@ -324,7 +366,7 @@ def test_NeighborExchage_attr_check_recv_rank_ids_is_float_failed(): | |||
| _executor.compile(net, _x1) | |||
| def test_NeighborExchage_attr_check_send_shape_not_tuple_failed(): | |||
| def test_NeighborExchange_attr_check_send_shape_not_tuple_failed(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: send_shapes should be tuple(list), but a list is given | |||
| @@ -348,7 +390,7 @@ def test_NeighborExchage_attr_check_send_shape_not_tuple_failed(): | |||
| _executor.compile(net, _x1) | |||
| def test_NeighborExchage_attr_check_recv_type_numpy_failed(): | |||
| def test_NeighborExchange_attr_check_recv_type_numpy_failed(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: recv_type should be mindspore type, but a numpy type is given | |||
| @@ -372,7 +414,7 @@ def test_NeighborExchage_attr_check_recv_type_numpy_failed(): | |||
| _executor.compile(net, _x1) | |||
| def test_NeighborExchage_attr_invalid_grpup_failed(): | |||
| def test_NeighborExchange_attr_invalid_grpup_failed(): | |||
| """ | |||
| Feature: NeighborExchange | |||
| Description: group should be str, but a tuple is given | |||