Browse Source

fix find nodes with param

pull/15253/head
yao_yf 4 years ago
parent
commit
17354e3c4e
2 changed files with 5 additions and 4 deletions
  1. +4
    -3
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  2. +1
    -1
      mindspore/ccsrc/frontend/parallel/step_parallel.h

+ 4
- 3
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -2760,15 +2760,16 @@ bool IsCohesiveNode(const CNodePtr &cnode) {
IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather);
}

std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node) {
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index) {
std::vector<AnfNodePtr> node_inputs{node->inputs()};
std::vector<std::pair<std::string, int64_t>> param_names;
for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) {
int64_t idx = index > i ? index : i;
auto input = node_inputs[i];
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) {
param_names.push_back({input_parameter->name(), i});
param_names.push_back({input_parameter->name(), idx});
}
} else if (input->isa<CNode>()) {
CNodePtr cnode = input->cast<CNodePtr>();
@@ -2776,7 +2777,7 @@ std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &n
continue;
}
if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) {
auto input_param_names = NodeParameterName(cnode);
auto input_param_names = NodeParameterName(cnode, idx);
param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end());
}
}


+ 1
- 1
mindspore/ccsrc/frontend/parallel/step_parallel.h View File

@@ -139,7 +139,7 @@ bool IsLastStage();
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager);

std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node);
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index = -1);

void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes);



Loading…
Cancel
Save