diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 5e39cd4096..9e04c26eec 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -824,7 +824,7 @@ CNodePtr GetJUser(const NodeUsersMap &node_user_map, const CNodePtr &cnode, int static std::vector> FindPrimalJPair(const FuncGraphManagerPtr &manager, const FuncGraphPtr &primal_graph) { std::vector> primal_j_pair; - std::map primal_users_map; + std::map> primal_users_map; const auto &node_user_map = manager->node_users(); // Search primal graph user cnodes. for (auto &entry : primal_graph->func_graph_cnodes_index()) { @@ -834,12 +834,12 @@ static std::vector> FindPrimalJPair(const FuncGrap // To find real calling. auto fg = cnode->func_graph(); MS_EXCEPTION_IF_NULL(fg); - if (primal_users_map.find(fg) != primal_users_map.end()) { - MS_LOG(WARNING) << "It is recommended to call the forward network only once. Func graph: " << fg->ToString() - << ", cnode: " << cnode->DebugString() << ", trace: " << trace::DumpSourceLines(cnode); + auto iter = primal_users_map.find(fg); + if (iter != primal_users_map.end()) { + iter->second.second++; continue; } - primal_users_map[fg] = cnode; + primal_users_map[fg] = std::make_pair(cnode, 1); } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) { // To find J user. auto j_user = GetJUser(node_user_map, cnode, index); @@ -856,13 +856,22 @@ static std::vector> FindPrimalJPair(const FuncGrap << ", J user: " << j_user->DebugString(); continue; } + + auto primal_count_pair = iter->second; // Check input size. - auto primal = iter->second; + auto primal = primal_count_pair.first; if (primal->size() != j_user->size()) { MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal->DebugString() << " is " << primal->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size(); continue; } + if (primal_count_pair.second != 1) { + MS_LOG(WARNING) << "It is recommended to call the forward network only once."; + MS_LOG(INFO) << "There is more than one primal call for J operation in the same graph. Func graph: " + << graph->ToString() << ", primal call: " << primal->DebugString() + << ", J user: " << j_user->DebugString() << ", trace: " << trace::DumpSourceLines(primal); + continue; + } primal_user = primal; MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()