From: @huangbingjian Reviewed-by: Signed-off-by:pull/14612/MERGE
| @@ -108,7 +108,7 @@ bool InputCheck(const AnfNodePtr &node) { | |||||
| MS_LOG(INFO) << "Data->TransData->split, can not optimizer."; | MS_LOG(INFO) << "Data->TransData->split, can not optimizer."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (in_node_name == prim::kPrimControlDepend->name() || in_node_name == prim::kPrimDepend->name()) { | |||||
| if (in_node_name == prim::kPrimDepend->name()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| if ((AnfAlgo::HasNodeAttr("non_task", in_cnode) && AnfAlgo::GetNodeAttr<bool>(in_node, "non_task")) || | if ((AnfAlgo::HasNodeAttr("non_task", in_cnode) && AnfAlgo::GetNodeAttr<bool>(in_node, "non_task")) || | ||||
| @@ -131,7 +131,7 @@ bool OutputCheck(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| for (const auto &item : outputs) { | for (const auto &item : outputs) { | ||||
| if (IsPrimitiveCNode(item, prim::kPrimControlDepend) || IsPrimitiveCNode(item, prim::kPrimDepend)) { | |||||
| if (IsPrimitiveCNode(item, prim::kPrimDepend)) { | |||||
| MS_LOG(INFO) << "Split has control edge, can not optimizer."; | MS_LOG(INFO) << "Split has control edge, can not optimizer."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -168,7 +168,7 @@ const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, c | |||||
| (void)manager->Replace(output, bn_training_update_v2_outputs[output_index]); | (void)manager->Replace(output, bn_training_update_v2_outputs[output_index]); | ||||
| output_index++; | output_index++; | ||||
| } | } | ||||
| // Return the new node for control depends. | |||||
| // Return the new node. | |||||
| return bn_training_update_v2; | return bn_training_update_v2; | ||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -201,20 +201,6 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if (dropout_do_mask1 != nullptr) { | |||||
| // Dropout is used by ControlDepend in some situation, need to replace ControlDepend. | |||||
| auto &users = manager->node_users(); | |||||
| iter = users.find(dropout_node); | |||||
| if (iter != users.end()) { | |||||
| for (auto &node_index : iter->second) { | |||||
| auto used_node = node_index.first; | |||||
| if (AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimControlDepend)) { | |||||
| (void)manager->Replace(used_node, dropout_do_mask1); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // CreateDropoutDoMask-backward | // CreateDropoutDoMask-backward | ||||
| if (equiv->find(grad_input_) == equiv->end()) { | if (equiv->find(grad_input_) == equiv->end()) { | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -426,9 +426,6 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOu | |||||
| } | } | ||||
| auto output_info_list = iter->second; | auto output_info_list = iter->second; | ||||
| for (const auto &output_info : output_info_list) { | for (const auto &output_info : output_info_list) { | ||||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { | |||||
| continue; | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && | if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && | ||||
| output_info.second == kDependAttachNodeIndex) { | output_info.second == kDependAttachNodeIndex) { | ||||
| continue; | continue; | ||||
| @@ -908,16 +905,12 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto manager = graph->manager(); | auto manager = graph->manager(); | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| // find BatchNorm's output which is a Depend or ControlDepend | |||||
| // find BatchNorm's output which is a Depend | |||||
| for (const auto &node_index : manager->node_users()[old_node]) { | for (const auto &node_index : manager->node_users()[old_node]) { | ||||
| AnfNodePtr output = node_index.first; | AnfNodePtr output = node_index.first; | ||||
| size_t index = IntToSize(node_index.second); | size_t index = IntToSize(node_index.second); | ||||
| MS_EXCEPTION_IF_NULL(output); | MS_EXCEPTION_IF_NULL(output); | ||||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { | |||||
| auto control_depend = output->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(control_depend); | |||||
| control_depend->set_input(index, new_node); | |||||
| } else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) { | |||||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) { | |||||
| auto depend = output->cast<CNodePtr>(); | auto depend = output->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(depend); | MS_EXCEPTION_IF_NULL(depend); | ||||
| depend->set_input(index, new_node); | depend->set_input(index, new_node); | ||||
| @@ -210,7 +210,7 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor | |||||
| // Create a new value node of func graph,not kernel graph | // Create a new value node of func graph,not kernel graph | ||||
| ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); | ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); | ||||
| // Transfer depend or control_depend to the new node | |||||
| // Transfer depend to the new node | |||||
| void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); | void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); | ||||
| AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); | AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); | ||||
| @@ -327,7 +327,7 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node, | |||||
| void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, | void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, | ||||
| const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) { | const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) { | ||||
| // Create depend node to hold new control depend node. | |||||
| // Create depend node to hold execution order. | |||||
| AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node}; | AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node}; | ||||
| auto depend_cnode = main_graph->NewCNode(d_inputs); | auto depend_cnode = main_graph->NewCNode(d_inputs); | ||||
| depend_cnode->set_abstract(clean_node->abstract()); | depend_cnode->set_abstract(clean_node->abstract()); | ||||
| @@ -501,18 +501,17 @@ bool AtomicCleanInsertter::IsExistStructuralObstacle(const KernelGraphPtr &main_ | |||||
| const FuncGraphManagerPtr &mng) { | const FuncGraphManagerPtr &mng) { | ||||
| auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false); | auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false); | ||||
| // If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce | // If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce | ||||
| // node and user node. If reduce is Depend or ControlDepend node, the origin node may be wrong! | |||||
| return std::all_of(reduce_users.cbegin(), reduce_users.cend(), | |||||
| [&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool { | |||||
| auto &user = user_info.first; | |||||
| if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend) || | |||||
| IsPrimitiveCNode(user, prim::kPrimControlDepend)) && | |||||
| !(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) { | |||||
| return false; | |||||
| } else { | |||||
| return true; | |||||
| } | |||||
| }); | |||||
| // node and user node. If reduce is Depend node, the origin node may be wrong! | |||||
| return std::all_of( | |||||
| reduce_users.cbegin(), reduce_users.cend(), [&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool { | |||||
| auto &user = user_info.first; | |||||
| if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend)) && | |||||
| !(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) { | |||||
| return false; | |||||
| } else { | |||||
| return true; | |||||
| } | |||||
| }); | |||||
| } | } | ||||
| bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { | bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { | ||||
| @@ -123,9 +123,9 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr | |||||
| bool changed = false; | bool changed = false; | ||||
| auto mng = kernel_graph->manager(); | auto mng = kernel_graph->manager(); | ||||
| // depend_prior[depend] = pair(prior, controlDependNode) | |||||
| // depend_prior[depend] = pair(prior, behind) | |||||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior; | std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior; | ||||
| InitDependPrior(todos, &depend_prior); | |||||
| // InitDependPrior(todos, &depend_prior); | |||||
| for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { | for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { | ||||
| auto node = (*iter)->cast<CNodePtr>(); | auto node = (*iter)->cast<CNodePtr>(); | ||||
| @@ -657,76 +657,6 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { | |||||
| #endif | #endif | ||||
| } | } | ||||
| void InitDependPrior(const std::vector<AnfNodePtr> &todos, | |||||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) { | |||||
| for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { | |||||
| auto cnode = (*iter)->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimControlDepend)) { | |||||
| continue; | |||||
| } | |||||
| auto prior_node = cnode->input(kControlDependPriorIndex); | |||||
| auto depend_node = cnode->input(kControlDependBehindIndex); | |||||
| MS_EXCEPTION_IF_NULL(prior_node); | |||||
| MS_EXCEPTION_IF_NULL(depend_node); | |||||
| std::vector<AnfNodePtr> prior_nodes = {prior_node}; | |||||
| std::vector<AnfNodePtr> depend_nodes = {depend_node}; | |||||
| int depend_mode = 0; | |||||
| if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { | |||||
| depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode); | |||||
| } | |||||
| auto GetOutputNodes = [cnode](const AnfNodePtr ¶m) -> std::vector<AnfNodePtr> { | |||||
| std::vector<AnfNodePtr> out_nodes; | |||||
| auto user_set = param->func_graph()->manager()->node_users()[param]; | |||||
| for (auto iter = user_set.cbegin(); iter != user_set.cend(); ++iter) { | |||||
| if (iter->first != cnode) { | |||||
| out_nodes.push_back(iter->first); | |||||
| } | |||||
| } | |||||
| return out_nodes; | |||||
| }; | |||||
| if (prior_node->isa<Parameter>() && depend_mode == 1) { | |||||
| prior_nodes = GetOutputNodes(prior_node); | |||||
| } | |||||
| if (depend_node->isa<Parameter>()) { | |||||
| depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{}; | |||||
| } | |||||
| std::vector<AnfNodePtr> real_prior_nodes; | |||||
| std::set<AnfNodePtr> prior_visited; | |||||
| for (const auto &tmp : prior_nodes) { | |||||
| AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); | |||||
| } | |||||
| prior_visited.clear(); | |||||
| std::vector<AnfNodePtr> real_depend_nodes; | |||||
| std::set<AnfNodePtr> depend_visited; | |||||
| for (const auto &tmp : depend_nodes) { | |||||
| AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); | |||||
| } | |||||
| depend_visited.clear(); | |||||
| for (auto &prior : real_prior_nodes) { | |||||
| if (AnfAlgo::CheckPrimitiveType(prior, prim::kPrimControlDepend)) { | |||||
| continue; | |||||
| } | |||||
| for (auto &depend : real_depend_nodes) { | |||||
| if (AnfAlgo::CheckPrimitiveType(depend, prim::kPrimControlDepend)) { | |||||
| continue; | |||||
| } | |||||
| depend_prior->insert({depend, std::make_pair(prior, cnode)}); | |||||
| } | |||||
| } | |||||
| real_prior_nodes.clear(); | |||||
| real_depend_nodes.clear(); | |||||
| } | |||||
| } | |||||
| void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, | void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, | ||||
| const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) { | const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) { | ||||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri; | std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri; | ||||
| @@ -75,8 +75,6 @@ std::vector<PrimitivePtr> GetFusibleOpList(); | |||||
| bool IsBasicFuseOp(const AnfNodePtr &node); | bool IsBasicFuseOp(const AnfNodePtr &node); | ||||
| bool IsFusibleOp(const AnfNodePtr &node); | bool IsFusibleOp(const AnfNodePtr &node); | ||||
| void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | ||||
| void InitDependPrior(const std::vector<AnfNodePtr> &todos, | |||||
| std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior); | |||||
| void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, | void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, | ||||
| const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs); | const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -55,7 +55,7 @@ CNodePtr GetRealPrevCNode(const AnfNodePtr &node, size_t index, std::vector<Kern | |||||
| auto item_idx = GetValue<int64_t>(value_node->value()); | auto item_idx = GetValue<int64_t>(value_node->value()); | ||||
| pass_vector->push_back(make_pair(cnode, IntToSize(1))); | pass_vector->push_back(make_pair(cnode, IntToSize(1))); | ||||
| return GetRealPrevCNode(cnode->input(1), LongToSize(item_idx), pass_vector); | return GetRealPrevCNode(cnode->input(1), LongToSize(item_idx), pass_vector); | ||||
| } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { | |||||
| } else if (IsPrimitive(input0, prim::kPrimDepend)) { | |||||
| pass_vector->push_back(make_pair(cnode, IntToSize(1))); | pass_vector->push_back(make_pair(cnode, IntToSize(1))); | ||||
| return GetRealPrevCNode(cnode->input(1), 0, pass_vector); | return GetRealPrevCNode(cnode->input(1), 0, pass_vector); | ||||
| } else if (IsPrimitive(input0, prim::kPrimUpdateState)) { | } else if (IsPrimitive(input0, prim::kPrimUpdateState)) { | ||||
| @@ -92,8 +92,7 @@ const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNode | |||||
| auto pass_size = pass_vector->size(); | auto pass_size = pass_vector->size(); | ||||
| for (size_t idx = 1; idx <= pass_size - 1; ++idx) { | for (size_t idx = 1; idx <= pass_size - 1; ++idx) { | ||||
| auto nd = (*pass_vector)[idx].first; | auto nd = (*pass_vector)[idx].first; | ||||
| if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend) || | |||||
| AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { | |||||
| if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend)) { | |||||
| has_depend_node = true; | has_depend_node = true; | ||||
| } | } | ||||
| if (users[nd].size() >= 2) { | if (users[nd].size() >= 2) { | ||||
| @@ -248,7 +248,7 @@ class AnfRuntimeAlgorithm { | |||||
| static void InferShape(const CNodePtr &node); | static void InferShape(const CNodePtr &node); | ||||
| static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); | static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); | ||||
| static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); | static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); | ||||
| // Find control_depend real input nodes. | |||||
| // Find real input nodes. | |||||
| static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, | static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, | ||||
| std::set<AnfNodePtr> *visited); | std::set<AnfNodePtr> *visited); | ||||
| }; | }; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -534,14 +534,17 @@ void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNul | |||||
| return_node->set_input(kFirstDataInputIndex, depend_node); | return_node->set_input(kFirstDataInputIndex, depend_node); | ||||
| } | } | ||||
| void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node, | |||||
| NotNull<AnfNodePtr> second_node) { | |||||
| MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() | |||||
| << ", the second node is " << second_node->DebugString(); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())), | |||||
| first_node, second_node}; | |||||
| auto control_depend = kg->NewCNode(inputs); | |||||
| InsertDependToGraph(kg, NOT_NULL(control_depend)); | |||||
| void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> prior_node, | |||||
| NotNull<AnfNodePtr> behind_node) { | |||||
| MS_LOG(INFO) << "Insert control dependence at the end of graph, the prior node is " << prior_node->DebugString() | |||||
| << ", the behind node is " << behind_node->DebugString(); | |||||
| auto manager = kg->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| AnfNodePtrList inputs = {NewValueNode(prim::kPrimDepend), behind_node, prior_node}; | |||||
| auto depend_cnode = kg->NewCNode(inputs); | |||||
| if (!manager->Replace(behind_node, depend_cnode)) { | |||||
| MS_LOG(EXCEPTION) << behind_node->DebugString() << ", replace node failed."; | |||||
| } | |||||
| } | } | ||||
| void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -422,7 +422,7 @@ void KernelGraph::CheckLoop() { | |||||
| none_zero_nodes[it.first] = it.second; | none_zero_nodes[it.first] = it.second; | ||||
| } | } | ||||
| } | } | ||||
| // if don't consider control depend and loop exit,a exception will be throw | |||||
| // if don't consider loop exit,a exception will be throw | |||||
| if (!none_zero_nodes.empty()) { | if (!none_zero_nodes.empty()) { | ||||
| MS_LOG(WARNING) << "Nums of loop:" << GetLoopNum(none_zero_nodes); | MS_LOG(WARNING) << "Nums of loop:" << GetLoopNum(none_zero_nodes); | ||||
| MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); | MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); | ||||
| @@ -815,61 +815,10 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) { | |||||
| return output_nodes; | return output_nodes; | ||||
| } | } | ||||
| // update the depend relations of control depend | |||||
| void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) { | |||||
| for (const auto &node : depends) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| return; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { | |||||
| MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend"; | |||||
| } | |||||
| auto prior_node = cnode->input(kControlDependPriorIndex); | |||||
| auto depend_node = cnode->input(kControlDependBehindIndex); | |||||
| MS_EXCEPTION_IF_NULL(prior_node); | |||||
| MS_EXCEPTION_IF_NULL(depend_node); | |||||
| std::vector<AnfNodePtr> prior_nodes = {prior_node}; | |||||
| std::vector<AnfNodePtr> depend_nodes = {depend_node}; | |||||
| int depend_mode = 0; | |||||
| if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { | |||||
| depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode); | |||||
| } | |||||
| MS_LOG(DEBUG) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString() | |||||
| << "], depend_mode :" << depend_mode << "."; | |||||
| if (prior_node->isa<Parameter>() && depend_mode == 1) { | |||||
| prior_nodes = GetOutputNodes(prior_node); | |||||
| } | |||||
| if (depend_node->isa<Parameter>()) { | |||||
| depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{}; | |||||
| } | |||||
| std::vector<AnfNodePtr> real_prior_nodes; | |||||
| std::set<AnfNodePtr> prior_visited; | |||||
| for (const auto &tmp : prior_nodes) { | |||||
| AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); | |||||
| } | |||||
| std::vector<AnfNodePtr> real_depend_nodes; | |||||
| std::set<AnfNodePtr> depend_visited; | |||||
| for (const auto &tmp : depend_nodes) { | |||||
| AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); | |||||
| } | |||||
| UpdateNodeInputOutputEdges(real_prior_nodes, real_depend_nodes); | |||||
| } | |||||
| } | |||||
| void KernelGraph::UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes, | void KernelGraph::UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes, | ||||
| const std::vector<AnfNodePtr> &real_depend_nodes) { | const std::vector<AnfNodePtr> &real_depend_nodes) { | ||||
| for (auto &first_node : real_prior_nodes) { | for (auto &first_node : real_prior_nodes) { | ||||
| if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { | |||||
| continue; | |||||
| } | |||||
| for (auto &second_node : real_depend_nodes) { | for (auto &second_node : real_depend_nodes) { | ||||
| if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { | |||||
| continue; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(first_node); | MS_EXCEPTION_IF_NULL(first_node); | ||||
| MS_EXCEPTION_IF_NULL(second_node); | MS_EXCEPTION_IF_NULL(second_node); | ||||
| MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); | MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); | ||||
| @@ -878,35 +827,6 @@ void KernelGraph::UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real | |||||
| } | } | ||||
| } | } | ||||
| bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, | |||||
| std::unordered_set<AnfNodePtr> *visited_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(que); | |||||
| MS_EXCEPTION_IF_NULL(visited_nodes); | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { | |||||
| return false; | |||||
| } | |||||
| // set the control depend visited but don't push it into the que | |||||
| if (visited_nodes->find(node) != visited_nodes->end()) { | |||||
| return true; | |||||
| } | |||||
| (void)visited_nodes->insert(cnode); | |||||
| // add a 0 depend num to keep the link relations to prepare for finding zero output nodes | |||||
| auto prior_node = cnode->input(kControlDependPriorIndex); | |||||
| auto depend_node = cnode->input(kControlDependBehindIndex); | |||||
| for (const auto &input : cnode->inputs()) { | |||||
| AddDependEdge(node, input, 0); | |||||
| } | |||||
| PushNoVisitedNode(depend_node, que, visited_nodes); | |||||
| PushNoVisitedNode(prior_node, que, visited_nodes); | |||||
| return true; | |||||
| } | |||||
| void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) { | void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) { | ||||
| MS_EXCEPTION_IF_NULL(seed_nodes); | MS_EXCEPTION_IF_NULL(seed_nodes); | ||||
| node_output_edges_.clear(); | node_output_edges_.clear(); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -286,15 +286,11 @@ class KernelGraph : public FuncGraph { | |||||
| std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true); | std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true); | ||||
| // update node edge list | // update node edge list | ||||
| void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes); | void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes); | ||||
| // add node depend edge by data edge or control depend | |||||
| // add node depend edge by data edge | |||||
| void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); | void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); | ||||
| void UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes, | void UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes, | ||||
| const std::vector<AnfNodePtr> &real_depend_nodes); | const std::vector<AnfNodePtr> &real_depend_nodes); | ||||
| // handle control depend | |||||
| std::vector<AnfNodePtr> GetOutputNodes(const AnfNodePtr &node); | std::vector<AnfNodePtr> GetOutputNodes(const AnfNodePtr &node); | ||||
| bool HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, | |||||
| std::unordered_set<AnfNodePtr> *visited_nodes); | |||||
| void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends); | |||||
| AnfNodePtr TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value); | AnfNodePtr TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value); | ||||
| AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract); | AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract); | ||||
| AnfNodePtr TransCNodeTuple(const CNodePtr &node); | AnfNodePtr TransCNodeTuple(const CNodePtr &node); | ||||
| @@ -223,10 +223,8 @@ BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &gra | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| VectorRef ret; | VectorRef ret; | ||||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | for (size_t i = 1; i < cnode->inputs().size(); ++i) { | ||||
| if (!AnfAlgo::CheckPrimitiveType(cnode->input(i), prim::kPrimControlDepend)) { | |||||
| auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node); | |||||
| ret.push_back(out); | |||||
| } | |||||
| auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node); | |||||
| ret.push_back(out); | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -386,22 +384,6 @@ bool ExistSummaryNode(const KernelGraph *graph) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { | |||||
| return false; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| const auto &node_inputs = cnode->inputs(); | |||||
| for (size_t i = 1; i < node_inputs.size(); ++i) { | |||||
| if (!AnfAlgo::CheckPrimitiveType(node_inputs[i], prim::kPrimControlDepend)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs, | void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs, | ||||
| std::map<AnfNodePtr, size_t> *parameter_index) { | std::map<AnfNodePtr, size_t> *parameter_index) { | ||||
| size_t index = 0; | size_t index = 0; | ||||
| @@ -692,9 +674,6 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const | |||||
| AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) { | AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| if (IgnoreCreateParameterForMakeTuple(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); | auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); | ||||
| auto parameters = AnfAlgo::GetAllOutput(new_parameter); | auto parameters = AnfAlgo::GetAllOutput(new_parameter); | ||||
| std::vector<AnfNodePtr> pre_graph_out = {node}; | std::vector<AnfNodePtr> pre_graph_out = {node}; | ||||
| @@ -1872,9 +1851,6 @@ std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr | |||||
| auto &users = front_func_graph_manager->node_users()[front_node]; | auto &users = front_func_graph_manager->node_users()[front_node]; | ||||
| std::vector<AnfNodePtr> result; | std::vector<AnfNodePtr> result; | ||||
| for (auto &user : users) { | for (auto &user : users) { | ||||
| if (IsPrimitiveCNode(user.first, prim::kPrimControlDepend)) { | |||||
| continue; | |||||
| } | |||||
| if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) { | if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) { | ||||
| auto depend_cnode = user.first->cast<CNodePtr>(); | auto depend_cnode = user.first->cast<CNodePtr>(); | ||||
| if (depend_cnode == nullptr) { | if (depend_cnode == nullptr) { | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -84,7 +84,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { | |||||
| } | } | ||||
| } | } | ||||
| std::vector<PrimitivePtr> adapter_convert_ops = {prim::kPrimDepend, prim::kPrimControlDepend, prim::kPrimLoad}; | |||||
| std::vector<PrimitivePtr> adapter_convert_ops = {prim::kPrimDepend, prim::kPrimLoad}; | |||||
| for (auto &item : adapter_convert_ops) { | for (auto &item : adapter_convert_ops) { | ||||
| if (IsPrimitiveCNode(node, item)) { | if (IsPrimitiveCNode(node, item)) { | ||||
| return true; | return true; | ||||
| @@ -243,8 +243,7 @@ CNodePtr MergeNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, int64_t sw | |||||
| return merge_op; | return merge_op; | ||||
| } | } | ||||
| // construct a depend node with merge output node, merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data)) | |||||
| // control_depend(output_node, square_op) | |||||
| // merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data)) | |||||
| AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node, | AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node, | ||||
| int64_t switch_idx) { | int64_t switch_idx) { | ||||
| tensor::TensorPtr const_data = GetConstData(); | tensor::TensorPtr const_data = GetConstData(); | ||||
| @@ -259,54 +258,21 @@ AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr | |||||
| SetSquareOp(switch_idx, square_op); | SetSquareOp(switch_idx, square_op); | ||||
| } | } | ||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| AnfNodePtrList inputs = {NewValueNode(prim::kPrimDepend), square_op, output_node}; | |||||
| auto depend_cnode = graph->NewCNode(inputs); | |||||
| if (!manager->Replace(square_op, depend_cnode)) { | |||||
| MS_LOG(EXCEPTION) << square_op->DebugString() << ", replace node failed."; | |||||
| } | |||||
| CNodePtr merge_op = GetMergeOp(switch_idx); | CNodePtr merge_op = GetMergeOp(switch_idx); | ||||
| if (merge_op == nullptr) { | if (merge_op == nullptr) { | ||||
| merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op); | merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op); | ||||
| SetMergeOp(switch_idx, merge_op); | SetMergeOp(switch_idx, merge_op); | ||||
| } | } | ||||
| std::vector<AnfNodePtr> control_depend_nodes{NewValueNode(prim::kPrimControlDepend), output_node, square_op}; | |||||
| auto control_depend_op = graph->NewCNode(control_depend_nodes); | |||||
| std::vector<AnfNodePtr> depend_nodes{NewValueNode(prim::kPrimDepend), merge_op, control_depend_op}; | |||||
| auto depend_op = graph->NewCNode(depend_nodes); | |||||
| return depend_op; | |||||
| } | |||||
| // construct a merge output and add dependency with the netoutput node from control_depend | |||||
| // we need to reserve the control_depend node, besides the generated merge node and control_depend node | |||||
| CNodePtr GenerateSwitchControlDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, | |||||
| const AnfNodePtr &ctrl_dep_node, const AnfNodePtr &ctrl_depend_dst, | |||||
| int64_t switch_idx) { | |||||
| auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast<PrimitivePtr>(); | |||||
| auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast<PrimitivePtr>(); | |||||
| std::vector<int64_t> shp = {1}; | |||||
| tensor::TensorPtr const_data = std::make_shared<tensor::Tensor>(kInt64->type_id(), shp); | |||||
| auto *val = static_cast<int64_t *>(const_data->data_c()); | |||||
| *val = 0; | |||||
| // for the control_depend netoutput node , add two const data to merge the flow ,one for depended node with same | |||||
| // switch the other use the opposite | |||||
| auto ctrl_data = NewValueNode(const_data); | |||||
| auto oppsite_ctrl_data = NewValueNode(const_data); | |||||
| auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx); | |||||
| auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx); | |||||
| std::vector<AnfNodePtr> square_nodes{NewValueNode(PrimSquare), ctrl_node}; | |||||
| auto square_op = graph->NewCNode(square_nodes); | |||||
| std::vector<AnfNodePtr> merge_nodes; | |||||
| merge_nodes.push_back(NewValueNode(PrimMerge)); | |||||
| std::vector<AnfNodePtr> make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node}; | |||||
| merge_nodes.push_back(graph->NewCNode(make_tuple_nodes)); | |||||
| auto merge_output = graph->NewCNode(merge_nodes); | |||||
| std::vector<AnfNodePtr> control_depend_nodes{NewValueNode(prim::kPrimControlDepend), ctrl_depend_dst, square_op}; | |||||
| auto cond_dep_output = graph->NewCNode(control_depend_nodes); | |||||
| std::vector<AnfNodePtr> depended_make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), ctrl_dep_node, merge_output, | |||||
| cond_dep_output}; | |||||
| return graph->NewCNode(depended_make_tuple_nodes); | |||||
| return merge_op; | |||||
| } | } | ||||
| // generate switch nodes for true graph node inputs | // generate switch nodes for true graph node inputs | ||||
| @@ -321,26 +287,12 @@ AnfNodePtr GenerateSwitchDependFalseNode(const FuncGraphPtr &graph, const AnfNod | |||||
| return GenerateSwitchDependNode(graph, cond, data, 0); | return GenerateSwitchDependNode(graph, cond, data, 0); | ||||
| } | } | ||||
| // generate switch nodes for true graph node inputs | |||||
| CNodePtr GenerateSwitchControlDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, | |||||
| const AnfNodePtr &con_input, const AnfNodePtr &output) { | |||||
| // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch | |||||
| return GenerateSwitchControlDependNode(graph, cond, con_input, output, 1); | |||||
| } | |||||
| // generate switch nodes for false graph node inputs | |||||
| CNodePtr GenerateSwitchControlDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, | |||||
| const AnfNodePtr &con_input, const AnfNodePtr &output) { | |||||
| // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch | |||||
| return GenerateSwitchControlDependNode(graph, cond, con_input, output, 0); | |||||
| } | |||||
| // to judge if the node used in ControlDepend is a net output node | |||||
| // to judge if the node used in Depend is a net output node | |||||
| bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { | bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { | ||||
| auto uses = manager->node_users()[node]; | auto uses = manager->node_users()[node]; | ||||
| bool is_output_node = true; | bool is_output_node = true; | ||||
| for (auto &item : uses) { | for (auto &item : uses) { | ||||
| if (IsPrimitiveCNode(item.first, prim::kPrimControlDepend) || IsPrimitiveCNode(item.first, prim::kPrimDepend)) { | |||||
| if (IsPrimitiveCNode(item.first, prim::kPrimDepend)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| is_output_node = false; | is_output_node = false; | ||||
| @@ -353,8 +305,7 @@ bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) | |||||
| void GenerateReplNodeForDependMakeTuple( | void GenerateReplNodeForDependMakeTuple( | ||||
| const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond, | const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond, | ||||
| const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node, | const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node, | ||||
| const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func, | |||||
| const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) { | |||||
| const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func) { | |||||
| MS_EXCEPTION_IF_NULL(graph->manager()); | MS_EXCEPTION_IF_NULL(graph->manager()); | ||||
| auto make_tuple_inputs = depended_node->cast<CNodePtr>()->inputs(); | auto make_tuple_inputs = depended_node->cast<CNodePtr>()->inputs(); | ||||
| @@ -368,26 +319,6 @@ void GenerateReplNodeForDependMakeTuple( | |||||
| new_make_tuple_nodes.push_back(depended_tuple_input_node); | new_make_tuple_nodes.push_back(depended_tuple_input_node); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsPrimitiveCNode(depended_tuple_input_node->cast<CNodePtr>(), prim::kPrimControlDepend)) { | |||||
| // only when the control depend input is not square op (the op to use as merge output) | |||||
| auto control_inputs = depended_tuple_input_node->cast<CNodePtr>()->inputs(); | |||||
| if (control_inputs.size() != 3) { | |||||
| MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); | |||||
| } | |||||
| // control inputs: primitive, src, dst | |||||
| auto dst_node = control_inputs[2]; | |||||
| if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { | |||||
| auto gen_node = gen_ctl_depd_func(graph, cond, make_tuple_inputs[idx], dst_node); | |||||
| MS_EXCEPTION_IF_NULL(gen_node); | |||||
| auto tuple_inputs = gen_node->inputs(); | |||||
| // add depended tuple inputs to new_make_tuple directly | |||||
| for (size_t i = 1; i < tuple_inputs.size(); i++) { | |||||
| new_make_tuple_nodes.push_back(tuple_inputs[i]); | |||||
| } | |||||
| } | |||||
| replace_make_tuple = true; | |||||
| continue; | |||||
| } | |||||
| if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) { | if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) { | ||||
| auto gen_node = generate_func(graph, cond, depended_tuple_input_node); | auto gen_node = generate_func(graph, cond, depended_tuple_input_node); | ||||
| @@ -408,8 +339,7 @@ void GenerateReplNodeForDependMakeTuple( | |||||
| void GenerateRepDepend( | void GenerateRepDepend( | ||||
| const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond, | const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond, | ||||
| const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node, | const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node, | ||||
| const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func, | |||||
| const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) { | |||||
| const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func) { | |||||
| auto inputs = node->inputs(); | auto inputs = node->inputs(); | ||||
| if (inputs.size() != 3) { | if (inputs.size() != 3) { | ||||
| MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node]."; | MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node]."; | ||||
| @@ -422,19 +352,7 @@ void GenerateRepDepend( | |||||
| new_depened_inputs.push_back(inputs[1]); | new_depened_inputs.push_back(inputs[1]); | ||||
| // depended node should be make_tuple or a single depended node | // depended node should be make_tuple or a single depended node | ||||
| if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) { | if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) { | ||||
| GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func, gen_ctl_depd_func); | |||||
| } else if (IsPrimitiveCNode(depended_node, prim::kPrimControlDepend)) { | |||||
| // only when the control depend input is not square op (the op to use as merge output) | |||||
| auto control_inputs = depended_node->cast<CNodePtr>()->inputs(); | |||||
| // control inputs: primitive, src, dst | |||||
| if (control_inputs.size() != 3) { | |||||
| MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); | |||||
| } | |||||
| auto dst_node = control_inputs[2]; | |||||
| if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { | |||||
| auto gen_node = gen_ctl_depd_func(graph, cond, depended_node, dst_node); | |||||
| (*repl_node)[depended_node] = gen_node; | |||||
| } | |||||
| GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func); | |||||
| } else { | } else { | ||||
| // Check if there is only single user for depend_node. | // Check if there is only single user for depend_node. | ||||
| if (graph->manager()->node_users()[depended_node].size() == 1) { | if (graph->manager()->node_users()[depended_node].size() == 1) { | ||||
| @@ -448,11 +366,9 @@ void GenerateRepDepend( | |||||
| // generate depend node for netoutput node, to resolve the stream synchronize problem of ge | // generate depend node for netoutput node, to resolve the stream synchronize problem of ge | ||||
| // traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const) | // traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const) | ||||
| // and add control_depend of graph output node and square node. | |||||
| FuncGraphPtr TransformGraphDependNode( | FuncGraphPtr TransformGraphDependNode( | ||||
| const FuncGraphPtr &graph, const AnfNodePtr &cond, | const FuncGraphPtr &graph, const AnfNodePtr &cond, | ||||
| const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &gen_depend_func, | |||||
| const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) { | |||||
| const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &gen_depend_func) { | |||||
| auto manager = graph->manager(); | auto manager = graph->manager(); | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| @@ -478,7 +394,7 @@ FuncGraphPtr TransformGraphDependNode( | |||||
| if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) { | if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func, gen_ctl_depd_func); | |||||
| GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func); | |||||
| } | } | ||||
| } | } | ||||
| ResetSharedOp(); | ResetSharedOp(); | ||||
| @@ -494,12 +410,12 @@ FuncGraphPtr TransformGraphDependNode( | |||||
| FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { | FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { | ||||
| (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode); | (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode); | ||||
| return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode, GenerateSwitchControlDependTrueNode); | |||||
| return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode); | |||||
| } | } | ||||
| FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { | FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { | ||||
| (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode); | (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode); | ||||
| return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode, GenerateSwitchControlDependFalseNode); | |||||
| return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode); | |||||
| } | } | ||||
| // judge if the true and false graph output is compatible(they shall have same tuple size) | // judge if the true and false graph output is compatible(they shall have same tuple size) | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -218,10 +218,10 @@ bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) { | |||||
| if (output_set_iter == node_users.end()) { | if (output_set_iter == node_users.end()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| for (const auto &node_index_set : output_set_iter->second) { | |||||
| if (!IsBpropNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) { | |||||
| return true; | |||||
| } | |||||
| if (std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(), | |||||
| [](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); })) { | |||||
| return true; | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -367,7 +367,6 @@ constexpr char HISTOGRAMSUMMARY[] = "HistogramSummary"; | |||||
| constexpr char DEBUG[] = "Debug"; | constexpr char DEBUG[] = "Debug"; | ||||
| constexpr char BROADCASTGRADIENTARGS[] = "BroadcastGradientArgs"; | constexpr char BROADCASTGRADIENTARGS[] = "BroadcastGradientArgs"; | ||||
| constexpr char INVERTPERMUTATION[] = "InvertPermutation"; | constexpr char INVERTPERMUTATION[] = "InvertPermutation"; | ||||
| constexpr char CONTROLDEPEND[] = "ControlDepend"; | |||||
| constexpr char DOT[] = "dot"; | constexpr char DOT[] = "dot"; | ||||
| constexpr char IM2COL[] = "im2col"; | constexpr char IM2COL[] = "im2col"; | ||||
| constexpr char COL2IM[] = "col2im"; | constexpr char COL2IM[] = "col2im"; | ||||
| @@ -259,10 +259,8 @@ BaseRef CreateOutputTensors(const AnfNodePtr &output_node, const KernelGraphPtr | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| VectorRef ret; | VectorRef ret; | ||||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | for (size_t i = 1; i < cnode->inputs().size(); ++i) { | ||||
| if (!AnfAlgo::CheckPrimitiveType(cnode->input(i), prim::kPrimControlDepend)) { | |||||
| auto out = CreateOutputTensors(cnode->input(i), graph, input_tensors); | |||||
| ret.push_back(out); | |||||
| } | |||||
| auto out = CreateOutputTensors(cnode->input(i), graph, input_tensors); | |||||
| ret.push_back(out); | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -1044,7 +1044,7 @@ bool DfGraphConvertor::IsControlEdgeNode(const AnfNodePtr &node) { | |||||
| OperatorPtr DfGraphConvertor::ToOperatorPtr(const AnfNodePtr &node) { | OperatorPtr DfGraphConvertor::ToOperatorPtr(const AnfNodePtr &node) { | ||||
| auto op = Convert(GetRealOpNode(node)); | auto op = Convert(GetRealOpNode(node)); | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "Convert control depend node to operator failed, " << node->ToString(); | |||||
| MS_LOG(ERROR) << "Convert real op node to operator failed, " << node->ToString(); | |||||
| error_ = FAILED; | error_ = FAILED; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -1170,13 +1170,13 @@ void DfGraphConvertor::AutoMonadSetControlInput(const AnfNodePtr &node) { | |||||
| void DfGraphConvertor::SetOpControlInput(const AnfNodePtr &node) { | void DfGraphConvertor::SetOpControlInput(const AnfNodePtr &node) { | ||||
| AutoMonadSetControlInput(node); | AutoMonadSetControlInput(node); | ||||
| if (control_depend_cache_.find(node.get()) == control_depend_cache_.end()) { | |||||
| if (control_edge_cache_.find(node.get()) == control_edge_cache_.end()) { | |||||
| return; | return; | ||||
| } | } | ||||
| std::vector<ControlEdge> control_edges = control_depend_cache_[node.get()]; | |||||
| std::vector<ControlEdge> control_edges = control_edge_cache_[node.get()]; | |||||
| if ((control_edges.empty())) { | if ((control_edges.empty())) { | ||||
| MS_LOG(ERROR) << "Get control depend node's src or dest operator failed"; | |||||
| MS_LOG(ERROR) << "Get control edge node's src or dest operator failed"; | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -1600,7 +1600,7 @@ std::vector<OperatorPtr> DfGraphConvertor::ConvertDependNode(const AnfNodePtr no | |||||
| for (size_t index = 1; index < node_inputs.size(); index++) { | for (size_t index = 1; index < node_inputs.size(); index++) { | ||||
| auto op = Convert(GetRealOpNode(node_inputs[index])); | auto op = Convert(GetRealOpNode(node_inputs[index])); | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "Convert control depend node to operator failed"; | |||||
| MS_LOG(ERROR) << "Convert real op node to operator failed"; | |||||
| error_ = FAILED; | error_ = FAILED; | ||||
| return std::vector<OperatorPtr>({}); | return std::vector<OperatorPtr>({}); | ||||
| } | } | ||||
| @@ -1611,194 +1611,13 @@ std::vector<OperatorPtr> DfGraphConvertor::ConvertDependNode(const AnfNodePtr no | |||||
| auto op = Convert(GetRealOpNode(node)); | auto op = Convert(GetRealOpNode(node)); | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "Convert control depend node to operator failed"; | |||||
| MS_LOG(ERROR) << "Convert real op node to operator failed"; | |||||
| error_ = FAILED; | error_ = FAILED; | ||||
| return std::vector<OperatorPtr>({}); | return std::vector<OperatorPtr>({}); | ||||
| } | } | ||||
| return std::vector<OperatorPtr>({op}); | return std::vector<OperatorPtr>({op}); | ||||
| } | } | ||||
| // get the anf node list for depend | |||||
| std::vector<AnfNodePtr> DfGraphConvertor::GetDependNodes(const AnfNodePtr &node) { | |||||
| std::vector<AnfNodePtr> nodes; | |||||
| // for make tuple, should control depend on the tuple items | |||||
| if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { | |||||
| auto node_inputs = node->cast<CNodePtr>()->inputs(); | |||||
| for (size_t index = 1; index < node_inputs.size(); index++) { | |||||
| nodes.push_back(GetRealOpNode(node_inputs[index])); | |||||
| } | |||||
| return nodes; | |||||
| } | |||||
| // for parameter ,find the apply that used the parameter as the control depended node | |||||
| if (node->isa<Parameter>()) { | |||||
| auto uses = node->func_graph()->manager()->node_users()[node]; | |||||
| for (auto &use : uses) { | |||||
| auto use_node = use.first; | |||||
| if ((use_node->isa<CNode>()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { | |||||
| nodes.push_back(GetRealOpNode(use_node)); | |||||
| } | |||||
| } | |||||
| return nodes; | |||||
| } | |||||
| nodes.push_back(GetRealOpNode(node)); | |||||
| return nodes; | |||||
| } | |||||
| void DfGraphConvertor::DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node) { | |||||
| #ifdef DRAW_GE_GRAPH | |||||
| auto src_depend_nodes = GetDependNodes(src_node); | |||||
| auto dst_depend_nodes = GetDependNodes(dest_node); | |||||
| if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() > 1) { | |||||
| for (auto &item : dst_depend_nodes) { | |||||
| compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[item.get()] | |||||
| << "[style=\"dotted\"]" << endl; | |||||
| } | |||||
| } else if (src_depend_nodes.size() > 1 && dst_depend_nodes.size() == 1) { | |||||
| for (auto &item : src_depend_nodes) { | |||||
| compute_sout_ << op_draw_name_[item.get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] | |||||
| << "[style=\"dotted\"]" << endl; | |||||
| } | |||||
| } else if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() == 1) { | |||||
| compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] | |||||
| << "[style=\"dotted\"]" << endl; | |||||
| } | |||||
| #endif | |||||
| } | |||||
| void DfGraphConvertor::GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, | |||||
| const AnfNodePtr &dest_node, | |||||
| const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list, | |||||
| const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list) { | |||||
| if (src_node->isa<Parameter>()) { | |||||
| auto uses = node->func_graph()->manager()->node_users()[src_node]; | |||||
| for (auto &use : uses) { | |||||
| auto use_node = use.first; | |||||
| if ((use_node->isa<CNode>()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && | |||||
| (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { | |||||
| auto converted_list = ConvertDependNode(use_node); | |||||
| src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (dest_node->isa<Parameter>()) { | |||||
| auto uses = node->func_graph()->manager()->node_users()[dest_node]; | |||||
| for (auto &use : uses) { | |||||
| auto use_node = use.first; | |||||
| if ((use_node->isa<CNode>()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && | |||||
| (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { | |||||
| auto converted_list = ConvertDependNode(use_node); | |||||
| dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| bool DfGraphConvertor::GetControlDependList(const CNodePtr &node, | |||||
| const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list, | |||||
| const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list) { | |||||
| const int CONTROL_DEPEND_INDEX = 0; | |||||
| const int SRC_NODE_INDEX = 1; | |||||
| const int DEST_NODE_INDEX = 2; | |||||
| const int DEPEND_MODE_NORMAL_USE = 0; | |||||
| const int DEPEND_MODE_ON_PARAMETER_USE = 1; | |||||
| auto node_inputs = node->inputs(); | |||||
| if (node_inputs.size() <= DEST_NODE_INDEX) { | |||||
| MS_LOG(WARNING) << "Control depend node input size error"; | |||||
| return false; | |||||
| } | |||||
| auto src_node = node_inputs[SRC_NODE_INDEX]; | |||||
| auto dest_node = node_inputs[DEST_NODE_INDEX]; | |||||
| if ((src_node == nullptr) || (dest_node == nullptr)) { | |||||
| MS_LOG(ERROR) << "Control depend node miss src or dest node"; | |||||
| error_ = FAILED; | |||||
| return false; | |||||
| } | |||||
| AnfNodePtr fn = node_inputs[CONTROL_DEPEND_INDEX]; | |||||
| PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(fn); | |||||
| ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); | |||||
| int depend_mode = DEPEND_MODE_NORMAL_USE; | |||||
| if (mode_ptr != nullptr) { | |||||
| auto mode_int = mode_ptr->cast<Int64ImmPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(mode_int); | |||||
| depend_mode = mode_int->value(); | |||||
| MS_LOG(DEBUG) << "depend_mode = " << depend_mode; | |||||
| } | |||||
| if (depend_mode == DEPEND_MODE_ON_PARAMETER_USE) { | |||||
| GetDependOnParameterUse(node, src_node, dest_node, src_ops_list, dst_ops_list); | |||||
| } | |||||
| if (src_node->isa<CNode>()) { | |||||
| auto converted_list = ConvertDependNode(src_node); | |||||
| src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); | |||||
| } | |||||
| if (dest_node->isa<CNode>()) { | |||||
| auto converted_list = ConvertDependNode(dest_node); | |||||
| dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); | |||||
| } | |||||
| if (src_ops_list->empty() || dst_ops_list->empty()) { | |||||
| MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it"; | |||||
| error_ = SUCCESS; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) { | |||||
| const int SRC_NODE_INDEX = 1; | |||||
| const int DEST_NODE_INDEX = 2; | |||||
| if (control_depend_cache_.find(node.get()) != control_depend_cache_.end()) { | |||||
| return; | |||||
| } | |||||
| auto node_inputs = node->inputs(); | |||||
| if (node_inputs.size() <= DEST_NODE_INDEX) { | |||||
| MS_LOG(WARNING) << "Control depend node input size error"; | |||||
| return; | |||||
| } | |||||
| auto src_node = node_inputs[SRC_NODE_INDEX]; | |||||
| auto dest_node = node_inputs[DEST_NODE_INDEX]; | |||||
| if ((src_node == nullptr) || (dest_node == nullptr)) { | |||||
| MS_LOG(ERROR) << "Control depend node miss src or dest node"; | |||||
| error_ = FAILED; | |||||
| return; | |||||
| } | |||||
| std::shared_ptr<std::vector<OperatorPtr>> src_ops_list = std::make_shared<std::vector<OperatorPtr>>(); | |||||
| std::shared_ptr<std::vector<OperatorPtr>> dst_ops_list = std::make_shared<std::vector<OperatorPtr>>(); | |||||
| if (!GetControlDependList(node, src_ops_list, dst_ops_list)) { | |||||
| MS_LOG(ERROR) << "Get depend list failed"; | |||||
| error_ = FAILED; | |||||
| return; | |||||
| } | |||||
| std::vector<ControlEdge> control_edges; | |||||
| if (src_ops_list->size() == 1 && dst_ops_list->size() > 1) { | |||||
| (void)std::transform(dst_ops_list->begin(), dst_ops_list->end(), std::back_inserter(control_edges), | |||||
| [src_ops_list](const OperatorPtr &op) -> ControlEdge { | |||||
| return {(*src_ops_list)[0], op}; | |||||
| }); | |||||
| } else if (src_ops_list->size() > 1 && dst_ops_list->size() == 1) { | |||||
| (void)std::transform(src_ops_list->begin(), src_ops_list->end(), std::back_inserter(control_edges), | |||||
| [dst_ops_list](const OperatorPtr &op) -> ControlEdge { | |||||
| return {op, (*dst_ops_list)[0]}; | |||||
| }); | |||||
| } else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) { | |||||
| control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]}); | |||||
| } else if (src_ops_list->empty() || dst_ops_list->empty()) { | |||||
| MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it"; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size() | |||||
| << " -> dst:" << dst_ops_list->size(); | |||||
| error_ = FAILED; | |||||
| return; | |||||
| } | |||||
| control_depend_cache_[node.get()] = control_edges; | |||||
| #ifdef DRAW_GE_GRAPH | |||||
| DrawControlDepend(src_node, dest_node); | |||||
| #endif | |||||
| } | |||||
| bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) { | bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) { | ||||
| // ignore apply node of return | // ignore apply node of return | ||||
| if (name == "" || name == prim::kPrimReturn->name() || name == prim::kPrimDepend->name() || | if (name == "" || name == prim::kPrimReturn->name() || name == prim::kPrimDepend->name() || | ||||
| @@ -1818,12 +1637,6 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) | |||||
| return false; | return false; | ||||
| } | } | ||||
| // ControlDepend | |||||
| if (name == prim::kPrimControlDepend->name()) { | |||||
| ConvertControlDependNode(node); | |||||
| return false; | |||||
| } | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -145,19 +145,11 @@ class DfGraphConvertor { | |||||
| OperatorPtr ConvertCNode(CNodePtr node); | OperatorPtr ConvertCNode(CNodePtr node); | ||||
| std::vector<OperatorPtr> ConvertDependNode(AnfNodePtr node); | std::vector<OperatorPtr> ConvertDependNode(AnfNodePtr node); | ||||
| AnfNodePtr GetRealOpNode(AnfNodePtr node); | AnfNodePtr GetRealOpNode(AnfNodePtr node); | ||||
| std::vector<AnfNodePtr> GetDependNodes(const AnfNodePtr &node); | |||||
| OperatorPtr ConvertParameter(AnfNodePtr node); | OperatorPtr ConvertParameter(AnfNodePtr node); | ||||
| Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); | Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); | ||||
| OperatorPtr ConvertValueNode(ValueNodePtr node); | OperatorPtr ConvertValueNode(ValueNodePtr node); | ||||
| void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node); | void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node); | ||||
| void ConvertTupleGetItem(const CNodePtr node); | void ConvertTupleGetItem(const CNodePtr node); | ||||
| void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, | |||||
| const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list, | |||||
| const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list); | |||||
| bool GetControlDependList(const CNodePtr &node, const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list, | |||||
| const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list); | |||||
| void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node); | |||||
| void ConvertControlDependNode(const CNodePtr node); | |||||
| void ConvertMakeTuple(const CNodePtr node); | void ConvertMakeTuple(const CNodePtr node); | ||||
| bool CheckCNode(const std::string &name, const CNodePtr node); | bool CheckCNode(const std::string &name, const CNodePtr node); | ||||
| void TraceOutput(AnfNodePtr node); | void TraceOutput(AnfNodePtr node); | ||||
| @@ -195,7 +187,7 @@ class DfGraphConvertor { | |||||
| std::shared_ptr<DfGraph> broadcast_graph_{nullptr}; | std::shared_ptr<DfGraph> broadcast_graph_{nullptr}; | ||||
| std::unordered_map<AnfNode *, DfGraph> branches_map_; | std::unordered_map<AnfNode *, DfGraph> branches_map_; | ||||
| std::unordered_map<AnfNode *, OperatorPtr> op_cache_; | std::unordered_map<AnfNode *, OperatorPtr> op_cache_; | ||||
| std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_depend_cache_; | |||||
| std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_edge_cache_; | |||||
| std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>> monad_control_edge_cache_; | std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>> monad_control_edge_cache_; | ||||
| /* record "tuple_getitem"<->"out_handler" mapping */ | /* record "tuple_getitem"<->"out_handler" mapping */ | ||||
| std::unordered_map<AnfNode *, OutHandler> out_handle_cache_; | std::unordered_map<AnfNode *, OutHandler> out_handle_cache_; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -51,88 +51,6 @@ std::string GetOtherTarget(const std::vector<AnfNodePtr> &nodes) { | |||||
| } | } | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node, | |||||
| std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(prior_node); | |||||
| MS_EXCEPTION_IF_NULL(behind_node); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto &node_users = manager->node_users(); | |||||
| if (prior_node->isa<Parameter>()) { | |||||
| for (auto &user : node_users[prior_node]) { | |||||
| auto cnode = user.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||||
| prior_nodes->emplace_back(cnode); | |||||
| } | |||||
| } | |||||
| } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) { | |||||
| prior_nodes->emplace_back(prior_node); | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| if (behind_node->isa<Parameter>()) { | |||||
| for (auto &user : node_users[behind_node]) { | |||||
| auto cnode = user.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||||
| depend_nodes->emplace_back(cnode); | |||||
| } | |||||
| } | |||||
| } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) { | |||||
| depend_nodes->emplace_back(behind_node); | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges, | |||||
| std::map<AnfNodePtr, size_t> *nodes_ref) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto input_cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(input_cnode); | |||||
| auto prior_node = input_cnode->input(kControlDependPriorIndex); | |||||
| auto depend_node = input_cnode->input(kControlDependBehindIndex); | |||||
| MS_EXCEPTION_IF_NULL(prior_node); | |||||
| MS_EXCEPTION_IF_NULL(depend_node); | |||||
| auto prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0)); | |||||
| MS_EXCEPTION_IF_NULL(prim_ptr); | |||||
| ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); | |||||
| int64_t depend_mode = 0; | |||||
| if (mode_ptr != nullptr) { | |||||
| depend_mode = GetValue<int64_t>(mode_ptr); | |||||
| } | |||||
| if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) { | |||||
| return; | |||||
| } | |||||
| std::vector<AnfNodePtr> prior_nodes; | |||||
| std::vector<AnfNodePtr> behind_nodes; | |||||
| if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) { | |||||
| return; | |||||
| } | |||||
| for (auto &first_node : prior_nodes) { | |||||
| for (auto &second_node : behind_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(first_node); | |||||
| MS_EXCEPTION_IF_NULL(second_node); | |||||
| auto iter = control_edges->find(second_node); | |||||
| if (iter == control_edges->end()) { | |||||
| (void)control_edges->insert( | |||||
| std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node})); | |||||
| } else { | |||||
| iter->second.emplace_back(first_node); | |||||
| } | |||||
| auto ref_iter = nodes_ref->find(first_node); | |||||
| if (ref_iter != nodes_ref->end()) { | |||||
| ref_iter->second++; | |||||
| } else { | |||||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1)); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref, | void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref, | ||||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) { | std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) { | ||||
| @@ -149,9 +67,6 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| for (auto &input : cnode->inputs()) { | for (auto &input : cnode->inputs()) { | ||||
| if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) { | |||||
| AddControlEdge(graph, input, control_edges, nodes_ref); | |||||
| } | |||||
| auto iter = nodes_ref->find(input); | auto iter = nodes_ref->find(input); | ||||
| if (iter != nodes_ref->end()) { | if (iter != nodes_ref->end()) { | ||||
| iter->second++; | iter->second++; | ||||
| @@ -479,11 +394,9 @@ void AddSegmentDependency(const FuncGraphPtr &graph, const std::string &default_ | |||||
| node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | ||||
| } | } | ||||
| GraphSegmentPtr node_segment{nullptr}; | GraphSegmentPtr node_segment{nullptr}; | ||||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||||
| auto node_iter = node_to_segment.find(node); | |||||
| if (node_iter != node_to_segment.end()) { | |||||
| node_segment = node_iter->second; | |||||
| } | |||||
| auto node_iter = node_to_segment.find(node); | |||||
| if (node_iter != node_to_segment.end()) { | |||||
| node_segment = node_iter->second; | |||||
| } | } | ||||
| for (auto &input : node_inputs) { | for (auto &input : node_inputs) { | ||||
| if (node_segment != nullptr && !node_segment->is_cut_ && input->isa<CNode>()) { | if (node_segment != nullptr && !node_segment->is_cut_ && input->isa<CNode>()) { | ||||
| @@ -615,18 +528,14 @@ void SplitDynamicNodeSegment(const std::vector<AnfNodePtr> &segment_nodes, std:: | |||||
| std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment, | std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment, | ||||
| const std::set<AnfNodePtr> &dynamic_nodes_set) { | const std::set<AnfNodePtr> &dynamic_nodes_set) { | ||||
| SplitDynamicNodesHelper helper; | SplitDynamicNodesHelper helper; | ||||
| bool is_last_node_dynamic = false; | |||||
| for (auto &node : segment_nodes) { | for (auto &node : segment_nodes) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||||
| helper.AddNode(node, is_last_node_dynamic); | |||||
| continue; | |||||
| } | |||||
| auto &inputs = cnode->inputs(); | auto &inputs = cnode->inputs(); | ||||
| bool has_dynamic_shape = dynamic_nodes_set.find(node) != dynamic_nodes_set.end(); | bool has_dynamic_shape = dynamic_nodes_set.find(node) != dynamic_nodes_set.end(); | ||||
| bool depend_common_node = false; | bool depend_common_node = false; | ||||
| bool depend_dynamic_node = false; | bool depend_dynamic_node = false; | ||||
| bool is_last_node_dynamic = false; | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | for (size_t i = 1; i < inputs.size(); ++i) { | ||||
| if (dynamic_nodes_set.find(inputs[i]) != dynamic_nodes_set.end()) { | if (dynamic_nodes_set.find(inputs[i]) != dynamic_nodes_set.end()) { | ||||
| has_dynamic_shape = true; | has_dynamic_shape = true; | ||||
| @@ -1,7 +1,7 @@ | |||||
| /** | /** | ||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | ||||
| * | * | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -87,26 +87,7 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo | |||||
| if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) { | if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) { | ||||
| eqv[node] = node; | eqv[node] = node; | ||||
| } else if (eqv.find(node) == eqv.end()) { | } else if (eqv.find(node) == eqv.end()) { | ||||
| if (IsPrimitiveCNode(node, prim::kPrimControlDepend)) { | |||||
| eqv[node] = NewValueNode(MakeValue(0)); | |||||
| return eqv[node]; | |||||
| } | |||||
| bool ignore_make_tuple = false; | |||||
| if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { | |||||
| ignore_make_tuple = true; | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| const auto &node_inputs = cnode->inputs(); | |||||
| for (size_t i = 1; i < node_inputs.size(); ++i) { | |||||
| if (!IsPrimitiveCNode(node_inputs[i], prim::kPrimControlDepend)) { | |||||
| ignore_make_tuple = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!ignore_make_tuple) { | |||||
| inputs.push_back(node); | |||||
| } | |||||
| inputs.push_back(node); | |||||
| eqv[node] = fg->add_parameter(); | eqv[node] = fg->add_parameter(); | ||||
| eqv[node]->set_abstract(node->abstract()); | eqv[node]->set_abstract(node->abstract()); | ||||
| eqv[node]->set_kernel_info(node->kernel_info_ptr()); | eqv[node]->set_kernel_info(node->kernel_info_ptr()); | ||||
| @@ -148,14 +129,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||||
| for (size_t i = 2; i < inps.size(); ++i) { | for (size_t i = 2; i < inps.size(); ++i) { | ||||
| args.emplace_back(NewValueNode(MakeValue(0))); | args.emplace_back(NewValueNode(MakeValue(0))); | ||||
| } | } | ||||
| } else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) { | |||||
| for (size_t i = 1; i < inps.size(); ++i) { | |||||
| if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) { | |||||
| args.emplace_back(NewValueNode(MakeValue(static_cast<int>(i)))); | |||||
| } else { | |||||
| args.emplace_back(RefSubGraphNode(fg, inps[i], &inputs, &eqv)); | |||||
| } | |||||
| } | |||||
| } else { | } else { | ||||
| (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), | (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), | ||||
| [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); | [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); | ||||
| @@ -182,8 +182,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -188,23 +188,6 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP | |||||
| return args_spec_list[0]->Broaden(); | return args_spec_list[0]->Broaden(); | ||||
| } | } | ||||
| AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // args: Two objects of a subclass of AbstractBase | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | |||||
| auto arg_src = args_spec_list[0]; | |||||
| auto arg_dst = args_spec_list[1]; | |||||
| // control depend can not setup tuple of ops to tuple of ops dependency relation | |||||
| if (arg_src->isa<AbstractTuple>() && arg_dst->isa<AbstractTuple>()) { | |||||
| auto src_size = arg_src->cast<AbstractTuplePtr>()->size(); | |||||
| auto dst_size = arg_src->cast<AbstractTuplePtr>()->size(); | |||||
| if (src_size > 1 && dst_size > 1) { | |||||
| MS_LOG(EXCEPTION) << "Control depend can not setup operator dependency relationship from tuple from tuple"; | |||||
| } | |||||
| } | |||||
| return std::make_shared<AbstractScalar>(kAnyValue, kBool); | |||||
| } | |||||
| AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: two tensors and a tuple. | // Inputs: two tensors and a tuple. | ||||
| @@ -149,7 +149,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}}, | {prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}}, | ||||
| {prim::kPrimDepend, {InferImplDepend, nullptr, true}}, | {prim::kPrimDepend, {InferImplDepend, nullptr, true}}, | ||||
| {prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}}, | {prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}}, | ||||
| {prim::kPrimControlDepend, {InferImplControlDepend, nullptr, true}}, | |||||
| // Debug | // Debug | ||||
| {prim::kPrimDebug, {InferImplDebug, nullptr, true}}, | {prim::kPrimDebug, {InferImplDebug, nullptr, true}}, | ||||
| // Dynamic shape testing | // Dynamic shape testing | ||||
| @@ -453,7 +453,6 @@ inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookB | |||||
| inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | ||||
| inline const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); | inline const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); | ||||
| inline const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print"); | inline const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print"); | ||||
| inline const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend"); | |||||
| inline const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); | inline const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); | ||||
| inline const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); | inline const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); | ||||
| inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | ||||
| @@ -1,7 +1,7 @@ | |||||
| /** | /** | ||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | ||||
| * | * | ||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -399,8 +399,7 @@ std::string GetAttrTarget(const PrimitivePtr &primitive, const ValuePtr &att_tar | |||||
| if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || | if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || | ||||
| IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || | IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || | ||||
| IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || | IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || | ||||
| IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || | |||||
| IsPrimitive(attr_input, prim::kPrimPartial)) { | |||||
| IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) { | |||||
| primitive->EraseAttr(primitive_target); | primitive->EraseAttr(primitive_target); | ||||
| return default_target; | return default_target; | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -23,7 +23,6 @@ | |||||
| #include "ir/dtype/tensor_type.h" | #include "ir/dtype/tensor_type.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| #include "abstract/primitive_infer_map.h" | #include "abstract/primitive_infer_map.h" | ||||
| #include "ops/control_depend.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -23,7 +23,6 @@ | |||||
| #include "ops/expand_dims.h" | #include "ops/expand_dims.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| #include "abstract/primitive_infer_map.h" | #include "abstract/primitive_infer_map.h" | ||||
| #include "ops/control_depend.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -30,7 +30,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, | |||||
| "identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key", | "identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key", | ||||
| "get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary", | "get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary", | ||||
| "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", | "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", | ||||
| "InvertPermutation", "ControlDepend", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", | |||||
| "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", | |||||
| "stop_gradient", "Send", "UpdateState", "Load"}; | "stop_gradient", "Send", "UpdateState", "Load"}; | ||||
| // clang-format on | // clang-format on | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -27,7 +27,6 @@ | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "mindspore/core/ir/primitive.h" | #include "mindspore/core/ir/primitive.h" | ||||
| #include "ops/fusion/partial_fusion.h" | #include "ops/fusion/partial_fusion.h" | ||||
| #include "ops/control_depend.h" | |||||
| #include "ops/depend.h" | #include "ops/depend.h" | ||||
| #include "ops/make_tuple.h" | #include "ops/make_tuple.h" | ||||
| #include "ops/quant_dtype_cast.h" | #include "ops/quant_dtype_cast.h" | ||||
| @@ -213,8 +212,7 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { | |||||
| MS_LOG(ERROR) << "value node is invalid."; | MS_LOG(ERROR) << "value node is invalid."; | ||||
| return; | return; | ||||
| } | } | ||||
| if (value_node->value() != nullptr && (opt::CheckPrimitiveType(depend_node, prim::kPrimDepend) || | |||||
| opt::CheckPrimitiveType(depend_node, prim::kPrimControlDepend))) { | |||||
| if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) { | |||||
| has_depend = true; | has_depend = true; | ||||
| bool mask_out = (depend_node->inputs().size() == 3); | bool mask_out = (depend_node->inputs().size() == 3); | ||||
| for (size_t j = 1; j < depend_node->inputs().size(); ++j) { | for (size_t j = 1; j < depend_node->inputs().size(); ++j) { | ||||
| @@ -466,8 +464,8 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc | |||||
| } | } | ||||
| RemoveIfDepend(cnode); | RemoveIfDepend(cnode); | ||||
| if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend || | |||||
| prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) { | |||||
| if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameTupleGetItem || | |||||
| prim->name() == mindspore::ops::kNameMakeTuple) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (prim->name() == "make_tuple") { | if (prim->name() == "make_tuple") { | ||||
| @@ -57,8 +57,8 @@ bool IsRealKernel(const AnfNodePtr &node) { | |||||
| IsPrimitive(input, prim::kPrimTensorSummary) || | IsPrimitive(input, prim::kPrimTensorSummary) || | ||||
| IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || | IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || | ||||
| IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || | IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || | ||||
| IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || | |||||
| IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); | |||||
| IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimReturn) || | |||||
| IsPrimitive(input, prim::kPrimPartial); | |||||
| return !is_virtual_node; | return !is_virtual_node; | ||||
| } | } | ||||
| @@ -43,8 +43,8 @@ tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) { | |||||
| bool IsSpecialType(const CNodePtr &cnode) { | bool IsSpecialType(const CNodePtr &cnode) { | ||||
| if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || | if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || | ||||
| CheckPrimitiveType(cnode, prim::kPrimControlDepend) || CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || | |||||
| CheckPrimitiveType(cnode, prim::kPrimReturn) || CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) || | |||||
| CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimReturn) || | |||||
| CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) || | |||||
| CheckPrimitiveType(cnode, std::make_shared<Primitive>("If"))) { | CheckPrimitiveType(cnode, std::make_shared<Primitive>("If"))) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -58,13 +58,6 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph | |||||
| return lite::RET_NO_CHANGE; | return lite::RET_NO_CHANGE; | ||||
| } | } | ||||
| } | } | ||||
| if (CheckPrimitiveType(anf_node, prim::kPrimControlDepend)) { | |||||
| if (cnode->size() != InputDoubleNum) { | |||||
| MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; | |||||
| remove_cnode_.insert(anf_node); | |||||
| return lite::RET_NO_CHANGE; | |||||
| } | |||||
| } | |||||
| bool replace_succ = manager->Replace(anf_node, cnode->input(1)); | bool replace_succ = manager->Replace(anf_node, cnode->input(1)); | ||||
| if (!replace_succ) { | if (!replace_succ) { | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -96,7 +96,6 @@ class GPT2FinetuneCell(nn.Cell): | |||||
| self.get_status = P.NPUGetFloatStatus() | self.get_status = P.NPUGetFloatStatus() | ||||
| self.clear_before_grad = P.NPUClearFloatStatus() | self.clear_before_grad = P.NPUClearFloatStatus() | ||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | self.reduce_sum = P.ReduceSum(keep_dims=False) | ||||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||||
| self.base = Tensor(1, mstype.float32) | self.base = Tensor(1, mstype.float32) | ||||
| self.less_equal = P.LessEqual() | self.less_equal = P.LessEqual() | ||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| @@ -132,8 +131,8 @@ class GPT2FinetuneCell(nn.Cell): | |||||
| if not self.gpu_target: | if not self.gpu_target: | ||||
| init = self.alloc_status() | init = self.alloc_status() | ||||
| init = F.depend(init, loss) | |||||
| clear_before_grad = self.clear_before_grad(init) | clear_before_grad = self.clear_before_grad(init) | ||||
| F.control_depend(loss, init) | |||||
| self.depend_parameter_use(clear_before_grad, scaling_sens) | self.depend_parameter_use(clear_before_grad, scaling_sens) | ||||
| grads = self.grad(self.network, weights)(input_ids, | grads = self.grad(self.network, weights)(input_ids, | ||||
| input_mask, | input_mask, | ||||
| @@ -145,10 +144,10 @@ class GPT2FinetuneCell(nn.Cell): | |||||
| if self.reducer_flag: | if self.reducer_flag: | ||||
| grads = self.grad_reducer(grads) | grads = self.grad_reducer(grads) | ||||
| if not self.gpu_target: | if not self.gpu_target: | ||||
| init = F.depend(init, grads) | |||||
| flag = self.get_status(init) | flag = self.get_status(init) | ||||
| init = F.depend(init, flag) | |||||
| flag_sum = self.reduce_sum(init, (0,)) | flag_sum = self.reduce_sum(init, (0,)) | ||||
| F.control_depend(grads, flag) | |||||
| F.control_depend(flag, flag_sum) | |||||
| else: | else: | ||||
| flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) | flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) | ||||
| flag_sum = self.addn(flag_sum) | flag_sum = self.addn(flag_sum) | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -74,9 +74,9 @@ TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence_with_make_tupl | |||||
| /* | /* | ||||
| * def before(x, y, a, b): | * def before(x, y, a, b): | ||||
| * z = make_tuple(TransData(a), TransData(b)) | * z = make_tuple(TransData(a), TransData(b)) | ||||
| * depend_intput = control_depend(y, z) | |||||
| * sum = add(x, depend_intput) | |||||
| * return sum | |||||
| * depend_intput = depend(y, z) | |||||
| * sum_add = add(x, depend_intput) | |||||
| * return sum_add | |||||
| */ | */ | ||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence_with_make_tuple", "before"); | FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence_with_make_tuple", "before"); | ||||
| @@ -93,11 +93,11 @@ TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence_with_make_tupl | |||||
| TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence) { | TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence) { | ||||
| /* | /* | ||||
| * def before(x, y, a, b): | |||||
| * z = make_tuple(TransData(a), TransData(b)) | |||||
| * depend_intput = control_depend(y, z) | |||||
| * sum = add(x, depend_intput) | |||||
| * return sum | |||||
| * def before(x, y, z): | |||||
| * new_z = TransData(z) | |||||
| * depend_intput = depend(y, new_z) | |||||
| * sum_add = add(x, depend_intput) | |||||
| * return sum_add | |||||
| */ | */ | ||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence", "before"); | FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence", "before"); | ||||