GitOrigin-RevId: 0af7b076e6
tags/v1.7.0
| @@ -0,0 +1,183 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from collections import OrderedDict, defaultdict | |||||
| from functools import partial | |||||
| from ...logger import get_logger | |||||
| from ..expr import ( | |||||
| Expr, | |||||
| is_apply_def, | |||||
| is_call_function, | |||||
| is_call_module, | |||||
| is_call_tensor_method, | |||||
| is_constant, | |||||
| ) | |||||
| from .pattern import ( | |||||
| AnyPattern, | |||||
| ApplyDefPattern, | |||||
| CallPattern, | |||||
| ConstantPattern, | |||||
| ExprPattern, | |||||
| FunctionPattern, | |||||
| ModulePattern, | |||||
| OrPattern, | |||||
| TensorMethodPattern, | |||||
| VarPattern, | |||||
| ) | |||||
| from .utils import register_obj | |||||
| logger = get_logger(__name__) | |||||
| class PatternMatcher: | |||||
| method_dict = {} | |||||
| register_visiter_func = partial(register_obj, _dict=method_dict) | |||||
| def __init__(self) -> None: | |||||
| self.matched_patterns = [] | |||||
| self.matched_exprs = OrderedDict() | |||||
| def match(self, pattern: ExprPattern, expr: Expr) -> bool: | |||||
| self.matched_exprs.clear() | |||||
| self.matched_patterns.clear() | |||||
| pattern.check_users(False) | |||||
| res = self.visit_pattern(pattern, expr) | |||||
| if res and not self._check_users(): | |||||
| self.clear_map(0) | |||||
| res = False | |||||
| self._clear_pattern_users() | |||||
| return res | |||||
| def clear_map(self, mark): | |||||
| for _ in range(len(self.matched_patterns) - mark): | |||||
| p = self.matched_patterns.pop() | |||||
| self.matched_exprs.pop(p) | |||||
| p._clear_users() | |||||
| def _clear_pattern_users(self): | |||||
| for p in self.matched_patterns: | |||||
| p._clear_users() | |||||
| def _check_users(self) -> bool: | |||||
| for pat, expr in self.matched_exprs.items(): | |||||
| if pat._check_users: | |||||
| pattern_users = pat._users | |||||
| if len(expr.outputs) != 1: | |||||
| logger.warning( | |||||
| "only support single output, and the matching " | |||||
| "result may be wrong" | |||||
| ) | |||||
| continue | |||||
| expr_users = expr.outputs[0].users | |||||
| if len(pattern_users) != len(expr_users): | |||||
| return False | |||||
| for pat, expr in zip(pattern_users, expr_users): | |||||
| if self.matched_exprs[pat] != expr: | |||||
| return False | |||||
| return True | |||||
| def visit_pattern(self, pattern: ExprPattern, expr: Expr) -> bool: | |||||
| if pattern in self.matched_exprs: | |||||
| if self.matched_exprs[pattern] is expr: | |||||
| if isinstance(pattern, (OrPattern)): | |||||
| assert self._visit_or_pattern(pattern, expr) == True | |||||
| return True | |||||
| else: | |||||
| return False | |||||
| else: | |||||
| mark = len(self.matched_patterns) | |||||
| visiter = self.method_dict.get(type(pattern)) | |||||
| matched = visiter(self, pattern, expr) | |||||
| if matched: | |||||
| self.matched_patterns.append(pattern) | |||||
| self.matched_exprs[pattern] = expr | |||||
| else: | |||||
| self.clear_map(mark) | |||||
| return matched | |||||
| @register_visiter_func(OrPattern) | |||||
| def _visit_or_pattern(self, pattern: OrPattern, expr: Expr) -> bool: | |||||
| if self.visit_pattern(pattern.left, expr): | |||||
| if pattern._users: | |||||
| pattern.left._add_users(pattern._users[-1]) | |||||
| return True | |||||
| if self.visit_pattern(pattern.right, expr): | |||||
| if pattern._users: | |||||
| pattern.right._add_users(pattern._users[-1]) | |||||
| return True | |||||
| return False | |||||
| @register_visiter_func(CallPattern) | |||||
| def _visit_call_pattern(self, pattern: CallPattern, expr: Expr) -> bool: | |||||
| mark = len(self.matched_patterns) | |||||
| match_res = self.visit_pattern(pattern.op, expr) | |||||
| if not match_res: | |||||
| self.clear_map(mark) | |||||
| return False | |||||
| inputs = expr.inputs | |||||
| if isinstance(pattern.op, ModulePattern): | |||||
| inputs = inputs[1:] | |||||
| if (pattern._match_all_args and len(pattern.args) != len(inputs)) or ( | |||||
| not pattern._match_all_args and len(pattern.args) > len(inputs) | |||||
| ): | |||||
| self.clear_map(mark) | |||||
| return False | |||||
| for i, pat in enumerate(pattern.args): | |||||
| pat._add_users(pattern) | |||||
| match_res = self.visit_pattern(pat, inputs[i].expr) | |||||
| if not match_res: | |||||
| pat._clear_users() | |||||
| self.clear_map(mark) | |||||
| return False | |||||
| return True | |||||
| @register_visiter_func(ModulePattern) | |||||
| def _visit_module_pattern(self, pattern: ModulePattern, expr: Expr) -> bool: | |||||
| if not is_call_module(expr, pattern.target): | |||||
| return False | |||||
| module = expr.inputs[0].owner | |||||
| for key, target in pattern.attrs.items(): | |||||
| value = getattr(module, key, None) | |||||
| if target != value: | |||||
| return False | |||||
| return True | |||||
| @register_visiter_func(FunctionPattern) | |||||
| def _visit_function_pattern(self, pattern: FunctionPattern, expr: Expr) -> bool: | |||||
| if not is_call_function(expr, pattern.target): | |||||
| return False | |||||
| kwargs = expr.kwargs | |||||
| for key, target in pattern.params.items(): | |||||
| value = kwargs.get(key, None) | |||||
| if target != value: | |||||
| return False | |||||
| return True | |||||
| @register_visiter_func(TensorMethodPattern) | |||||
| def _visit_tensor_method_pattern( | |||||
| self, pattern: TensorMethodPattern, expr: Expr | |||||
| ) -> bool: | |||||
| return is_call_tensor_method(expr, pattern.target) | |||||
| @register_visiter_func(ApplyDefPattern) | |||||
| def _visit_apply_pattern(self, pattern: ApplyDefPattern, expr: Expr) -> bool: | |||||
| return is_apply_def(expr, pattern.target) | |||||
| @register_visiter_func(ConstantPattern) | |||||
| def _visit_const_pattern(self, pattern: ConstantPattern, expr: Expr) -> bool: | |||||
| return is_constant(expr) | |||||
| @register_visiter_func(VarPattern) | |||||
| def _visit_var_pattern(self, pattern: VarPattern, expr: Expr) -> bool: | |||||
| return not is_constant(expr) | |||||
| @register_visiter_func(AnyPattern) | |||||
| def _visit_any_pattern(self, pattern: AnyPattern, expr: Expr) -> bool: | |||||
| return True | |||||
| @@ -0,0 +1,252 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from abc import abstractmethod | |||||
| from typing import Any, Callable, Dict, List | |||||
| from ...core._imperative_rt import OpDef | |||||
| from ...logger import get_logger | |||||
| from ...module import Module | |||||
| from ..expr import Expr | |||||
| from ..node import Node | |||||
| logger = get_logger(__name__) | |||||
| class ExprPattern: | |||||
| def __init__(self): | |||||
| self._check_users = True | |||||
| self._users = [] | |||||
| def __call__(self, *args): | |||||
| args = list(args) | |||||
| if len(args) == 1 and args[0] is None: | |||||
| args = None | |||||
| return CallPattern(self, *args) | |||||
| def __add__(self, other): | |||||
| return is_op("__add__")(self, other) | |||||
| def __iadd__(self, other): | |||||
| return is_op("__iadd__")(self, other) | |||||
| def __radd__(self, other): | |||||
| return is_op("__radd__")(self, other) | |||||
| def __sub__(self, other): | |||||
| return is_op("__sub__")(self, other) | |||||
| def __isub__(self, other): | |||||
| return is_op("__isub__")(self, other) | |||||
| def __rsub__(self, other): | |||||
| return is_op("__rsub__")(self, other) | |||||
| def __mul__(self, other): | |||||
| return is_op("__mul__")(self, other) | |||||
| def __imul__(self, other): | |||||
| return is_op("__imul__")(self, other) | |||||
| def __rmul__(self, other): | |||||
| return is_op("__rmul__")(self, other) | |||||
| def __truediv__(self, other): | |||||
| return is_op("__truediv__")(self, other) | |||||
| def __itruediv__(self, other): | |||||
| return is_op("__itruediv__")(self, other) | |||||
| def __rtruediv__(self, other): | |||||
| return is_op("__rtruediv__")(self, other) | |||||
| def __or__(self, other): | |||||
| assert isinstance(other, ExprPattern) | |||||
| return OrPattern(self, other) | |||||
| def get_output(self, index): | |||||
| raise NotImplementedError | |||||
| def check_users(self, check: bool = True): | |||||
| self._check_users = check | |||||
| return self | |||||
| def _add_users(self, pattern: "ExprPattern"): | |||||
| self._users.append(pattern) | |||||
| def _clear_users(self,): | |||||
| self._users.clear() | |||||
| def __getitem__(self, index): | |||||
| return is_op("__getitem__")(self, index) | |||||
| def has_attr(self, **attrs): | |||||
| logger.warning("has_param only support ModulePattern") | |||||
| return self | |||||
| def has_param(self, **params): | |||||
| logger.warning("has_param only support FunctionPattern") | |||||
| return self | |||||
| @abstractmethod | |||||
| def __repr__(self) -> str: | |||||
| raise NotImplementedError | |||||
| class CallPattern(ExprPattern): | |||||
| def __init__(self, op: ExprPattern, *args: List[ExprPattern]): | |||||
| super().__init__() | |||||
| self.op = op | |||||
| self.args = list(filter(lambda x: isinstance(x, ExprPattern), args)) | |||||
| self._match_all_args = True | |||||
| def __repr__(self) -> str: | |||||
| return "{}({})".format(self.op, ",".join(str(x) for x in self.args)) | |||||
| def not_all_args(self): | |||||
| self._match_all_args = False | |||||
| def check_users(self, check: bool = True): | |||||
| self._check_users = check | |||||
| self.op.check_users(check) | |||||
| return self | |||||
| def _add_users(self, pattern: "ExprPattern"): | |||||
| self._users.append(pattern) | |||||
| self.op._add_users(pattern) | |||||
| def _clear_users(self): | |||||
| self._users.clear() | |||||
| self.op._clear_users() | |||||
| class OrPattern(ExprPattern): | |||||
| def __init__(self, left: ExprPattern, right: ExprPattern): | |||||
| super().__init__() | |||||
| self.left = left | |||||
| self.right = right | |||||
| def __repr__(self) -> str: | |||||
| return "({}|{})".format(self.left, self.right) | |||||
| def check_users(self, check: bool = True): | |||||
| self._check_users = check | |||||
| self.left.check_users(check) | |||||
| self.right.check_users(check) | |||||
| return self | |||||
| def _clear_users(self): | |||||
| self._users.clear() | |||||
| self.left._clear_users() | |||||
| self.right._clear_users() | |||||
| class GetOutputPaterrn(ExprPattern): | |||||
| def __init__(self, op, index): | |||||
| super().__init__() | |||||
| self.op = op | |||||
| self.index = index | |||||
| def __repr__(self) -> str: | |||||
| return "{}[{}]".format(self.op, self.index) | |||||
| class ModulePattern(ExprPattern): | |||||
| def __init__(self, module_cls: Module) -> None: | |||||
| super().__init__() | |||||
| self.attrs = {} | |||||
| self.target = module_cls | |||||
| def has_attr(self, **attrs): | |||||
| self.attrs.update(attrs) | |||||
| return self | |||||
| def __repr__(self) -> str: | |||||
| return "{}".format(self.target.__name__) | |||||
| class FunctionPattern(ExprPattern): | |||||
| def __init__(self, func: Callable): | |||||
| super().__init__() | |||||
| self.params = {} | |||||
| self.target = func | |||||
| def has_params(self, **params): | |||||
| self.params.update(params) | |||||
| return self | |||||
| def __repr__(self) -> str: | |||||
| return "{}".format(self.target.__name__) | |||||
| class TensorMethodPattern(ExprPattern): | |||||
| def __init__(self, method: str): | |||||
| super().__init__() | |||||
| self.target = method | |||||
| def __repr__(self) -> str: | |||||
| return self.target | |||||
| class ApplyDefPattern(ExprPattern): | |||||
| def __init__(self, opdef: OpDef): | |||||
| super().__init__() | |||||
| self.target = opdef | |||||
| def __repr__(self) -> str: | |||||
| return "{}".format(self.target.__name__) | |||||
| class VarPattern(ExprPattern): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def __repr__(self) -> str: | |||||
| return "var" | |||||
| class ConstantPattern(ExprPattern): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def __repr__(self) -> str: | |||||
| return "const" | |||||
| class AnyPattern(ExprPattern): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def __repr__(self) -> str: | |||||
| return "any" | |||||
| def is_op(target): | |||||
| if isinstance(target, type): | |||||
| if issubclass(target, Module): | |||||
| return ModulePattern(target) | |||||
| if issubclass(target, OpDef): | |||||
| return ApplyDefPattern(target) | |||||
| elif callable(target): | |||||
| return FunctionPattern(target) | |||||
| elif isinstance(target, str): | |||||
| return TensorMethodPattern(target) | |||||
| else: | |||||
| raise ValueError("not support") | |||||
| def is_const(): | |||||
| return ConstantPattern().check_users(False) | |||||
| def any_node(): | |||||
| return AnyPattern() | |||||
| def is_var(): | |||||
| return VarPattern() | |||||
| @@ -0,0 +1,38 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import copy | |||||
| from typing import Any, Dict, List | |||||
| from ..expr import Expr, is_constant, is_getattr | |||||
| from ..node import Node, TensorNode | |||||
| def register_obj(objs: List[Any], _dict: Dict): | |||||
| if not isinstance(objs, List): | |||||
| objs = [objs] | |||||
| def _register(any_obj: Any): | |||||
| for obj in objs: | |||||
| _dict[obj] = any_obj | |||||
| return any_obj | |||||
| return _register | |||||
| def get_const_value(expr: Expr, fall_back: Any = None): | |||||
| value = fall_back | |||||
| if isinstance(expr, Node): | |||||
| expr = expr.expr | |||||
| if is_getattr(expr) and isinstance(expr.outputs[0], TensorNode): | |||||
| module = expr.inputs[0].owner | |||||
| assert module is not None | |||||
| value = copy.deepcopy(expr.interpret(module)[0]) | |||||
| elif is_constant(expr): | |||||
| value = copy.deepcopy(expr.interpret()[0]) | |||||
| return value | |||||