|
|
|
@@ -21,109 +21,70 @@ |
|
|
|
|
|
|
|
#include "optimizer/optimizer.h" |
|
|
|
#include "optimizer/irpass.h" |
|
|
|
#include "ir/visitor.h" |
|
|
|
#include "operator/ops.h" |
|
|
|
#include "utils/graph_utils.h" |
|
|
|
#include "operator/composite/composite.h" |
|
|
|
#include "ir/pattern_matcher.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace irpass { |
|
|
|
// {prim::kPrimMakeRef, X, Y, Z} -> Y |
|
|
|
class MakeRefEliminater : public AnfVisitor { |
|
|
|
class MakeRefEliminater : public OptimizerCaller { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
y_ = nullptr; |
|
|
|
auto gety = [this](const AnfNodePtr &node) -> bool { |
|
|
|
this->y_ = node; |
|
|
|
return true; |
|
|
|
}; |
|
|
|
|
|
|
|
AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node); |
|
|
|
return y_; |
|
|
|
PatternNode<AnfNodePtr> x, y, z; |
|
|
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &) override {} |
|
|
|
|
|
|
|
private: |
|
|
|
AnfNodePtr y_{nullptr}; |
|
|
|
}; |
|
|
|
|
|
|
|
// {prim::kPrimGetRefValue, Parameter} -> Parameter |
|
|
|
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter |
|
|
|
class GetRefParamEliminater : public AnfVisitor { |
|
|
|
class GetRefParamEliminater : public OptimizerCaller { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
x_ = nullptr; |
|
|
|
AnfVisitor::Match(prim::kPrimGetRefOrigin, {IsParam})(node); |
|
|
|
if (x_ != nullptr) { |
|
|
|
return x_; |
|
|
|
} |
|
|
|
AnfVisitor::Match(prim::kPrimGetRefValue, {IsParam})(node); |
|
|
|
return x_; |
|
|
|
PatternNode<AnfNodePtr> x; |
|
|
|
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node)); |
|
|
|
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node)); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override { x_ = node; } |
|
|
|
|
|
|
|
private: |
|
|
|
AnfNodePtr x_{nullptr}; |
|
|
|
}; |
|
|
|
|
|
|
|
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X |
|
|
|
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y |
|
|
|
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z |
|
|
|
class GetMakeRefEliminater : public AnfVisitor { |
|
|
|
class GetMakeRefEliminater : public OptimizerCaller { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (cnode == nullptr || cnode->size() != 2) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// {prim::kPrimGetRefKey/Value, {...}} |
|
|
|
auto ref = cnode->input(1)->cast<CNodePtr>(); |
|
|
|
if (ref == nullptr || !ref->IsApply(prim::kPrimMakeRef) || ref->size() != 4) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// {prim::kPrimMakeRef, X, Y, Z} |
|
|
|
if (cnode->IsApply(prim::kPrimGetRefKey)) { |
|
|
|
return ref->input(1); |
|
|
|
} |
|
|
|
|
|
|
|
if (cnode->IsApply(prim::kPrimGetRefValue)) { |
|
|
|
return ref->input(2); |
|
|
|
} |
|
|
|
|
|
|
|
if (cnode->IsApply(prim::kPrimGetRefOrigin)) { |
|
|
|
return ref->input(3); |
|
|
|
} |
|
|
|
PatternNode<AnfNodePtr> x, y, z; |
|
|
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); |
|
|
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); |
|
|
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z); |
|
|
|
|
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
// IsValueNode<RefKey> |
|
|
|
class ReplaceRefkeyByParam : public AnfVisitor { |
|
|
|
class ReplaceRefkeyByParam : public OptimizerCaller { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { |
|
|
|
if (!IsValueNode<RefKey>(node)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto refkey = GetValueNode<RefKeyPtr>(node); |
|
|
|
auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource()); |
|
|
|
MS_EXCEPTION_IF_NULL(resource); |
|
|
|
|
|
|
|
auto top_graph = resource->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(top_graph); |
|
|
|
|
|
|
|
for (const auto &tnode : top_graph->parameters()) { |
|
|
|
auto para = tnode->cast<ParameterPtr>(); |
|
|
|
if (para != nullptr && para->name() == refkey->tag()) { |
|
|
|
return para; |
|
|
|
auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr { |
|
|
|
auto refkey = GetValueNode<RefKeyPtr>(node); |
|
|
|
auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource()); |
|
|
|
MS_EXCEPTION_IF_NULL(resource); |
|
|
|
|
|
|
|
auto top_graph = resource->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(top_graph); |
|
|
|
|
|
|
|
for (const auto &tnode : top_graph->parameters()) { |
|
|
|
auto para = tnode->cast<ParameterPtr>(); |
|
|
|
if (para != nullptr && para->name() == refkey->tag()) { |
|
|
|
return para; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
}; |
|
|
|
PatternNode<AnfNodePtr> x; |
|
|
|
MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode<RefKey>, node)); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
}; |
|
|
|
|