| @@ -28,8 +28,7 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| constexpr auto kSingleInputIndex = 1; | constexpr auto kSingleInputIndex = 1; | ||||
| namespace { | namespace { | ||||
| AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -41,15 +40,6 @@ AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node | |||||
| if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { | if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| // Check whether the node has only one output node. | |||||
| if (manager->node_users().find(cnode) == manager->node_users().end()) { | |||||
| MS_LOG(EXCEPTION) << "The node should be used by at least another node's input"; | |||||
| } | |||||
| if (manager->node_users()[cnode].size() > 1) { | |||||
| return nullptr; | |||||
| } | |||||
| CheckCNodeInputSize(cnode, kSingleInputIndex + 1); | CheckCNodeInputSize(cnode, kSingleInputIndex + 1); | ||||
| return cnode->input(kSingleInputIndex); | return cnode->input(kSingleInputIndex); | ||||
| } | } | ||||
| @@ -63,7 +53,7 @@ bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||||
| std::vector<AnfNodePtr> new_make_tuple_inputs; | std::vector<AnfNodePtr> new_make_tuple_inputs; | ||||
| bool need_update = false; | bool need_update = false; | ||||
| for (const auto &input : cnode->inputs()) { | for (const auto &input : cnode->inputs()) { | ||||
| AnfNodePtr replace_input = GetReplaceNode(func_graph, input); | |||||
| AnfNodePtr replace_input = GetReplaceNode(input); | |||||
| // If replace input is not null, it will be the input of the TransData or Cast. | // If replace input is not null, it will be the input of the TransData or Cast. | ||||
| if (replace_input == nullptr) { | if (replace_input == nullptr) { | ||||
| new_make_tuple_inputs.push_back(input); | new_make_tuple_inputs.push_back(input); | ||||
| @@ -119,7 +109,7 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con | |||||
| if (ReplaceMakeTuple(func_graph, replacing_cnode)) { | if (ReplaceMakeTuple(func_graph, replacing_cnode)) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| AnfNodePtr replace_node = GetReplaceNode(func_graph, replacing_cnode); | |||||
| AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); | |||||
| if (replace_node == nullptr) { | if (replace_node == nullptr) { | ||||
| MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); | MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); | ||||
| return nullptr; | return nullptr; | ||||