|
|
|
@@ -819,11 +819,58 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr GetJUser(const NodeUsersMap &node_user_map, const CNodePtr &cnode, int index) { |
|
|
|
auto it = node_user_map.find(cnode); |
|
|
|
if (it == node_user_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "J CNode not used {" << cnode->DebugString(2) << "/" << index << "}"; |
|
|
|
} |
|
|
|
auto &j_users = it->second; |
|
|
|
auto size = j_users.size(); |
|
|
|
if (size != 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}"; |
|
|
|
} |
|
|
|
return j_users.begin()->first->cast<CNodePtr>(); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr GetPrimalUser(const CNodePtr &j_user, const std::map<FuncGraphPtr, std::vector<CNodePtr>> &primal_map) { |
|
|
|
// Check if J operation has relevant primal call in the same graph. |
|
|
|
auto graph = j_user->func_graph(); |
|
|
|
auto iter = primal_map.find(graph); |
|
|
|
if (iter == primal_map.end()) { |
|
|
|
MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString() |
|
|
|
<< ", J user: " << j_user->DebugString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// Check if there is only one primal call corresponding to the specified j user. |
|
|
|
auto primal_users = iter->second; |
|
|
|
if (primal_users.size() != 1) { |
|
|
|
MS_LOG(WARNING) << "It is recommended to call the forward network only once."; |
|
|
|
MS_LOG(INFO) << "There is " << primal_users.size() |
|
|
|
<< " primal calls for same J operation in the same graph. Func graph: " << graph->ToString() |
|
|
|
<< ", J operation: " << j_user->DebugString() << ", Primal call: "; |
|
|
|
size_t count = 0; |
|
|
|
for (const auto &user : primal_users) { |
|
|
|
MS_LOG(INFO) << "[ " << ++count << " ] : " << user->DebugString(2) << ", trace: " << trace::DumpSourceLines(user); |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// Check input size. |
|
|
|
auto primal_user = primal_users[0]; |
|
|
|
if (primal_user->size() != j_user->size()) { |
|
|
|
MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal_user->DebugString() << " is " |
|
|
|
<< primal_user->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return primal_user; |
|
|
|
} |
|
|
|
|
|
|
|
static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager, |
|
|
|
const FuncGraphPtr &primal_graph) { |
|
|
|
std::vector<std::pair<CNodePtr, CNodePtr>> primal_j_pair; |
|
|
|
std::map<FuncGraphPtr, CNodePtr> primal_users_map; |
|
|
|
auto &node_user_map = manager->node_users(); |
|
|
|
std::map<FuncGraphPtr, std::vector<CNodePtr>> primal_map; |
|
|
|
const auto &node_user_map = manager->node_users(); |
|
|
|
// Search primal graph user cnodes. |
|
|
|
for (auto &entry : primal_graph->func_graph_cnodes_index()) { |
|
|
|
auto cnode = entry.first->first->cast<CNodePtr>(); |
|
|
|
@@ -832,47 +879,26 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap |
|
|
|
// To find real calling. |
|
|
|
auto fg = cnode->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(fg); |
|
|
|
if (primal_users_map.find(fg) != primal_users_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "The forward network is only allowed to be called once. Func graph: " << fg->ToString() |
|
|
|
<< ", cnode: " << cnode->DebugString() << ", trace: " << trace::DumpSourceLines(cnode); |
|
|
|
auto iter = primal_map.find(fg); |
|
|
|
if (iter != primal_map.end()) { |
|
|
|
iter->second.push_back(cnode); |
|
|
|
continue; |
|
|
|
} |
|
|
|
primal_users_map[fg] = cnode; |
|
|
|
primal_map[fg] = {cnode}; |
|
|
|
} else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) { |
|
|
|
// To find J user. |
|
|
|
auto it = node_user_map.find(cnode); |
|
|
|
if (it == node_user_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "J CNode not used {" << cnode->DebugString(2) << "/" << index << "}"; |
|
|
|
} |
|
|
|
auto &j_users = it->second; |
|
|
|
auto size = j_users.size(); |
|
|
|
if (size != 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}"; |
|
|
|
} |
|
|
|
auto j_user = j_users.begin()->first->cast<CNodePtr>(); |
|
|
|
auto j_user = GetJUser(node_user_map, cnode, index); |
|
|
|
primal_j_pair.push_back({nullptr, j_user}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &[primal_user, j_user] : primal_j_pair) { |
|
|
|
// Check if J operation has relevant primal call in the same graph |
|
|
|
auto graph = j_user->func_graph(); |
|
|
|
auto iter = primal_users_map.find(graph); |
|
|
|
if (iter == primal_users_map.end()) { |
|
|
|
MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString() |
|
|
|
<< ", J user: " << j_user->DebugString(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
// Check input size. |
|
|
|
auto primal = iter->second; |
|
|
|
if (primal->size() != j_user->size()) { |
|
|
|
MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal->DebugString() << " is " |
|
|
|
<< primal->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size(); |
|
|
|
continue; |
|
|
|
auto primal = GetPrimalUser(j_user, primal_map); |
|
|
|
if (primal != nullptr) { |
|
|
|
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString() |
|
|
|
<< " and J user is: " << j_user->DebugString(); |
|
|
|
primal_user = primal; |
|
|
|
} |
|
|
|
|
|
|
|
primal_user = primal; |
|
|
|
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString() |
|
|
|
<< " and J user is: " << j_user->DebugString(); |
|
|
|
} |
|
|
|
return primal_j_pair; |
|
|
|
} |
|
|
|
|