| @@ -2755,6 +2755,11 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo | |||
| } | |||
| } | |||
| bool IsCohesiveNode(const CNodePtr &cnode) { | |||
| return IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) || | |||
| IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather); | |||
| } | |||
| std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node) { | |||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||
| std::vector<std::pair<std::string, int64_t>> param_names; | |||
| @@ -2768,17 +2773,11 @@ std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &n | |||
| } else if (input->isa<CNode>()) { | |||
| CNodePtr cnode = input->cast<CNodePtr>(); | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| return param_names; | |||
| continue; | |||
| } | |||
| if ((IsPrimitiveCNode(cnode, prim::kPrimCast) && cnode->inputs().size() >= 1) || | |||
| IsPrimitiveCNode(cnode, prim::kPrimLoad)) { | |||
| auto inp = cnode->input(1); | |||
| if (inp->isa<Parameter>()) { | |||
| auto inp_param = inp->cast<ParameterPtr>(); | |||
| if (inp_param->has_default() && ParameterRequireGrad(inp_param)) { | |||
| param_names.push_back({inp_param->name(), i}); | |||
| } | |||
| } | |||
| if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) { | |||
| auto input_param_names = NodeParameterName(cnode); | |||
| param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end()); | |||
| } | |||
| } | |||
| } | |||
| @@ -3517,11 +3516,6 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||
| // if the input or parameter has multiple users, check whether its split strategies are consistent. | |||
| CheckParameterSplit(all_nodes); | |||
| // save strategy as checkpoint for multi-train | |||
| if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { | |||
| CheckpointStrategy(all_nodes); | |||
| } | |||
| HandleSymbolicKeyInstance(root, all_nodes); | |||
| // cover Parallel shape | |||
| @@ -3533,6 +3527,10 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||
| // set the shape for optimizer's clone tensor | |||
| SetClonedTensorShapeForOptimizer(root); | |||
| // save strategy as checkpoint for multi-train | |||
| if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { | |||
| CheckpointStrategy(all_nodes); | |||
| } | |||
| // ForwardCommunication BackwardCommunication TensorRedistribution | |||
| ParallelCommunication(root, all_nodes, manager); | |||