| @@ -25,10 +25,20 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| AnfNodePtr TransThroughDepend(const AnfNodePtr &node) { | |||||
| auto cur_node = node; | |||||
| while (IsPrimitiveCNode(cur_node, prim::kPrimDepend)) { | |||||
| cur_node = cur_node->cast<CNodePtr>()->input(1); | |||||
| } | |||||
| return cur_node; | |||||
| } | |||||
| bool IsValueNode(const AnfNodePtr &node) { return IsVNode(TransThroughDepend(node)); } | |||||
| // {prim::kPrimCast, X, T} | // {prim::kPrimCast, X, T} | ||||
| AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | ||||
| Reset(); | Reset(); | ||||
| AnfVisitor::Match(prim::kPrimCast, {IsNode, IsVNode})(node); | |||||
| AnfVisitor::Match(prim::kPrimCast, {IsNode, IsValueNode})(node); | |||||
| // check pattern match | // check pattern match | ||||
| if (tgt_ == nullptr) { | if (tgt_ == nullptr) { | ||||
| @@ -50,6 +60,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod | |||||
| } | } | ||||
| if (src_type->type_id() == tgt_type->type_id()) { | if (src_type->type_id() == tgt_type->type_id()) { | ||||
| if (IsPrimitiveCNode(node->cast<CNodePtr>()->input(2), prim::kPrimDepend)) { | |||||
| auto new_depend = | |||||
| node->func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), src_, node->cast<CNodePtr>()->input(2)}); | |||||
| return new_depend; | |||||
| } | |||||
| return src_; | return src_; | ||||
| } | } | ||||
| @@ -57,10 +72,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod | |||||
| } | } | ||||
| void CastSameTypeEliminater::Visit(const AnfNodePtr &node) { | void CastSameTypeEliminater::Visit(const AnfNodePtr &node) { | ||||
| auto cur_node = TransThroughDepend(node); | |||||
| if (src_ == nullptr) { | if (src_ == nullptr) { | ||||
| src_ = node; | |||||
| src_ = cur_node; | |||||
| } else { | } else { | ||||
| tgt_ = node; | |||||
| tgt_ = cur_node; | |||||
| } | } | ||||
| } | } | ||||