|
|
|
@@ -916,13 +916,14 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { |
|
|
|
return new_value_node; |
|
|
|
} |
|
|
|
|
|
|
|
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) { |
|
|
|
void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(old_node); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto manager = graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
// Find BatchNorm's output which is a Depend or UpdateState. |
|
|
|
for (const auto &node_index : manager->node_users()[old_node]) { |
|
|
|
auto node_users = manager->node_users()[old_node]; |
|
|
|
for (const auto &node_index : node_users) { |
|
|
|
AnfNodePtr output = node_index.first; |
|
|
|
size_t index = IntToSize(node_index.second); |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
@@ -930,7 +931,7 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C |
|
|
|
AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) { |
|
|
|
auto depend = output->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(depend); |
|
|
|
depend->set_input(index, new_node); |
|
|
|
manager->SetEdge(depend, index, new_node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|