From: @huangbingjian Reviewed-by: Signed-off-by:pull/14612/MERGE
| @@ -108,7 +108,7 @@ bool InputCheck(const AnfNodePtr &node) { | |||
| MS_LOG(INFO) << "Data->TransData->split, can not optimizer."; | |||
| return false; | |||
| } | |||
| if (in_node_name == prim::kPrimControlDepend->name() || in_node_name == prim::kPrimDepend->name()) { | |||
| if (in_node_name == prim::kPrimDepend->name()) { | |||
| return false; | |||
| } | |||
| if ((AnfAlgo::HasNodeAttr("non_task", in_cnode) && AnfAlgo::GetNodeAttr<bool>(in_node, "non_task")) || | |||
| @@ -131,7 +131,7 @@ bool OutputCheck(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| return false; | |||
| } | |||
| for (const auto &item : outputs) { | |||
| if (IsPrimitiveCNode(item, prim::kPrimControlDepend) || IsPrimitiveCNode(item, prim::kPrimDepend)) { | |||
| if (IsPrimitiveCNode(item, prim::kPrimDepend)) { | |||
| MS_LOG(INFO) << "Split has control edge, can not optimizer."; | |||
| return false; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -168,7 +168,7 @@ const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, c | |||
| (void)manager->Replace(output, bn_training_update_v2_outputs[output_index]); | |||
| output_index++; | |||
| } | |||
| // Return the new node for control depends. | |||
| // Return the new node. | |||
| return bn_training_update_v2; | |||
| } | |||
| } // namespace opt | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -201,20 +201,6 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f | |||
| } | |||
| } | |||
| } | |||
| if (dropout_do_mask1 != nullptr) { | |||
| // Dropout is used by ControlDepend in some situation, need to replace ControlDepend. | |||
| auto &users = manager->node_users(); | |||
| iter = users.find(dropout_node); | |||
| if (iter != users.end()) { | |||
| for (auto &node_index : iter->second) { | |||
| auto used_node = node_index.first; | |||
| if (AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimControlDepend)) { | |||
| (void)manager->Replace(used_node, dropout_do_mask1); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // CreateDropoutDoMask-backward | |||
| if (equiv->find(grad_input_) == equiv->end()) { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -426,9 +426,6 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOu | |||
| } | |||
| auto output_info_list = iter->second; | |||
| for (const auto &output_info : output_info_list) { | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && | |||
| output_info.second == kDependAttachNodeIndex) { | |||
| continue; | |||
| @@ -908,16 +905,12 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| // find BatchNorm's output which is a Depend or ControlDepend | |||
| // find BatchNorm's output which is a Depend | |||
| for (const auto &node_index : manager->node_users()[old_node]) { | |||
| AnfNodePtr output = node_index.first; | |||
| size_t index = IntToSize(node_index.second); | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { | |||
| auto control_depend = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(control_depend); | |||
| control_depend->set_input(index, new_node); | |||
| } else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) { | |||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) { | |||
| auto depend = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(depend); | |||
| depend->set_input(index, new_node); | |||
| @@ -210,7 +210,7 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor | |||
| // Create a new value node of func graph,not kernel graph | |||
| ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); | |||
| // Transfer depend or control_depend to the new node | |||
| // Transfer depend to the new node | |||
| void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); | |||
| AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); | |||
| @@ -327,7 +327,7 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, | |||
| void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, | |||
| const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) { | |||
| // Create depend node to hold new control depend node. | |||
| // Create depend node to hold execution order. | |||
| AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node}; | |||
| auto depend_cnode = main_graph->NewCNode(d_inputs); | |||
| depend_cnode->set_abstract(clean_node->abstract()); | |||
| @@ -501,18 +501,17 @@ bool AtomicCleanInsertter::IsExistStructuralObstacle(const KernelGraphPtr &main_ | |||
| const FuncGraphManagerPtr &mng) { | |||
| auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false); | |||
| // If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce | |||
| // node and user node. If reduce is Depend or ControlDepend node, the origin node may be wrong! | |||
| return std::all_of(reduce_users.cbegin(), reduce_users.cend(), | |||
| [&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool { | |||
| auto &user = user_info.first; | |||
| if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend) || | |||
| IsPrimitiveCNode(user, prim::kPrimControlDepend)) && | |||
| !(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) { | |||
| return false; | |||
| } else { | |||
| return true; | |||
| } | |||
| }); | |||
| // node and user node. If reduce is Depend node, the origin node may be wrong! | |||
| return std::all_of( | |||
| reduce_users.cbegin(), reduce_users.cend(), [&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool { | |||
| auto &user = user_info.first; | |||
| if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend)) && | |||
| !(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) { | |||
| return false; | |||
| } else { | |||
| return true; | |||
| } | |||
| }); | |||
| } | |||
| bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { | |||
| @@ -123,9 +123,9 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr | |||
| bool changed = false; | |||
| auto mng = kernel_graph->manager(); | |||
| // depend_prior[depend] = pair(prior, controlDependNode) | |||
| // depend_prior[depend] = pair(prior, behind) | |||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior; | |||
| InitDependPrior(todos, &depend_prior); | |||
| // InitDependPrior(todos, &depend_prior); | |||
| for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { | |||
| auto node = (*iter)->cast<CNodePtr>(); | |||
| @@ -657,76 +657,6 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { | |||
| #endif | |||
| } | |||
| void InitDependPrior(const std::vector<AnfNodePtr> &todos, | |||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) { | |||
| for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { | |||
| auto cnode = (*iter)->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| continue; | |||
| } | |||
| if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimControlDepend)) { | |||
| continue; | |||
| } | |||
| auto prior_node = cnode->input(kControlDependPriorIndex); | |||
| auto depend_node = cnode->input(kControlDependBehindIndex); | |||
| MS_EXCEPTION_IF_NULL(prior_node); | |||
| MS_EXCEPTION_IF_NULL(depend_node); | |||
| std::vector<AnfNodePtr> prior_nodes = {prior_node}; | |||
| std::vector<AnfNodePtr> depend_nodes = {depend_node}; | |||
| int depend_mode = 0; | |||
| if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { | |||
| depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode); | |||
| } | |||
| auto GetOutputNodes = [cnode](const AnfNodePtr ¶m) -> std::vector<AnfNodePtr> { | |||
| std::vector<AnfNodePtr> out_nodes; | |||
| auto user_set = param->func_graph()->manager()->node_users()[param]; | |||
| for (auto iter = user_set.cbegin(); iter != user_set.cend(); ++iter) { | |||
| if (iter->first != cnode) { | |||
| out_nodes.push_back(iter->first); | |||
| } | |||
| } | |||
| return out_nodes; | |||
| }; | |||
| if (prior_node->isa<Parameter>() && depend_mode == 1) { | |||
| prior_nodes = GetOutputNodes(prior_node); | |||
| } | |||
| if (depend_node->isa<Parameter>()) { | |||
| depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{}; | |||
| } | |||
| std::vector<AnfNodePtr> real_prior_nodes; | |||
| std::set<AnfNodePtr> prior_visited; | |||
| for (const auto &tmp : prior_nodes) { | |||
| AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); | |||
| } | |||
| prior_visited.clear(); | |||
| std::vector<AnfNodePtr> real_depend_nodes; | |||
| std::set<AnfNodePtr> depend_visited; | |||
| for (const auto &tmp : depend_nodes) { | |||
| AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); | |||
| } | |||
| depend_visited.clear(); | |||
| for (auto &prior : real_prior_nodes) { | |||
| if (AnfAlgo::CheckPrimitiveType(prior, prim::kPrimControlDepend)) { | |||
| continue; | |||
| } | |||
| for (auto &depend : real_depend_nodes) { | |||
| if (AnfAlgo::CheckPrimitiveType(depend, prim::kPrimControlDepend)) { | |||
| continue; | |||
| } | |||
| depend_prior->insert({depend, std::make_pair(prior, cnode)}); | |||
| } | |||
| } | |||
| real_prior_nodes.clear(); | |||
| real_depend_nodes.clear(); | |||
| } | |||
| } | |||
| void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, | |||
| const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) { | |||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri; | |||
| @@ -75,8 +75,6 @@ std::vector<PrimitivePtr> GetFusibleOpList(); | |||
| bool IsBasicFuseOp(const AnfNodePtr &node); | |||
| bool IsFusibleOp(const AnfNodePtr &node); | |||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | |||
| void InitDependPrior(const std::vector<AnfNodePtr> &todos, | |||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior); | |||
| void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, | |||
| const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -55,7 +55,7 @@ CNodePtr GetRealPrevCNode(const AnfNodePtr &node, size_t index, std::vector<Kern | |||
| auto item_idx = GetValue<int64_t>(value_node->value()); | |||
| pass_vector->push_back(make_pair(cnode, IntToSize(1))); | |||
| return GetRealPrevCNode(cnode->input(1), LongToSize(item_idx), pass_vector); | |||
| } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { | |||
| } else if (IsPrimitive(input0, prim::kPrimDepend)) { | |||
| pass_vector->push_back(make_pair(cnode, IntToSize(1))); | |||
| return GetRealPrevCNode(cnode->input(1), 0, pass_vector); | |||
| } else if (IsPrimitive(input0, prim::kPrimUpdateState)) { | |||
| @@ -92,8 +92,7 @@ const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNode | |||
| auto pass_size = pass_vector->size(); | |||
| for (size_t idx = 1; idx <= pass_size - 1; ++idx) { | |||
| auto nd = (*pass_vector)[idx].first; | |||
| if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend) || | |||
| AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { | |||
| if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend)) { | |||
| has_depend_node = true; | |||
| } | |||
| if (users[nd].size() >= 2) { | |||
| @@ -248,7 +248,7 @@ class AnfRuntimeAlgorithm { | |||
| static void InferShape(const CNodePtr &node); | |||
| static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); | |||
| static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); | |||
| // Find control_depend real input nodes. | |||
| // Find real input nodes. | |||
| static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, | |||
| std::set<AnfNodePtr> *visited); | |||
| }; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -534,14 +534,17 @@ void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNul | |||
| return_node->set_input(kFirstDataInputIndex, depend_node); | |||
| } | |||
| void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node, | |||
| NotNull<AnfNodePtr> second_node) { | |||
| MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() | |||
| << ", the second node is " << second_node->DebugString(); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())), | |||
| first_node, second_node}; | |||
| auto control_depend = kg->NewCNode(inputs); | |||
| InsertDependToGraph(kg, NOT_NULL(control_depend)); | |||
| void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> prior_node, | |||
| NotNull<AnfNodePtr> behind_node) { | |||
| MS_LOG(INFO) << "Insert control dependence at the end of graph, the prior node is " << prior_node->DebugString() | |||
| << ", the behind node is " << behind_node->DebugString(); | |||
| auto manager = kg->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| AnfNodePtrList inputs = {NewValueNode(prim::kPrimDepend), behind_node, prior_node}; | |||
| auto depend_cnode = kg->NewCNode(inputs); | |||
| if (!manager->Replace(behind_node, depend_cnode)) { | |||
| MS_LOG(EXCEPTION) << behind_node->DebugString() << ", replace node failed."; | |||
| } | |||
| } | |||
| void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -422,7 +422,7 @@ void KernelGraph::CheckLoop() { | |||
| none_zero_nodes[it.first] = it.second; | |||
| } | |||
| } | |||
| // if don't consider control depend and loop exit,a exception will be throw | |||
| // if don't consider loop exit,a exception will be throw | |||
| if (!none_zero_nodes.empty()) { | |||
| MS_LOG(WARNING) << "Nums of loop:" << GetLoopNum(none_zero_nodes); | |||
| MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); | |||
| @@ -815,61 +815,10 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) { | |||
| return output_nodes; | |||
| } | |||
| // update the depend relations of control depend | |||
| void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) { | |||
| for (const auto &node : depends) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| return; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { | |||
| MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend"; | |||
| } | |||
| auto prior_node = cnode->input(kControlDependPriorIndex); | |||
| auto depend_node = cnode->input(kControlDependBehindIndex); | |||
| MS_EXCEPTION_IF_NULL(prior_node); | |||
| MS_EXCEPTION_IF_NULL(depend_node); | |||
| std::vector<AnfNodePtr> prior_nodes = {prior_node}; | |||
| std::vector<AnfNodePtr> depend_nodes = {depend_node}; | |||
| int depend_mode = 0; | |||
| if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { | |||
| depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode); | |||
| } | |||
| MS_LOG(DEBUG) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString() | |||
| << "], depend_mode :" << depend_mode << "."; | |||
| if (prior_node->isa<Parameter>() && depend_mode == 1) { | |||
| prior_nodes = GetOutputNodes(prior_node); | |||
| } | |||
| if (depend_node->isa<Parameter>()) { | |||
| depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{}; | |||
| } | |||
| std::vector<AnfNodePtr> real_prior_nodes; | |||
| std::set<AnfNodePtr> prior_visited; | |||
| for (const auto &tmp : prior_nodes) { | |||
| AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); | |||
| } | |||
| std::vector<AnfNodePtr> real_depend_nodes; | |||
| std::set<AnfNodePtr> depend_visited; | |||
| for (const auto &tmp : depend_nodes) { | |||
| AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); | |||
| } | |||
| UpdateNodeInputOutputEdges(real_prior_nodes, real_depend_nodes); | |||
| } | |||
| } | |||
| void KernelGraph::UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes, | |||
| const std::vector<AnfNodePtr> &real_depend_nodes) { | |||
| for (auto &first_node : real_prior_nodes) { | |||
| if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { | |||
| continue; | |||
| } | |||
| for (auto &second_node : real_depend_nodes) { | |||
| if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { | |||
| continue; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(first_node); | |||
| MS_EXCEPTION_IF_NULL(second_node); | |||
| MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); | |||
| @@ -878,35 +827,6 @@ void KernelGraph::UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real | |||
| } | |||
| } | |||
| bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, | |||
| std::unordered_set<AnfNodePtr> *visited_nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(que); | |||
| MS_EXCEPTION_IF_NULL(visited_nodes); | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { | |||
| return false; | |||
| } | |||
| // set the control depend visited but don't push it into the que | |||
| if (visited_nodes->find(node) != visited_nodes->end()) { | |||
| return true; | |||
| } | |||
| (void)visited_nodes->insert(cnode); | |||
| // add a 0 depend num to keep the link relations to prepare for finding zero output nodes | |||
| auto prior_node = cnode->input(kControlDependPriorIndex); | |||
| auto depend_node = cnode->input(kControlDependBehindIndex); | |||
| for (const auto &input : cnode->inputs()) { | |||
| AddDependEdge(node, input, 0); | |||
| } | |||
| PushNoVisitedNode(depend_node, que, visited_nodes); | |||
| PushNoVisitedNode(prior_node, que, visited_nodes); | |||
| return true; | |||
| } | |||
| void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) { | |||
| MS_EXCEPTION_IF_NULL(seed_nodes); | |||
| node_output_edges_.clear(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -286,15 +286,11 @@ class KernelGraph : public FuncGraph { | |||
| std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true); | |||
| // update node edge list | |||
| void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes); | |||
| // add node depend edge by data edge or control depend | |||
| // add node depend edge by data edge | |||
| void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); | |||
| void UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes, | |||
| const std::vector<AnfNodePtr> &real_depend_nodes); | |||
| // handle control depend | |||
| std::vector<AnfNodePtr> GetOutputNodes(const AnfNodePtr &node); | |||
| bool HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, | |||
| std::unordered_set<AnfNodePtr> *visited_nodes); | |||
| void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends); | |||
| AnfNodePtr TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value); | |||
| AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract); | |||
| AnfNodePtr TransCNodeTuple(const CNodePtr &node); | |||
| @@ -223,10 +223,8 @@ BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &gra | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| VectorRef ret; | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| if (!AnfAlgo::CheckPrimitiveType(cnode->input(i), prim::kPrimControlDepend)) { | |||
| auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node); | |||
| ret.push_back(out); | |||
| } | |||
| auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node); | |||
| ret.push_back(out); | |||
| } | |||
| return ret; | |||
| } | |||
| @@ -386,22 +384,6 @@ bool ExistSummaryNode(const KernelGraph *graph) { | |||
| return false; | |||
| } | |||
| bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { | |||
| return false; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const auto &node_inputs = cnode->inputs(); | |||
| for (size_t i = 1; i < node_inputs.size(); ++i) { | |||
| if (!AnfAlgo::CheckPrimitiveType(node_inputs[i], prim::kPrimControlDepend)) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs, | |||
| std::map<AnfNodePtr, size_t> *parameter_index) { | |||
| size_t index = 0; | |||
| @@ -692,9 +674,6 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const | |||
| AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| if (IgnoreCreateParameterForMakeTuple(node)) { | |||
| return nullptr; | |||
| } | |||
| auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); | |||
| auto parameters = AnfAlgo::GetAllOutput(new_parameter); | |||
| std::vector<AnfNodePtr> pre_graph_out = {node}; | |||
| @@ -1872,9 +1851,6 @@ std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr | |||
| auto &users = front_func_graph_manager->node_users()[front_node]; | |||
| std::vector<AnfNodePtr> result; | |||
| for (auto &user : users) { | |||
| if (IsPrimitiveCNode(user.first, prim::kPrimControlDepend)) { | |||
| continue; | |||
| } | |||
| if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) { | |||
| auto depend_cnode = user.first->cast<CNodePtr>(); | |||
| if (depend_cnode == nullptr) { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -84,7 +84,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { | |||
| } | |||
| } | |||
| std::vector<PrimitivePtr> adapter_convert_ops = {prim::kPrimDepend, prim::kPrimControlDepend, prim::kPrimLoad}; | |||
| std::vector<PrimitivePtr> adapter_convert_ops = {prim::kPrimDepend, prim::kPrimLoad}; | |||
| for (auto &item : adapter_convert_ops) { | |||
| if (IsPrimitiveCNode(node, item)) { | |||
| return true; | |||
| @@ -243,8 +243,7 @@ CNodePtr MergeNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, int64_t sw | |||
| return merge_op; | |||
| } | |||
| // construct a depend node with merge output node, merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data)) | |||
| // control_depend(output_node, square_op) | |||
| // merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data)) | |||
| AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node, | |||
| int64_t switch_idx) { | |||
| tensor::TensorPtr const_data = GetConstData(); | |||
| @@ -259,54 +258,21 @@ AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr | |||
| SetSquareOp(switch_idx, square_op); | |||
| } | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| AnfNodePtrList inputs = {NewValueNode(prim::kPrimDepend), square_op, output_node}; | |||
| auto depend_cnode = graph->NewCNode(inputs); | |||
| if (!manager->Replace(square_op, depend_cnode)) { | |||
| MS_LOG(EXCEPTION) << square_op->DebugString() << ", replace node failed."; | |||
| } | |||
| CNodePtr merge_op = GetMergeOp(switch_idx); | |||
| if (merge_op == nullptr) { | |||
| merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op); | |||
| SetMergeOp(switch_idx, merge_op); | |||
| } | |||
| std::vector<AnfNodePtr> control_depend_nodes{NewValueNode(prim::kPrimControlDepend), output_node, square_op}; | |||
| auto control_depend_op = graph->NewCNode(control_depend_nodes); | |||
| std::vector<AnfNodePtr> depend_nodes{NewValueNode(prim::kPrimDepend), merge_op, control_depend_op}; | |||
| auto depend_op = graph->NewCNode(depend_nodes); | |||
| return depend_op; | |||
| } | |||
| // construct a merge output and add dependency with the netoutput node from control_depend | |||
| // we need to reserve the control_depend node, besides the generated merge node and control_depend node | |||
| CNodePtr GenerateSwitchControlDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, | |||
| const AnfNodePtr &ctrl_dep_node, const AnfNodePtr &ctrl_depend_dst, | |||
| int64_t switch_idx) { | |||
| auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast<PrimitivePtr>(); | |||
| auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast<PrimitivePtr>(); | |||
| std::vector<int64_t> shp = {1}; | |||
| tensor::TensorPtr const_data = std::make_shared<tensor::Tensor>(kInt64->type_id(), shp); | |||
| auto *val = static_cast<int64_t *>(const_data->data_c()); | |||
| *val = 0; | |||
| // for the control_depend netoutput node , add two const data to merge the flow ,one for depended node with same | |||
| // switch the other use the opposite | |||
| auto ctrl_data = NewValueNode(const_data); | |||
| auto oppsite_ctrl_data = NewValueNode(const_data); | |||
| auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx); | |||
| auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx); | |||
| std::vector<AnfNodePtr> square_nodes{NewValueNode(PrimSquare), ctrl_node}; | |||
| auto square_op = graph->NewCNode(square_nodes); | |||
| std::vector<AnfNodePtr> merge_nodes; | |||
| merge_nodes.push_back(NewValueNode(PrimMerge)); | |||
| std::vector<AnfNodePtr> make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node}; | |||
| merge_nodes.push_back(graph->NewCNode(make_tuple_nodes)); | |||
| auto merge_output = graph->NewCNode(merge_nodes); | |||
| std::vector<AnfNodePtr> control_depend_nodes{NewValueNode(prim::kPrimControlDepend), ctrl_depend_dst, square_op}; | |||
| auto cond_dep_output = graph->NewCNode(control_depend_nodes); | |||
| std::vector<AnfNodePtr> depended_make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), ctrl_dep_node, merge_output, | |||
| cond_dep_output}; | |||
| return graph->NewCNode(depended_make_tuple_nodes); | |||
| return merge_op; | |||
| } | |||
| // generate switch nodes for true graph node inputs | |||
| @@ -321,26 +287,12 @@ AnfNodePtr GenerateSwitchDependFalseNode(const FuncGraphPtr &graph, const AnfNod | |||
| return GenerateSwitchDependNode(graph, cond, data, 0); | |||
| } | |||
| // generate switch nodes for true graph node inputs | |||
| CNodePtr GenerateSwitchControlDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, | |||
| const AnfNodePtr &con_input, const AnfNodePtr &output) { | |||
| // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch | |||
| return GenerateSwitchControlDependNode(graph, cond, con_input, output, 1); | |||
| } | |||
| // generate switch nodes for false graph node inputs | |||
| CNodePtr GenerateSwitchControlDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, | |||
| const AnfNodePtr &con_input, const AnfNodePtr &output) { | |||
| // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch | |||
| return GenerateSwitchControlDependNode(graph, cond, con_input, output, 0); | |||
| } | |||
| // to judge if the node used in ControlDepend is a net output node | |||
| // to judge if the node used in Depend is a net output node | |||
| bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { | |||
| auto uses = manager->node_users()[node]; | |||
| bool is_output_node = true; | |||
| for (auto &item : uses) { | |||
| if (IsPrimitiveCNode(item.first, prim::kPrimControlDepend) || IsPrimitiveCNode(item.first, prim::kPrimDepend)) { | |||
| if (IsPrimitiveCNode(item.first, prim::kPrimDepend)) { | |||
| continue; | |||
| } | |||
| is_output_node = false; | |||
| @@ -353,8 +305,7 @@ bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) | |||
| void GenerateReplNodeForDependMakeTuple( | |||
| const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond, | |||
| const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node, | |||
| const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func, | |||
| const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) { | |||
| const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func) { | |||
| MS_EXCEPTION_IF_NULL(graph->manager()); | |||
| auto make_tuple_inputs = depended_node->cast<CNodePtr>()->inputs(); | |||
| @@ -368,26 +319,6 @@ void GenerateReplNodeForDependMakeTuple( | |||
| new_make_tuple_nodes.push_back(depended_tuple_input_node); | |||
| continue; | |||
| } | |||
| if (IsPrimitiveCNode(depended_tuple_input_node->cast<CNodePtr>(), prim::kPrimControlDepend)) { | |||
| // only when the control depend input is not square op (the op to use as merge output) | |||
| auto control_inputs = depended_tuple_input_node->cast<CNodePtr>()->inputs(); | |||
| if (control_inputs.size() != 3) { | |||
| MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); | |||
| } | |||
| // control inputs: primitive, src, dst | |||
| auto dst_node = control_inputs[2]; | |||
| if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { | |||
| auto gen_node = gen_ctl_depd_func(graph, cond, make_tuple_inputs[idx], dst_node); | |||
| MS_EXCEPTION_IF_NULL(gen_node); | |||
| auto tuple_inputs = gen_node->inputs(); | |||
| // add depended tuple inputs to new_make_tuple directly | |||
| for (size_t i = 1; i < tuple_inputs.size(); i++) { | |||
| new_make_tuple_nodes.push_back(tuple_inputs[i]); | |||
| } | |||
| } | |||
| replace_make_tuple = true; | |||
| continue; | |||
| } | |||
| if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) { | |||
| auto gen_node = generate_func(graph, cond, depended_tuple_input_node); | |||
| @@ -408,8 +339,7 @@ void GenerateReplNodeForDependMakeTuple( | |||
| void GenerateRepDepend( | |||
| const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond, | |||
| const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node, | |||
| const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func, | |||
| const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) { | |||
| const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func) { | |||
| auto inputs = node->inputs(); | |||
| if (inputs.size() != 3) { | |||
| MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node]."; | |||
| @@ -422,19 +352,7 @@ void GenerateRepDepend( | |||
| new_depened_inputs.push_back(inputs[1]); | |||
| // depended node should be make_tuple or a single depended node | |||
| if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) { | |||
| GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func, gen_ctl_depd_func); | |||
| } else if (IsPrimitiveCNode(depended_node, prim::kPrimControlDepend)) { | |||
| // only when the control depend input is not square op (the op to use as merge output) | |||
| auto control_inputs = depended_node->cast<CNodePtr>()->inputs(); | |||
| // control inputs: primitive, src, dst | |||
| if (control_inputs.size() != 3) { | |||
| MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); | |||
| } | |||
| auto dst_node = control_inputs[2]; | |||
| if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { | |||
| auto gen_node = gen_ctl_depd_func(graph, cond, depended_node, dst_node); | |||
| (*repl_node)[depended_node] = gen_node; | |||
| } | |||
| GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func); | |||
| } else { | |||
| // Check if there is only single user for depend_node. | |||
| if (graph->manager()->node_users()[depended_node].size() == 1) { | |||
| @@ -448,11 +366,9 @@ void GenerateRepDepend( | |||
| // generate depend node for netoutput node, to resolve the stream synchronize problem of ge | |||
| // traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const) | |||
| // and add control_depend of graph output node and square node. | |||
| FuncGraphPtr TransformGraphDependNode( | |||
| const FuncGraphPtr &graph, const AnfNodePtr &cond, | |||
| const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &gen_depend_func, | |||
| const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) { | |||
| const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &gen_depend_func) { | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| @@ -478,7 +394,7 @@ FuncGraphPtr TransformGraphDependNode( | |||
| if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) { | |||
| continue; | |||
| } | |||
| GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func, gen_ctl_depd_func); | |||
| GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func); | |||
| } | |||
| } | |||
| ResetSharedOp(); | |||
| @@ -494,12 +410,12 @@ FuncGraphPtr TransformGraphDependNode( | |||
| FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { | |||
| (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode); | |||
| return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode, GenerateSwitchControlDependTrueNode); | |||
| return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode); | |||
| } | |||
| FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { | |||
| (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode); | |||
| return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode, GenerateSwitchControlDependFalseNode); | |||
| return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode); | |||
| } | |||
| // judge if the true and false graph output is compatible(they shall have same tuple size) | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -218,10 +218,10 @@ bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) { | |||
| if (output_set_iter == node_users.end()) { | |||
| return false; | |||
| } | |||
| for (const auto &node_index_set : output_set_iter->second) { | |||
| if (!IsBpropNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) { | |||
| return true; | |||
| } | |||
| if (std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(), | |||
| [](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); })) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -367,7 +367,6 @@ constexpr char HISTOGRAMSUMMARY[] = "HistogramSummary"; | |||
| constexpr char DEBUG[] = "Debug"; | |||
| constexpr char BROADCASTGRADIENTARGS[] = "BroadcastGradientArgs"; | |||
| constexpr char INVERTPERMUTATION[] = "InvertPermutation"; | |||
| constexpr char CONTROLDEPEND[] = "ControlDepend"; | |||
| constexpr char DOT[] = "dot"; | |||
| constexpr char IM2COL[] = "im2col"; | |||
| constexpr char COL2IM[] = "col2im"; | |||
| @@ -259,10 +259,8 @@ BaseRef CreateOutputTensors(const AnfNodePtr &output_node, const KernelGraphPtr | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| VectorRef ret; | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| if (!AnfAlgo::CheckPrimitiveType(cnode->input(i), prim::kPrimControlDepend)) { | |||
| auto out = CreateOutputTensors(cnode->input(i), graph, input_tensors); | |||
| ret.push_back(out); | |||
| } | |||
| auto out = CreateOutputTensors(cnode->input(i), graph, input_tensors); | |||
| ret.push_back(out); | |||
| } | |||
| return ret; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1044,7 +1044,7 @@ bool DfGraphConvertor::IsControlEdgeNode(const AnfNodePtr &node) { | |||
| OperatorPtr DfGraphConvertor::ToOperatorPtr(const AnfNodePtr &node) { | |||
| auto op = Convert(GetRealOpNode(node)); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "Convert control depend node to operator failed, " << node->ToString(); | |||
| MS_LOG(ERROR) << "Convert real op node to operator failed, " << node->ToString(); | |||
| error_ = FAILED; | |||
| return nullptr; | |||
| } | |||
| @@ -1170,13 +1170,13 @@ void DfGraphConvertor::AutoMonadSetControlInput(const AnfNodePtr &node) { | |||
| void DfGraphConvertor::SetOpControlInput(const AnfNodePtr &node) { | |||
| AutoMonadSetControlInput(node); | |||
| if (control_depend_cache_.find(node.get()) == control_depend_cache_.end()) { | |||
| if (control_edge_cache_.find(node.get()) == control_edge_cache_.end()) { | |||
| return; | |||
| } | |||
| std::vector<ControlEdge> control_edges = control_depend_cache_[node.get()]; | |||
| std::vector<ControlEdge> control_edges = control_edge_cache_[node.get()]; | |||
| if ((control_edges.empty())) { | |||
| MS_LOG(ERROR) << "Get control depend node's src or dest operator failed"; | |||
| MS_LOG(ERROR) << "Get control edge node's src or dest operator failed"; | |||
| return; | |||
| } | |||
| @@ -1600,7 +1600,7 @@ std::vector<OperatorPtr> DfGraphConvertor::ConvertDependNode(const AnfNodePtr no | |||
| for (size_t index = 1; index < node_inputs.size(); index++) { | |||
| auto op = Convert(GetRealOpNode(node_inputs[index])); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "Convert control depend node to operator failed"; | |||
| MS_LOG(ERROR) << "Convert real op node to operator failed"; | |||
| error_ = FAILED; | |||
| return std::vector<OperatorPtr>({}); | |||
| } | |||
| @@ -1611,194 +1611,13 @@ std::vector<OperatorPtr> DfGraphConvertor::ConvertDependNode(const AnfNodePtr no | |||
| auto op = Convert(GetRealOpNode(node)); | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "Convert control depend node to operator failed"; | |||
| MS_LOG(ERROR) << "Convert real op node to operator failed"; | |||
| error_ = FAILED; | |||
| return std::vector<OperatorPtr>({}); | |||
| } | |||
| return std::vector<OperatorPtr>({op}); | |||
| } | |||
| // get the anf node list for depend | |||
| std::vector<AnfNodePtr> DfGraphConvertor::GetDependNodes(const AnfNodePtr &node) { | |||
| std::vector<AnfNodePtr> nodes; | |||
| // for make tuple, should control depend on the tuple items | |||
| if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { | |||
| auto node_inputs = node->cast<CNodePtr>()->inputs(); | |||
| for (size_t index = 1; index < node_inputs.size(); index++) { | |||
| nodes.push_back(GetRealOpNode(node_inputs[index])); | |||
| } | |||
| return nodes; | |||
| } | |||
| // for parameter ,find the apply that used the parameter as the control depended node | |||
| if (node->isa<Parameter>()) { | |||
| auto uses = node->func_graph()->manager()->node_users()[node]; | |||
| for (auto &use : uses) { | |||
| auto use_node = use.first; | |||
| if ((use_node->isa<CNode>()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { | |||
| nodes.push_back(GetRealOpNode(use_node)); | |||
| } | |||
| } | |||
| return nodes; | |||
| } | |||
| nodes.push_back(GetRealOpNode(node)); | |||
| return nodes; | |||
| } | |||
| void DfGraphConvertor::DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node) { | |||
| #ifdef DRAW_GE_GRAPH | |||
| auto src_depend_nodes = GetDependNodes(src_node); | |||
| auto dst_depend_nodes = GetDependNodes(dest_node); | |||
| if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() > 1) { | |||
| for (auto &item : dst_depend_nodes) { | |||
| compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[item.get()] | |||
| << "[style=\"dotted\"]" << endl; | |||
| } | |||
| } else if (src_depend_nodes.size() > 1 && dst_depend_nodes.size() == 1) { | |||
| for (auto &item : src_depend_nodes) { | |||
| compute_sout_ << op_draw_name_[item.get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] | |||
| << "[style=\"dotted\"]" << endl; | |||
| } | |||
| } else if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() == 1) { | |||
| compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] | |||
| << "[style=\"dotted\"]" << endl; | |||
| } | |||
| #endif | |||
| } | |||
| void DfGraphConvertor::GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, | |||
| const AnfNodePtr &dest_node, | |||
| const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list, | |||
| const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list) { | |||
| if (src_node->isa<Parameter>()) { | |||
| auto uses = node->func_graph()->manager()->node_users()[src_node]; | |||
| for (auto &use : uses) { | |||
| auto use_node = use.first; | |||
| if ((use_node->isa<CNode>()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && | |||
| (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { | |||
| auto converted_list = ConvertDependNode(use_node); | |||
| src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); | |||
| } | |||
| } | |||
| } | |||
| if (dest_node->isa<Parameter>()) { | |||
| auto uses = node->func_graph()->manager()->node_users()[dest_node]; | |||
| for (auto &use : uses) { | |||
| auto use_node = use.first; | |||
| if ((use_node->isa<CNode>()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && | |||
| (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { | |||
| auto converted_list = ConvertDependNode(use_node); | |||
| dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| bool DfGraphConvertor::GetControlDependList(const CNodePtr &node, | |||
| const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list, | |||
| const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list) { | |||
| const int CONTROL_DEPEND_INDEX = 0; | |||
| const int SRC_NODE_INDEX = 1; | |||
| const int DEST_NODE_INDEX = 2; | |||
| const int DEPEND_MODE_NORMAL_USE = 0; | |||
| const int DEPEND_MODE_ON_PARAMETER_USE = 1; | |||
| auto node_inputs = node->inputs(); | |||
| if (node_inputs.size() <= DEST_NODE_INDEX) { | |||
| MS_LOG(WARNING) << "Control depend node input size error"; | |||
| return false; | |||
| } | |||
| auto src_node = node_inputs[SRC_NODE_INDEX]; | |||
| auto dest_node = node_inputs[DEST_NODE_INDEX]; | |||
| if ((src_node == nullptr) || (dest_node == nullptr)) { | |||
| MS_LOG(ERROR) << "Control depend node miss src or dest node"; | |||
| error_ = FAILED; | |||
| return false; | |||
| } | |||
| AnfNodePtr fn = node_inputs[CONTROL_DEPEND_INDEX]; | |||
| PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(fn); | |||
| ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); | |||
| int depend_mode = DEPEND_MODE_NORMAL_USE; | |||
| if (mode_ptr != nullptr) { | |||
| auto mode_int = mode_ptr->cast<Int64ImmPtr>(); | |||
| MS_EXCEPTION_IF_NULL(mode_int); | |||
| depend_mode = mode_int->value(); | |||
| MS_LOG(DEBUG) << "depend_mode = " << depend_mode; | |||
| } | |||
| if (depend_mode == DEPEND_MODE_ON_PARAMETER_USE) { | |||
| GetDependOnParameterUse(node, src_node, dest_node, src_ops_list, dst_ops_list); | |||
| } | |||
| if (src_node->isa<CNode>()) { | |||
| auto converted_list = ConvertDependNode(src_node); | |||
| src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); | |||
| } | |||
| if (dest_node->isa<CNode>()) { | |||
| auto converted_list = ConvertDependNode(dest_node); | |||
| dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); | |||
| } | |||
| if (src_ops_list->empty() || dst_ops_list->empty()) { | |||
| MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it"; | |||
| error_ = SUCCESS; | |||
| } | |||
| return true; | |||
| } | |||
| void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) { | |||
| const int SRC_NODE_INDEX = 1; | |||
| const int DEST_NODE_INDEX = 2; | |||
| if (control_depend_cache_.find(node.get()) != control_depend_cache_.end()) { | |||
| return; | |||
| } | |||
| auto node_inputs = node->inputs(); | |||
| if (node_inputs.size() <= DEST_NODE_INDEX) { | |||
| MS_LOG(WARNING) << "Control depend node input size error"; | |||
| return; | |||
| } | |||
| auto src_node = node_inputs[SRC_NODE_INDEX]; | |||
| auto dest_node = node_inputs[DEST_NODE_INDEX]; | |||
| if ((src_node == nullptr) || (dest_node == nullptr)) { | |||
| MS_LOG(ERROR) << "Control depend node miss src or dest node"; | |||
| error_ = FAILED; | |||
| return; | |||
| } | |||
| std::shared_ptr<std::vector<OperatorPtr>> src_ops_list = std::make_shared<std::vector<OperatorPtr>>(); | |||
| std::shared_ptr<std::vector<OperatorPtr>> dst_ops_list = std::make_shared<std::vector<OperatorPtr>>(); | |||
| if (!GetControlDependList(node, src_ops_list, dst_ops_list)) { | |||
| MS_LOG(ERROR) << "Get depend list failed"; | |||
| error_ = FAILED; | |||
| return; | |||
| } | |||
| std::vector<ControlEdge> control_edges; | |||
| if (src_ops_list->size() == 1 && dst_ops_list->size() > 1) { | |||
| (void)std::transform(dst_ops_list->begin(), dst_ops_list->end(), std::back_inserter(control_edges), | |||
| [src_ops_list](const OperatorPtr &op) -> ControlEdge { | |||
| return {(*src_ops_list)[0], op}; | |||
| }); | |||
| } else if (src_ops_list->size() > 1 && dst_ops_list->size() == 1) { | |||
| (void)std::transform(src_ops_list->begin(), src_ops_list->end(), std::back_inserter(control_edges), | |||
| [dst_ops_list](const OperatorPtr &op) -> ControlEdge { | |||
| return {op, (*dst_ops_list)[0]}; | |||
| }); | |||
| } else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) { | |||
| control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]}); | |||
| } else if (src_ops_list->empty() || dst_ops_list->empty()) { | |||
| MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it"; | |||
| } else { | |||
| MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size() | |||
| << " -> dst:" << dst_ops_list->size(); | |||
| error_ = FAILED; | |||
| return; | |||
| } | |||
| control_depend_cache_[node.get()] = control_edges; | |||
| #ifdef DRAW_GE_GRAPH | |||
| DrawControlDepend(src_node, dest_node); | |||
| #endif | |||
| } | |||
| bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) { | |||
| // ignore apply node of return | |||
| if (name == "" || name == prim::kPrimReturn->name() || name == prim::kPrimDepend->name() || | |||
| @@ -1818,12 +1637,6 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) | |||
| return false; | |||
| } | |||
| // ControlDepend | |||
| if (name == prim::kPrimControlDepend->name()) { | |||
| ConvertControlDependNode(node); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -145,19 +145,11 @@ class DfGraphConvertor { | |||
| OperatorPtr ConvertCNode(CNodePtr node); | |||
| std::vector<OperatorPtr> ConvertDependNode(AnfNodePtr node); | |||
| AnfNodePtr GetRealOpNode(AnfNodePtr node); | |||
| std::vector<AnfNodePtr> GetDependNodes(const AnfNodePtr &node); | |||
| OperatorPtr ConvertParameter(AnfNodePtr node); | |||
| Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); | |||
| OperatorPtr ConvertValueNode(ValueNodePtr node); | |||
| void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node); | |||
| void ConvertTupleGetItem(const CNodePtr node); | |||
| void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, | |||
| const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list, | |||
| const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list); | |||
| bool GetControlDependList(const CNodePtr &node, const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list, | |||
| const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list); | |||
| void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node); | |||
| void ConvertControlDependNode(const CNodePtr node); | |||
| void ConvertMakeTuple(const CNodePtr node); | |||
| bool CheckCNode(const std::string &name, const CNodePtr node); | |||
| void TraceOutput(AnfNodePtr node); | |||
| @@ -195,7 +187,7 @@ class DfGraphConvertor { | |||
| std::shared_ptr<DfGraph> broadcast_graph_{nullptr}; | |||
| std::unordered_map<AnfNode *, DfGraph> branches_map_; | |||
| std::unordered_map<AnfNode *, OperatorPtr> op_cache_; | |||
| std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_depend_cache_; | |||
| std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_edge_cache_; | |||
| std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>> monad_control_edge_cache_; | |||
| /* record "tuple_getitem"<->"out_handler" mapping */ | |||
| std::unordered_map<AnfNode *, OutHandler> out_handle_cache_; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -51,88 +51,6 @@ std::string GetOtherTarget(const std::vector<AnfNodePtr> &nodes) { | |||
| } | |||
| return ""; | |||
| } | |||
| bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node, | |||
| std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) { | |||
| MS_EXCEPTION_IF_NULL(prior_node); | |||
| MS_EXCEPTION_IF_NULL(behind_node); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &node_users = manager->node_users(); | |||
| if (prior_node->isa<Parameter>()) { | |||
| for (auto &user : node_users[prior_node]) { | |||
| auto cnode = user.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||
| prior_nodes->emplace_back(cnode); | |||
| } | |||
| } | |||
| } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) { | |||
| prior_nodes->emplace_back(prior_node); | |||
| } else { | |||
| return false; | |||
| } | |||
| if (behind_node->isa<Parameter>()) { | |||
| for (auto &user : node_users[behind_node]) { | |||
| auto cnode = user.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||
| depend_nodes->emplace_back(cnode); | |||
| } | |||
| } | |||
| } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) { | |||
| depend_nodes->emplace_back(behind_node); | |||
| } else { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges, | |||
| std::map<AnfNodePtr, size_t> *nodes_ref) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto input_cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(input_cnode); | |||
| auto prior_node = input_cnode->input(kControlDependPriorIndex); | |||
| auto depend_node = input_cnode->input(kControlDependBehindIndex); | |||
| MS_EXCEPTION_IF_NULL(prior_node); | |||
| MS_EXCEPTION_IF_NULL(depend_node); | |||
| auto prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim_ptr); | |||
| ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); | |||
| int64_t depend_mode = 0; | |||
| if (mode_ptr != nullptr) { | |||
| depend_mode = GetValue<int64_t>(mode_ptr); | |||
| } | |||
| if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) { | |||
| return; | |||
| } | |||
| std::vector<AnfNodePtr> prior_nodes; | |||
| std::vector<AnfNodePtr> behind_nodes; | |||
| if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) { | |||
| return; | |||
| } | |||
| for (auto &first_node : prior_nodes) { | |||
| for (auto &second_node : behind_nodes) { | |||
| MS_EXCEPTION_IF_NULL(first_node); | |||
| MS_EXCEPTION_IF_NULL(second_node); | |||
| auto iter = control_edges->find(second_node); | |||
| if (iter == control_edges->end()) { | |||
| (void)control_edges->insert( | |||
| std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node})); | |||
| } else { | |||
| iter->second.emplace_back(first_node); | |||
| } | |||
| auto ref_iter = nodes_ref->find(first_node); | |||
| if (ref_iter != nodes_ref->end()) { | |||
| ref_iter->second++; | |||
| } else { | |||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref, | |||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) { | |||
| @@ -149,9 +67,6 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (auto &input : cnode->inputs()) { | |||
| if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) { | |||
| AddControlEdge(graph, input, control_edges, nodes_ref); | |||
| } | |||
| auto iter = nodes_ref->find(input); | |||
| if (iter != nodes_ref->end()) { | |||
| iter->second++; | |||
| @@ -479,11 +394,9 @@ void AddSegmentDependency(const FuncGraphPtr &graph, const std::string &default_ | |||
| node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | |||
| } | |||
| GraphSegmentPtr node_segment{nullptr}; | |||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||
| auto node_iter = node_to_segment.find(node); | |||
| if (node_iter != node_to_segment.end()) { | |||
| node_segment = node_iter->second; | |||
| } | |||
| auto node_iter = node_to_segment.find(node); | |||
| if (node_iter != node_to_segment.end()) { | |||
| node_segment = node_iter->second; | |||
| } | |||
| for (auto &input : node_inputs) { | |||
| if (node_segment != nullptr && !node_segment->is_cut_ && input->isa<CNode>()) { | |||
| @@ -615,18 +528,14 @@ void SplitDynamicNodeSegment(const std::vector<AnfNodePtr> &segment_nodes, std:: | |||
| std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment, | |||
| const std::set<AnfNodePtr> &dynamic_nodes_set) { | |||
| SplitDynamicNodesHelper helper; | |||
| bool is_last_node_dynamic = false; | |||
| for (auto &node : segment_nodes) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||
| helper.AddNode(node, is_last_node_dynamic); | |||
| continue; | |||
| } | |||
| auto &inputs = cnode->inputs(); | |||
| bool has_dynamic_shape = dynamic_nodes_set.find(node) != dynamic_nodes_set.end(); | |||
| bool depend_common_node = false; | |||
| bool depend_dynamic_node = false; | |||
| bool is_last_node_dynamic = false; | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| if (dynamic_nodes_set.find(inputs[i]) != dynamic_nodes_set.end()) { | |||
| has_dynamic_shape = true; | |||
| @@ -1,7 +1,7 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -87,26 +87,7 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo | |||
| if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) { | |||
| eqv[node] = node; | |||
| } else if (eqv.find(node) == eqv.end()) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimControlDepend)) { | |||
| eqv[node] = NewValueNode(MakeValue(0)); | |||
| return eqv[node]; | |||
| } | |||
| bool ignore_make_tuple = false; | |||
| if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { | |||
| ignore_make_tuple = true; | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const auto &node_inputs = cnode->inputs(); | |||
| for (size_t i = 1; i < node_inputs.size(); ++i) { | |||
| if (!IsPrimitiveCNode(node_inputs[i], prim::kPrimControlDepend)) { | |||
| ignore_make_tuple = false; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (!ignore_make_tuple) { | |||
| inputs.push_back(node); | |||
| } | |||
| inputs.push_back(node); | |||
| eqv[node] = fg->add_parameter(); | |||
| eqv[node]->set_abstract(node->abstract()); | |||
| eqv[node]->set_kernel_info(node->kernel_info_ptr()); | |||
| @@ -148,14 +129,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||
| for (size_t i = 2; i < inps.size(); ++i) { | |||
| args.emplace_back(NewValueNode(MakeValue(0))); | |||
| } | |||
| } else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) { | |||
| for (size_t i = 1; i < inps.size(); ++i) { | |||
| if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) { | |||
| args.emplace_back(NewValueNode(MakeValue(static_cast<int>(i)))); | |||
| } else { | |||
| args.emplace_back(RefSubGraphNode(fg, inps[i], &inputs, &eqv)); | |||
| } | |||
| } | |||
| } else { | |||
| (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), | |||
| [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); | |||
| @@ -182,8 +182,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -188,23 +188,6 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // args: Two objects of a subclass of AbstractBase | |||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | |||
| auto arg_src = args_spec_list[0]; | |||
| auto arg_dst = args_spec_list[1]; | |||
| // control depend can not setup tuple of ops to tuple of ops dependency relation | |||
| if (arg_src->isa<AbstractTuple>() && arg_dst->isa<AbstractTuple>()) { | |||
| auto src_size = arg_src->cast<AbstractTuplePtr>()->size(); | |||
| auto dst_size = arg_src->cast<AbstractTuplePtr>()->size(); | |||
| if (src_size > 1 && dst_size > 1) { | |||
| MS_LOG(EXCEPTION) << "Control depend can not setup operator dependency relationship from tuple from tuple"; | |||
| } | |||
| } | |||
| return std::make_shared<AbstractScalar>(kAnyValue, kBool); | |||
| } | |||
| AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: two tensors and a tuple. | |||
| @@ -149,7 +149,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}}, | |||
| {prim::kPrimDepend, {InferImplDepend, nullptr, true}}, | |||
| {prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}}, | |||
| {prim::kPrimControlDepend, {InferImplControlDepend, nullptr, true}}, | |||
| // Debug | |||
| {prim::kPrimDebug, {InferImplDebug, nullptr, true}}, | |||
| // Dynamic shape testing | |||
| @@ -453,7 +453,6 @@ inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookB | |||
| inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | |||
| inline const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); | |||
| inline const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print"); | |||
| inline const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend"); | |||
| inline const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); | |||
| inline const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); | |||
| inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | |||
| @@ -1,7 +1,7 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -399,8 +399,7 @@ std::string GetAttrTarget(const PrimitivePtr &primitive, const ValuePtr &att_tar | |||
| if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || | |||
| IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || | |||
| IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || | |||
| IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || | |||
| IsPrimitive(attr_input, prim::kPrimPartial)) { | |||
| IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) { | |||
| primitive->EraseAttr(primitive_target); | |||
| return default_target; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -23,7 +23,6 @@ | |||
| #include "ir/dtype/tensor_type.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "ops/control_depend.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -23,7 +23,6 @@ | |||
| #include "ops/expand_dims.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "ops/control_depend.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -30,7 +30,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, | |||
| "identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key", | |||
| "get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary", | |||
| "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", | |||
| "InvertPermutation", "ControlDepend", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", | |||
| "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", | |||
| "stop_gradient", "Send", "UpdateState", "Load"}; | |||
| // clang-format on | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -27,7 +27,6 @@ | |||
| #include "abstract/abstract_value.h" | |||
| #include "mindspore/core/ir/primitive.h" | |||
| #include "ops/fusion/partial_fusion.h" | |||
| #include "ops/control_depend.h" | |||
| #include "ops/depend.h" | |||
| #include "ops/make_tuple.h" | |||
| #include "ops/quant_dtype_cast.h" | |||
| @@ -213,8 +212,7 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { | |||
| MS_LOG(ERROR) << "value node is invalid."; | |||
| return; | |||
| } | |||
| if (value_node->value() != nullptr && (opt::CheckPrimitiveType(depend_node, prim::kPrimDepend) || | |||
| opt::CheckPrimitiveType(depend_node, prim::kPrimControlDepend))) { | |||
| if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) { | |||
| has_depend = true; | |||
| bool mask_out = (depend_node->inputs().size() == 3); | |||
| for (size_t j = 1; j < depend_node->inputs().size(); ++j) { | |||
| @@ -466,8 +464,8 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc | |||
| } | |||
| RemoveIfDepend(cnode); | |||
| if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend || | |||
| prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) { | |||
| if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameTupleGetItem || | |||
| prim->name() == mindspore::ops::kNameMakeTuple) { | |||
| continue; | |||
| } | |||
| if (prim->name() == "make_tuple") { | |||
| @@ -57,8 +57,8 @@ bool IsRealKernel(const AnfNodePtr &node) { | |||
| IsPrimitive(input, prim::kPrimTensorSummary) || | |||
| IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || | |||
| IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || | |||
| IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || | |||
| IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); | |||
| IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimReturn) || | |||
| IsPrimitive(input, prim::kPrimPartial); | |||
| return !is_virtual_node; | |||
| } | |||
| @@ -43,8 +43,8 @@ tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) { | |||
| bool IsSpecialType(const CNodePtr &cnode) { | |||
| if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || | |||
| CheckPrimitiveType(cnode, prim::kPrimControlDepend) || CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || | |||
| CheckPrimitiveType(cnode, prim::kPrimReturn) || CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) || | |||
| CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimReturn) || | |||
| CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) || | |||
| CheckPrimitiveType(cnode, std::make_shared<Primitive>("If"))) { | |||
| return true; | |||
| } | |||
| @@ -58,13 +58,6 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| } | |||
| if (CheckPrimitiveType(anf_node, prim::kPrimControlDepend)) { | |||
| if (cnode->size() != InputDoubleNum) { | |||
| MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; | |||
| remove_cnode_.insert(anf_node); | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| } | |||
| bool replace_succ = manager->Replace(anf_node, cnode->input(1)); | |||
| if (!replace_succ) { | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -96,7 +96,6 @@ class GPT2FinetuneCell(nn.Cell): | |||
| self.get_status = P.NPUGetFloatStatus() | |||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||
| self.base = Tensor(1, mstype.float32) | |||
| self.less_equal = P.LessEqual() | |||
| self.hyper_map = C.HyperMap() | |||
| @@ -132,8 +131,8 @@ class GPT2FinetuneCell(nn.Cell): | |||
| if not self.gpu_target: | |||
| init = self.alloc_status() | |||
| init = F.depend(init, loss) | |||
| clear_before_grad = self.clear_before_grad(init) | |||
| F.control_depend(loss, init) | |||
| self.depend_parameter_use(clear_before_grad, scaling_sens) | |||
| grads = self.grad(self.network, weights)(input_ids, | |||
| input_mask, | |||
| @@ -145,10 +144,10 @@ class GPT2FinetuneCell(nn.Cell): | |||
| if self.reducer_flag: | |||
| grads = self.grad_reducer(grads) | |||
| if not self.gpu_target: | |||
| init = F.depend(init, grads) | |||
| flag = self.get_status(init) | |||
| init = F.depend(init, flag) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| F.control_depend(grads, flag) | |||
| F.control_depend(flag, flag_sum) | |||
| else: | |||
| flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) | |||
| flag_sum = self.addn(flag_sum) | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -74,9 +74,9 @@ TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence_with_make_tupl | |||
| /* | |||
| * def before(x, y, a, b): | |||
| * z = make_tuple(TransData(a), TransData(b)) | |||
| * depend_intput = control_depend(y, z) | |||
| * sum = add(x, depend_intput) | |||
| * return sum | |||
| * depend_intput = depend(y, z) | |||
| * sum_add = add(x, depend_intput) | |||
| * return sum_add | |||
| */ | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence_with_make_tuple", "before"); | |||
| @@ -93,11 +93,11 @@ TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence_with_make_tupl | |||
| TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence) { | |||
| /* | |||
| * def before(x, y, a, b): | |||
| * z = make_tuple(TransData(a), TransData(b)) | |||
| * depend_intput = control_depend(y, z) | |||
| * sum = add(x, depend_intput) | |||
| * return sum | |||
| * def before(x, y, z): | |||
| * new_z = TransData(z) | |||
| * depend_intput = depend(y, new_z) | |||
| * sum_add = add(x, depend_intput) | |||
| * return sum_add | |||
| */ | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence", "before"); | |||