Browse Source

!15810 fix gather_p_info judgement

From: @yao_yf
Reviewed-by: @stsuteng,@yangzhenzhang
Signed-off-by: @stsuteng
pull/15810/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
a2a24f7833
3 changed files with 25 additions and 15 deletions
  1. +1
    -1
      mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.cc
  2. +13
    -12
      mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
  3. +11
    -2
      mindspore/ccsrc/frontend/parallel/step_parallel.cc

+ 1
- 1
mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.cc View File

@@ -95,7 +95,7 @@ Status GatherNdInfo::InferTensorMap() {

// cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices.
TensorMap indices_tensor_map;
int64_t size = SizeToLong(inputs_shape_[0].size());
int64_t size = SizeToLong(inputs_shape_[1].size());
for (int64_t i = 0; i < size; ++i) {
indices_tensor_map.push_back(size - i - 1);
}


+ 13
- 12
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc View File

@@ -354,6 +354,16 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
return operator_info;
}

bool IsFindWrong(const OperatorInfoPtr current_op_ptr, const std::string &prim_name) {
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
(current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
(current_op_ptr->name().find(prim_name + "Info") == std::string::npos);
if (prim_name == GATHERV2) {
is_find_wrong = is_find_wrong && (current_op_ptr->name().find(prim_name + "PInfo") == std::string::npos);
}
return is_find_wrong;
}

// Using CNode's UniqueIds to construct nodes
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
@@ -399,10 +409,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
if (single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) {
const auto &current_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
(current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
(current_op_ptr->name().find(prim->name()) == std::string::npos);
if (is_find_wrong) {
if (IsFindWrong(current_op_ptr, prim->name())) {
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
<< " does not match the Prim: " << prim->name()
<< ". The fullname_with_scope: " << cnode->fullname_with_scope();
@@ -456,10 +463,7 @@ void SetOperatorToCNode(const OperatorInfoPtr &current_op_ptr, const PrimitivePt
if (current_op_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed.";
} else {
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
(current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
(current_op_ptr->name().find(prim->name()) == std::string::npos);
if (is_find_wrong) {
if (IsFindWrong(current_op_ptr, prim->name())) {
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
<< " does not match the Prim: " << prim->name();
}
@@ -518,10 +522,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
bool is_op_created = single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size());
if (is_op_created) {
const auto &current_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
(current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
(current_op_ptr->name().find(prim->name()) == std::string::npos);
if (is_find_wrong) {
if (IsFindWrong(current_op_ptr, prim->name())) {
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
<< " does not match the Prim: " << prim->name()
<< ". The fullname_with_scope: " << cnode->fullname_with_scope();


+ 11
- 2
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -2842,6 +2842,16 @@ std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &n
return param_names;
}

bool IsGatherPInfo(const std::string &name) {
std::vector<std::string> gather_p_info_names = {"GatherPInfo", "SparseGatherV2Info", "EmbeddingLookupInfo"};
for (std::string info_name : gather_p_info_names) {
if (name.find(info_name) != std::string::npos) {
return true;
}
}
return false;
}

void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
StrategyMap stra_map;
TensorInfoMap tensor_info_map;
@@ -2873,8 +2883,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
}
tensor_info_map[param_name_pair.first] = input_tensor_info[param_name_pair.second - 1];
}
if (operator_info->name().find(EMBEDDING_LOOKUP) != std::string::npos ||
operator_info->name().find(GATHERV2) != std::string::npos) {
if (IsGatherPInfo(operator_info->name())) {
auto gatherv2_info = std::dynamic_pointer_cast<GatherPInfo>(operator_info);
auto param_split_shapes = gatherv2_info->param_split_shapes();
auto index_offsets = gatherv2_info->index_offsets();


Loading…
Cancel
Save