Browse Source

fix attr of hccl stub node

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
tags/v1.1.0
zhoufeng 5 years ago
parent
commit
5f25f34f46
1 changed files with 40 additions and 9 deletions
  1. +40
    -9
      mindspore/ccsrc/runtime/hccl_adapter/converter.cc

+ 40
- 9
mindspore/ccsrc/runtime/hccl_adapter/converter.cc View File

@@ -51,7 +51,44 @@ static ge::DataType ConvertHcclDTypeToGeDType(HcclDataType datatype) {
return iter->second; return iter->second;
} }


template <class T>
struct IsString {
// cppcheck-suppress unusedStructMember
static constexpr bool value = false;
};

template <>
struct IsString<std::string> {
// cppcheck-suppress unusedStructMember
static constexpr bool value = true;
};

namespace mindspore::hccl { namespace mindspore::hccl {
template <class T>
static T ConvertAttr(const CNodePtr &cnode, const ge::OpDescPtr &ge_op, const std::string &anf_attr_name,
const std::string &ge_attr_name) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(ge_op);
if (!AnfAlgo::HasNodeAttr(anf_attr_name, cnode)) {
MS_LOG(INFO) << "Node " << cnode->DebugString() << " has no attr " << anf_attr_name << ", skip.";
return T();
}

bool ret;
auto attr = AnfAlgo::GetNodeAttr<T>(cnode, anf_attr_name);
if constexpr (IsString<T>::value) {
ret = ge::AttrUtils::SetStr(*ge_op, ge_attr_name, attr);
} else {
ret = ge::AttrUtils::SetInt(*ge_op, ge_attr_name, attr);
}

if (!ret) {
MS_LOG(EXCEPTION) << "Set attr " << ge_attr_name << " for ge node of " << cnode->DebugString() << " failed.";
}
MS_LOG(INFO) << "Convert success, attr " << ge_attr_name << " is " << attr;
return attr;
}

std::string GetGeNodeName(const CNodePtr &cnode) { std::string GetGeNodeName(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (IsPrimitiveCNode(cnode, prim::kPrimAllReduce)) { if (IsPrimitiveCNode(cnode, prim::kPrimAllReduce)) {
@@ -93,15 +130,9 @@ std::tuple<ge::NodePtr, ge::ComputeGraphPtr> GenerateStubGeNode(const AnfNodePtr
<< " failed."; << " failed.";
} }


// set rank size
if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) {
auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
ret = ge::AttrUtils::SetInt(*op_desc, ge::HCOM_ATTR_RANK_SIZE, rank_size);
if (!ret) {
MS_LOG(EXCEPTION) << "Set attr " << ge::HCOM_ATTR_RANK_SIZE << " for ge node of " << cnode->DebugString()
<< " failed.";
}
}
// set node attr
(void)ConvertAttr<int64_t>(cnode, op_desc, kAttrRankSize, ge::HCOM_ATTR_RANK_SIZE);
(void)ConvertAttr<std::string>(cnode, op_desc, kAttrGroup, ge::HCOM_ATTR_GROUP);


ge::ComputeGraphPtr ge_graph = std::make_shared<ge::ComputeGraph>(kStubDataStructureName); ge::ComputeGraphPtr ge_graph = std::make_shared<ge::ComputeGraph>(kStubDataStructureName);
MS_EXCEPTION_IF_NULL(ge_graph); MS_EXCEPTION_IF_NULL(ge_graph);


Loading…
Cancel
Save