GitOrigin-RevId: 0af7b076e6
tags/v1.7.0
| @@ -0,0 +1,183 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from collections import OrderedDict, defaultdict | |||
| from functools import partial | |||
| from ...logger import get_logger | |||
| from ..expr import ( | |||
| Expr, | |||
| is_apply_def, | |||
| is_call_function, | |||
| is_call_module, | |||
| is_call_tensor_method, | |||
| is_constant, | |||
| ) | |||
| from .pattern import ( | |||
| AnyPattern, | |||
| ApplyDefPattern, | |||
| CallPattern, | |||
| ConstantPattern, | |||
| ExprPattern, | |||
| FunctionPattern, | |||
| ModulePattern, | |||
| OrPattern, | |||
| TensorMethodPattern, | |||
| VarPattern, | |||
| ) | |||
| from .utils import register_obj | |||
| logger = get_logger(__name__) | |||
| class PatternMatcher: | |||
| method_dict = {} | |||
| register_visiter_func = partial(register_obj, _dict=method_dict) | |||
| def __init__(self) -> None: | |||
| self.matched_patterns = [] | |||
| self.matched_exprs = OrderedDict() | |||
| def match(self, pattern: ExprPattern, expr: Expr) -> bool: | |||
| self.matched_exprs.clear() | |||
| self.matched_patterns.clear() | |||
| pattern.check_users(False) | |||
| res = self.visit_pattern(pattern, expr) | |||
| if res and not self._check_users(): | |||
| self.clear_map(0) | |||
| res = False | |||
| self._clear_pattern_users() | |||
| return res | |||
| def clear_map(self, mark): | |||
| for _ in range(len(self.matched_patterns) - mark): | |||
| p = self.matched_patterns.pop() | |||
| self.matched_exprs.pop(p) | |||
| p._clear_users() | |||
| def _clear_pattern_users(self): | |||
| for p in self.matched_patterns: | |||
| p._clear_users() | |||
| def _check_users(self) -> bool: | |||
| for pat, expr in self.matched_exprs.items(): | |||
| if pat._check_users: | |||
| pattern_users = pat._users | |||
| if len(expr.outputs) != 1: | |||
| logger.warning( | |||
| "only support single output, and the matching " | |||
| "result may be wrong" | |||
| ) | |||
| continue | |||
| expr_users = expr.outputs[0].users | |||
| if len(pattern_users) != len(expr_users): | |||
| return False | |||
| for pat, expr in zip(pattern_users, expr_users): | |||
| if self.matched_exprs[pat] != expr: | |||
| return False | |||
| return True | |||
| def visit_pattern(self, pattern: ExprPattern, expr: Expr) -> bool: | |||
| if pattern in self.matched_exprs: | |||
| if self.matched_exprs[pattern] is expr: | |||
| if isinstance(pattern, (OrPattern)): | |||
| assert self._visit_or_pattern(pattern, expr) == True | |||
| return True | |||
| else: | |||
| return False | |||
| else: | |||
| mark = len(self.matched_patterns) | |||
| visiter = self.method_dict.get(type(pattern)) | |||
| matched = visiter(self, pattern, expr) | |||
| if matched: | |||
| self.matched_patterns.append(pattern) | |||
| self.matched_exprs[pattern] = expr | |||
| else: | |||
| self.clear_map(mark) | |||
| return matched | |||
| @register_visiter_func(OrPattern) | |||
| def _visit_or_pattern(self, pattern: OrPattern, expr: Expr) -> bool: | |||
| if self.visit_pattern(pattern.left, expr): | |||
| if pattern._users: | |||
| pattern.left._add_users(pattern._users[-1]) | |||
| return True | |||
| if self.visit_pattern(pattern.right, expr): | |||
| if pattern._users: | |||
| pattern.right._add_users(pattern._users[-1]) | |||
| return True | |||
| return False | |||
| @register_visiter_func(CallPattern) | |||
| def _visit_call_pattern(self, pattern: CallPattern, expr: Expr) -> bool: | |||
| mark = len(self.matched_patterns) | |||
| match_res = self.visit_pattern(pattern.op, expr) | |||
| if not match_res: | |||
| self.clear_map(mark) | |||
| return False | |||
| inputs = expr.inputs | |||
| if isinstance(pattern.op, ModulePattern): | |||
| inputs = inputs[1:] | |||
| if (pattern._match_all_args and len(pattern.args) != len(inputs)) or ( | |||
| not pattern._match_all_args and len(pattern.args) > len(inputs) | |||
| ): | |||
| self.clear_map(mark) | |||
| return False | |||
| for i, pat in enumerate(pattern.args): | |||
| pat._add_users(pattern) | |||
| match_res = self.visit_pattern(pat, inputs[i].expr) | |||
| if not match_res: | |||
| pat._clear_users() | |||
| self.clear_map(mark) | |||
| return False | |||
| return True | |||
| @register_visiter_func(ModulePattern) | |||
| def _visit_module_pattern(self, pattern: ModulePattern, expr: Expr) -> bool: | |||
| if not is_call_module(expr, pattern.target): | |||
| return False | |||
| module = expr.inputs[0].owner | |||
| for key, target in pattern.attrs.items(): | |||
| value = getattr(module, key, None) | |||
| if target != value: | |||
| return False | |||
| return True | |||
| @register_visiter_func(FunctionPattern) | |||
| def _visit_function_pattern(self, pattern: FunctionPattern, expr: Expr) -> bool: | |||
| if not is_call_function(expr, pattern.target): | |||
| return False | |||
| kwargs = expr.kwargs | |||
| for key, target in pattern.params.items(): | |||
| value = kwargs.get(key, None) | |||
| if target != value: | |||
| return False | |||
| return True | |||
| @register_visiter_func(TensorMethodPattern) | |||
| def _visit_tensor_method_pattern( | |||
| self, pattern: TensorMethodPattern, expr: Expr | |||
| ) -> bool: | |||
| return is_call_tensor_method(expr, pattern.target) | |||
| @register_visiter_func(ApplyDefPattern) | |||
| def _visit_apply_pattern(self, pattern: ApplyDefPattern, expr: Expr) -> bool: | |||
| return is_apply_def(expr, pattern.target) | |||
| @register_visiter_func(ConstantPattern) | |||
| def _visit_const_pattern(self, pattern: ConstantPattern, expr: Expr) -> bool: | |||
| return is_constant(expr) | |||
| @register_visiter_func(VarPattern) | |||
| def _visit_var_pattern(self, pattern: VarPattern, expr: Expr) -> bool: | |||
| return not is_constant(expr) | |||
| @register_visiter_func(AnyPattern) | |||
| def _visit_any_pattern(self, pattern: AnyPattern, expr: Expr) -> bool: | |||
| return True | |||
| @@ -0,0 +1,252 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import abstractmethod | |||
| from typing import Any, Callable, Dict, List | |||
| from ...core._imperative_rt import OpDef | |||
| from ...logger import get_logger | |||
| from ...module import Module | |||
| from ..expr import Expr | |||
| from ..node import Node | |||
| logger = get_logger(__name__) | |||
| class ExprPattern: | |||
| def __init__(self): | |||
| self._check_users = True | |||
| self._users = [] | |||
| def __call__(self, *args): | |||
| args = list(args) | |||
| if len(args) == 1 and args[0] is None: | |||
| args = None | |||
| return CallPattern(self, *args) | |||
| def __add__(self, other): | |||
| return is_op("__add__")(self, other) | |||
| def __iadd__(self, other): | |||
| return is_op("__iadd__")(self, other) | |||
| def __radd__(self, other): | |||
| return is_op("__radd__")(self, other) | |||
| def __sub__(self, other): | |||
| return is_op("__sub__")(self, other) | |||
| def __isub__(self, other): | |||
| return is_op("__isub__")(self, other) | |||
| def __rsub__(self, other): | |||
| return is_op("__rsub__")(self, other) | |||
| def __mul__(self, other): | |||
| return is_op("__mul__")(self, other) | |||
| def __imul__(self, other): | |||
| return is_op("__imul__")(self, other) | |||
| def __rmul__(self, other): | |||
| return is_op("__rmul__")(self, other) | |||
| def __truediv__(self, other): | |||
| return is_op("__truediv__")(self, other) | |||
| def __itruediv__(self, other): | |||
| return is_op("__itruediv__")(self, other) | |||
| def __rtruediv__(self, other): | |||
| return is_op("__rtruediv__")(self, other) | |||
| def __or__(self, other): | |||
| assert isinstance(other, ExprPattern) | |||
| return OrPattern(self, other) | |||
| def get_output(self, index): | |||
| raise NotImplementedError | |||
| def check_users(self, check: bool = True): | |||
| self._check_users = check | |||
| return self | |||
| def _add_users(self, pattern: "ExprPattern"): | |||
| self._users.append(pattern) | |||
| def _clear_users(self,): | |||
| self._users.clear() | |||
| def __getitem__(self, index): | |||
| return is_op("__getitem__")(self, index) | |||
| def has_attr(self, **attrs): | |||
| logger.warning("has_param only support ModulePattern") | |||
| return self | |||
| def has_param(self, **params): | |||
| logger.warning("has_param only support FunctionPattern") | |||
| return self | |||
| @abstractmethod | |||
| def __repr__(self) -> str: | |||
| raise NotImplementedError | |||
| class CallPattern(ExprPattern): | |||
| def __init__(self, op: ExprPattern, *args: List[ExprPattern]): | |||
| super().__init__() | |||
| self.op = op | |||
| self.args = list(filter(lambda x: isinstance(x, ExprPattern), args)) | |||
| self._match_all_args = True | |||
| def __repr__(self) -> str: | |||
| return "{}({})".format(self.op, ",".join(str(x) for x in self.args)) | |||
| def not_all_args(self): | |||
| self._match_all_args = False | |||
| def check_users(self, check: bool = True): | |||
| self._check_users = check | |||
| self.op.check_users(check) | |||
| return self | |||
| def _add_users(self, pattern: "ExprPattern"): | |||
| self._users.append(pattern) | |||
| self.op._add_users(pattern) | |||
| def _clear_users(self): | |||
| self._users.clear() | |||
| self.op._clear_users() | |||
| class OrPattern(ExprPattern): | |||
| def __init__(self, left: ExprPattern, right: ExprPattern): | |||
| super().__init__() | |||
| self.left = left | |||
| self.right = right | |||
| def __repr__(self) -> str: | |||
| return "({}|{})".format(self.left, self.right) | |||
| def check_users(self, check: bool = True): | |||
| self._check_users = check | |||
| self.left.check_users(check) | |||
| self.right.check_users(check) | |||
| return self | |||
| def _clear_users(self): | |||
| self._users.clear() | |||
| self.left._clear_users() | |||
| self.right._clear_users() | |||
| class GetOutputPaterrn(ExprPattern): | |||
| def __init__(self, op, index): | |||
| super().__init__() | |||
| self.op = op | |||
| self.index = index | |||
| def __repr__(self) -> str: | |||
| return "{}[{}]".format(self.op, self.index) | |||
| class ModulePattern(ExprPattern): | |||
| def __init__(self, module_cls: Module) -> None: | |||
| super().__init__() | |||
| self.attrs = {} | |||
| self.target = module_cls | |||
| def has_attr(self, **attrs): | |||
| self.attrs.update(attrs) | |||
| return self | |||
| def __repr__(self) -> str: | |||
| return "{}".format(self.target.__name__) | |||
| class FunctionPattern(ExprPattern): | |||
| def __init__(self, func: Callable): | |||
| super().__init__() | |||
| self.params = {} | |||
| self.target = func | |||
| def has_params(self, **params): | |||
| self.params.update(params) | |||
| return self | |||
| def __repr__(self) -> str: | |||
| return "{}".format(self.target.__name__) | |||
| class TensorMethodPattern(ExprPattern): | |||
| def __init__(self, method: str): | |||
| super().__init__() | |||
| self.target = method | |||
| def __repr__(self) -> str: | |||
| return self.target | |||
| class ApplyDefPattern(ExprPattern): | |||
| def __init__(self, opdef: OpDef): | |||
| super().__init__() | |||
| self.target = opdef | |||
| def __repr__(self) -> str: | |||
| return "{}".format(self.target.__name__) | |||
| class VarPattern(ExprPattern): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def __repr__(self) -> str: | |||
| return "var" | |||
| class ConstantPattern(ExprPattern): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def __repr__(self) -> str: | |||
| return "const" | |||
| class AnyPattern(ExprPattern): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def __repr__(self) -> str: | |||
| return "any" | |||
| def is_op(target): | |||
| if isinstance(target, type): | |||
| if issubclass(target, Module): | |||
| return ModulePattern(target) | |||
| if issubclass(target, OpDef): | |||
| return ApplyDefPattern(target) | |||
| elif callable(target): | |||
| return FunctionPattern(target) | |||
| elif isinstance(target, str): | |||
| return TensorMethodPattern(target) | |||
| else: | |||
| raise ValueError("not support") | |||
| def is_const(): | |||
| return ConstantPattern().check_users(False) | |||
| def any_node(): | |||
| return AnyPattern() | |||
| def is_var(): | |||
| return VarPattern() | |||
| @@ -0,0 +1,38 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import copy | |||
| from typing import Any, Dict, List | |||
| from ..expr import Expr, is_constant, is_getattr | |||
| from ..node import Node, TensorNode | |||
| def register_obj(objs: List[Any], _dict: Dict): | |||
| if not isinstance(objs, List): | |||
| objs = [objs] | |||
| def _register(any_obj: Any): | |||
| for obj in objs: | |||
| _dict[obj] = any_obj | |||
| return any_obj | |||
| return _register | |||
| def get_const_value(expr: Expr, fall_back: Any = None): | |||
| value = fall_back | |||
| if isinstance(expr, Node): | |||
| expr = expr.expr | |||
| if is_getattr(expr) and isinstance(expr.outputs[0], TensorNode): | |||
| module = expr.inputs[0].owner | |||
| assert module is not None | |||
| value = copy.deepcopy(expr.interpret(module)[0]) | |||
| elif is_constant(expr): | |||
| value = copy.deepcopy(expr.interpret()[0]) | |||
| return value | |||