|
|
|
@@ -2755,6 +2755,11 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool IsCohesiveNode(const CNodePtr &cnode) { |
|
|
|
return IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) || |
|
|
|
IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node) { |
|
|
|
std::vector<AnfNodePtr> node_inputs{node->inputs()}; |
|
|
|
std::vector<std::pair<std::string, int64_t>> param_names; |
|
|
|
@@ -2768,17 +2773,11 @@ std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &n |
|
|
|
} else if (input->isa<CNode>()) { |
|
|
|
CNodePtr cnode = input->cast<CNodePtr>(); |
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
return param_names; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if ((IsPrimitiveCNode(cnode, prim::kPrimCast) && cnode->inputs().size() >= 1) || |
|
|
|
IsPrimitiveCNode(cnode, prim::kPrimLoad)) { |
|
|
|
auto inp = cnode->input(1); |
|
|
|
if (inp->isa<Parameter>()) { |
|
|
|
auto inp_param = inp->cast<ParameterPtr>(); |
|
|
|
if (inp_param->has_default() && ParameterRequireGrad(inp_param)) { |
|
|
|
param_names.push_back({inp_param->name(), i}); |
|
|
|
} |
|
|
|
} |
|
|
|
if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) { |
|
|
|
auto input_param_names = NodeParameterName(cnode); |
|
|
|
param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -3517,11 +3516,6 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) |
|
|
|
// if the input or parameter has multiple users, check whether its split strategies are consistent. |
|
|
|
CheckParameterSplit(all_nodes); |
|
|
|
|
|
|
|
// save strategy as checkpoint for multi-train |
|
|
|
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { |
|
|
|
CheckpointStrategy(all_nodes); |
|
|
|
} |
|
|
|
|
|
|
|
HandleSymbolicKeyInstance(root, all_nodes); |
|
|
|
|
|
|
|
// cover Parallel shape |
|
|
|
@@ -3533,6 +3527,10 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) |
|
|
|
// set the shape for optimizer's clone tensor |
|
|
|
SetClonedTensorShapeForOptimizer(root); |
|
|
|
|
|
|
|
// save strategy as checkpoint for multi-train |
|
|
|
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { |
|
|
|
CheckpointStrategy(all_nodes); |
|
|
|
} |
|
|
|
// ForwardCommunication BackwardCommunication TensorRedistribution |
|
|
|
ParallelCommunication(root, all_nodes, manager); |
|
|
|
|
|
|
|
|