|
|
|
@@ -32,19 +32,23 @@ using mindspore::abstract::AbstractBase; |
|
|
|
using mindspore::abstract::AbstractFunction; |
|
|
|
using mindspore::abstract::AbstractFunctionPtr; |
|
|
|
|
|
|
|
BasePtr AbsOf(const AnfNodePtr &node) { |
|
|
|
BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto node_abs = node->abstract(); |
|
|
|
// in testcase: TestOptOpt.CSE, node->abstract() is null; |
|
|
|
// In testcase: TestOptOpt.CSE, node->abstract() is null. |
|
|
|
if (node_abs == nullptr) { |
|
|
|
return kAnyValue; |
|
|
|
} |
|
|
|
// Ignore the tracking_id and prim pointer hash; |
|
|
|
if (node_abs->isa<abstract::PrimitiveAbstractClosure>()) { |
|
|
|
// Ignore the tracking_id and prim pointer hash. |
|
|
|
auto prim_abs = node_abs->cast<abstract::PrimitiveAbstractClosurePtr>(); |
|
|
|
return prim_abs->prim(); |
|
|
|
} else if (ignore_fg_abs_tracking_id && node_abs->isa<abstract::FuncGraphAbstractClosure>()) { |
|
|
|
// Ignore the tracking_id. |
|
|
|
auto new_fg_abs = node_abs->cast<abstract::AbstractFunctionPtr>()->Copy(); |
|
|
|
new_fg_abs->set_tracking_id(nullptr); |
|
|
|
return new_fg_abs; |
|
|
|
} |
|
|
|
|
|
|
|
return node_abs; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -68,7 +72,7 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { |
|
|
|
ValueNodePtr value_node = node->cast<ValueNodePtr>(); |
|
|
|
auto value = value_node->value(); |
|
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
|
h = hash_combine(value->hash(), (AbsOf(value_node)->hash())); |
|
|
|
h = hash_combine(value->hash(), (AbsOf(value_node, true)->hash())); |
|
|
|
} else if (node->isa<CNode>()) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
@@ -134,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool chec |
|
|
|
if (main->isa<ValueNode>() && node->isa<ValueNode>()) { |
|
|
|
auto main_value = GetValueNode(main); |
|
|
|
auto node_value = GetValueNode(node); |
|
|
|
return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); |
|
|
|
return (AbsOf(main, true) == AbsOf(node, true)) && (*main_value == *node_value); |
|
|
|
} else if (main->isa<CNode>() && node->isa<CNode>()) { |
|
|
|
auto c_main = main->cast<CNodePtr>(); |
|
|
|
auto c_node = node->cast<CNodePtr>(); |
|
|
|
|