|
|
@@ -177,7 +177,8 @@ void FusedBatchNormFusion::EliminateMonadNodes(const FuncGraphPtr &func_graph, c |
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
auto assign_sub1 = GetAnfNodeByVar(equiv, assign_sub1_var_); |
|
|
auto assign_sub1 = GetAnfNodeByVar(equiv, assign_sub1_var_); |
|
|
MS_EXCEPTION_IF_NULL(assign_sub1); |
|
|
MS_EXCEPTION_IF_NULL(assign_sub1); |
|
|
for (const auto &node_index : manager->node_users()[assign_sub1]) { |
|
|
|
|
|
|
|
|
auto users = manager->node_users()[assign_sub1]; |
|
|
|
|
|
for (const auto &node_index : users) { |
|
|
const AnfNodePtr &output = node_index.first; |
|
|
const AnfNodePtr &output = node_index.first; |
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) { |
|
|
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) { |
|
|
|