diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 6e67f3f9a8..9a2863d3de 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -822,10 +822,10 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) { return true; } -static std::pair FindPrimalJPair(const FuncGraphManagerPtr &manager, - const FuncGraphPtr &primal_graph) { - CNodePtr primal_user = nullptr; - CNodePtr j_user = nullptr; +static std::vector> FindPrimalJPair(const FuncGraphManagerPtr &manager, + const FuncGraphPtr &primal_graph) { + std::vector> primal_j_pair; + std::map primal_users_map; auto &node_user_map = manager->node_users(); // Search primal graph user cnodes. for (auto &entry : primal_graph->func_graph_cnodes_index()) { @@ -833,7 +833,13 @@ static std::pair FindPrimalJPair(const FuncGraphManagerPtr & auto index = entry.first->second; if (index == 0) { // To find real calling. - primal_user = cnode; + auto fg = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + if (primal_users_map.find(fg) != primal_users_map.end()) { + MS_LOG(EXCEPTION) << "The forward network is only allowed to be called once. Func graph: " << fg->ToString() + << ", cnode: " << cnode->DebugString() << ", trace: " << trace::DumpSourceLines(cnode); + } + primal_users_map[fg] = cnode; } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) { // To find J user. auto it = node_user_map.find(cnode); @@ -845,13 +851,33 @@ static std::pair FindPrimalJPair(const FuncGraphManagerPtr & if (size != 1) { MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}"; } - j_user = j_users.begin()->first->cast(); + auto j_user = j_users.begin()->first->cast(); + primal_j_pair.push_back({nullptr, j_user}); } - if (j_user != nullptr && primal_user != nullptr) { - break; + } + + for (auto &[primal_user, j_user] : primal_j_pair) { + // Check if J operation has relevant primal call in the same graph + auto graph = j_user->func_graph(); + auto iter = primal_users_map.find(graph); + if (iter == primal_users_map.end()) { + MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString() + << ", J user: " << j_user->DebugString(); + continue; } + // Check input size. + auto primal = iter->second; + if (primal->size() != j_user->size()) { + MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal->DebugString() << " is " + << primal->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size(); + continue; + } + + primal_user = primal; + MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString() + << " and J user is: " << j_user->DebugString(); } - return {primal_user, j_user}; + return primal_j_pair; } static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) { @@ -915,40 +941,37 @@ void DFunctor::EliminatePrimalGraph() { // Find primal user and paired J user cnodes. auto manager = primal_graph_->manager(); MS_EXCEPTION_IF_NULL(manager); - auto [primal_user, j_user] = FindPrimalJPair(manager, primal_graph_); - if (primal_user == nullptr || j_user == nullptr) { - // Skip if one of them not found. - return; - } - // Check input size. - if (primal_user->size() != j_user->size()) { - MS_LOG(WARNING) << "Input size incorrect, primal:" << primal_user->DebugString() - << " juser:" << j_user->DebugString(); - return; + auto prim_j_pair = FindPrimalJPair(manager, primal_graph_); + for (auto &[primal_user, j_user] : prim_j_pair) { + if (primal_user == nullptr || j_user == nullptr) { + // Skip if one of them not found. + return; + } + + // Replace primal graph with k graph. + auto k_vnode = NewValueNode(k_graph_); + auto primal_abs = primal_user->abstract(); + primal_user->set_input(0, k_vnode); + primal_user->set_abstract(j_user->abstract()); + + // If both inputs are same except monads, we copy primal monad args to k graph + // so that they can be combined in CSE (common subexpression elimination) pass. + const bool has_monad = CopyMonadArguments(primal_user, j_user); + // Remove the UpdateState nodes after primal_user if need. + if (has_monad) { + RemovePrimalUpdateStates(manager, primal_user); + } + + // Insert tuple_getitem after primal user cnode. + auto construct_wrapper = primal_user->func_graph(); + auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem); + auto imm0 = std::make_shared(0); + auto idx0 = NewValueNode(SizeToLong(0)); + idx0->set_abstract(std::make_shared(imm0)); + auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0}); + getitem0->set_abstract(primal_abs); + manager->Replace(primal_user, getitem0); } - // Replace primal graph with k graph. - auto k_vnode = NewValueNode(k_graph_); - auto primal_abs = primal_user->abstract(); - primal_user->set_input(0, k_vnode); - primal_user->set_abstract(j_user->abstract()); - - // If both inputs are same except monads, we copy primal monad args to k graph - // so that they can be combined in CSE (common subexpression elimination) pass. - const bool has_monad = CopyMonadArguments(primal_user, j_user); - // Remove the UpdateState nodes after primal_user if need. - if (has_monad) { - RemovePrimalUpdateStates(manager, primal_user); - } - - // Insert tuple_getitem after primal user cnode. - auto construct_wrapper = primal_user->func_graph(); - auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem); - auto imm0 = std::make_shared(0); - auto idx0 = NewValueNode(SizeToLong(0)); - idx0->set_abstract(std::make_shared(imm0)); - auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0}); - getitem0->set_abstract(primal_abs); - manager->Replace(primal_user, getitem0); } } // namespace ad } // namespace mindspore