Browse Source

!30157 auto_parallel_adasum_cpp_part

Merge pull request !30157 from yao_yf/auto_parallel_adasum_cpp_part
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
dd593d2b87
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 554 additions and 140 deletions
  1. +9
    -1
      mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
  2. +1
    -0
      mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc
  3. +11
    -7
      mindspore/ccsrc/frontend/parallel/device_manager.cc
  4. +0
    -1
      mindspore/ccsrc/frontend/parallel/device_manager.h
  5. +48
    -0
      mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc
  6. +17
    -0
      mindspore/ccsrc/frontend/parallel/graph_util/node_info.h
  7. +1
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h
  8. +4
    -25
      mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc
  9. +442
    -0
      mindspore/ccsrc/frontend/parallel/parameter_manager.cc
  10. +9
    -0
      mindspore/ccsrc/frontend/parallel/parameter_manager.h
  11. +12
    -77
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  12. +0
    -2
      mindspore/ccsrc/frontend/parallel/step_parallel.h
  13. +0
    -26
      mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc
  14. +0
    -1
      mindspore/ccsrc/frontend/parallel/step_parallel_utils.h

+ 9
- 1
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc View File

@@ -28,7 +28,15 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
PConstant const_(node);
PConstant const_2(node);
PConstant any_const(node);

// if node has keep_alive attr, it would not be eliminated.
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->HasAttr("keep_alive") && GetValue<bool>(prim->GetAttr("keep_alive"))) {
MS_LOG(INFO) << "keep node " << node->fullname_with_scope() << " alive";
return nullptr;
}
}
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
MATCH_REPLACE(node, x + zero_, x); // Add by zero
MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero


+ 1
- 0
mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc View File

@@ -25,6 +25,7 @@
#include "frontend/parallel/costmodel_context.h"
#include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/status.h"
#include "frontend/parallel/parameter_manager.h"
#include "frontend/parallel/step_parallel.h"
#include "utils/log_adapter.h"



+ 11
- 7
mindspore/ccsrc/frontend/parallel/device_manager.cc View File

@@ -301,20 +301,24 @@ RankList DeviceManager::FindRankListByHashName(const std::string &hash_name) {

std::string HashName(const std::string &origin_name) { return std::to_string(std::hash<string>{}(origin_name)); }

// Group name is generated using the increasing ranks of the devices.
// E.g. the devices' ranks are '<0, 5, 3, 7, 1>', and the generated group name
// is '0-1-3-5-7'.
std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) {
std::string RankListName(const RankList &ranks) {
std::string rank_list_name;
std::vector<int64_t>::iterator it;
std::sort(ranks.begin(), ranks.end()); // sorted in increasing order
for (it = ranks.begin(); it != ranks.end(); ++it) {
for (auto it = ranks.begin(); it != ranks.end(); ++it) {
if (it == ranks.begin()) {
rank_list_name = std::to_string(*it);
} else {
rank_list_name += "-" + std::to_string(*it);
}
}
return rank_list_name;
}

// Group name is generated using the increasing ranks of the devices.
// E.g. the devices' ranks are '<0, 5, 3, 7, 1>', and the generated group name
// is '0-1-3-5-7'.
std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) {
std::sort(ranks.begin(), ranks.end()); // sorted in increasing order
std::string rank_list_name = RankListName(ranks);

// hash rank-list-name and add ranks' size as prefix
std::string group_hash_name = HashName(rank_list_name);


+ 0
- 1
mindspore/ccsrc/frontend/parallel/device_manager.h View File

@@ -69,7 +69,6 @@ class DeviceManager {

Device CreateNewDeviceByRank(int64_t rank) const;
std::vector<Device> CreateDeviceListByRankList(RankList ranks);

std::string GenerateGroupNameByRanks(RankList dev_ranks);
Group CreateGroup(const std::string &group_name, const std::vector<Device> &devices);
Group CreateGroup(const RankList &dev_ranks);


+ 48
- 0
mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc View File

@@ -418,5 +418,53 @@ void SetUserAttrs(const mindspore::HashMap<std::string, ValuePtr> &origin_prim_a
}
}
}

// Convert ValueTuple/ValueList to vector
Status TransValueSequeueToVector(const ValuePtr &input_value, std::vector<int64_t> *input) {
MS_EXCEPTION_IF_NULL(input_value);
if (!input_value->isa<ValueSequeue>()) {
MS_LOG(ERROR) << "Input value must be ValueTuplePtr.";
return FAILED;
}
ValueSequeuePtr value_seq = input_value->cast<ValueSequeuePtr>();
for (auto &element : value_seq->value()) {
MS_EXCEPTION_IF_NULL(element);
if (element->isa<Int64Imm>()) {
int64_t value = element->cast<Int64ImmPtr>()->value();
input->push_back(value);
} else {
MS_LOG(ERROR) << "The value must be int64";
return FAILED;
}
}
return SUCCESS;
}

// Get the input of cnode, skipping DEPEND/LOAD/UPDATESTATE
const AnfNodePtr RealInputNode(const CNodePtr cnode, size_t index) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() <= index) {
MS_LOG(EXCEPTION) << "cnode inputs size: " << cnode->size() << " is less equal index: " << index;
}
auto input0 = cnode->input(index);
if (!input0->isa<CNode>()) {
return input0;
}
auto prim = GetCNodePrimitive(input0);
MS_EXCEPTION_IF_NULL(prim);
while (prim->name() == LOAD || prim->name() == DEPEND || prim->name() == UPDATESTATE) {
if (prim->name() == LOAD || prim->name() == DEPEND) {
input0 = input0->cast<CNodePtr>()->input(1);
} else {
input0 = input0->cast<CNodePtr>()->input(2);
}
if (!input0->isa<CNode>()) {
return input0;
}
prim = GetCNodePrimitive(input0);
MS_EXCEPTION_IF_NULL(prim);
}
return input0;
}
} // namespace parallel
} // namespace mindspore

+ 17
- 0
mindspore/ccsrc/frontend/parallel/graph_util/node_info.h View File

@@ -52,7 +52,24 @@ bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_op

bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index,
bool *is_next_reshape, size_t curr_depth);

void SetUserAttrs(const mindspore::HashMap<std::string, ValuePtr> &origin_prim_attrs, const PrimitivePtr &self_prim);

Status TransValueSequeueToVector(const ValuePtr &input_value, std::vector<int64_t> *input);

template <typename T>
std::shared_ptr<typename std::enable_if<std::is_base_of<ValueSequeue, T>::value, T>::type> TransVectorToValueSequeue(
const std::vector<int64_t> &input) {
std::vector<ValuePtr> elements;
for (auto dim : input) {
ValuePtr value_dim = MakeValue<int64_t>(dim);
elements.push_back(value_dim);
}
std::shared_ptr<T> seq_value = std::make_shared<T>(elements);
return seq_value;
}

const AnfNodePtr RealInputNode(const CNodePtr cnode, size_t index);
} // namespace parallel
} // namespace mindspore



+ 1
- 0
mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h View File

@@ -31,6 +31,7 @@ constexpr int64_t NO_SPLIT_MAP = -1;
constexpr int64_t NO_SPLIT_STRATEGY = 1;
constexpr int64_t SPLIT_FLAG = 1;
constexpr int64_t NO_SPLIT_FLAG = 0;
constexpr int64_t ADASUM_MIN_DIS = 8;
constexpr size_t MATMUL_ATTRS_SIZE = 2;
constexpr size_t SLICE_BEGIN_INDEX = 1;
constexpr size_t SLICE_SIZE_INDEX = 2;


+ 4
- 25
mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc View File

@@ -23,6 +23,7 @@

#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/resource.h"

@@ -50,28 +51,6 @@ Status StridedSliceInfo::GetMask(const std::string &mask_name, int64_t *mask_val
return SUCCESS;
}

Status GetInput(const ValuePtr &input_value, std::vector<int64_t> *input) {
MS_EXCEPTION_IF_NULL(input_value);
ValueTuplePtr value_tuple = input_value->cast<ValueTuplePtr>();
if (value_tuple == nullptr) {
MS_LOG(ERROR) << "Input value must be ValueTuplePtr.";
return FAILED;
}

for (auto &element : value_tuple->value()) {
MS_EXCEPTION_IF_NULL(element);
if (element->isa<Int64Imm>()) {
int64_t value = element->cast<Int64ImmPtr>()->value();
input->push_back(value);
} else {
MS_LOG(ERROR) << "The value must be int64";
return FAILED;
}
}

return SUCCESS;
}

Status StridedSliceInfo::GetAttrs() {
if (attrs_.size() < STRIDED_SLICE_ATTRS_SIZE) {
MS_LOG(ERROR) << name_ << ": The size of attrs small than " << STRIDED_SLICE_ATTRS_SIZE;
@@ -91,9 +70,9 @@ Status StridedSliceInfo::GetAttrs() {
return FAILED;
}

if ((GetInput(input_value_[STRIDED_SLICE_BEGIN_INDEX], &begin_) != SUCCESS) ||
(GetInput(input_value_[STRIDED_SLICE_END_INDEX], &end_) != SUCCESS) ||
(GetInput(input_value_[STRIDED_SLICE_STRIDES_INDEX], &strides_) != SUCCESS)) {
if ((TransValueSequeueToVector(input_value_[STRIDED_SLICE_BEGIN_INDEX], &begin_) != SUCCESS) ||
(TransValueSequeueToVector(input_value_[STRIDED_SLICE_END_INDEX], &end_) != SUCCESS) ||
(TransValueSequeueToVector(input_value_[STRIDED_SLICE_STRIDES_INDEX], &strides_) != SUCCESS)) {
return FAILED;
}



+ 442
- 0
mindspore/ccsrc/frontend/parallel/parameter_manager.cc View File

@@ -25,6 +25,7 @@
#include <set>
#include <string>
#include <utility>
#include <cmath>

#include "utils/hash_map.h"
#include "base/core_ops.h"
@@ -563,6 +564,447 @@ static bool IsOriginWeight(const ParameterPtr &param) {
return true;
}

static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
if (IsValueNode<RefKey>(node)) {
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
if (param_v.size() != 1) {
MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is "
<< param_v.size();
}
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
return std::make_pair(nullptr, true);
}
return std::make_pair(node, true);
}
return std::make_pair(nullptr, false);
}

static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node) {
auto param_ptr = node->user_data<parallel::TensorLayout>();
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
return std::make_pair(nullptr, false);
}
return std::make_pair(node, false);
}

// Only used for InsertMirrorOps
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
return std::make_pair(nullptr, false);
}

if (node->isa<Parameter>()) {
return FindParameterByParameter(node);
}

if (node->isa<ValueNode>()) {
return FindParameterByValueNode(node, func_graph);
}

CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
auto res = FindParameter(cnode->input(index), func_graph);
if (!res.first) {
continue;
}
return res;
}
}

// When not fully use opt shard, allgather and mirror would be both inserted.
// Skip allgather here and find parameter recursively.
if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) {
return std::make_pair(nullptr, false);
}

ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim);
if ((prim->name() == DEPEND || prim->name() == LOAD || IsInAllGatherNodeList(cnode)) && index != 1) {
continue;
}
auto res = FindParameter(cnode->input(index), func_graph);
if (!res.first) {
continue;
}
return res;
}
return std::make_pair(nullptr, false);
}

std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> adasum_param_map;
for (auto &parameter_node : root->parameters()) {
MS_EXCEPTION_IF_NULL(parameter_node);
auto cloned_parameter = parameter_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(cloned_parameter);

if (!ParameterIsCloned(parameter_node)) {
auto parameter_tensor_layout = cloned_parameter->user_data<TensorLayout>();
adasum_param_map["adasum_delta_weight." + cloned_parameter->name()] = parameter_tensor_layout;
}
}
return adasum_param_map;
}

Shape ValueSequeueScaleToShape(const ValuePtr &value_seq, const Shape &scale, size_t expand_ratio = 1) {
if (!value_seq->isa<ValueSequeue>()) {
MS_LOG(EXCEPTION) << "The input is not a value_sequeue";
}
std::vector<int64_t> origin_value_vector;
if (TransValueSequeueToVector(value_seq, &origin_value_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << "Transform value_seq to vector failed";
}
if (origin_value_vector.size() != scale.size()) {
MS_LOG(EXCEPTION) << "Shape not equal, cannot scale, value_seq size is: " << origin_value_vector.size()
<< " scale size is: " << scale.size();
}
for (size_t i = 0; i < scale.size(); ++i) {
origin_value_vector[i] = origin_value_vector[i] / scale[i];
if (i == 0) {
origin_value_vector[i] = origin_value_vector[i] * expand_ratio;
}
}
return origin_value_vector;
}

ValuePtr ValueSequeueScale(const ValuePtr &value_seq, const Shape &scale, size_t expand_ratio = 1) {
Shape origin_value_vector = ValueSequeueScaleToShape(value_seq, scale, expand_ratio);
if (value_seq->isa<ValueTuple>()) {
return TransVectorToValueSequeue<ValueTuple>(origin_value_vector);
}
return TransVectorToValueSequeue<ValueList>(origin_value_vector);
}

void ReplaceAdaSumStridedSliceValue(const CNodePtr &stridedslice_cnode1,
const std::shared_ptr<TensorLayout> &target_param_layout,
size_t slice_expand_ratio) {
auto target_param_info = std::make_shared<TensorInfo>(target_param_layout->SqueezeShape());
Dimensions param_strategy = target_param_info->InferStrategy();
auto new_begin1_value =
ValueSequeueScale(GetValueNode(stridedslice_cnode1->input(2)), param_strategy, slice_expand_ratio);
auto new_end1_value =
ValueSequeueScale(GetValueNode(stridedslice_cnode1->input(3)), param_strategy, slice_expand_ratio);
ValueNodePtr new_begin_value_node = std::make_shared<ValueNode>(new_begin1_value);
ValueNodePtr new_end_value_node = std::make_shared<ValueNode>(new_end1_value);
stridedslice_cnode1->set_input(2, new_begin_value_node);
stridedslice_cnode1->set_input(3, new_end_value_node);
}

RankList GetRankListByLayout(const std::shared_ptr<TensorLayout> &target_param_layout) {
int64_t rank = g_device_manager->global_rank();
auto dev_shape = target_param_layout->device_arrangement().array();
auto stage_device_list = g_device_manager->GetDeviceListInThisStage();
DeviceMatrix dev_matrix(rank, stage_device_list, dev_shape);
RankList group_devices;
if (dev_matrix.GetDevicesByTensorMap(target_param_layout->tensor_map().array(), &group_devices) != SUCCESS) {
MS_LOG(EXCEPTION) << "Get adasum parameter origin mirror group by tensor layout failed.";
}
return group_devices;
}

std::vector<bool> IsBorderAdaSumSendReceive(const AnfNodePtr &node, const RankList &group_devices) {
bool is_send = IsPrimitiveCNode(node, prim::kPrimSend);
PrimitivePtr send_rec_prim = GetCNodePrimitive(node);
int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr("opposite_rank"));
int64_t rank = g_device_manager->global_rank();
int64_t adasum_rank_distance = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
if (adasum_rank_distance < ADASUM_MIN_DIS) {
adasum_rank_distance = ADASUM_MIN_DIS;
}
size_t border_step = size_t(log2(adasum_rank_distance / ADASUM_MIN_DIS));
int64_t fusion_id = GetValue<int64_t>(send_rec_prim->GetAttr("origin_fusion"));
// when cuting nodes, the fusion id should change.
int64_t new_fusion_id = fusion_id + g_device_manager->DeviceNum() * (border_step + 1);
send_rec_prim->set_attr(FUSION, MakeValue(new_fusion_id));
std::vector<int64_t> group_list;
int64_t new_dest_src_rank;
if (rank > origin_dest_rank) {
group_list = {origin_dest_rank, rank};
new_dest_src_rank = 0;
} else {
group_list = {rank, origin_dest_rank};
new_dest_src_rank = 1;
}
Group adasum_send_rec_group = g_device_manager->CreateGroup(group_list);
send_rec_prim->set_attr(GROUP, MakeValue(adasum_send_rec_group.name()));
if (is_send) {
send_rec_prim->set_attr(DEST_RANK, MakeValue(new_dest_src_rank));
} else {
send_rec_prim->set_attr(SRC_RANK, MakeValue(new_dest_src_rank));
}
int64_t rank_dis = abs(origin_dest_rank - rank);
if (adasum_rank_distance == ADASUM_MIN_DIS) {
return {false, false, false, false};
}
bool is_origin_first_node_if_forward = false;
bool is_new_first_node_if_forward = false;
bool is_origin_last_node_if_rollback = false;
bool is_new_last_node_if_rollback = false;
if (rank_dis == ADASUM_MIN_DIS) {
is_origin_first_node_if_forward = true;
is_origin_last_node_if_rollback = true;
}
if (rank_dis == adasum_rank_distance) {
is_new_first_node_if_forward = true;
}
if (rank_dis == adasum_rank_distance / 2) {
is_new_last_node_if_rollback = true;
}
return {is_origin_first_node_if_forward, is_new_first_node_if_forward, is_origin_last_node_if_rollback,
is_new_last_node_if_rollback};
}

void HandleAdaSumReshape(const CNodePtr &reshape_cnode, const std::shared_ptr<TensorLayout> &target_param_layout) {
auto slice_shape = target_param_layout->slice_shape().array();
auto slice_shape_value = TransVectorToValueSequeue<ValueTuple>(slice_shape);
ValueNodePtr new_slice_shape_value_node = std::make_shared<ValueNode>(slice_shape_value);
reshape_cnode->set_input(2, new_slice_shape_value_node);
}

void RemoveAdasumRedundantNodes(const FuncGraphManagerPtr &manager,
std::unordered_map<std::string, CNodePtr> *forward_origin_first_node_map,
std::unordered_map<std::string, CNodePtr> *forward_new_first_node_map,
std::unordered_map<std::string, CNodePtr> *rollback_origin_last_node_map,
std::unordered_map<std::string, CNodePtr> *rollback_new_last_node_map) {
// connect forward last node and rollback first node
if (forward_origin_first_node_map->size() != forward_new_first_node_map->size() ||
rollback_origin_last_node_map->size() != rollback_new_last_node_map->size()) {
MS_LOG(EXCEPTION) << "The over border node is not equal in adasum forward process and rollback process.";
}
for (auto node : *forward_origin_first_node_map) {
std::string target_param = node.first;
CNodePtr forward_origin_first_node = node.second;
CNodePtr forward_new_first_node = (*forward_new_first_node_map)[target_param];
manager->SetEdge(forward_new_first_node, 1, forward_origin_first_node->input(1));
}
for (auto node : *rollback_origin_last_node_map) {
std::string target_param = node.first;
CNodePtr rollback_origin_last_node = node.second;
CNodePtr rollback_new_last_node = (*rollback_new_last_node_map)[target_param];
manager->Replace(rollback_origin_last_node, rollback_new_last_node);
}
}

void HandleAdasumAllReduce(const PrimitivePtr &prim, const RankList &group_devices) {
size_t step = size_t(GetValue<int64_t>(prim->GetAttr("step")));
std::vector<int64_t> neighbor_ids;
int64_t adasum_rank_distance = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
if (adasum_rank_distance < ADASUM_MIN_DIS) {
adasum_rank_distance = ADASUM_MIN_DIS;
}
size_t border_step = size_t(log2(adasum_rank_distance / ADASUM_MIN_DIS));
MS_LOG(INFO) << "current border step is: " << border_step;
if (step < border_step) {
return;
}
int64_t rank = g_device_manager->global_rank();
size_t double_d = size_t(2 << step);
for (size_t index = 0; index < double_d; ++index) {
int64_t node_rank = rank / ADASUM_MIN_DIS;
int64_t neighbor_id = (node_rank / double_d * double_d + index) * ADASUM_MIN_DIS + rank % ADASUM_MIN_DIS;
neighbor_ids.push_back(neighbor_id);
}
Group adasum_allreduce_group = g_device_manager->CreateGroup(neighbor_ids);
auto new_group_name = MakeValue(adasum_allreduce_group.name());
int64_t fusion_id = GetValue<int64_t>(prim->GetAttr("origin_fusion"));
int64_t new_fusion_id = fusion_id + g_device_manager->DeviceNum() * (border_step + 1);
prim->set_attr(GROUP, new_group_name);
prim->set_attr(FUSION, MakeValue(new_fusion_id));
}

void HandleAdasumSlice(const AnfNodePtr &stridedslice_node1, const std::shared_ptr<TensorLayout> &target_param_layout,
const std::string &target_param, size_t slice_expand_ratio) {
auto stridedslice_cnode1 = stridedslice_node1->cast<CNodePtr>();
ReplaceAdaSumStridedSliceValue(stridedslice_cnode1, target_param_layout, slice_expand_ratio);
auto squeeze_node = RealInputNode(stridedslice_cnode1, 1);
if (!IsPrimitiveCNode(squeeze_node, prim::kPrimSqueeze)) {
MS_LOG(EXCEPTION) << "The stridedslice input node should be squeeze in adasum";
}
auto squeeze_cnode = squeeze_node->cast<CNodePtr>();
FuncGraphManagerPtr manager = squeeze_node->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
AnfNodeIndexSet node_set = manager->node_users()[squeeze_cnode];
for (auto &node_pair : node_set) {
if (IsPrimitiveCNode(node_pair.first, prim::kPrimStridedSlice) && node_pair.first != stridedslice_node1) {
CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
ReplaceAdaSumStridedSliceValue(use_apply, target_param_layout, slice_expand_ratio);
}
}
}

void HandleAdaSumConcat(const AnfNodePtr &concat_node, const std::vector<bool> &border_info,
const std::string &target_param,
std::unordered_map<std::string, CNodePtr> *rollback_new_last_node_map,
std::unordered_map<std::string, CNodePtr> *rollback_origin_last_node_map) {
if (border_info[3]) {
(*rollback_new_last_node_map)[target_param] = concat_node->cast<CNodePtr>();
}
if (border_info[2]) {
auto manager = concat_node->func_graph()->manager();
AnfNodeIndexSet concat_node_user_set = manager->node_users()[concat_node];
for (auto &node_pair : concat_node_user_set) {
if (IsPrimitiveCNode(node_pair.first, prim::kPrimMakeTuple)) {
AnfNodeIndexSet make_tuple_node_user_set = manager->node_users()[node_pair.first];
for (auto &tuple_user : make_tuple_node_user_set) {
if (IsPrimitiveCNode(tuple_user.first, prim::kPrimConcat)) {
(*rollback_origin_last_node_map)[target_param] = tuple_user.first->cast<CNodePtr>();
return;
}
}
return;
}
}
}
}

void HandleAdaSumSqueeze(const AnfNodePtr &stridedslice_node1, const std::vector<bool> &border_info,
const std::string &target_param,
std::unordered_map<std::string, CNodePtr> *forward_origin_first_node_map,
std::unordered_map<std::string, CNodePtr> *forward_new_first_node_map) {
auto squeeze_node = RealInputNode(stridedslice_node1->cast<CNodePtr>(), 1);
if (border_info[0]) {
(*forward_origin_first_node_map)[target_param] = squeeze_node->cast<CNodePtr>();
}
if (border_info[1]) {
(*forward_new_first_node_map)[target_param] = squeeze_node->cast<CNodePtr>();
}
}

bool HandleAdaSum(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map) {
std::unordered_map<std::string, CNodePtr> forward_origin_first_node_map;
std::unordered_map<std::string, CNodePtr> forward_new_first_node_map;
std::unordered_map<std::string, CNodePtr> rollback_origin_last_node_map;
std::unordered_map<std::string, CNodePtr> rollback_new_last_node_map;
bool is_adasum = false;
for (auto &node : all_nodes) {
bool is_allreduce = IsPrimitiveCNode(node, prim::kPrimAllReduce);
bool is_reshape = IsPrimitiveCNode(node, prim::kPrimReshape);
bool is_send = IsPrimitiveCNode(node, prim::kPrimSend);
bool is_receive = IsPrimitiveCNode(node, prim::kPrimReceive);
if (!is_allreduce && !is_reshape && !is_send && !is_receive) {
continue;
}
std::string target_param;
CNodePtr cnode = node->cast<CNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)->cast<ValueNodePtr>());
if (!prim->HasAttr("target_param")) {
continue;
}
target_param = GetValue<std::string>(prim->GetAttr("target_param"));
auto target_param_layout = (*adasum_param_tensor_layout_map)[target_param];
RankList group_devices = GetRankListByLayout(target_param_layout);
int64_t adasum_rank_distance = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
// when the repeat dim is right, the parameter do not enable adasum.
if (adasum_rank_distance == 1 && group_devices.size() < size_t(g_device_manager->stage_device_num())) {
continue;
}
MS_LOG(INFO) << "Apply adasum in auto parallel, current dealing node is: " << node->fullname_with_scope();
is_adasum = true;
size_t slice_expand_ratio = adasum_rank_distance / ADASUM_MIN_DIS > 0 ? adasum_rank_distance / ADASUM_MIN_DIS : 1;
if (is_reshape) {
HandleAdaSumReshape(cnode, (*adasum_param_tensor_layout_map)[target_param]);
}
if (is_allreduce && prim->HasAttr("step")) {
HandleAdasumAllReduce(prim, group_devices);
}
if (is_send || is_receive) {
std::vector<bool> border_info = IsBorderAdaSumSendReceive(node, group_devices);
if (is_receive) {
auto target_param_info = std::make_shared<TensorInfo>(*target_param_layout);
Dimensions param_strategy = target_param_info->InferStrategy();
Shape new_rec_shape = ValueSequeueScaleToShape(prim->GetAttr(SHAPE), param_strategy, slice_expand_ratio);
auto new_rec_shape_value = TransVectorToValueSequeue<ValueList>(new_rec_shape);
prim->set_attr(SHAPE, new_rec_shape_value);
continue;
}
auto stridedslice_node1 = RealInputNode(cnode, 1);
if (IsPrimitiveCNode(stridedslice_node1, prim::kPrimConcat)) {
HandleAdaSumConcat(stridedslice_node1, border_info, target_param, &rollback_new_last_node_map,
&rollback_origin_last_node_map);
continue;
}
if (!IsPrimitiveCNode(stridedslice_node1, prim::kPrimStridedSlice)) {
continue;
}
HandleAdasumSlice(stridedslice_node1, target_param_layout, target_param, slice_expand_ratio);
HandleAdaSumSqueeze(stridedslice_node1, border_info, target_param, &forward_origin_first_node_map,
&forward_new_first_node_map);
}
}
RemoveAdasumRedundantNodes(root->manager(), &forward_origin_first_node_map, &forward_new_first_node_map,
&rollback_origin_last_node_map, &rollback_new_last_node_map);
return is_adasum;
}

void ResetMirrorAttr(const PrimitivePtr &prim, const RankList &new_group) {
if (new_group.size() == 1) {
prim->set_attr(DEV_NUM, MakeValue(new_group.size()));
prim->set_attr(GROUP, MakeValue("one_rank_group"));
prim->set_attr(GROUP_RANKS, MakeValue(std::to_string(new_group[0])));
return;
}
Group adasum_mirror_group = g_device_manager->CreateGroup(new_group);
auto new_group_name = MakeValue(adasum_mirror_group.name());
prim->set_attr(GROUP, new_group_name);
prim->set_attr(DEV_NUM, MakeValue(new_group.size()));
std::string rank_list_name = g_device_manager->FindRankListNameByHashName(adasum_mirror_group.name());
prim->set_attr(GROUP_RANKS, MakeValue(rank_list_name));
}

void HandleMirrorInAdaSum(
const FuncGraphPtr &root,
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map) {
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(root->get_return());
for (auto &node : all_nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimMirror)) {
continue;
}
CNodePtr mirror_cnode = node->cast<CNodePtr>();
auto param_node_pair = FindParameter(mirror_cnode->input(1), node->func_graph());
if (!param_node_pair.first) {
MS_LOG(EXCEPTION) << "Mirror input is not a param";
}
auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
std::string param_name = param_ptr->name();
MS_LOG(INFO) << "Mirror param name is: " << param_name;
std::string target_param = "adasum_delta_weight." + param_name;
auto target_param_layout = (*adasum_param_tensor_layout_map)[target_param];

// Change mirror group
RankList group_devices = GetRankListByLayout(target_param_layout);
int64_t rank = g_device_manager->global_rank();
size_t group_dis = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
auto prim = GetCNodePrimitive(node);
if (group_dis < ADASUM_MIN_DIS) {
size_t new_group_size = size_t(ADASUM_MIN_DIS) / group_dis;
// compute new group range
size_t group_begin = 0;
for (size_t group_end = new_group_size; group_end < group_devices.size() + new_group_size;
group_end += new_group_size) {
int64_t max_group_value =
group_end >= group_devices.size() ? (group_devices.back() + 1) : group_devices[group_end];
if (group_devices[group_begin] <= rank && rank < max_group_value) {
std::vector<int64_t> new_group(group_devices.begin() + group_begin, group_devices.begin() + group_end);
MS_LOG(INFO) << "Find new mirror group in adasum: " << new_group << " target_param:" << target_param;
ResetMirrorAttr(prim, new_group);
break;
}
group_begin = group_end;
}
continue;
}
ResetMirrorAttr(prim, {rank});
}
}

void HandleAdaFactorOpt(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
for (auto &param_node : root->parameters()) {


+ 9
- 0
mindspore/ccsrc/frontend/parallel/parameter_manager.h View File

@@ -20,6 +20,8 @@
#include <vector>
#include <string>
#include <utility>
#include <memory>
#include <unordered_map>
#include "base/base.h"
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/step_parallel_utils.h"
@@ -40,6 +42,13 @@ void HandleNoUsedParameter(const FuncGraphPtr &root);
void HandleFullySplitParameters(const FuncGraphPtr &root);
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root);
void HandleAdaFactorOpt(const FuncGraphPtr &root);
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root);
bool HandleAdaSum(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map);
void HandleMirrorInAdaSum(
const FuncGraphPtr &root,
std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map);
bool ParameterIsCloned(const AnfNodePtr &parameter_node);
bool IsFullySplitParameter(const ParameterPtr &param_ptr, size_t allow_repeat_num = 1);
} // namespace parallel


+ 12
- 77
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -944,79 +944,6 @@ void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
}
}

static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
if (IsValueNode<RefKey>(node)) {
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
if (param_v.size() != 1) {
MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is "
<< param_v.size();
}
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
return std::make_pair(nullptr, true);
}
return std::make_pair(node, true);
}
return std::make_pair(nullptr, false);
}

static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node) {
auto param_ptr = node->user_data<parallel::TensorLayout>();
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
return std::make_pair(nullptr, false);
}
return std::make_pair(node, false);
}

// Only used for InsertMirrorOps
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
return std::make_pair(nullptr, false);
}

if (node->isa<Parameter>()) {
return FindParameterByParameter(node);
}

if (node->isa<ValueNode>()) {
return FindParameterByValueNode(node, func_graph);
}

CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
auto res = FindParameter(cnode->input(index), func_graph);
if (!res.first) {
continue;
}
return res;
}
}

// When not fully use opt shard, allgather and mirror would be both inserted.
// Skip allgather here and find parameter recursively.
if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) {
return std::make_pair(nullptr, false);
}

ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim);
if ((prim->name() == DEPEND || prim->name() == LOAD || IsInAllGatherNodeList(cnode)) && index != 1) {
continue;
}
auto res = FindParameter(cnode->input(index), func_graph);
if (!res.first) {
continue;
}
return res;
}
return std::make_pair(nullptr, false);
}

// only used for FindCNode
CNodePtr SkipTrivialNodesMoveDown(const FuncGraphManagerPtr &manager, CNodePtr node) {
MS_EXCEPTION_IF_NULL(node);
@@ -2906,6 +2833,10 @@ void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes)
if (prim->name() != RESHAPE) {
continue;
}
Shape origin_dst_shape = GetValue<std::vector<int64_t>>(cnode->input(2)->cast<ValueNodePtr>()->value());
if (origin_dst_shape.size() == 1 && origin_dst_shape[0] == -1) {
continue;
}
auto root = node->func_graph();
auto grad_node = FindGrad(cnode, 0);
if (grad_node) {
@@ -3185,10 +3116,8 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
(root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) {
if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) {
if (HasStrategy(root)) {
MS_LOG(INFO) << "Strategies ignored in " << parallel_mode
<< ", set_strategy() only valid in [semi_]auto_parallel.";
}
MS_LOG(WARNING) << "Strategies would be ignored in " << parallel_mode
<< ", shard() only valid in [semi_]auto_parallel.";
root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
}
ReorderForPipelineSplit(root, manager, pipeline_stages);
@@ -3254,12 +3183,18 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)

HandleAdaFactorOpt(root);

auto adasum_param_tensor_layout_map = AdaSumParamTensorLayout(root);
bool is_apply_adasum = HandleAdaSum(root, all_nodes, &adasum_param_tensor_layout_map);

// save strategy as checkpoint for multi-train
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
CheckpointStrategy(all_nodes, root);
}
// ForwardCommunication BackwardCommunication TensorRedistribution
ParallelCommunication(root, all_nodes, manager);
if (is_apply_adasum) {
HandleMirrorInAdaSum(root, &adasum_param_tensor_layout_map);
}

PipelinePostProcess(root, all_nodes);



+ 0
- 2
mindspore/ccsrc/frontend/parallel/step_parallel.h View File

@@ -86,8 +86,6 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node);

void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node);

std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph);

std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph,
size_t max_depth);



+ 0
- 26
mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc View File

@@ -477,32 +477,6 @@ AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, cons
return new_node;
}

AnfNodePtr RealInputNode(const CNodePtr cnode, size_t index) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() <= index) {
MS_LOG(EXCEPTION) << "cnode inputs size: " << cnode->size() << " is less equal index: " << index;
}
auto input0 = cnode->input(index);
if (!input0->isa<CNode>()) {
return input0;
}
auto prim = GetCNodePrimitive(input0);
MS_EXCEPTION_IF_NULL(prim);
while (prim->name() == LOAD || prim->name() == DEPEND || prim->name() == UPDATESTATE) {
if (prim->name() == LOAD || prim->name() == DEPEND) {
input0 = input0->cast<CNodePtr>()->input(1);
} else if (prim->name() == UPDATESTATE) {
input0 = input0->cast<CNodePtr>()->input(2);
}
if (!input0->isa<CNode>()) {
return input0;
}
prim = GetCNodePrimitive(input0);
MS_EXCEPTION_IF_NULL(prim);
}
return input0;
}

void LabelGenMaskMicro(const FuncGraphPtr &root) {
AnfNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret);


+ 0
- 1
mindspore/ccsrc/frontend/parallel/step_parallel_utils.h View File

@@ -31,7 +31,6 @@ const int64_t TWO_INPUT_SIZE = 2;
// common method
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
bool IsParallelCareNode(const CNodePtr &cnode);
AnfNodePtr RealInputNode(const CNodePtr cnode, size_t index);
Shapes GetNodeShape(const AnfNodePtr &node);
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
const CNodePtr &node);


Loading…
Cancel
Save