Browse Source

!15113 strategy_ckpt_file_adapt_optimizer_shard

From: @yao_yf
Reviewed-by: @stsuteng,@yangzhenzhang
Signed-off-by: @stsuteng
pull/15113/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
33e5a60bb6
1 changed files with 13 additions and 15 deletions
  1. +13
    -15
      mindspore/ccsrc/frontend/parallel/step_parallel.cc

+ 13
- 15
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -2755,6 +2755,11 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
}
}

bool IsCohesiveNode(const CNodePtr &cnode) {
return IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather);
}

std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node) {
std::vector<AnfNodePtr> node_inputs{node->inputs()};
std::vector<std::pair<std::string, int64_t>> param_names;
@@ -2768,17 +2773,11 @@ std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &n
} else if (input->isa<CNode>()) {
CNodePtr cnode = input->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0))) {
return param_names;
continue;
}
if ((IsPrimitiveCNode(cnode, prim::kPrimCast) && cnode->inputs().size() >= 1) ||
IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
auto inp = cnode->input(1);
if (inp->isa<Parameter>()) {
auto inp_param = inp->cast<ParameterPtr>();
if (inp_param->has_default() && ParameterRequireGrad(inp_param)) {
param_names.push_back({inp_param->name(), i});
}
}
if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) {
auto input_param_names = NodeParameterName(cnode);
param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end());
}
}
}
@@ -3517,11 +3516,6 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// if the input or parameter has multiple users, check whether its split strategies are consistent.
CheckParameterSplit(all_nodes);

// save strategy as checkpoint for multi-train
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
CheckpointStrategy(all_nodes);
}

HandleSymbolicKeyInstance(root, all_nodes);

// cover Parallel shape
@@ -3533,6 +3527,10 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// set the shape for optimizer's clone tensor
SetClonedTensorShapeForOptimizer(root);

// save strategy as checkpoint for multi-train
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
CheckpointStrategy(all_nodes);
}
// ForwardCommunication BackwardCommunication TensorRedistribution
ParallelCommunication(root, all_nodes, manager);



Loading…
Cancel
Save