diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 86dc958bb8..a1370a4097 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -819,11 +819,58 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) { return true; } +CNodePtr GetJUser(const NodeUsersMap &node_user_map, const CNodePtr &cnode, int index) { + auto it = node_user_map.find(cnode); + if (it == node_user_map.end()) { + MS_LOG(EXCEPTION) << "J CNode not used {" << cnode->DebugString(2) << "/" << index << "}"; + } + auto &j_users = it->second; + auto size = j_users.size(); + if (size != 1) { + MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}"; + } + return j_users.begin()->first->cast(); +} + +CNodePtr GetPrimalUser(const CNodePtr &j_user, const std::map> &primal_map) { + // Check if J operation has relevant primal call in the same graph. + auto graph = j_user->func_graph(); + auto iter = primal_map.find(graph); + if (iter == primal_map.end()) { + MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString() + << ", J user: " << j_user->DebugString(); + return nullptr; + } + + // Check if there is only one primal call corresponding to the specified j user. + auto primal_users = iter->second; + if (primal_users.size() != 1) { + MS_LOG(WARNING) << "It is recommended to call the forward network only once."; + MS_LOG(INFO) << "There is " << primal_users.size() + << " primal calls for same J operation in the same graph. Func graph: " << graph->ToString() + << ", J operation: " << j_user->DebugString() << ", Primal call: "; + size_t count = 0; + for (const auto &user : primal_users) { + MS_LOG(INFO) << "[ " << ++count << " ] : " << user->DebugString(2) << ", trace: " << trace::DumpSourceLines(user); + } + return nullptr; + } + + // Check input size. + auto primal_user = primal_users[0]; + if (primal_user->size() != j_user->size()) { + MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal_user->DebugString() << " is " + << primal_user->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size(); + return nullptr; + } + return primal_user; +} + static std::vector> FindPrimalJPair(const FuncGraphManagerPtr &manager, const FuncGraphPtr &primal_graph) { std::vector> primal_j_pair; - std::map primal_users_map; - auto &node_user_map = manager->node_users(); + std::map> primal_map; + const auto &node_user_map = manager->node_users(); // Search primal graph user cnodes. for (auto &entry : primal_graph->func_graph_cnodes_index()) { auto cnode = entry.first->first->cast(); @@ -832,47 +879,26 @@ static std::vector> FindPrimalJPair(const FuncGrap // To find real calling. auto fg = cnode->func_graph(); MS_EXCEPTION_IF_NULL(fg); - if (primal_users_map.find(fg) != primal_users_map.end()) { - MS_LOG(EXCEPTION) << "The forward network is only allowed to be called once. Func graph: " << fg->ToString() - << ", cnode: " << cnode->DebugString() << ", trace: " << trace::DumpSourceLines(cnode); + auto iter = primal_map.find(fg); + if (iter != primal_map.end()) { + iter->second.push_back(cnode); + continue; } - primal_users_map[fg] = cnode; + primal_map[fg] = {cnode}; } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) { // To find J user. - auto it = node_user_map.find(cnode); - if (it == node_user_map.end()) { - MS_LOG(EXCEPTION) << "J CNode not used {" << cnode->DebugString(2) << "/" << index << "}"; - } - auto &j_users = it->second; - auto size = j_users.size(); - if (size != 1) { - MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}"; - } - auto j_user = j_users.begin()->first->cast(); + auto j_user = GetJUser(node_user_map, cnode, index); primal_j_pair.push_back({nullptr, j_user}); } } for (auto &[primal_user, j_user] : primal_j_pair) { - // Check if J operation has relevant primal call in the same graph - auto graph = j_user->func_graph(); - auto iter = primal_users_map.find(graph); - if (iter == primal_users_map.end()) { - MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString() - << ", J user: " << j_user->DebugString(); - continue; - } - // Check input size. - auto primal = iter->second; - if (primal->size() != j_user->size()) { - MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal->DebugString() << " is " - << primal->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size(); - continue; + auto primal = GetPrimalUser(j_user, primal_map); + if (primal != nullptr) { + MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString() + << " and J user is: " << j_user->DebugString(); + primal_user = primal; } - - primal_user = primal; - MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString() - << " and J user is: " << j_user->DebugString(); } return primal_j_pair; }