Browse Source

!1833 Pattern Matcher class for optimizations

Merge pull request !1833 from Giancarlo/pattern_matcher
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
067616d0a5
2 changed files with 377 additions and 166 deletions
  1. +306
    -0
      mindspore/ccsrc/ir/pattern_matcher.h
  2. +71
    -166
      mindspore/ccsrc/optimizer/irpass/branch_culling.h

+ 306
- 0
mindspore/ccsrc/ir/pattern_matcher.h View File

@@ -0,0 +1,306 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_
#define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_

#include <tuple>
#include <vector>

#include "ir/anf.h"
#include "operator/ops.h"

namespace mindspore {

///
/// Base class for all recognizable patterns.
/// We implement an Expression Template approach using static polymorphism based on
/// the Curiously Recurring Template Pattern (CRTP) which "achieves a similar effect
/// to the use of virtual functions without the costs..." as described in:
/// https://en.wikipedia.org/wiki/Expression_templates and
/// https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
/// The TryCapture function tries to capture the pattern with the given node.
/// The GetNode function builds a new node using the captured values.
///

template <typename T>
class PBase {
public:
const T &get_object() const { return *static_cast<const T *>(this); }

template <typename TN>
bool TryCapture(const TN &value) const {
get_object().Reset();
return get_object().TryCapture_(value);
}

using Internal = T;
};

template <typename T>
class PIsEqual {
public:
bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; }
};

template <typename T>
class PatternNode : public PBase<PatternNode<T> > {
public:
T GetNode(const AnfNodePtr &node) const {
if (!captured_) {
MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode.";
}
return captured_node_;
}

bool TryCapture_(const T &node) const {
if (!captured_) {
captured_node_ = node;
captured_ = true;
return true;
}
return PIsEqual<T>()(captured_node_, node);
}

void Reset() const { captured_ = false; }
using Internal = const PatternNode<T> &;

protected:
mutable T captured_node_;
mutable bool captured_{false};
};

template <typename T, typename T2>
class PBinOperation : public PBase<PBinOperation<T, T2> > {
public:
PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y) : prim_(prim), x_(x), y_(y) {}

AnfNodePtr GetNode(const AnfNodePtr &node) const {
AnfNodePtr lhs = x_.GetNode(node->func_graph());
AnfNodePtr rhs = y_.GetNode(node->func_graph());
AnfNodePtrList list = {prim_->cast<AnfNodePtr>(), lhs, rhs};
return NewCNode(list, node->func_graph());
}

bool TryCapture_(const AnfNodePtr &node) const {
if (IsPrimitiveCNode(node, prim_)) {
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
if (inputs.size() == 3) {
// Binary Prim assumes only two inputs
if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) {
return false;
}
return true;
}
}
return false;
}

void Reset() const {
x_.Reset();
y_.Reset();
}

private:
const PrimitivePtr prim_;
typename T::Internal x_;
typename T2::Internal y_;
};

///
/// Helper functions to apply a pattern function on all elements of a tuple
///
namespace tuple_utils {
template <bool stop, size_t Index, typename Func>
struct apply_func_tuple_item {
template <typename TTuple>
static void apply(Func *func, const TTuple &tuple) {
(*func)(Index, std::get<Index>(tuple));
apply_func_tuple_item<(Index + 1) == std::tuple_size<TTuple>::value, (Index + 1), Func>::apply(func, tuple);
}
};

template <size_t Index, typename Func>
struct apply_func_tuple_item<true, Index, Func> {
template <typename TTuple>
static void apply(Func *func, const TTuple &tuple) {}
};

template <typename Func, typename TTuple>
inline void apply_func_tuple(Func *func, const TTuple &tuple) {
apply_func_tuple_item<std::tuple_size<TTuple>::value == 0, 0, Func>::apply(func, tuple);
}

struct PTupleResetCapture {
template <typename T>
void operator()(size_t i, const T &pattern) const {
pattern.Reset();
}
};

struct PTupleCapture {
explicit PTupleCapture(const AnfNodePtrList tuple) : tuple_(tuple) {}

template <typename TPattern>
void operator()(size_t i, const TPattern &pattern) {
// Check if the first node is a Primitive
if (i == 0 && tuple_[i]->isa<Primitive>()) {
auto prim = tuple_[i]->cast<PrimitivePtr>();
if (tuple_[i] != pattern.GetNode(tuple_[i])) {
captured_ = false;
}
} else {
captured_ = captured_ && pattern.TryCapture_(tuple_[i]);
}
}

const AnfNodePtrList tuple_;
bool captured_{true};
};

struct PTupleGetNode {
explicit PTupleGetNode(const AnfNodePtr &node) : node_(node) {}

template <typename TPattern>
void operator()(size_t, const TPattern &pattern) {
args_.push_back(pattern.GetNode(node_));
}

const AnfNodePtr &node_;
std::vector<AnfNodePtr> args_;
};
} // namespace tuple_utils

template <typename... TArgs>
class PCNode : public PBase<PCNode<TArgs...> > {
public:
explicit PCNode(const TArgs &... args) : args_(args...) {}

AnfNodePtr GetNode(const AnfNodePtr &node) const {
tuple_utils::PTupleGetNode get_node(node);
tuple_utils::apply_func_tuple(&get_node, args_);
return NewCNode(get_node.args_, node->func_graph());
}

bool TryCapture_(const AnfNodePtr &node) const {
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
if (inputs.size() != sizeof...(TArgs)) {
return false;
}
tuple_utils::PTupleCapture capture_func(inputs);
tuple_utils::apply_func_tuple(&capture_func, args_);
return capture_func.captured_;
}

return false;
}

void Reset() const {
tuple_utils::PTupleResetCapture reset;
tuple_utils::apply_func_tuple(&reset, args_);
}

private:
std::tuple<typename TArgs::Internal...> args_;
};

template <typename... TArgs>
class PPrimitive : public PBase<PPrimitive<TArgs...> > {
public:
explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {}

AnfNodePtr GetNode(const AnfNodePtr &node) const {
tuple_utils::PTupleGetNode get_node(node);
tuple_utils::apply_func_tuple(&get_node, args_);
auto prim_cnode = get_node.args_;
prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_));
return NewCNode(prim_cnode, node->func_graph());
}

bool TryCapture_(const AnfNodePtr &node) const {
if (IsPrimitiveCNode(node, prim_)) {
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
if ((inputs.size() - 1) != sizeof...(TArgs)) {
return false;
}

AnfNodePtrList rest(inputs.begin() + 1, inputs.end());
tuple_utils::PTupleCapture capture_func(rest);
tuple_utils::apply_func_tuple(&capture_func, args_);

return capture_func.captured_;
}

return false;
}

void Reset() const {
tuple_utils::PTupleResetCapture reset;
tuple_utils::apply_func_tuple(&reset, args_);
}

private:
const PrimitivePtr prim_;
std::tuple<typename TArgs::Internal...> args_;
};

// Macro for binary operation functions
#define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \
template <typename T, typename T2> \
inline PBinOperation<T, T2> Operator(const PBase<T> &x, const PBase<T2> &y) { \
return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \
}

// Arithmetic operations
BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd);
BIN_OPERATION_PATTERN(operator*, prim::kPrimMul);

// Macros for match and replace
#define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \
if ((CaptureNode).TryCapture(OrigNode)) { \
return (ReplaceWith).GetNode(OrigNode); \
}

#define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \
if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \
return (ReplaceWith).GetNode(OrigNode); \
}

#define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \
if ((CaptureNode).TryCapture(OrigNode)) { \
if ((Condition)) { \
return (ReplaceWith).GetNode(OrigNode); \
} \
return (ElseNode).GetNode(OrigNode); \
}

#define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \
if ((CaptureNode).TryCapture(OrigNode)) { \
return (Lambda)(); \
}

#define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \
if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \
return (Lambda)(); \
}

} // namespace mindspore

#endif // #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_

+ 71
- 166
mindspore/ccsrc/optimizer/irpass/branch_culling.h View File

@@ -26,141 +26,61 @@
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "ir/pattern_matcher.h"


namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
// {prim::kPrimSwitch, true, X, Y} // {prim::kPrimSwitch, true, X, Y}
// {prim::kPrimSwitch, false, X, Y} // {prim::kPrimSwitch, false, X, Y}
class SwitchSimplify : public AnfVisitor {
class SwitchSimplify {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
auto getx = [this](const AnfNodePtr &node) -> bool {
this->x_ = node;
return true;
};
auto gety = [this](const AnfNodePtr &node) -> bool {
this->y_ = node;
return true;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode<AnfNodePtr> cond, true_br, false_br;
auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr {
auto cond_value_ = GetValue<bool>(GetValueNode(cond.GetNode(node)));
if (cond_value_) {
return true_br.GetNode(node);
}
return false_br.GetNode(node);
}; };
AnfVisitor::Match(prim::kPrimSwitch, {IsValueNode<BoolImm>, getx, gety})(node);


// simplify the switch
if (is_match_) {
if (cond_) {
return x_;
}
return y_;
}
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda,
IsValueNode<BoolImm>(cond.GetNode(node)));


return nullptr; return nullptr;
} }

void Visit(const AnfNodePtr &node) override {
if (!is_match_ && IsValueNode<BoolImm>(node)) {
cond_ = GetValue<bool>(GetValueNode(node));
is_match_ = true;
}
}

void Reset() {
x_ = nullptr;
y_ = nullptr;
cond_ = false;
is_match_ = false;
}

private:
bool is_match_{false}, cond_{false};
AnfNodePtr x_{nullptr}, y_{nullptr};
}; };


// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} => // {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} =>
// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} // {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
class FloatTupleGetItemSwitch : public AnfVisitor {
class FloatTupleGetItemSwitch {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node);

auto fg = node->func_graph();
if (Xs_.empty() || c_ == nullptr || fg == nullptr) {
return nullptr;
}

auto true_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), Xs_[1], c_});
auto false_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), Xs_[2], c_});

return fg->NewCNode({NewValueNode(prim::kPrimSwitch), Xs_[0], true_node, false_node});
}

void Visit(const CNodePtr &cnode) override {
// {prim::kPrimSwith, X1, X2, X3}
if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch) || cnode->size() != 4) {
return;
}

// copy X1, X2, X3
auto &inputs = cnode->inputs();
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
}

void Visit(const ValueNodePtr &vnode) override { c_ = vnode; }

void Reset() {
Xs_.clear();
c_ = nullptr;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode<AnfNodePtr> cond, true_br, false_br, x;
MATCH_REPLACE_IF(node,
PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x),
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x),
PPrimitive(prim::kPrimTupleGetItem, false_br, x)),
IsVNode(x.GetNode(node)));
return nullptr;
} }

private:
AnfNodePtr c_{nullptr};
std::vector<AnfNodePtr> Xs_{};
}; };


// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => // {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}} // {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}}
class FloatEnvGetItemSwitch : public AnfVisitor {
class FloatEnvGetItemSwitch {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
is_match_ = false;
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsNode, IsNode})(node);
if (!is_match_) {
return nullptr;
}

// {prim::kPrimEnvGetItem, {...}, X4, X5}
auto cnode = node->cast<CNodePtr>();
auto sw_node = cnode->input(1)->cast<CNodePtr>();
auto x4 = cnode->input(2);
auto x5 = cnode->input(3);
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2;
MATCH_REPLACE_IF(node,
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2),
PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2)),
IsNode(x.GetNode(node)) && IsNode(x2.GetNode(node)));


is_match_ = false;
AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsNode, IsNode})(sw_node);
if (!is_match_) {
return nullptr;
}

// {prim::kPrimSwitch, X1, X2, X3}
auto x1 = sw_node->input(1);
auto x2 = sw_node->input(2);
auto x3 = sw_node->input(3);

auto fg = node->func_graph();
if (fg == nullptr) {
return nullptr;
}

auto true_node = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x2, x4, x5});
auto false_node = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x3, x4, x5});

return fg->NewCNode({NewValueNode(prim::kPrimSwitch), x1, true_node, false_node});
return nullptr;
} }

void Visit(const AnfNodePtr &) override { is_match_ = true; }

private:
bool is_match_{false};
}; };


namespace internal { namespace internal {
@@ -173,79 +93,64 @@ AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfN
} // namespace internal } // namespace internal


// {{prim::kPrimSwitch, X, G1, G2}, Xs} // {{prim::kPrimSwitch, X, G1, G2}, Xs}
class ConvertSwitchReplacement : public AnfVisitor {
class ConvertSwitchReplacement {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (!node->isa<CNode>() || node->func_graph() == nullptr) { if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr; return nullptr;
} }


Reset();
auto cnode = node->cast<CNodePtr>();
if (cnode->size() < 1) {
auto cnode_ = node->cast<CNodePtr>();
if (cnode_->size() < 1) {
return nullptr; return nullptr;
} }


// {prim::kPrimSwitch, X, G1, G2}
AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode<FuncGraph>, IsValueNode<FuncGraph>})(cnode->input(0));
if (g2_ == nullptr || g1_->output() == nullptr || g2_->output() == nullptr) {
return nullptr;
}
// for switch replace method, only graphs without graph inside can be replaced
for (auto &item : g1_->value_nodes()) {
auto value_node = item.first;
if (IsValueNode<FuncGraph>(value_node)) {
return nullptr;
auto node_ = cnode_->input(0);

PatternNode<AnfNodePtr> cond, true_br, false_br;

auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr {
auto g1_ = GetValueNode<FuncGraphPtr>(true_br.GetNode(node_));
auto g2_ = GetValueNode<FuncGraphPtr>(false_br.GetNode(node_));
auto x_ = cond.GetNode(node_);

// for switch replace method, only graphs without graph inside can be replaced
for (auto &item : g1_->value_nodes()) {
auto value_node = item.first;
if (IsValueNode<FuncGraph>(value_node)) {
return nullptr;
}
} }
}


for (auto &item : g2_->value_nodes()) {
auto value_node = item.first;
if (IsValueNode<FuncGraph>(value_node)) {
return nullptr;
for (auto &item : g2_->value_nodes()) {
auto value_node = item.first;
if (IsValueNode<FuncGraph>(value_node)) {
return nullptr;
}
} }
}


auto true_output = g1_->output()->abstract();
auto false_output = g2_->output()->abstract();
auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_);
auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_);

std::vector<AnfNodePtr> params;
auto fg = node->func_graph();
auto cloned_g1 = InlineClone(trans_g1, fg, params);
auto cloned_g2 = InlineClone(trans_g2, fg, params);
auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg);
return nnode;
}
auto true_output = g1_->output()->abstract();
auto false_output = g2_->output()->abstract();
auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_);
auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_);


void Visit(const AnfNodePtr &node) override {
if (x_ == nullptr) {
x_ = node;
return;
}
AnfVisitor::Visit(node);
}
std::vector<AnfNodePtr> params;
auto fg = node_->func_graph();
auto cloned_g1 = InlineClone(trans_g1, fg, params);
auto cloned_g2 = InlineClone(trans_g2, fg, params);
auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg);


void Visit(const ValueNodePtr &vnode) override {
auto g = GetValueNode<FuncGraphPtr>(vnode);
if (g1_ == nullptr) {
g1_ = g;
} else {
g2_ = g;
}
}
return nnode;
};


void Reset() {
x_ = nullptr;
g1_ = nullptr;
g2_ = nullptr;
}
MATCH_REPLACE_LAMBDA_IF(node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda,
IsNode(cond.GetNode(node_)) && IsValueNode<FuncGraph>(true_br.GetNode(node_)) &&
IsValueNode<FuncGraph>(false_br.GetNode(node_)));


private:
AnfNodePtr x_{nullptr};
FuncGraphPtr g1_{nullptr}, g2_{nullptr};
return nullptr;
}
}; };

} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore


Loading…
Cancel
Save