|
|
|
@@ -65,7 +65,7 @@ class PIsEqual { |
|
|
|
template <typename T = AnfNodePtr> |
|
|
|
class PatternNode : public PBase<PatternNode<T> > { |
|
|
|
public: |
|
|
|
T GetNode(const AnfNodePtr &node) const { |
|
|
|
T GetNode(const AnfNodePtr &) const { |
|
|
|
if (!captured_) { |
|
|
|
MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode."; |
|
|
|
} |
|
|
|
@@ -108,11 +108,11 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > { |
|
|
|
auto inputs = cnode->inputs(); |
|
|
|
if (inputs.size() == 3) { |
|
|
|
// Binary Prim assumes only two inputs |
|
|
|
if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) { |
|
|
|
if (!x_.TryCapture(inputs[1]) || !y_.TryCapture(inputs[2])) { |
|
|
|
// If the operation is commutative, then check with inversed operands |
|
|
|
if (is_commutative_) { |
|
|
|
Reset(); |
|
|
|
if (!x_.TryCapture_(inputs[2]) || !y_.TryCapture_(inputs[1])) { |
|
|
|
if (!x_.TryCapture(inputs[2]) || !y_.TryCapture(inputs[1])) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
@@ -208,30 +208,77 @@ class PCNode : public PBase<PCNode<TArgs...> > { |
|
|
|
AnfNodePtr GetNode(const AnfNodePtr &node) const { |
|
|
|
tuple_utils::PTupleGetNode get_node(node); |
|
|
|
tuple_utils::apply_func_tuple(&get_node, args_); |
|
|
|
return NewCNode(get_node.args_, node->func_graph()); |
|
|
|
auto prim_cnode = get_node.args_; |
|
|
|
// In case this PCNode has captured extra nodes |
|
|
|
if (extra_nodes_.size() > 0) { |
|
|
|
prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end()); |
|
|
|
} |
|
|
|
return NewCNode(prim_cnode, node->func_graph()); |
|
|
|
} |
|
|
|
|
|
|
|
bool TryCapture_(const AnfNodePtr &node) const { |
|
|
|
if (node->isa<CNode>()) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
auto inputs = cnode->inputs(); |
|
|
|
if (inputs.size() != sizeof...(TArgs)) { |
|
|
|
|
|
|
|
auto pattern_arg_len = sizeof...(TArgs); |
|
|
|
// There aren't enough inputs in Node to fill up the Pattern |
|
|
|
if (inputs.size() < pattern_arg_len) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
tuple_utils::PTupleCapture capture_func(inputs); |
|
|
|
tuple_utils::apply_func_tuple(&capture_func, args_); |
|
|
|
return capture_func.captured_; |
|
|
|
} |
|
|
|
|
|
|
|
// Pattern must exactly match the number of Node inputs. |
|
|
|
if (!has_min_extra_nodes_) { |
|
|
|
// Inputs in Node perfectly match number of tokens in Pattern. |
|
|
|
if ((inputs.size() - 1) == pattern_arg_len) { |
|
|
|
AnfNodePtrList tokens(inputs.begin() + 1, inputs.end()); |
|
|
|
tuple_utils::PTupleCapture capture_func(tokens); |
|
|
|
tuple_utils::apply_func_tuple(&capture_func, args_); |
|
|
|
return capture_func.captured_; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// Pattern may accept extra (non specified) nodes at the end of the CNode |
|
|
|
// There must be at least `min_extra_nodes` additional nodes in the inputs. |
|
|
|
if ((inputs.size() - 1) >= pattern_arg_len + min_extra_nodes_) { |
|
|
|
AnfNodePtrList tokens(inputs.begin() + 1, inputs.begin() + 1 + pattern_arg_len); |
|
|
|
tuple_utils::PTupleCapture capture_func(tokens); |
|
|
|
tuple_utils::apply_func_tuple(&capture_func, args_); |
|
|
|
// If it could capture the initial set of nodes specified in the Pattern |
|
|
|
// and there are enough extra inputs to add |
|
|
|
if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) { |
|
|
|
extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); |
|
|
|
return true; |
|
|
|
} |
|
|
|
return capture_func.captured_; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
/// This function sets the PCNode object to capture at least `min_extra_nodes_` nodes after the last one |
|
|
|
/// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or |
|
|
|
/// more nodes after the last one specified when building the PCNode. |
|
|
|
const PCNode<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const { |
|
|
|
has_min_extra_nodes_ = true; |
|
|
|
min_extra_nodes_ = min_extra_nodes; |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
void Reset() const { |
|
|
|
tuple_utils::PTupleResetCapture reset; |
|
|
|
tuple_utils::apply_func_tuple(&reset, args_); |
|
|
|
has_min_extra_nodes_ = false; |
|
|
|
extra_nodes_.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
std::tuple<typename TArgs::Internal...> args_; |
|
|
|
mutable AnfNodePtrList extra_nodes_; |
|
|
|
mutable bool has_min_extra_nodes_{false}; |
|
|
|
mutable size_t min_extra_nodes_{0}; |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename... TArgs> |
|
|
|
@@ -244,6 +291,11 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > { |
|
|
|
tuple_utils::apply_func_tuple(&get_node, args_); |
|
|
|
auto prim_cnode = get_node.args_; |
|
|
|
prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_)); |
|
|
|
|
|
|
|
// In case this PPrimitive has captured extra nodes |
|
|
|
if (extra_nodes_.size() > 0) { |
|
|
|
prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end()); |
|
|
|
} |
|
|
|
return NewCNode(prim_cnode, node->func_graph()); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -251,35 +303,66 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > { |
|
|
|
if (IsPrimitiveCNode(node, prim_)) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
auto inputs = cnode->inputs(); |
|
|
|
if ((inputs.size() - 1) != sizeof...(TArgs)) { |
|
|
|
// Number of arguments in Primitive Pattern (not including the Primitive node) |
|
|
|
auto pattern_arg_len = sizeof...(TArgs); |
|
|
|
// There aren't enough inputs in Node to fill up the Pattern |
|
|
|
if ((inputs.size() - 1) < pattern_arg_len) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtrList rest(inputs.begin() + 1, inputs.end()); |
|
|
|
tuple_utils::PTupleCapture capture_func(rest); |
|
|
|
tuple_utils::apply_func_tuple(&capture_func, args_); |
|
|
|
// Pattern must exactly match the number of Node inputs. |
|
|
|
if (!has_min_extra_nodes_) { |
|
|
|
// Inputs in Node perfectly match number of tokens in Pattern. |
|
|
|
if ((inputs.size() - 1) == pattern_arg_len) { |
|
|
|
AnfNodePtrList tokens(inputs.begin() + 1, inputs.end()); |
|
|
|
tuple_utils::PTupleCapture capture_func(tokens); |
|
|
|
tuple_utils::apply_func_tuple(&capture_func, args_); |
|
|
|
return capture_func.captured_; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
return capture_func.captured_; |
|
|
|
// Pattern may accept extra (non specified) nodes at the end of the Primitive |
|
|
|
// There must be at least `min_extra_nodes` additional nodes in the inputs. |
|
|
|
if ((inputs.size() - 1) >= pattern_arg_len + min_extra_nodes_) { |
|
|
|
AnfNodePtrList tokens(inputs.begin() + 1, inputs.begin() + 1 + pattern_arg_len); |
|
|
|
tuple_utils::PTupleCapture capture_func(tokens); |
|
|
|
tuple_utils::apply_func_tuple(&capture_func, args_); |
|
|
|
// If it could capture the initial set of nodes specified in the Pattern |
|
|
|
// and there are enough extra inputs to add |
|
|
|
if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) { |
|
|
|
extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); |
|
|
|
return true; |
|
|
|
} |
|
|
|
return capture_func.captured_; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// If set to true, TryCapture will try to capture the nodes in iversed nodes as well (only for two input case) |
|
|
|
const PPrimitive<TArgs...> &Commutative(const bool &is_commutative = true) const { |
|
|
|
is_commutative_ = is_commutative; |
|
|
|
/// This function sets the PPrimitive object to capture at least `min_extra_nodes_` nodes after the last one |
|
|
|
/// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or |
|
|
|
/// more nodes after the last one specified when building the PPrimitive. |
|
|
|
const PPrimitive<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const { |
|
|
|
has_min_extra_nodes_ = true; |
|
|
|
min_extra_nodes_ = min_extra_nodes; |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
void Reset() const { |
|
|
|
tuple_utils::PTupleResetCapture reset; |
|
|
|
tuple_utils::apply_func_tuple(&reset, args_); |
|
|
|
has_min_extra_nodes_ = false; |
|
|
|
extra_nodes_.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
const PrimitivePtr prim_; |
|
|
|
std::tuple<typename TArgs::Internal...> args_; |
|
|
|
mutable bool is_commutative_{false}; |
|
|
|
mutable AnfNodePtrList extra_nodes_; |
|
|
|
mutable bool has_min_extra_nodes_{false}; |
|
|
|
mutable size_t min_extra_nodes_{0}; |
|
|
|
}; |
|
|
|
|
|
|
|
/// |
|
|
|
|