Browse Source

match depend of eliminate cast

tags/v1.2.0-rc1
chenfei 4 years ago
parent
commit
d28fcbf598
1 changed files with 19 additions and 3 deletions
  1. +19
    -3
      mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc

+ 19
- 3
mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc View File

@@ -25,10 +25,20 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
AnfNodePtr TransThroughDepend(const AnfNodePtr &node) {
auto cur_node = node;
while (IsPrimitiveCNode(cur_node, prim::kPrimDepend)) {
cur_node = cur_node->cast<CNodePtr>()->input(1);
}
return cur_node;
}

bool IsValueNode(const AnfNodePtr &node) { return IsVNode(TransThroughDepend(node)); }

// {prim::kPrimCast, X, T} // {prim::kPrimCast, X, T}
AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
Reset(); Reset();
AnfVisitor::Match(prim::kPrimCast, {IsNode, IsVNode})(node);
AnfVisitor::Match(prim::kPrimCast, {IsNode, IsValueNode})(node);


// check pattern match // check pattern match
if (tgt_ == nullptr) { if (tgt_ == nullptr) {
@@ -50,6 +60,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod
} }


if (src_type->type_id() == tgt_type->type_id()) { if (src_type->type_id() == tgt_type->type_id()) {
if (IsPrimitiveCNode(node->cast<CNodePtr>()->input(2), prim::kPrimDepend)) {
auto new_depend =
node->func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), src_, node->cast<CNodePtr>()->input(2)});
return new_depend;
}
return src_; return src_;
} }


@@ -57,10 +72,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod
} }


void CastSameTypeEliminater::Visit(const AnfNodePtr &node) { void CastSameTypeEliminater::Visit(const AnfNodePtr &node) {
auto cur_node = TransThroughDepend(node);
if (src_ == nullptr) { if (src_ == nullptr) {
src_ = node;
src_ = cur_node;
} else { } else {
tgt_ = node;
tgt_ = cur_node;
} }
} }




Loading…
Cancel
Save