|
|
|
@@ -824,7 +824,7 @@ CNodePtr GetJUser(const NodeUsersMap &node_user_map, const CNodePtr &cnode, int |
|
|
|
static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager, |
|
|
|
const FuncGraphPtr &primal_graph) { |
|
|
|
std::vector<std::pair<CNodePtr, CNodePtr>> primal_j_pair; |
|
|
|
std::map<FuncGraphPtr, CNodePtr> primal_users_map; |
|
|
|
std::map<FuncGraphPtr, std::pair<CNodePtr, int>> primal_users_map; |
|
|
|
const auto &node_user_map = manager->node_users(); |
|
|
|
// Search primal graph user cnodes. |
|
|
|
for (auto &entry : primal_graph->func_graph_cnodes_index()) { |
|
|
|
@@ -834,12 +834,12 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap |
|
|
|
// To find real calling. |
|
|
|
auto fg = cnode->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(fg); |
|
|
|
if (primal_users_map.find(fg) != primal_users_map.end()) { |
|
|
|
MS_LOG(WARNING) << "It is recommended to call the forward network only once. Func graph: " << fg->ToString() |
|
|
|
<< ", cnode: " << cnode->DebugString() << ", trace: " << trace::DumpSourceLines(cnode); |
|
|
|
auto iter = primal_users_map.find(fg); |
|
|
|
if (iter != primal_users_map.end()) { |
|
|
|
iter->second.second++; |
|
|
|
continue; |
|
|
|
} |
|
|
|
primal_users_map[fg] = cnode; |
|
|
|
primal_users_map[fg] = std::make_pair(cnode, 1); |
|
|
|
} else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) { |
|
|
|
// To find J user. |
|
|
|
auto j_user = GetJUser(node_user_map, cnode, index); |
|
|
|
@@ -856,13 +856,22 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap |
|
|
|
<< ", J user: " << j_user->DebugString(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto primal_count_pair = iter->second; |
|
|
|
// Check input size. |
|
|
|
auto primal = iter->second; |
|
|
|
auto primal = primal_count_pair.first; |
|
|
|
if (primal->size() != j_user->size()) { |
|
|
|
MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal->DebugString() << " is " |
|
|
|
<< primal->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (primal_count_pair.second != 1) { |
|
|
|
MS_LOG(WARNING) << "It is recommended to call the forward network only once."; |
|
|
|
MS_LOG(INFO) << "There is more than one primal call for J operation in the same graph. Func graph: " |
|
|
|
<< graph->ToString() << ", primal call: " << primal->DebugString() |
|
|
|
<< ", J user: " << j_user->DebugString() << ", trace: " << trace::DumpSourceLines(primal); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
primal_user = primal; |
|
|
|
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString() |
|
|
|
|