From 17354e3c4e4e2a8621def5217a448f77906ec02c Mon Sep 17 00:00:00 2001 From: yao_yf Date: Fri, 16 Apr 2021 09:19:57 +0800 Subject: [PATCH] fix find nodes with param --- mindspore/ccsrc/frontend/parallel/step_parallel.cc | 7 ++++--- mindspore/ccsrc/frontend/parallel/step_parallel.h | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index c140b53b75..516bb40ccc 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -2760,15 +2760,16 @@ bool IsCohesiveNode(const CNodePtr &cnode) { IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather); } -std::vector> NodeParameterName(const CNodePtr &node) { +std::vector> NodeParameterName(const CNodePtr &node, int64_t index) { std::vector node_inputs{node->inputs()}; std::vector> param_names; for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) { + int64_t idx = index > i ? index : i; auto input = node_inputs[i]; if (input->isa()) { auto input_parameter = input->cast(); if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) { - param_names.push_back({input_parameter->name(), i}); + param_names.push_back({input_parameter->name(), idx}); } } else if (input->isa()) { CNodePtr cnode = input->cast(); @@ -2776,7 +2777,7 @@ std::vector> NodeParameterName(const CNodePtr &n continue; } if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) { - auto input_param_names = NodeParameterName(cnode); + auto input_param_names = NodeParameterName(cnode, idx); param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end()); } } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 69e2f4662f..d19ac3a13a 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -139,7 +139,7 @@ bool IsLastStage(); void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager); -std::vector> NodeParameterName(const CNodePtr &node); +std::vector> NodeParameterName(const CNodePtr &node, int64_t index = -1); void CheckpointStrategy(const std::vector &all_nodes);