diff --git a/mindspore/ccsrc/runtime/hccl_adapter/converter.cc b/mindspore/ccsrc/runtime/hccl_adapter/converter.cc index 9ecefcd5d5..432093a641 100644 --- a/mindspore/ccsrc/runtime/hccl_adapter/converter.cc +++ b/mindspore/ccsrc/runtime/hccl_adapter/converter.cc @@ -51,7 +51,44 @@ static ge::DataType ConvertHcclDTypeToGeDType(HcclDataType datatype) { return iter->second; } +template +struct IsString { + // cppcheck-suppress unusedStructMember + static constexpr bool value = false; +}; + +template <> +struct IsString { + // cppcheck-suppress unusedStructMember + static constexpr bool value = true; +}; + namespace mindspore::hccl { +template +static T ConvertAttr(const CNodePtr &cnode, const ge::OpDescPtr &ge_op, const std::string &anf_attr_name, + const std::string &ge_attr_name) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(ge_op); + if (!AnfAlgo::HasNodeAttr(anf_attr_name, cnode)) { + MS_LOG(INFO) << "Node " << cnode->DebugString() << " has no attr " << anf_attr_name << ", skip."; + return T(); + } + + bool ret; + auto attr = AnfAlgo::GetNodeAttr(cnode, anf_attr_name); + if constexpr (IsString::value) { + ret = ge::AttrUtils::SetStr(*ge_op, ge_attr_name, attr); + } else { + ret = ge::AttrUtils::SetInt(*ge_op, ge_attr_name, attr); + } + + if (!ret) { + MS_LOG(EXCEPTION) << "Set attr " << ge_attr_name << " for ge node of " << cnode->DebugString() << " failed."; + } + MS_LOG(INFO) << "Convert success, attr " << ge_attr_name << " is " << attr; + return attr; +} + std::string GetGeNodeName(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); if (IsPrimitiveCNode(cnode, prim::kPrimAllReduce)) { @@ -93,15 +130,9 @@ std::tuple GenerateStubGeNode(const AnfNodePtr << " failed."; } - // set rank size - if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) { - auto rank_size = AnfAlgo::GetNodeAttr(cnode, kAttrRankSize); - ret = ge::AttrUtils::SetInt(*op_desc, ge::HCOM_ATTR_RANK_SIZE, rank_size); - if (!ret) { - MS_LOG(EXCEPTION) << "Set attr " << ge::HCOM_ATTR_RANK_SIZE << " for ge node of " << cnode->DebugString() - << " failed."; - } - } + // set node attr + (void)ConvertAttr(cnode, op_desc, kAttrRankSize, ge::HCOM_ATTR_RANK_SIZE); + (void)ConvertAttr(cnode, op_desc, kAttrGroup, ge::HCOM_ATTR_GROUP); ge::ComputeGraphPtr ge_graph = std::make_shared(kStubDataStructureName); MS_EXCEPTION_IF_NULL(ge_graph);