|
|
|
@@ -90,6 +90,22 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { |
|
|
|
return changed; |
|
|
|
} |
|
|
|
|
|
|
|
bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { |
|
|
|
bool has_random_effect = false; |
|
|
|
auto prim_main = GetCNodePrimitive(main); |
|
|
|
auto prim_node = GetCNodePrimitive(node); |
|
|
|
if (prim_main == prim_node) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (prim_main != nullptr) { |
|
|
|
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); |
|
|
|
if (effect_val != nullptr && effect_val->isa<BoolImm>()) { |
|
|
|
has_random_effect = GetValue<bool>(effect_val); |
|
|
|
} |
|
|
|
} |
|
|
|
return has_random_effect; |
|
|
|
} |
|
|
|
|
|
|
|
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { |
|
|
|
MS_EXCEPTION_IF_NULL(main); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
@@ -122,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (IsPrimitiveCNode(c_main, prim::kPrimDropoutGenMask)) { |
|
|
|
if (CheckRandomEffect(c_main, c_node)) { |
|
|
|
appsame = false; |
|
|
|
} |
|
|
|
replace = appsame; |
|
|
|
|