|
|
@@ -129,28 +129,7 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod |
|
|
if (!new_node->isa<CNode>()) { |
|
|
if (!new_node->isa<CNode>()) { |
|
|
MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << "."; |
|
|
MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << "."; |
|
|
} |
|
|
} |
|
|
auto c_node = node->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_node); |
|
|
|
|
|
auto inputs = c_node->inputs(); |
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs; |
|
|
|
|
|
(void)std::transform( |
|
|
|
|
|
inputs.begin(), inputs.end(), std::back_inserter(new_inputs), [this](const AnfNodePtr &inp) -> AnfNodePtr { |
|
|
|
|
|
auto new_inp = ReplicateDisconnectedNode(inp); |
|
|
|
|
|
// Refer the comments in BuildReplacedNode. |
|
|
|
|
|
if (inp->isa<CNode>()) { |
|
|
|
|
|
auto c_inp = inp->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_inp); |
|
|
|
|
|
auto c_new_inp = new_inp->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_new_inp); |
|
|
|
|
|
MS_LOG(DEBUG) << "Replace in order, inp node: " << inp->DebugString() << " -> " << new_inp->DebugString(); |
|
|
|
|
|
c_new_inp->func_graph()->ReplaceInOrder(c_inp, c_new_inp); |
|
|
|
|
|
} |
|
|
|
|
|
return new_inp; |
|
|
|
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
auto c_new_node = new_node->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_new_node); |
|
|
|
|
|
c_new_node->set_inputs(new_inputs); |
|
|
|
|
|
|
|
|
UpdateNewCNodeInputs(node, new_node); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
iter = specializer->repl_node_->find(node); |
|
|
iter = specializer->repl_node_->find(node); |
|
|
@@ -164,6 +143,31 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod |
|
|
return new_node; |
|
|
return new_node; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node) { |
|
|
|
|
|
auto c_node = node->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_node); |
|
|
|
|
|
auto inputs = c_node->inputs(); |
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs; |
|
|
|
|
|
(void)std::transform( |
|
|
|
|
|
inputs.begin(), inputs.end(), std::back_inserter(new_inputs), [this](const AnfNodePtr &inp) -> AnfNodePtr { |
|
|
|
|
|
auto new_inp = ReplicateDisconnectedNode(inp); |
|
|
|
|
|
// Refer the comments in BuildReplacedNode. |
|
|
|
|
|
if (inp->isa<CNode>()) { |
|
|
|
|
|
auto c_inp = inp->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_inp); |
|
|
|
|
|
auto c_new_inp = new_inp->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_new_inp); |
|
|
|
|
|
MS_LOG(DEBUG) << "Replace in order, inp node: " << inp->DebugString() << " -> " << new_inp->DebugString(); |
|
|
|
|
|
c_new_inp->func_graph()->ReplaceInOrder(c_inp, c_new_inp); |
|
|
|
|
|
} |
|
|
|
|
|
return new_inp; |
|
|
|
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
auto c_new_node = new_node->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_new_node); |
|
|
|
|
|
c_new_node->set_inputs(new_inputs); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) { |
|
|
AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) { |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
FuncGraphPtr fg = node->func_graph(); |
|
|
FuncGraphPtr fg = node->func_graph(); |
|
|
|