diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index c140b53b75..516bb40ccc 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -2760,15 +2760,16 @@ bool IsCohesiveNode(const CNodePtr &cnode) { IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather); } -std::vector> NodeParameterName(const CNodePtr &node) { +std::vector> NodeParameterName(const CNodePtr &node, int64_t index) { std::vector node_inputs{node->inputs()}; std::vector> param_names; for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) { + int64_t idx = index > i ? index : i; auto input = node_inputs[i]; if (input->isa()) { auto input_parameter = input->cast(); if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) { - param_names.push_back({input_parameter->name(), i}); + param_names.push_back({input_parameter->name(), idx}); } } else if (input->isa()) { CNodePtr cnode = input->cast(); @@ -2776,7 +2777,7 @@ std::vector> NodeParameterName(const CNodePtr &n continue; } if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) { - auto input_param_names = NodeParameterName(cnode); + auto input_param_names = NodeParameterName(cnode, idx); param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end()); } } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 69e2f4662f..d19ac3a13a 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -139,7 +139,7 @@ bool IsLastStage(); void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager); -std::vector> NodeParameterName(const CNodePtr &node); +std::vector> NodeParameterName(const CNodePtr &node, int64_t index = -1); void CheckpointStrategy(const std::vector &all_nodes);