From d28fcbf5989bed7aafda1af0b6c394b0984c0b9c Mon Sep 17 00:00:00 2001 From: chenfei Date: Tue, 23 Feb 2021 16:37:13 +0800 Subject: [PATCH] match depend of eliminate cast --- .../optimizer/irpass/cast_eliminate.cc | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc index 069012f77f..ee20ababab 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc @@ -25,10 +25,20 @@ namespace mindspore { namespace opt { namespace irpass { +AnfNodePtr TransThroughDepend(const AnfNodePtr &node) { + auto cur_node = node; + while (IsPrimitiveCNode(cur_node, prim::kPrimDepend)) { + cur_node = cur_node->cast()->input(1); + } + return cur_node; +} + +bool IsValueNode(const AnfNodePtr &node) { return IsVNode(TransThroughDepend(node)); } + // {prim::kPrimCast, X, T} AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { Reset(); - AnfVisitor::Match(prim::kPrimCast, {IsNode, IsVNode})(node); + AnfVisitor::Match(prim::kPrimCast, {IsNode, IsValueNode})(node); // check pattern match if (tgt_ == nullptr) { @@ -50,6 +60,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod } if (src_type->type_id() == tgt_type->type_id()) { + if (IsPrimitiveCNode(node->cast()->input(2), prim::kPrimDepend)) { + auto new_depend = + node->func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), src_, node->cast()->input(2)}); + return new_depend; + } return src_; } @@ -57,10 +72,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod } void CastSameTypeEliminater::Visit(const AnfNodePtr &node) { + auto cur_node = TransThroughDepend(node); if (src_ == nullptr) { - src_ = node; + src_ = cur_node; } else { - tgt_ = node; + tgt_ = cur_node; } }