From d7641123bb3bc1281bd42438479d89626fb44711 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Tue, 13 Apr 2021 21:14:26 +0800 Subject: [PATCH] strategy_ckpt_file_adapt_optimizer_shard --- .../ccsrc/frontend/parallel/step_parallel.cc | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 8f79fc4c17..e7f6ec9b01 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -2755,6 +2755,11 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector> NodeParameterName(const CNodePtr &node) { std::vector node_inputs{node->inputs()}; std::vector> param_names; @@ -2768,17 +2773,11 @@ std::vector> NodeParameterName(const CNodePtr &n } else if (input->isa()) { CNodePtr cnode = input->cast(); if (!IsValueNode(cnode->input(0))) { - return param_names; + continue; } - if ((IsPrimitiveCNode(cnode, prim::kPrimCast) && cnode->inputs().size() >= 1) || - IsPrimitiveCNode(cnode, prim::kPrimLoad)) { - auto inp = cnode->input(1); - if (inp->isa()) { - auto inp_param = inp->cast(); - if (inp_param->has_default() && ParameterRequireGrad(inp_param)) { - param_names.push_back({inp_param->name(), i}); - } - } + if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) { + auto input_param_names = NodeParameterName(cnode); + param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end()); } } } @@ -3517,11 +3516,6 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) // if the input or parameter has multiple users, check whether its split strategies are consistent. CheckParameterSplit(all_nodes); - // save strategy as checkpoint for multi-train - if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { - CheckpointStrategy(all_nodes); - } - HandleSymbolicKeyInstance(root, all_nodes); // cover Parallel shape @@ -3533,6 +3527,10 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) // set the shape for optimizer's clone tensor SetClonedTensorShapeForOptimizer(root); + // save strategy as checkpoint for multi-train + if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { + CheckpointStrategy(all_nodes); + } // ForwardCommunication BackwardCommunication TensorRedistribution ParallelCommunication(root, all_nodes, manager);