|
|
|
@@ -108,6 +108,18 @@ namespace { |
|
|
|
// Isomorphism |
|
|
|
bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, |
|
|
|
NodeMapEquiv *const equiv_node); |
|
|
|
|
|
|
|
bool SameValueNode(const AnfNodePtr &node1, const AnfNodePtr &node2) { |
|
|
|
auto a1 = GetValueNode(node1); |
|
|
|
auto a2 = GetValueNode(node2); |
|
|
|
if (a1->isa<Primitive>() && a2->isa<Primitive>()) { |
|
|
|
return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name(); |
|
|
|
} else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) { |
|
|
|
return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>())); |
|
|
|
} |
|
|
|
return *a1 == *a2; |
|
|
|
} |
|
|
|
|
|
|
|
bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, |
|
|
|
NodeMapEquiv *const equiv_node) { |
|
|
|
if (equiv_node == nullptr) { |
|
|
|
@@ -122,15 +134,7 @@ bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraph |
|
|
|
equiv_node); |
|
|
|
} |
|
|
|
if (node1->isa<ValueNode>() && node2->isa<ValueNode>()) { |
|
|
|
auto a1 = GetValueNode(node1); |
|
|
|
auto a2 = GetValueNode(node2); |
|
|
|
if (a1->isa<Primitive>() && a2->isa<Primitive>()) { |
|
|
|
return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name(); |
|
|
|
} else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) { |
|
|
|
return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>())); |
|
|
|
} else { |
|
|
|
return *a1 == *a2; |
|
|
|
} |
|
|
|
return SameValueNode(node1, node2); |
|
|
|
} |
|
|
|
if (node1->isa<Parameter>() && node2->isa<Parameter>()) { |
|
|
|
auto para1 = node1->cast<ParameterPtr>(); |
|
|
|
|