|
|
|
@@ -25,10 +25,20 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace irpass { |
|
|
|
AnfNodePtr TransThroughDepend(const AnfNodePtr &node) { |
|
|
|
auto cur_node = node; |
|
|
|
while (IsPrimitiveCNode(cur_node, prim::kPrimDepend)) { |
|
|
|
cur_node = cur_node->cast<CNodePtr>()->input(1); |
|
|
|
} |
|
|
|
return cur_node; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsValueNode(const AnfNodePtr &node) { return IsVNode(TransThroughDepend(node)); } |
|
|
|
|
|
|
|
// {prim::kPrimCast, X, T} |
|
|
|
AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { |
|
|
|
Reset(); |
|
|
|
AnfVisitor::Match(prim::kPrimCast, {IsNode, IsVNode})(node); |
|
|
|
AnfVisitor::Match(prim::kPrimCast, {IsNode, IsValueNode})(node); |
|
|
|
|
|
|
|
// check pattern match |
|
|
|
if (tgt_ == nullptr) { |
|
|
|
@@ -50,6 +60,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod |
|
|
|
} |
|
|
|
|
|
|
|
if (src_type->type_id() == tgt_type->type_id()) { |
|
|
|
if (IsPrimitiveCNode(node->cast<CNodePtr>()->input(2), prim::kPrimDepend)) { |
|
|
|
auto new_depend = |
|
|
|
node->func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), src_, node->cast<CNodePtr>()->input(2)}); |
|
|
|
return new_depend; |
|
|
|
} |
|
|
|
return src_; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -57,10 +72,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod |
|
|
|
} |
|
|
|
|
|
|
|
void CastSameTypeEliminater::Visit(const AnfNodePtr &node) { |
|
|
|
auto cur_node = TransThroughDepend(node); |
|
|
|
if (src_ == nullptr) { |
|
|
|
src_ = node; |
|
|
|
src_ = cur_node; |
|
|
|
} else { |
|
|
|
tgt_ = node; |
|
|
|
tgt_ = cur_node; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|