From: @lichen666 Reviewed-by: @stsuteng Signed-off-by: @stsutengtags/v1.2.0-rc1
| @@ -69,6 +69,8 @@ namespace mindspore { | |||
| namespace session { | |||
| const size_t kInvalidIndex = SIZE_MAX; | |||
| constexpr size_t kReturnDataIndex = 1; | |||
| constexpr char SR_TAG[] = "sr_tag"; | |||
| constexpr char BACKWARD[] = "backward"; | |||
| namespace { | |||
| void DumpGraphExeOrder(const std::vector<CNodePtr> &execution_order, const std::string &tag = "") { | |||
| MS_LOG(INFO) << "Dump execution_order size " << execution_order.size(); | |||
| @@ -460,6 +462,90 @@ GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNode | |||
| return graph_id; | |||
| } | |||
| bool IsBackward(const CNodePtr &cnode) { | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| return prim->HasAttr(BACKWARD); | |||
| } | |||
| // compare the value of send/recv sr_tag | |||
| bool comp(const CNodePtr &node1, const CNodePtr &node2) { | |||
| auto prim1 = GetValueNode<PrimitivePtr>(node1->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim1); | |||
| auto prim2 = GetValueNode<PrimitivePtr>(node1->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim2); | |||
| auto sr_tag_value1 = prim1->GetAttr(SR_TAG); | |||
| MS_EXCEPTION_IF_NULL(sr_tag_value1); | |||
| auto sr_tag_value2 = prim2->GetAttr(SR_TAG); | |||
| MS_EXCEPTION_IF_NULL(sr_tag_value2); | |||
| auto sr_tag1 = GetValue<int64_t>(sr_tag_value1); | |||
| auto sr_tag2 = GetValue<int64_t>(sr_tag_value2); | |||
| return sr_tag1 < sr_tag2; | |||
| } | |||
| // Reorder the execution order of send | |||
| void ReorderSend(std::vector<CNodePtr> *execution_order, std::vector<CNodePtr> op_v) { | |||
| auto last_node = op_v.back(); | |||
| for (auto &node : op_v) { | |||
| if (node == last_node) { | |||
| continue; | |||
| } | |||
| auto node_iter = std::find(execution_order->begin(), execution_order->end(), node); | |||
| (void)execution_order->erase(node_iter); | |||
| } | |||
| std::sort(op_v.begin(), op_v.end(), comp); | |||
| auto last_node_iter = std::find(execution_order->begin(), execution_order->end(), last_node); | |||
| auto node_iter = execution_order->erase(last_node_iter); | |||
| // all send will insert the end of the last node | |||
| execution_order->insert(node_iter, op_v.begin(), op_v.end()); | |||
| } | |||
| // Reorder the execution order of receive | |||
| void ReorderRecv(std::vector<CNodePtr> *execution_order, std::vector<CNodePtr> op_v) { | |||
| auto begin_node = op_v.front(); | |||
| for (auto &node : op_v) { | |||
| if (node == begin_node) { | |||
| continue; | |||
| } | |||
| auto node_iter = std::find(execution_order->begin(), execution_order->end(), node); | |||
| (void)execution_order->erase(node_iter); | |||
| } | |||
| std::sort(op_v.begin(), op_v.end(), comp); | |||
| auto begin_node_iter = std::find(execution_order->begin(), execution_order->end(), begin_node); | |||
| auto node_iter = execution_order->erase(begin_node_iter); | |||
| // all receive will insert before the begin node | |||
| execution_order->insert(node_iter, op_v.begin(), op_v.end()); | |||
| } | |||
| void ReorderSendRecv(std::vector<CNodePtr> *execution_order) { | |||
| std::vector<CNodePtr> forward_send, forward_recv, backward_send, backward_recv; | |||
| for (auto &cnode : *execution_order) { | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimSend) && IsBackward(cnode)) { | |||
| backward_send.push_back(cnode); | |||
| continue; | |||
| } else if (IsPrimitiveCNode(cnode, prim::kPrimSend)) { | |||
| forward_send.push_back(cnode); | |||
| continue; | |||
| } | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimReceive) && IsBackward(cnode)) { | |||
| backward_recv.push_back(cnode); | |||
| } else if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) { | |||
| forward_recv.push_back(cnode); | |||
| } | |||
| } | |||
| if (!forward_send.empty()) { | |||
| ReorderSend(execution_order, forward_send); | |||
| } | |||
| if (!backward_send.empty()) { | |||
| ReorderSend(execution_order, backward_send); | |||
| } | |||
| if (!forward_recv.empty()) { | |||
| ReorderRecv(execution_order, forward_recv); | |||
| } | |||
| if (!backward_recv.empty()) { | |||
| ReorderRecv(execution_order, backward_recv); | |||
| } | |||
| } | |||
| GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { | |||
| MS_LOG(INFO) << "Start"; | |||
| std::vector<KernelGraphPtr> all_graphs; | |||
| @@ -520,6 +606,11 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { | |||
| // adjust kernel | |||
| AdjustKernel(root_graph); | |||
| // reorder send/recv | |||
| auto execution_order = root_graph->execution_order(); | |||
| ReorderSendRecv(&execution_order); | |||
| root_graph->set_execution_order(execution_order); | |||
| #if ENABLE_CPU && ENABLE_D | |||
| InitPsWorker(root_graph); | |||
| #endif | |||
| @@ -28,6 +28,7 @@ | |||
| #include "frontend/parallel/context.h" | |||
| #include "frontend/parallel/step_parallel.h" | |||
| #include "frontend/parallel/node_check.h" | |||
| #include "frontend/parallel/graph_util/node_info.h" | |||
| #include "ir/anf.h" | |||
| #include "base/core_ops.h" | |||
| #include "utils/comm_manager.h" | |||
| @@ -51,12 +52,37 @@ static bool IsInWhiteList(const CNodePtr &cnode) { | |||
| return false; | |||
| } | |||
| static void SetGradTag(const AnfNodePtr &node, NodeUsersMap node_users_map) { | |||
| auto node_users = node_users_map[node]; | |||
| for (auto &user_pair : node_users) { | |||
| auto user_node = user_pair.first; | |||
| if (!user_node->grad()) { | |||
| user_node->set_grad(true); | |||
| SetGradTag(user_node, node_users_map); | |||
| } | |||
| } | |||
| } | |||
| void PipelineTransformer::LabelRequiredGradCNode() { | |||
| auto parameters = root_->parameters(); | |||
| auto node_users_map = manager_->node_users(); | |||
| for (auto parameter : parameters) { | |||
| if (!ParameterRequireGrad(parameter)) { | |||
| continue; | |||
| } | |||
| SetGradTag(parameter, node_users_map); | |||
| } | |||
| } | |||
| void PipelineTransformer::Coloring() { | |||
| auto need_coloring = true; | |||
| std::set<int64_t> stage_set; | |||
| while (need_coloring) { | |||
| need_coloring = false; | |||
| for (auto &fg : manager_->func_graphs()) { | |||
| if (fg == root_) { | |||
| continue; | |||
| } | |||
| auto value_nodes = fg->value_nodes(); | |||
| for (auto &value_pair : value_nodes) { | |||
| auto node = value_pair.first; | |||
| @@ -64,10 +90,12 @@ void PipelineTransformer::Coloring() { | |||
| continue; | |||
| } | |||
| auto graph = GetValueNode<FuncGraphPtr>(node); | |||
| auto need_grad = graph->get_return()->grad(); | |||
| auto node_users = manager_->node_users()[node]; | |||
| for (auto &user_pair : node_users) { | |||
| auto user_node = user_pair.first->cast<CNodePtr>(); | |||
| user_node->set_stage(graph->stage()); | |||
| user_node->set_grad(need_grad); | |||
| auto user_node_graph = user_node->func_graph(); | |||
| if (graph->stage() != -1) { | |||
| stage_set.insert(graph->stage()); | |||
| @@ -90,7 +118,11 @@ void PipelineTransformer::Coloring() { | |||
| void PipelineTransformer::BroadCastColoring() { | |||
| for (auto &fg : manager_->func_graphs()) { | |||
| if (fg == root_ || fg->stage() == -1) { | |||
| continue; | |||
| } | |||
| DoBroadCast(fg); | |||
| SetNoStageNode(fg); | |||
| } | |||
| } | |||
| @@ -190,32 +222,17 @@ void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) { | |||
| while (need_coloring) { | |||
| need_coloring = false; | |||
| auto all_nodes = func->nodes(); | |||
| auto node_users = manager_->node_users(); | |||
| for (auto &node : all_nodes) { | |||
| // only cnode can broadcast color. | |||
| if (!node->isa<CNode>()) { | |||
| if (node->isa<CNode>() || node->stage() == -1) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode->stage() == -1) { | |||
| // broadcast from inputs to outputs | |||
| for (auto &input : cnode->inputs()) { | |||
| if (input->isa<CNode>() && input->stage() == stage_) { | |||
| cnode->set_stage(input->stage()); | |||
| need_coloring = true; | |||
| } | |||
| } | |||
| } else if (cnode->stage() == stage_) { | |||
| // broadcast from outputs to inputs | |||
| for (auto &input : cnode->inputs()) { | |||
| if (input->stage() != -1 || !input->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto input_cnode = input->cast<CNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(input_cnode->input(0)); | |||
| if (prim != nullptr && prim->name() == VIRTUAL_DATA_SET) { | |||
| continue; | |||
| } | |||
| input->set_stage(cnode->stage()); | |||
| auto stage = node->stage(); | |||
| for (auto &user_pair : node_users[node]) { | |||
| auto user_node = user_pair.first->cast<CNodePtr>(); | |||
| auto user_node_stage = user_node->stage(); | |||
| if (IsValueNode<FuncGraph>(user_node->input(0)) && stage > user_node_stage) { | |||
| user_node->set_stage(stage); | |||
| need_coloring = true; | |||
| } | |||
| } | |||
| @@ -223,6 +240,16 @@ void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) { | |||
| } | |||
| } | |||
| void PipelineTransformer::SetNoStageNode(const FuncGraphPtr &func) { | |||
| auto all_nodes = func->nodes(); | |||
| for (auto &node : all_nodes) { | |||
| if (!node->isa<CNode>() || node->stage() != -1) { | |||
| continue; | |||
| } | |||
| node->set_stage(0); | |||
| } | |||
| } | |||
| void PipelineTransformer::HandleSharedParameter() { | |||
| auto parameters = root_->parameters(); | |||
| for (auto ¶meter : parameters) { | |||
| @@ -412,7 +439,12 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode | |||
| if (node->isa<Parameter>()) { | |||
| recv_input = {NewValueNode(recv_op), node}; | |||
| } else { | |||
| recv_input = {NewValueNode(recv_op), virtual_param_}; | |||
| if (node->grad()) { | |||
| recv_input = {NewValueNode(recv_op), virtual_param_}; | |||
| } else { | |||
| auto param = root_->parameters()[0]; | |||
| recv_input = {NewValueNode(recv_op), param}; | |||
| } | |||
| } | |||
| auto recv = graph->NewCNode(recv_input); | |||
| auto node_abstract = node->abstract(); | |||
| @@ -505,7 +537,11 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { | |||
| manager_->Replace(graph->output(), out_input[1]); | |||
| } | |||
| if (out_input.size() > 2) { | |||
| auto out_node = graph->NewCNode(out_input); | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| make_tuple_inputs.insert(make_tuple_inputs.begin() + 1, out_input.begin() + 2, out_input.end()); | |||
| auto make_tuple = graph->NewCNode(make_tuple_inputs); | |||
| std::vector<AnfNodePtr> out_depend_inputs = {out_input[0], out_input[1], make_tuple}; | |||
| auto out_node = graph->NewCNode(out_depend_inputs); | |||
| manager_->Replace(graph->output(), out_node); | |||
| } | |||
| } | |||
| @@ -47,6 +47,7 @@ class PipelineTransformer { | |||
| global_rank_(global_rank), | |||
| per_stage_rank_num_(per_stage_rank_num) {} | |||
| virtual ~PipelineTransformer() = default; | |||
| void LabelRequiredGradCNode(); | |||
| void Coloring(); | |||
| void BroadCastColoring(); | |||
| void HandleSharedParameter(); | |||
| @@ -63,6 +64,7 @@ class PipelineTransformer { | |||
| int64_t node_stage); | |||
| void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, | |||
| int64_t user_node_stage, int64_t node_stage); | |||
| void SetNoStageNode(const FuncGraphPtr &func); | |||
| void CutBorder(const FuncGraphPtr &graph); | |||
| bool IsStageNode(const CNodePtr &node); | |||
| AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node); | |||
| @@ -1291,7 +1291,7 @@ std::pair<AnfNodePtr, int64_t> FindParallelCareNode(const AnfNodePtr &node, int3 | |||
| MS_EXCEPTION_IF_NULL(prim_node_anf); | |||
| PrimitivePtr node_prim = prim_node_anf->value()->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(node_prim); | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| if ((node_prim->name() == DEPEND && node_pair.second != 1) || IsPrimitiveCNode(cnode, prim::kPrimReceive)) { | |||
| continue; | |||
| } | |||
| if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) { | |||
| @@ -90,6 +90,7 @@ bool PipelineSplit(const ResourcePtr &res) { | |||
| auto transformer = | |||
| std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num); | |||
| // step1: Do color graph | |||
| transformer->LabelRequiredGradCNode(); | |||
| transformer->Coloring(); | |||
| // step2: Do color broadcast | |||
| transformer->BroadCastColoring(); | |||
| @@ -100,7 +100,8 @@ class AnfNode : public Base { | |||
| fullname_with_scope_(""), | |||
| hash_(std::hash<const AnfNode *>()), | |||
| kernel_info_(nullptr), | |||
| stage_(-1) { | |||
| stage_(-1), | |||
| need_grad_(false) { | |||
| scope_ = ScopeManager::GetInstance().GetCurrentScope(); | |||
| } | |||
| @@ -190,6 +191,9 @@ class AnfNode : public Base { | |||
| int64_t stage() { return stage_; } | |||
| void set_stage(const int &stage) { stage_ = stage; } | |||
| bool grad() { return need_grad_; } | |||
| void set_grad(const bool &need_grad) { need_grad_ = need_grad; } | |||
| protected: | |||
| // Hold a weak ref to Graph as Graph also hold ref to AnfNode. | |||
| // Otherwise, func_graph_ and AnfNode will make a reference cycle. | |||
| @@ -205,6 +209,7 @@ class AnfNode : public Base { | |||
| KernelInfoDevicePtr kernel_info_; | |||
| UserData user_data_; | |||
| int64_t stage_; | |||
| bool need_grad_; | |||
| }; | |||
| // CNode represents the complex node with a set of arguments. | |||
| @@ -638,7 +638,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP | |||
| (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr ¶m) -> void { | |||
| MS_EXCEPTION_IF_NULL(param); | |||
| TraceGuard trace_guard(std::make_shared<TraceCopy>(param->debug_info())); | |||
| (void)new_func_graph->add_parameter(); | |||
| (void)new_func_graph->add_parameter()->set_abstract(param->abstract()); | |||
| }); | |||
| Cloner cloner = Cloner(); | |||
| @@ -85,6 +85,7 @@ def get_bprop_send(self): | |||
| shape = self.get_attr_dict()["shape"] | |||
| dtype = self.get_attr_dict()["dtype"] | |||
| send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group) | |||
| send_grad.add_prim_attr("backward", True) | |||
| def bprop(x, out, dout): | |||
| dx = send_grad() | |||
| @@ -96,6 +97,7 @@ def get_bprop_send(self): | |||
| def get_bprop_receive(self): | |||
| """Generate bprop for Receive.""" | |||
| receive_grad = Send(self.tag, self.rank, self.group) | |||
| receive_grad.add_prim_attr("backward", True) | |||
| depend = P.Depend() | |||
| cast = P.Cast() | |||
| @@ -21,7 +21,7 @@ from ... import context | |||
| from ...common import dtype as mstype | |||
| from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| from ...communication.management import get_rank, GlobalComm, _get_group | |||
| from ...communication.management import GlobalComm | |||
| class ExtractImagePatches(PrimitiveWithInfer): | |||
| @@ -409,7 +409,7 @@ class Send(PrimitiveWithInfer): | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP): | |||
| self.rank = get_rank(_get_group(group)) | |||
| self.rank = dest_rank | |||
| self.sr_tag = sr_tag | |||
| self.group = group | |||
| @@ -465,7 +465,7 @@ class Receive(PrimitiveWithInfer): | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP): | |||
| self.rank = get_rank(_get_group(group)) | |||
| self.rank = src_rank | |||
| self.tag = sr_tag | |||
| self.shape = shape | |||
| self.dtype = dtype | |||