|
|
|
@@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { |
|
|
|
|
|
|
|
return changed; |
|
|
|
} |
|
|
|
|
|
|
|
// The op like print, summary, or the op do not has true output, and always as a depend node input. |
|
|
|
static bool HasSideEffect(const AnfNodePtr &node) { |
|
|
|
auto prim = GetCNodePrimitive(node); |
|
|
|
if (prim == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT); |
|
|
|
if (side_effect_v != nullptr && side_effect_v->isa<BoolImm>()) { |
|
|
|
return GetValue<bool>(side_effect_v); |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
// If true do not merge the node. |
|
|
|
bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { |
|
|
|
bool has_random_effect = false; |
|
|
|
auto prim_main = GetCNodePrimitive(main); |
|
|
|
auto prim_node = GetCNodePrimitive(node); |
|
|
|
if (prim_main == prim_node) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
// if has random effect, when generate by different op (not same object), do not merge. |
|
|
|
if (prim_main != nullptr) { |
|
|
|
if (prim_main == prim_node) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); |
|
|
|
if (effect_val != nullptr && effect_val->isa<BoolImm>()) { |
|
|
|
has_random_effect = GetValue<bool>(effect_val); |
|
|
|
@@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons |
|
|
|
return has_random_effect; |
|
|
|
} |
|
|
|
|
|
|
|
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { |
|
|
|
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const { |
|
|
|
MS_EXCEPTION_IF_NULL(main); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
|
|
bool replace = false; |
|
|
|
if (main->isa<ValueNode>() && node->isa<ValueNode>()) { |
|
|
|
auto main_value = GetValueNode(main); |
|
|
|
auto node_value = GetValueNode(node); |
|
|
|
replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); |
|
|
|
return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); |
|
|
|
} else if (main->isa<CNode>() && node->isa<CNode>()) { |
|
|
|
auto c_main = main->cast<CNodePtr>(); |
|
|
|
auto c_node = node->cast<CNodePtr>(); |
|
|
|
// When appsame is true, check if has side effect, do not merge. |
|
|
|
if (check_side_effect && HasSideEffect(main)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
const auto &inp1 = c_main->inputs(); |
|
|
|
const auto &inp2 = c_node->inputs(); |
|
|
|
if (inp1.size() == inp2.size()) { |
|
|
|
bool appsame = true; |
|
|
|
for (size_t j = 0; j < inp1.size(); j++) { |
|
|
|
MS_EXCEPTION_IF_NULL(inp1[j]); |
|
|
|
MS_EXCEPTION_IF_NULL(inp2[j]); |
|
|
|
if (!(*inp1[j] == *inp2[j])) { |
|
|
|
// Handle the case of two different Tensor, but with the same value |
|
|
|
if (IsValueNode<tensor::Tensor>(inp1[j]) && IsValueNode<tensor::Tensor>(inp2[j])) { |
|
|
|
auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1[j]); |
|
|
|
auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2[j]); |
|
|
|
if (tensor1->ValueEqual(*tensor2)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (inp1.size() != inp2.size()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
for (size_t j = 0; j < inp1.size(); j++) { |
|
|
|
auto inp1_j = inp1[j]; |
|
|
|
auto inp2_j = inp2[j]; |
|
|
|
MS_EXCEPTION_IF_NULL(inp1_j); |
|
|
|
MS_EXCEPTION_IF_NULL(inp2_j); |
|
|
|
if (!(*inp1_j == *inp2_j)) { |
|
|
|
// Handle the case of two different Tensor, but with the same value |
|
|
|
if (IsValueNode<tensor::Tensor>(inp1_j) && IsValueNode<tensor::Tensor>(inp2_j)) { |
|
|
|
auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1_j); |
|
|
|
auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2_j); |
|
|
|
if (tensor1->ValueEqual(*tensor2)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
} else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) { |
|
|
|
// When the same side effect node as another two nodes' inputs, we still merge the node. |
|
|
|
// Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the |
|
|
|
// node. |
|
|
|
if (CheckReplace(inp1_j, inp2_j, false)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
appsame = false; |
|
|
|
break; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (CheckRandomEffect(c_main, c_node)) { |
|
|
|
appsame = false; |
|
|
|
} |
|
|
|
replace = appsame; |
|
|
|
} |
|
|
|
// When appsame is true, check if has random effect do not merge |
|
|
|
if (CheckRandomEffect(c_main, c_node)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
return replace; |
|
|
|
// a parameter node. |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group, |
|
|
|
|