|
|
@@ -1571,7 +1571,7 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res, |
|
|
static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res, |
|
|
const AnfNodePtr &node) { |
|
|
|
|
|
|
|
|
const AnfNodePtr &node, const std::string &op_name) { |
|
|
MS_EXCEPTION_IF_NULL(res.first); |
|
|
MS_EXCEPTION_IF_NULL(res.first); |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
auto cnode = res.first->cast<CNodePtr>(); |
|
|
auto cnode = res.first->cast<CNodePtr>(); |
|
|
@@ -1579,10 +1579,9 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0)); |
|
|
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0)); |
|
|
MS_EXCEPTION_IF_NULL(cnode_prim); |
|
|
MS_EXCEPTION_IF_NULL(cnode_prim); |
|
|
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); |
|
|
|
|
|
Operator op; |
|
|
Operator op; |
|
|
CNodePtr allgather; |
|
|
CNodePtr allgather; |
|
|
if (grad_accumulation_step > 1) { |
|
|
|
|
|
|
|
|
if (op_name == MINI_STEP_ALL_GATHER) { |
|
|
op = CreateMiniStepAllGatherOp(group); |
|
|
op = CreateMiniStepAllGatherOp(group); |
|
|
auto param_name = node->cast<ParameterPtr>()->name(); |
|
|
auto param_name = node->cast<ParameterPtr>()->name(); |
|
|
if (cnode_prim->name() == CAST) { |
|
|
if (cnode_prim->name() == CAST) { |
|
|
@@ -1613,21 +1612,41 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & |
|
|
} |
|
|
} |
|
|
FuncGraphManagerPtr manager = root->manager(); |
|
|
FuncGraphManagerPtr manager = root->manager(); |
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
|
|
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); |
|
|
|
|
|
std::string op_name; |
|
|
|
|
|
if (grad_accumulation_step > 1) { |
|
|
|
|
|
op_name = MINI_STEP_ALL_GATHER; |
|
|
|
|
|
} else { |
|
|
|
|
|
op_name = ALL_GATHER; |
|
|
|
|
|
} |
|
|
auto param_sub_set = manager->node_users()[parameter]; |
|
|
auto param_sub_set = manager->node_users()[parameter]; |
|
|
|
|
|
bool insert_flag = false; |
|
|
for (auto ¶m_pair : param_sub_set) { |
|
|
for (auto ¶m_pair : param_sub_set) { |
|
|
auto cnode = param_pair.first->cast<CNodePtr>(); |
|
|
auto cnode = param_pair.first->cast<CNodePtr>(); |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
if (cnode->in_forward_flag()) { |
|
|
if (cnode->in_forward_flag()) { |
|
|
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>(); |
|
|
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>(); |
|
|
if (distribute_operator == nullptr) { |
|
|
if (distribute_operator == nullptr) { |
|
|
MS_LOG(WARNING) << "Parallel optimizer: " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; |
|
|
|
|
|
|
|
|
MS_LOG(WARNING) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr"; |
|
|
} else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) { |
|
|
} else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) { |
|
|
MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " |
|
|
MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " |
|
|
<< distribute_operator->inputs_tensor_info().size(); |
|
|
<< distribute_operator->inputs_tensor_info().size(); |
|
|
} |
|
|
} |
|
|
// insert allgather operator between shard parameter and cnode |
|
|
|
|
|
InsertAllGatherOp(root, opt_shard_group, param_pair, parameter); |
|
|
|
|
|
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString(); |
|
|
|
|
|
|
|
|
if (insert_flag) { |
|
|
|
|
|
auto next_cnode = FindCNode(parameter, op_name, cnode->func_graph()); |
|
|
|
|
|
if (next_cnode.first) { |
|
|
|
|
|
manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second); |
|
|
|
|
|
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " |
|
|
|
|
|
<< GetPrimName(cnode); |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
// insert allgather operator between shard parameter and cnode |
|
|
|
|
|
InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name); |
|
|
|
|
|
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " |
|
|
|
|
|
<< GetPrimName(cnode); |
|
|
|
|
|
insert_flag = true; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|