Browse Source

!15063 Do not match when there is more than one primal call for J user

From: @huangbingjian
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
tags/v1.2.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f0bd06dbed
1 changed files with 15 additions and 6 deletions
  1. +15
    -6
      mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc

+ 15
- 6
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc View File

@@ -824,7 +824,7 @@ CNodePtr GetJUser(const NodeUsersMap &node_user_map, const CNodePtr &cnode, int
static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
const FuncGraphPtr &primal_graph) {
std::vector<std::pair<CNodePtr, CNodePtr>> primal_j_pair;
std::map<FuncGraphPtr, CNodePtr> primal_users_map;
std::map<FuncGraphPtr, std::pair<CNodePtr, int>> primal_users_map;
const auto &node_user_map = manager->node_users();
// Search primal graph user cnodes.
for (auto &entry : primal_graph->func_graph_cnodes_index()) {
@@ -834,12 +834,12 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap
// To find real calling.
auto fg = cnode->func_graph();
MS_EXCEPTION_IF_NULL(fg);
if (primal_users_map.find(fg) != primal_users_map.end()) {
MS_LOG(WARNING) << "It is recommended to call the forward network only once. Func graph: " << fg->ToString()
<< ", cnode: " << cnode->DebugString() << ", trace: " << trace::DumpSourceLines(cnode);
auto iter = primal_users_map.find(fg);
if (iter != primal_users_map.end()) {
iter->second.second++;
continue;
}
primal_users_map[fg] = cnode;
primal_users_map[fg] = std::make_pair(cnode, 1);
} else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
// To find J user.
auto j_user = GetJUser(node_user_map, cnode, index);
@@ -856,13 +856,22 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap
<< ", J user: " << j_user->DebugString();
continue;
}

auto primal_count_pair = iter->second;
// Check input size.
auto primal = iter->second;
auto primal = primal_count_pair.first;
if (primal->size() != j_user->size()) {
MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal->DebugString() << " is "
<< primal->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size();
continue;
}
if (primal_count_pair.second != 1) {
MS_LOG(WARNING) << "It is recommended to call the forward network only once.";
MS_LOG(INFO) << "There is more than one primal call for J operation in the same graph. Func graph: "
<< graph->ToString() << ", primal call: " << primal->DebugString()
<< ", J user: " << j_user->DebugString() << ", trace: " << trace::DumpSourceLines(primal);
continue;
}

primal_user = primal;
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()


Loading…
Cancel
Save