| @@ -1571,7 +1571,7 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf | |||
| } | |||
| static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res, | |||
| const AnfNodePtr &node) { | |||
| const AnfNodePtr &node, const std::string &op_name) { | |||
| MS_EXCEPTION_IF_NULL(res.first); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = res.first->cast<CNodePtr>(); | |||
| @@ -1579,10 +1579,9 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_EXCEPTION_IF_NULL(cnode_prim); | |||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||
| Operator op; | |||
| CNodePtr allgather; | |||
| if (grad_accumulation_step > 1) { | |||
| if (op_name == MINI_STEP_ALL_GATHER) { | |||
| op = CreateMiniStepAllGatherOp(group); | |||
| auto param_name = node->cast<ParameterPtr>()->name(); | |||
| if (cnode_prim->name() == CAST) { | |||
| @@ -1613,21 +1612,41 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & | |||
| } | |||
| FuncGraphManagerPtr manager = root->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||
| std::string op_name; | |||
| if (grad_accumulation_step > 1) { | |||
| op_name = MINI_STEP_ALL_GATHER; | |||
| } else { | |||
| op_name = ALL_GATHER; | |||
| } | |||
| auto param_sub_set = manager->node_users()[parameter]; | |||
| bool insert_flag = false; | |||
| for (auto ¶m_pair : param_sub_set) { | |||
| auto cnode = param_pair.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->in_forward_flag()) { | |||
| OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>(); | |||
| if (distribute_operator == nullptr) { | |||
| MS_LOG(WARNING) << "Parallel optimizer: " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; | |||
| MS_LOG(WARNING) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr"; | |||
| } else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) { | |||
| MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " | |||
| << distribute_operator->inputs_tensor_info().size(); | |||
| } | |||
| // insert allgather operator between shard parameter and cnode | |||
| InsertAllGatherOp(root, opt_shard_group, param_pair, parameter); | |||
| MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString(); | |||
| if (insert_flag) { | |||
| auto next_cnode = FindCNode(parameter, op_name, cnode->func_graph()); | |||
| if (next_cnode.first) { | |||
| manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second); | |||
| MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " | |||
| << GetPrimName(cnode); | |||
| continue; | |||
| } | |||
| } else { | |||
| // insert allgather operator between shard parameter and cnode | |||
| InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name); | |||
| MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " | |||
| << GetPrimName(cnode); | |||
| insert_flag = true; | |||
| } | |||
| } | |||
| } | |||
| } | |||