| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -822,10 +822,10 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| static std::pair<CNodePtr, CNodePtr> FindPrimalJPair(const FuncGraphManagerPtr &manager, | |||||
| const FuncGraphPtr &primal_graph) { | |||||
| CNodePtr primal_user = nullptr; | |||||
| CNodePtr j_user = nullptr; | |||||
| static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager, | |||||
| const FuncGraphPtr &primal_graph) { | |||||
| std::vector<std::pair<CNodePtr, CNodePtr>> primal_j_pair; | |||||
| std::map<FuncGraphPtr, CNodePtr> primal_users_map; | |||||
| auto &node_user_map = manager->node_users(); | auto &node_user_map = manager->node_users(); | ||||
| // Search primal graph user cnodes. | // Search primal graph user cnodes. | ||||
| for (auto &entry : primal_graph->func_graph_cnodes_index()) { | for (auto &entry : primal_graph->func_graph_cnodes_index()) { | ||||
| @@ -833,7 +833,13 @@ static std::pair<CNodePtr, CNodePtr> FindPrimalJPair(const FuncGraphManagerPtr & | |||||
| auto index = entry.first->second; | auto index = entry.first->second; | ||||
| if (index == 0) { | if (index == 0) { | ||||
| // To find real calling. | // To find real calling. | ||||
| primal_user = cnode; | |||||
| auto fg = cnode->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| if (primal_users_map.find(fg) != primal_users_map.end()) { | |||||
| MS_LOG(EXCEPTION) << "The forward network is only allowed to be called once. Func graph: " << fg->ToString() | |||||
| << ", cnode: " << cnode->DebugString() << ", trace: " << trace::DumpSourceLines(cnode); | |||||
| } | |||||
| primal_users_map[fg] = cnode; | |||||
| } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) { | } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) { | ||||
| // To find J user. | // To find J user. | ||||
| auto it = node_user_map.find(cnode); | auto it = node_user_map.find(cnode); | ||||
| @@ -845,13 +851,33 @@ static std::pair<CNodePtr, CNodePtr> FindPrimalJPair(const FuncGraphManagerPtr & | |||||
| if (size != 1) { | if (size != 1) { | ||||
| MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}"; | MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}"; | ||||
| } | } | ||||
| j_user = j_users.begin()->first->cast<CNodePtr>(); | |||||
| auto j_user = j_users.begin()->first->cast<CNodePtr>(); | |||||
| primal_j_pair.push_back({nullptr, j_user}); | |||||
| } | } | ||||
| if (j_user != nullptr && primal_user != nullptr) { | |||||
| break; | |||||
| } | |||||
| for (auto &[primal_user, j_user] : primal_j_pair) { | |||||
| // Check if J operation has relevant primal call in the same graph | |||||
| auto graph = j_user->func_graph(); | |||||
| auto iter = primal_users_map.find(graph); | |||||
| if (iter == primal_users_map.end()) { | |||||
| MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString() | |||||
| << ", J user: " << j_user->DebugString(); | |||||
| continue; | |||||
| } | } | ||||
| // Check input size. | |||||
| auto primal = iter->second; | |||||
| if (primal->size() != j_user->size()) { | |||||
| MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal->DebugString() << " is " | |||||
| << primal->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size(); | |||||
| continue; | |||||
| } | |||||
| primal_user = primal; | |||||
| MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString() | |||||
| << " and J user is: " << j_user->DebugString(); | |||||
| } | } | ||||
| return {primal_user, j_user}; | |||||
| return primal_j_pair; | |||||
| } | } | ||||
| static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) { | static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) { | ||||
| @@ -915,40 +941,37 @@ void DFunctor::EliminatePrimalGraph() { | |||||
| // Find primal user and paired J user cnodes. | // Find primal user and paired J user cnodes. | ||||
| auto manager = primal_graph_->manager(); | auto manager = primal_graph_->manager(); | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| auto [primal_user, j_user] = FindPrimalJPair(manager, primal_graph_); | |||||
| if (primal_user == nullptr || j_user == nullptr) { | |||||
| // Skip if one of them not found. | |||||
| return; | |||||
| } | |||||
| // Check input size. | |||||
| if (primal_user->size() != j_user->size()) { | |||||
| MS_LOG(WARNING) << "Input size incorrect, primal:" << primal_user->DebugString() | |||||
| << " juser:" << j_user->DebugString(); | |||||
| return; | |||||
| auto prim_j_pair = FindPrimalJPair(manager, primal_graph_); | |||||
| for (auto &[primal_user, j_user] : prim_j_pair) { | |||||
| if (primal_user == nullptr || j_user == nullptr) { | |||||
| // Skip if one of them not found. | |||||
| return; | |||||
| } | |||||
| // Replace primal graph with k graph. | |||||
| auto k_vnode = NewValueNode(k_graph_); | |||||
| auto primal_abs = primal_user->abstract(); | |||||
| primal_user->set_input(0, k_vnode); | |||||
| primal_user->set_abstract(j_user->abstract()); | |||||
| // If both inputs are same except monads, we copy primal monad args to k graph | |||||
| // so that they can be combined in CSE (common subexpression elimination) pass. | |||||
| const bool has_monad = CopyMonadArguments(primal_user, j_user); | |||||
| // Remove the UpdateState nodes after primal_user if need. | |||||
| if (has_monad) { | |||||
| RemovePrimalUpdateStates(manager, primal_user); | |||||
| } | |||||
| // Insert tuple_getitem after primal user cnode. | |||||
| auto construct_wrapper = primal_user->func_graph(); | |||||
| auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem); | |||||
| auto imm0 = std::make_shared<Int64Imm>(0); | |||||
| auto idx0 = NewValueNode(SizeToLong(0)); | |||||
| idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0)); | |||||
| auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0}); | |||||
| getitem0->set_abstract(primal_abs); | |||||
| manager->Replace(primal_user, getitem0); | |||||
| } | } | ||||
| // Replace primal graph with k graph. | |||||
| auto k_vnode = NewValueNode(k_graph_); | |||||
| auto primal_abs = primal_user->abstract(); | |||||
| primal_user->set_input(0, k_vnode); | |||||
| primal_user->set_abstract(j_user->abstract()); | |||||
| // If both inputs are same except monads, we copy primal monad args to k graph | |||||
| // so that they can be combined in CSE (common subexpression elimination) pass. | |||||
| const bool has_monad = CopyMonadArguments(primal_user, j_user); | |||||
| // Remove the UpdateState nodes after primal_user if need. | |||||
| if (has_monad) { | |||||
| RemovePrimalUpdateStates(manager, primal_user); | |||||
| } | |||||
| // Insert tuple_getitem after primal user cnode. | |||||
| auto construct_wrapper = primal_user->func_graph(); | |||||
| auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem); | |||||
| auto imm0 = std::make_shared<Int64Imm>(0); | |||||
| auto idx0 = NewValueNode(SizeToLong(0)); | |||||
| idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0)); | |||||
| auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0}); | |||||
| getitem0->set_abstract(primal_abs); | |||||
| manager->Replace(primal_user, getitem0); | |||||
| } | } | ||||
| } // namespace ad | } // namespace ad | ||||
| } // namespace mindspore | } // namespace mindspore | ||||