|
|
|
@@ -51,7 +51,44 @@ static ge::DataType ConvertHcclDTypeToGeDType(HcclDataType datatype) { |
|
|
|
return iter->second; |
|
|
|
} |
|
|
|
|
|
|
|
template <class T> |
|
|
|
struct IsString { |
|
|
|
// cppcheck-suppress unusedStructMember |
|
|
|
static constexpr bool value = false; |
|
|
|
}; |
|
|
|
|
|
|
|
template <> |
|
|
|
struct IsString<std::string> { |
|
|
|
// cppcheck-suppress unusedStructMember |
|
|
|
static constexpr bool value = true; |
|
|
|
}; |
|
|
|
|
|
|
|
namespace mindspore::hccl { |
|
|
|
template <class T> |
|
|
|
static T ConvertAttr(const CNodePtr &cnode, const ge::OpDescPtr &ge_op, const std::string &anf_attr_name, |
|
|
|
const std::string &ge_attr_name) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(ge_op); |
|
|
|
if (!AnfAlgo::HasNodeAttr(anf_attr_name, cnode)) { |
|
|
|
MS_LOG(INFO) << "Node " << cnode->DebugString() << " has no attr " << anf_attr_name << ", skip."; |
|
|
|
return T(); |
|
|
|
} |
|
|
|
|
|
|
|
bool ret; |
|
|
|
auto attr = AnfAlgo::GetNodeAttr<T>(cnode, anf_attr_name); |
|
|
|
if constexpr (IsString<T>::value) { |
|
|
|
ret = ge::AttrUtils::SetStr(*ge_op, ge_attr_name, attr); |
|
|
|
} else { |
|
|
|
ret = ge::AttrUtils::SetInt(*ge_op, ge_attr_name, attr); |
|
|
|
} |
|
|
|
|
|
|
|
if (!ret) { |
|
|
|
MS_LOG(EXCEPTION) << "Set attr " << ge_attr_name << " for ge node of " << cnode->DebugString() << " failed."; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Convert success, attr " << ge_attr_name << " is " << attr; |
|
|
|
return attr; |
|
|
|
} |
|
|
|
|
|
|
|
std::string GetGeNodeName(const CNodePtr &cnode) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (IsPrimitiveCNode(cnode, prim::kPrimAllReduce)) { |
|
|
|
@@ -93,15 +130,9 @@ std::tuple<ge::NodePtr, ge::ComputeGraphPtr> GenerateStubGeNode(const AnfNodePtr |
|
|
|
<< " failed."; |
|
|
|
} |
|
|
|
|
|
|
|
// set rank size |
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) { |
|
|
|
auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize); |
|
|
|
ret = ge::AttrUtils::SetInt(*op_desc, ge::HCOM_ATTR_RANK_SIZE, rank_size); |
|
|
|
if (!ret) { |
|
|
|
MS_LOG(EXCEPTION) << "Set attr " << ge::HCOM_ATTR_RANK_SIZE << " for ge node of " << cnode->DebugString() |
|
|
|
<< " failed."; |
|
|
|
} |
|
|
|
} |
|
|
|
// set node attr |
|
|
|
(void)ConvertAttr<int64_t>(cnode, op_desc, kAttrRankSize, ge::HCOM_ATTR_RANK_SIZE); |
|
|
|
(void)ConvertAttr<std::string>(cnode, op_desc, kAttrGroup, ge::HCOM_ATTR_GROUP); |
|
|
|
|
|
|
|
ge::ComputeGraphPtr ge_graph = std::make_shared<ge::ComputeGraph>(kStubDataStructureName); |
|
|
|
MS_EXCEPTION_IF_NULL(ge_graph); |
|
|
|
|