|
|
|
@@ -37,20 +37,6 @@ inline int64_t NormalizeDim(const std::vector<size_t> &shape, int64_t dim) { |
|
|
|
return dim < 0 ? SizeToLong(shape.size()) + dim : dim; |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t GetRankSize(const std::string &group) { |
|
|
|
uint32_t rank_size; |
|
|
|
const void *collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); |
|
|
|
MS_EXCEPTION_IF_NULL(collective_handle_); |
|
|
|
|
|
|
|
// Get group size |
|
|
|
auto get_group_size_funcptr = |
|
|
|
reinterpret_cast<GetGroupRanks>(dlsym(const_cast<void *>(collective_handle_), "GetGroupRanks")); |
|
|
|
MS_EXCEPTION_IF_NULL(get_group_size_funcptr); |
|
|
|
std::vector<int> group_ranks = (*get_group_size_funcptr)(group); |
|
|
|
rank_size = group_ranks.size(); |
|
|
|
return rank_size; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(all_to_all); |
|
|
|
@@ -113,7 +99,7 @@ CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_a |
|
|
|
std::vector<TypeId> dtypes(split_count, single_type); |
|
|
|
std::vector<std::vector<size_t>> shapes(split_count, single_shape); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, all_to_all_v.get()); |
|
|
|
uint32_t rank_size = GetRankSize(group); |
|
|
|
uint32_t rank_size = device::gpu::CollectiveInitializer::instance().GetGroupSize(group); |
|
|
|
std::vector<int64_t> rank_ids(rank_size, 0); |
|
|
|
for (uint32_t i = 0; i < rank_size; ++i) { |
|
|
|
rank_ids[i] = static_cast<int64_t>(i); |
|
|
|
|