|
|
|
@@ -25,6 +25,7 @@ |
|
|
|
#include <set> |
|
|
|
#include <string> |
|
|
|
#include <utility> |
|
|
|
#include <cmath> |
|
|
|
|
|
|
|
#include "utils/hash_map.h" |
|
|
|
#include "base/core_ops.h" |
|
|
|
@@ -563,6 +564,447 @@ static bool IsOriginWeight(const ParameterPtr ¶m) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { |
|
|
|
if (IsValueNode<RefKey>(node)) { |
|
|
|
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph); |
|
|
|
if (param_v.size() != 1) { |
|
|
|
MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is " |
|
|
|
<< param_v.size(); |
|
|
|
} |
|
|
|
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>(); |
|
|
|
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) { |
|
|
|
return std::make_pair(nullptr, true); |
|
|
|
} |
|
|
|
return std::make_pair(node, true); |
|
|
|
} |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} |
|
|
|
|
|
|
|
static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node) { |
|
|
|
auto param_ptr = node->user_data<parallel::TensorLayout>(); |
|
|
|
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) { |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} |
|
|
|
return std::make_pair(node, false); |
|
|
|
} |
|
|
|
|
|
|
|
// Only used for InsertMirrorOps |
|
|
|
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { |
|
|
|
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) { |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} |
|
|
|
|
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
return FindParameterByParameter(node); |
|
|
|
} |
|
|
|
|
|
|
|
if (node->isa<ValueNode>()) { |
|
|
|
return FindParameterByValueNode(node, func_graph); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
for (size_t index = 0; index < cnode->inputs().size(); ++index) { |
|
|
|
auto res = FindParameter(cnode->input(index), func_graph); |
|
|
|
if (!res.first) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
return res; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// When not fully use opt shard, allgather and mirror would be both inserted. |
|
|
|
// Skip allgather here and find parameter recursively. |
|
|
|
if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) { |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim_anf_node); |
|
|
|
for (size_t index = 0; index < cnode->inputs().size(); ++index) { |
|
|
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
if ((prim->name() == DEPEND || prim->name() == LOAD || IsInAllGatherNodeList(cnode)) && index != 1) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto res = FindParameter(cnode->input(index), func_graph); |
|
|
|
if (!res.first) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
return res; |
|
|
|
} |
|
|
|
return std::make_pair(nullptr, false); |
|
|
|
} |
|
|
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root) { |
|
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
|
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> adasum_param_map; |
|
|
|
for (auto ¶meter_node : root->parameters()) { |
|
|
|
MS_EXCEPTION_IF_NULL(parameter_node); |
|
|
|
auto cloned_parameter = parameter_node->cast<ParameterPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cloned_parameter); |
|
|
|
|
|
|
|
if (!ParameterIsCloned(parameter_node)) { |
|
|
|
auto parameter_tensor_layout = cloned_parameter->user_data<TensorLayout>(); |
|
|
|
adasum_param_map["adasum_delta_weight." + cloned_parameter->name()] = parameter_tensor_layout; |
|
|
|
} |
|
|
|
} |
|
|
|
return adasum_param_map; |
|
|
|
} |
|
|
|
|
|
|
|
Shape ValueSequeueScaleToShape(const ValuePtr &value_seq, const Shape &scale, size_t expand_ratio = 1) { |
|
|
|
if (!value_seq->isa<ValueSequeue>()) { |
|
|
|
MS_LOG(EXCEPTION) << "The input is not a value_sequeue"; |
|
|
|
} |
|
|
|
std::vector<int64_t> origin_value_vector; |
|
|
|
if (TransValueSequeueToVector(value_seq, &origin_value_vector) != SUCCESS) { |
|
|
|
MS_LOG(EXCEPTION) << "Transform value_seq to vector failed"; |
|
|
|
} |
|
|
|
if (origin_value_vector.size() != scale.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Shape not equal, cannot scale, value_seq size is: " << origin_value_vector.size() |
|
|
|
<< " scale size is: " << scale.size(); |
|
|
|
} |
|
|
|
for (size_t i = 0; i < scale.size(); ++i) { |
|
|
|
origin_value_vector[i] = origin_value_vector[i] / scale[i]; |
|
|
|
if (i == 0) { |
|
|
|
origin_value_vector[i] = origin_value_vector[i] * expand_ratio; |
|
|
|
} |
|
|
|
} |
|
|
|
return origin_value_vector; |
|
|
|
} |
|
|
|
|
|
|
|
ValuePtr ValueSequeueScale(const ValuePtr &value_seq, const Shape &scale, size_t expand_ratio = 1) { |
|
|
|
Shape origin_value_vector = ValueSequeueScaleToShape(value_seq, scale, expand_ratio); |
|
|
|
if (value_seq->isa<ValueTuple>()) { |
|
|
|
return TransVectorToValueSequeue<ValueTuple>(origin_value_vector); |
|
|
|
} |
|
|
|
return TransVectorToValueSequeue<ValueList>(origin_value_vector); |
|
|
|
} |
|
|
|
|
|
|
|
void ReplaceAdaSumStridedSliceValue(const CNodePtr &stridedslice_cnode1, |
|
|
|
const std::shared_ptr<TensorLayout> &target_param_layout, |
|
|
|
size_t slice_expand_ratio) { |
|
|
|
auto target_param_info = std::make_shared<TensorInfo>(target_param_layout->SqueezeShape()); |
|
|
|
Dimensions param_strategy = target_param_info->InferStrategy(); |
|
|
|
auto new_begin1_value = |
|
|
|
ValueSequeueScale(GetValueNode(stridedslice_cnode1->input(2)), param_strategy, slice_expand_ratio); |
|
|
|
auto new_end1_value = |
|
|
|
ValueSequeueScale(GetValueNode(stridedslice_cnode1->input(3)), param_strategy, slice_expand_ratio); |
|
|
|
ValueNodePtr new_begin_value_node = std::make_shared<ValueNode>(new_begin1_value); |
|
|
|
ValueNodePtr new_end_value_node = std::make_shared<ValueNode>(new_end1_value); |
|
|
|
stridedslice_cnode1->set_input(2, new_begin_value_node); |
|
|
|
stridedslice_cnode1->set_input(3, new_end_value_node); |
|
|
|
} |
|
|
|
|
|
|
|
RankList GetRankListByLayout(const std::shared_ptr<TensorLayout> &target_param_layout) { |
|
|
|
int64_t rank = g_device_manager->global_rank(); |
|
|
|
auto dev_shape = target_param_layout->device_arrangement().array(); |
|
|
|
auto stage_device_list = g_device_manager->GetDeviceListInThisStage(); |
|
|
|
DeviceMatrix dev_matrix(rank, stage_device_list, dev_shape); |
|
|
|
RankList group_devices; |
|
|
|
if (dev_matrix.GetDevicesByTensorMap(target_param_layout->tensor_map().array(), &group_devices) != SUCCESS) { |
|
|
|
MS_LOG(EXCEPTION) << "Get adasum parameter origin mirror group by tensor layout failed."; |
|
|
|
} |
|
|
|
return group_devices; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<bool> IsBorderAdaSumSendReceive(const AnfNodePtr &node, const RankList &group_devices) { |
|
|
|
bool is_send = IsPrimitiveCNode(node, prim::kPrimSend); |
|
|
|
PrimitivePtr send_rec_prim = GetCNodePrimitive(node); |
|
|
|
int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr("opposite_rank")); |
|
|
|
int64_t rank = g_device_manager->global_rank(); |
|
|
|
int64_t adasum_rank_distance = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1); |
|
|
|
if (adasum_rank_distance < ADASUM_MIN_DIS) { |
|
|
|
adasum_rank_distance = ADASUM_MIN_DIS; |
|
|
|
} |
|
|
|
size_t border_step = size_t(log2(adasum_rank_distance / ADASUM_MIN_DIS)); |
|
|
|
int64_t fusion_id = GetValue<int64_t>(send_rec_prim->GetAttr("origin_fusion")); |
|
|
|
// when cuting nodes, the fusion id should change. |
|
|
|
int64_t new_fusion_id = fusion_id + g_device_manager->DeviceNum() * (border_step + 1); |
|
|
|
send_rec_prim->set_attr(FUSION, MakeValue(new_fusion_id)); |
|
|
|
std::vector<int64_t> group_list; |
|
|
|
int64_t new_dest_src_rank; |
|
|
|
if (rank > origin_dest_rank) { |
|
|
|
group_list = {origin_dest_rank, rank}; |
|
|
|
new_dest_src_rank = 0; |
|
|
|
} else { |
|
|
|
group_list = {rank, origin_dest_rank}; |
|
|
|
new_dest_src_rank = 1; |
|
|
|
} |
|
|
|
Group adasum_send_rec_group = g_device_manager->CreateGroup(group_list); |
|
|
|
send_rec_prim->set_attr(GROUP, MakeValue(adasum_send_rec_group.name())); |
|
|
|
if (is_send) { |
|
|
|
send_rec_prim->set_attr(DEST_RANK, MakeValue(new_dest_src_rank)); |
|
|
|
} else { |
|
|
|
send_rec_prim->set_attr(SRC_RANK, MakeValue(new_dest_src_rank)); |
|
|
|
} |
|
|
|
int64_t rank_dis = abs(origin_dest_rank - rank); |
|
|
|
if (adasum_rank_distance == ADASUM_MIN_DIS) { |
|
|
|
return {false, false, false, false}; |
|
|
|
} |
|
|
|
bool is_origin_first_node_if_forward = false; |
|
|
|
bool is_new_first_node_if_forward = false; |
|
|
|
bool is_origin_last_node_if_rollback = false; |
|
|
|
bool is_new_last_node_if_rollback = false; |
|
|
|
if (rank_dis == ADASUM_MIN_DIS) { |
|
|
|
is_origin_first_node_if_forward = true; |
|
|
|
is_origin_last_node_if_rollback = true; |
|
|
|
} |
|
|
|
if (rank_dis == adasum_rank_distance) { |
|
|
|
is_new_first_node_if_forward = true; |
|
|
|
} |
|
|
|
if (rank_dis == adasum_rank_distance / 2) { |
|
|
|
is_new_last_node_if_rollback = true; |
|
|
|
} |
|
|
|
return {is_origin_first_node_if_forward, is_new_first_node_if_forward, is_origin_last_node_if_rollback, |
|
|
|
is_new_last_node_if_rollback}; |
|
|
|
} |
|
|
|
|
|
|
|
void HandleAdaSumReshape(const CNodePtr &reshape_cnode, const std::shared_ptr<TensorLayout> &target_param_layout) { |
|
|
|
auto slice_shape = target_param_layout->slice_shape().array(); |
|
|
|
auto slice_shape_value = TransVectorToValueSequeue<ValueTuple>(slice_shape); |
|
|
|
ValueNodePtr new_slice_shape_value_node = std::make_shared<ValueNode>(slice_shape_value); |
|
|
|
reshape_cnode->set_input(2, new_slice_shape_value_node); |
|
|
|
} |
|
|
|
|
|
|
|
void RemoveAdasumRedundantNodes(const FuncGraphManagerPtr &manager, |
|
|
|
std::unordered_map<std::string, CNodePtr> *forward_origin_first_node_map, |
|
|
|
std::unordered_map<std::string, CNodePtr> *forward_new_first_node_map, |
|
|
|
std::unordered_map<std::string, CNodePtr> *rollback_origin_last_node_map, |
|
|
|
std::unordered_map<std::string, CNodePtr> *rollback_new_last_node_map) { |
|
|
|
// connect forward last node and rollback first node |
|
|
|
if (forward_origin_first_node_map->size() != forward_new_first_node_map->size() || |
|
|
|
rollback_origin_last_node_map->size() != rollback_new_last_node_map->size()) { |
|
|
|
MS_LOG(EXCEPTION) << "The over border node is not equal in adasum forward process and rollback process."; |
|
|
|
} |
|
|
|
for (auto node : *forward_origin_first_node_map) { |
|
|
|
std::string target_param = node.first; |
|
|
|
CNodePtr forward_origin_first_node = node.second; |
|
|
|
CNodePtr forward_new_first_node = (*forward_new_first_node_map)[target_param]; |
|
|
|
manager->SetEdge(forward_new_first_node, 1, forward_origin_first_node->input(1)); |
|
|
|
} |
|
|
|
for (auto node : *rollback_origin_last_node_map) { |
|
|
|
std::string target_param = node.first; |
|
|
|
CNodePtr rollback_origin_last_node = node.second; |
|
|
|
CNodePtr rollback_new_last_node = (*rollback_new_last_node_map)[target_param]; |
|
|
|
manager->Replace(rollback_origin_last_node, rollback_new_last_node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void HandleAdasumAllReduce(const PrimitivePtr &prim, const RankList &group_devices) { |
|
|
|
size_t step = size_t(GetValue<int64_t>(prim->GetAttr("step"))); |
|
|
|
std::vector<int64_t> neighbor_ids; |
|
|
|
int64_t adasum_rank_distance = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1); |
|
|
|
if (adasum_rank_distance < ADASUM_MIN_DIS) { |
|
|
|
adasum_rank_distance = ADASUM_MIN_DIS; |
|
|
|
} |
|
|
|
size_t border_step = size_t(log2(adasum_rank_distance / ADASUM_MIN_DIS)); |
|
|
|
MS_LOG(INFO) << "current border step is: " << border_step; |
|
|
|
if (step < border_step) { |
|
|
|
return; |
|
|
|
} |
|
|
|
int64_t rank = g_device_manager->global_rank(); |
|
|
|
size_t double_d = size_t(2 << step); |
|
|
|
for (size_t index = 0; index < double_d; ++index) { |
|
|
|
int64_t node_rank = rank / ADASUM_MIN_DIS; |
|
|
|
int64_t neighbor_id = (node_rank / double_d * double_d + index) * ADASUM_MIN_DIS + rank % ADASUM_MIN_DIS; |
|
|
|
neighbor_ids.push_back(neighbor_id); |
|
|
|
} |
|
|
|
Group adasum_allreduce_group = g_device_manager->CreateGroup(neighbor_ids); |
|
|
|
auto new_group_name = MakeValue(adasum_allreduce_group.name()); |
|
|
|
int64_t fusion_id = GetValue<int64_t>(prim->GetAttr("origin_fusion")); |
|
|
|
int64_t new_fusion_id = fusion_id + g_device_manager->DeviceNum() * (border_step + 1); |
|
|
|
prim->set_attr(GROUP, new_group_name); |
|
|
|
prim->set_attr(FUSION, MakeValue(new_fusion_id)); |
|
|
|
} |
|
|
|
|
|
|
|
void HandleAdasumSlice(const AnfNodePtr &stridedslice_node1, const std::shared_ptr<TensorLayout> &target_param_layout, |
|
|
|
const std::string &target_param, size_t slice_expand_ratio) { |
|
|
|
auto stridedslice_cnode1 = stridedslice_node1->cast<CNodePtr>(); |
|
|
|
ReplaceAdaSumStridedSliceValue(stridedslice_cnode1, target_param_layout, slice_expand_ratio); |
|
|
|
auto squeeze_node = RealInputNode(stridedslice_cnode1, 1); |
|
|
|
if (!IsPrimitiveCNode(squeeze_node, prim::kPrimSqueeze)) { |
|
|
|
MS_LOG(EXCEPTION) << "The stridedslice input node should be squeeze in adasum"; |
|
|
|
} |
|
|
|
auto squeeze_cnode = squeeze_node->cast<CNodePtr>(); |
|
|
|
FuncGraphManagerPtr manager = squeeze_node->func_graph()->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
AnfNodeIndexSet node_set = manager->node_users()[squeeze_cnode]; |
|
|
|
for (auto &node_pair : node_set) { |
|
|
|
if (IsPrimitiveCNode(node_pair.first, prim::kPrimStridedSlice) && node_pair.first != stridedslice_node1) { |
|
|
|
CNodePtr use_apply = node_pair.first->cast<CNodePtr>(); |
|
|
|
ReplaceAdaSumStridedSliceValue(use_apply, target_param_layout, slice_expand_ratio); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void HandleAdaSumConcat(const AnfNodePtr &concat_node, const std::vector<bool> &border_info, |
|
|
|
const std::string &target_param, |
|
|
|
std::unordered_map<std::string, CNodePtr> *rollback_new_last_node_map, |
|
|
|
std::unordered_map<std::string, CNodePtr> *rollback_origin_last_node_map) { |
|
|
|
if (border_info[3]) { |
|
|
|
(*rollback_new_last_node_map)[target_param] = concat_node->cast<CNodePtr>(); |
|
|
|
} |
|
|
|
if (border_info[2]) { |
|
|
|
auto manager = concat_node->func_graph()->manager(); |
|
|
|
AnfNodeIndexSet concat_node_user_set = manager->node_users()[concat_node]; |
|
|
|
for (auto &node_pair : concat_node_user_set) { |
|
|
|
if (IsPrimitiveCNode(node_pair.first, prim::kPrimMakeTuple)) { |
|
|
|
AnfNodeIndexSet make_tuple_node_user_set = manager->node_users()[node_pair.first]; |
|
|
|
for (auto &tuple_user : make_tuple_node_user_set) { |
|
|
|
if (IsPrimitiveCNode(tuple_user.first, prim::kPrimConcat)) { |
|
|
|
(*rollback_origin_last_node_map)[target_param] = tuple_user.first->cast<CNodePtr>(); |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void HandleAdaSumSqueeze(const AnfNodePtr &stridedslice_node1, const std::vector<bool> &border_info, |
|
|
|
const std::string &target_param, |
|
|
|
std::unordered_map<std::string, CNodePtr> *forward_origin_first_node_map, |
|
|
|
std::unordered_map<std::string, CNodePtr> *forward_new_first_node_map) { |
|
|
|
auto squeeze_node = RealInputNode(stridedslice_node1->cast<CNodePtr>(), 1); |
|
|
|
if (border_info[0]) { |
|
|
|
(*forward_origin_first_node_map)[target_param] = squeeze_node->cast<CNodePtr>(); |
|
|
|
} |
|
|
|
if (border_info[1]) { |
|
|
|
(*forward_new_first_node_map)[target_param] = squeeze_node->cast<CNodePtr>(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool HandleAdaSum(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, |
|
|
|
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map) { |
|
|
|
std::unordered_map<std::string, CNodePtr> forward_origin_first_node_map; |
|
|
|
std::unordered_map<std::string, CNodePtr> forward_new_first_node_map; |
|
|
|
std::unordered_map<std::string, CNodePtr> rollback_origin_last_node_map; |
|
|
|
std::unordered_map<std::string, CNodePtr> rollback_new_last_node_map; |
|
|
|
bool is_adasum = false; |
|
|
|
for (auto &node : all_nodes) { |
|
|
|
bool is_allreduce = IsPrimitiveCNode(node, prim::kPrimAllReduce); |
|
|
|
bool is_reshape = IsPrimitiveCNode(node, prim::kPrimReshape); |
|
|
|
bool is_send = IsPrimitiveCNode(node, prim::kPrimSend); |
|
|
|
bool is_receive = IsPrimitiveCNode(node, prim::kPrimReceive); |
|
|
|
if (!is_allreduce && !is_reshape && !is_send && !is_receive) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
std::string target_param; |
|
|
|
CNodePtr cnode = node->cast<CNodePtr>(); |
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)->cast<ValueNodePtr>()); |
|
|
|
if (!prim->HasAttr("target_param")) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
target_param = GetValue<std::string>(prim->GetAttr("target_param")); |
|
|
|
auto target_param_layout = (*adasum_param_tensor_layout_map)[target_param]; |
|
|
|
RankList group_devices = GetRankListByLayout(target_param_layout); |
|
|
|
int64_t adasum_rank_distance = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1); |
|
|
|
// when the repeat dim is right, the parameter do not enable adasum. |
|
|
|
if (adasum_rank_distance == 1 && group_devices.size() < size_t(g_device_manager->stage_device_num())) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Apply adasum in auto parallel, current dealing node is: " << node->fullname_with_scope(); |
|
|
|
is_adasum = true; |
|
|
|
size_t slice_expand_ratio = adasum_rank_distance / ADASUM_MIN_DIS > 0 ? adasum_rank_distance / ADASUM_MIN_DIS : 1; |
|
|
|
if (is_reshape) { |
|
|
|
HandleAdaSumReshape(cnode, (*adasum_param_tensor_layout_map)[target_param]); |
|
|
|
} |
|
|
|
if (is_allreduce && prim->HasAttr("step")) { |
|
|
|
HandleAdasumAllReduce(prim, group_devices); |
|
|
|
} |
|
|
|
if (is_send || is_receive) { |
|
|
|
std::vector<bool> border_info = IsBorderAdaSumSendReceive(node, group_devices); |
|
|
|
if (is_receive) { |
|
|
|
auto target_param_info = std::make_shared<TensorInfo>(*target_param_layout); |
|
|
|
Dimensions param_strategy = target_param_info->InferStrategy(); |
|
|
|
Shape new_rec_shape = ValueSequeueScaleToShape(prim->GetAttr(SHAPE), param_strategy, slice_expand_ratio); |
|
|
|
auto new_rec_shape_value = TransVectorToValueSequeue<ValueList>(new_rec_shape); |
|
|
|
prim->set_attr(SHAPE, new_rec_shape_value); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto stridedslice_node1 = RealInputNode(cnode, 1); |
|
|
|
if (IsPrimitiveCNode(stridedslice_node1, prim::kPrimConcat)) { |
|
|
|
HandleAdaSumConcat(stridedslice_node1, border_info, target_param, &rollback_new_last_node_map, |
|
|
|
&rollback_origin_last_node_map); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!IsPrimitiveCNode(stridedslice_node1, prim::kPrimStridedSlice)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
HandleAdasumSlice(stridedslice_node1, target_param_layout, target_param, slice_expand_ratio); |
|
|
|
HandleAdaSumSqueeze(stridedslice_node1, border_info, target_param, &forward_origin_first_node_map, |
|
|
|
&forward_new_first_node_map); |
|
|
|
} |
|
|
|
} |
|
|
|
RemoveAdasumRedundantNodes(root->manager(), &forward_origin_first_node_map, &forward_new_first_node_map, |
|
|
|
&rollback_origin_last_node_map, &rollback_new_last_node_map); |
|
|
|
return is_adasum; |
|
|
|
} |
|
|
|
|
|
|
|
void ResetMirrorAttr(const PrimitivePtr &prim, const RankList &new_group) { |
|
|
|
if (new_group.size() == 1) { |
|
|
|
prim->set_attr(DEV_NUM, MakeValue(new_group.size())); |
|
|
|
prim->set_attr(GROUP, MakeValue("one_rank_group")); |
|
|
|
prim->set_attr(GROUP_RANKS, MakeValue(std::to_string(new_group[0]))); |
|
|
|
return; |
|
|
|
} |
|
|
|
Group adasum_mirror_group = g_device_manager->CreateGroup(new_group); |
|
|
|
auto new_group_name = MakeValue(adasum_mirror_group.name()); |
|
|
|
prim->set_attr(GROUP, new_group_name); |
|
|
|
prim->set_attr(DEV_NUM, MakeValue(new_group.size())); |
|
|
|
std::string rank_list_name = g_device_manager->FindRankListNameByHashName(adasum_mirror_group.name()); |
|
|
|
prim->set_attr(GROUP_RANKS, MakeValue(rank_list_name)); |
|
|
|
} |
|
|
|
|
|
|
|
void HandleMirrorInAdaSum( |
|
|
|
const FuncGraphPtr &root, |
|
|
|
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map) { |
|
|
|
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(root->get_return()); |
|
|
|
for (auto &node : all_nodes) { |
|
|
|
if (!IsPrimitiveCNode(node, prim::kPrimMirror)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
CNodePtr mirror_cnode = node->cast<CNodePtr>(); |
|
|
|
auto param_node_pair = FindParameter(mirror_cnode->input(1), node->func_graph()); |
|
|
|
if (!param_node_pair.first) { |
|
|
|
MS_LOG(EXCEPTION) << "Mirror input is not a param"; |
|
|
|
} |
|
|
|
auto param_ptr = param_node_pair.first->cast<ParameterPtr>(); |
|
|
|
std::string param_name = param_ptr->name(); |
|
|
|
MS_LOG(INFO) << "Mirror param name is: " << param_name; |
|
|
|
std::string target_param = "adasum_delta_weight." + param_name; |
|
|
|
auto target_param_layout = (*adasum_param_tensor_layout_map)[target_param]; |
|
|
|
|
|
|
|
// Change mirror group |
|
|
|
RankList group_devices = GetRankListByLayout(target_param_layout); |
|
|
|
int64_t rank = g_device_manager->global_rank(); |
|
|
|
size_t group_dis = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1); |
|
|
|
auto prim = GetCNodePrimitive(node); |
|
|
|
if (group_dis < ADASUM_MIN_DIS) { |
|
|
|
size_t new_group_size = size_t(ADASUM_MIN_DIS) / group_dis; |
|
|
|
// compute new group range |
|
|
|
size_t group_begin = 0; |
|
|
|
for (size_t group_end = new_group_size; group_end < group_devices.size() + new_group_size; |
|
|
|
group_end += new_group_size) { |
|
|
|
int64_t max_group_value = |
|
|
|
group_end >= group_devices.size() ? (group_devices.back() + 1) : group_devices[group_end]; |
|
|
|
if (group_devices[group_begin] <= rank && rank < max_group_value) { |
|
|
|
std::vector<int64_t> new_group(group_devices.begin() + group_begin, group_devices.begin() + group_end); |
|
|
|
MS_LOG(INFO) << "Find new mirror group in adasum: " << new_group << " target_param:" << target_param; |
|
|
|
ResetMirrorAttr(prim, new_group); |
|
|
|
break; |
|
|
|
} |
|
|
|
group_begin = group_end; |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
ResetMirrorAttr(prim, {rank}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void HandleAdaFactorOpt(const FuncGraphPtr &root) { |
|
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
|
for (auto ¶m_node : root->parameters()) { |
|
|
|
|