| @@ -0,0 +1,183 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ... import functional as F | |||||
| from ... import module as M | |||||
| from ...core.ops.builtin import GetVarShape | |||||
| from ...logger import get_logger | |||||
| from ...tensor import Tensor | |||||
| from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr | |||||
| from ..node import Node, TensorNode | |||||
| from .matcher import PatternMatcher | |||||
| from .pass_base import BackwardPass, ForwardPass, register_pass | |||||
| from .pattern import is_op | |||||
| from .utils import get_const_value | |||||
| logger = get_logger(__name__) | |||||
| @register_pass("AttrToConstant") | |||||
| class AttrToConstant(BackwardPass): | |||||
| r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr.""" | |||||
| name = "AttrToConstant" | |||||
| run_once = True | |||||
| def run_transform(self, expr: Expr): | |||||
| if not (is_getattr(expr) and isinstance(expr.outputs[0], TensorNode)): | |||||
| return expr | |||||
| graph = expr.top_graph | |||||
| value = get_const_value(expr) | |||||
| orig_node = expr.outputs[0] | |||||
| name = orig_node.name | |||||
| with graph.insert_exprs(expr): | |||||
| const_node = Constant.make(value, name=name) | |||||
| graph.replace_node({orig_node: const_node}) | |||||
| graph.compile() | |||||
| name = orig_node.name | |||||
| return const_node.expr | |||||
| @register_pass("FixInputShape") | |||||
| class FixInputShape(BackwardPass): | |||||
| name = "FixInputShape" | |||||
| run_once = True | |||||
| def run_transform(self, expr: Expr): | |||||
| if not is_apply_def(expr, GetVarShape): | |||||
| return expr | |||||
| shape = Tensor(expr.inputs[0].shape, dtype="int32") | |||||
| graph = expr.top_graph | |||||
| with graph.insert_exprs(expr): | |||||
| const_shape = Constant.make(shape) | |||||
| graph.replace_node({expr.outputs[0]: const_shape}) | |||||
| graph.compile() | |||||
| const_shape.name = expr.outputs[0].name | |||||
| return const_shape.expr | |||||
| @register_pass("FlodConstant") | |||||
| class FlodConstant(ForwardPass): | |||||
| r"""Constant folding.""" | |||||
| name = "FlodConstant" | |||||
| required_pass = ["AttrToConstant"] | |||||
| run_once = False | |||||
| def run_transform(self, expr: Expr): | |||||
| if len(expr.inputs) == 0 or any(not is_constant(n.expr) for n in expr.inputs): | |||||
| return expr | |||||
| const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0] | |||||
| graph = expr.top_graph | |||||
| with graph.insert_exprs(expr): | |||||
| const_node = Constant.make(const_var) | |||||
| graph.replace_node({expr.outputs[0]: const_node}) | |||||
| graph.compile() | |||||
| const_node.name = expr.outputs[0].name | |||||
| return const_node.expr | |||||
| @register_pass("NormElemWise") | |||||
| class NormElemWise(BackwardPass): | |||||
| r"""Transform add/sub or mul/div expr to add-only or mul-only chains. | |||||
| For example, the following code | |||||
| .. code-block:: | |||||
| b = 1 - a | |||||
| c = 2 * b | |||||
| d = 1 / c | |||||
| will be changed to | |||||
| .. code-block:: | |||||
| a1 = F.neg(a) | |||||
| b = a1 + 1 | |||||
| c = b * 2 | |||||
| d = F.pow(d, -1) | |||||
| """ | |||||
| name = "NormElemWise" | |||||
| required_pass = ["FlodConstant"] | |||||
| run_once = False | |||||
| def __init__(self,): | |||||
| super().__init__() | |||||
| self.pattern = is_op(F.add) | |||||
| for op in [F.sub, F.mul, F.div]: | |||||
| self.pattern |= is_op(op) | |||||
| for op in ["__add__", "__iadd__", "__radd__"]: | |||||
| self.pattern |= is_op(op) | |||||
| for op in ["__sub__", "__isub__", "__rsub__"]: | |||||
| self.pattern |= is_op(op) | |||||
| for op in ["__mul__", "__imul__", "__rmul__"]: | |||||
| self.pattern |= is_op(op) | |||||
| for op in ["__truediv__", "__itruediv__", "__rtruediv__"]: | |||||
| self.pattern |= is_op(op) | |||||
| def run_transform(self, expr: Expr): | |||||
| matcher = PatternMatcher() | |||||
| if not matcher.match(self.pattern, expr): | |||||
| return expr | |||||
| pattern = matcher.matched_patterns[0] | |||||
| target = pattern.target | |||||
| cofee, left_node, right_node = 1, None, None | |||||
| if len(expr.inputs) == 1 and target not in ["__add__", "__mul__"]: | |||||
| left_node = expr.inputs[0] | |||||
| right_node = expr.const_val[0][-1] | |||||
| if target in ["__rsub__", "__rtruediv__"]: | |||||
| cofee = -1 | |||||
| if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]: | |||||
| cofee = -1 | |||||
| elif len(expr.inputs) == 2 and ( | |||||
| target not in ["__add__", "__mul__"] or is_constant(expr.inputs[0].expr) | |||||
| ): | |||||
| left_node, right_node = expr.inputs | |||||
| if target in ["__rsub__", "__rtruediv__"]: | |||||
| left_node, right_node = right_node, left_node | |||||
| if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]: | |||||
| left_node, right_node = right_node, left_node | |||||
| if is_constant(left_node.expr): | |||||
| left_node, right_node = right_node, left_node | |||||
| cofee = -1 | |||||
| if left_node is None: | |||||
| return expr | |||||
| if isinstance(right_node, TensorNode): | |||||
| right_node = get_const_value(right_node.expr, right_node) | |||||
| graph = expr.top_graph | |||||
| with graph.insert_exprs(): | |||||
| if target in ["__mul__", "__imul__", "__rmul__", F.mul]: | |||||
| out_node = left_node * right_node | |||||
| elif target in ["__add__", "__iadd__", "__radd__", F.add]: | |||||
| out_node = left_node + right_node | |||||
| elif target in ["__sub__", "__isub__", "__rsub__", F.sub]: | |||||
| if cofee == -1: | |||||
| left_node = F.neg(left_node) | |||||
| else: | |||||
| if isinstance(right_node, TensorNode): | |||||
| right_node = F.neg(right_node) | |||||
| else: | |||||
| right_node = -1 * right_node | |||||
| out_node = left_node + right_node | |||||
| elif target in ["__truediv__", "__itruediv__", "__rtruediv__", F.div]: | |||||
| if cofee == -1: | |||||
| left_node = F.pow(left_node, -1) | |||||
| else: | |||||
| if isinstance(right_node, TensorNode): | |||||
| right_node = F.pow(right_node, -1) | |||||
| else: | |||||
| right_node = 1 / right_node | |||||
| out_node = left_node * right_node | |||||
| graph.replace_node({expr.outputs[0]: out_node}) | |||||
| graph.compile() | |||||
| return out_node.expr | |||||
| @@ -0,0 +1,298 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from collections import OrderedDict, defaultdict | |||||
| from copy import deepcopy | |||||
| from typing import Any, Dict, List, Set | |||||
| from ... import functional as F | |||||
| from ... import module as M | |||||
| from ...core.ops.builtin import GetVarShape | |||||
| from ...logger import get_logger | |||||
| from ...tensor import Parameter, Tensor | |||||
| from ..expr import ( | |||||
| Expr, | |||||
| is_apply_def, | |||||
| is_call_function, | |||||
| is_call_module, | |||||
| is_call_tensor_method, | |||||
| is_constant, | |||||
| is_getattr, | |||||
| ) | |||||
| from ..traced_module import InternalGraph | |||||
| from ..utils import assign_attr, get_subattr | |||||
| from .matcher import PatternMatcher | |||||
| from .pass_base import BackwardPass, register_pass | |||||
| from .pattern import is_const, is_op, is_var | |||||
| from .utils import get_const_value | |||||
| logger = get_logger(__name__) | |||||
| @register_pass("BackwardFoldScale") | |||||
| class BackwardFoldScale(BackwardPass): | |||||
| r"""Backward fold const scaling into weights of conv2d. | |||||
| For example, the following code | |||||
| .. code-block:: | |||||
| x = conv(x, w, b) | |||||
| x = relu(x) | |||||
| x1 = x + 3 | |||||
| x2 = x + 4 | |||||
| y = (x1 + x2) * 3 | |||||
| will be changed to | |||||
| .. code-block:: | |||||
| x = conv(x, w * 3, b * 3) | |||||
| x = relu(x) | |||||
| x1 = x + 9 | |||||
| x2 = x + 12 | |||||
| y = x1 + x2 | |||||
| """ | |||||
| name = "BackwardFoldScale" | |||||
| required_pass = ["AttrToConstant", "NormElemWise"] | |||||
| run_once = True | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| # todo : supoort more axis | |||||
| self.scale_message = OrderedDict() | |||||
| self.used_names = defaultdict(int) | |||||
| def run_transform(self, expr: Expr) -> Expr: | |||||
| if expr not in self.scale_message: | |||||
| return expr | |||||
| var = is_var().check_users(False) | |||||
| mul_const_pattern = var * is_const() | var * "*" | is_op(F.neg) | |||||
| add_const_pattern = var + is_const() | var + "*" | |||||
| conv_pattern = is_op(F.conv2d) | is_op(M.Conv2d) | |||||
| pattern = conv_pattern | add_const_pattern | mul_const_pattern | |||||
| macther = PatternMatcher() | |||||
| if not macther.match(pattern, expr): | |||||
| return expr | |||||
| macther_exprs = macther.matched_exprs | |||||
| if conv_pattern in macther_exprs: | |||||
| return self.fold_conv_mul(expr) | |||||
| if mul_const_pattern in macther_exprs: | |||||
| return self.fold_mul(expr) | |||||
| if add_const_pattern in macther_exprs: | |||||
| return self.fold_add_mul(expr) | |||||
| return expr | |||||
| def fold_add_mul(self, expr: Expr): | |||||
| if self.scale_message[expr] is None: | |||||
| return expr | |||||
| scale = self.scale_message[expr] | |||||
| if len(expr.inputs) == 1: | |||||
| const = expr.const_val[0][-1] | |||||
| else: | |||||
| const = get_const_value(expr.inputs[1]) | |||||
| const = const * scale | |||||
| inp_node = expr.inputs[0] | |||||
| graph = expr.top_graph | |||||
| with graph.insert_exprs(): | |||||
| add_node = inp_node + const | |||||
| graph.replace_node({expr.outputs[0]: add_node}) | |||||
| graph.compile() | |||||
| add_node.name = expr.outputs[0].name | |||||
| return add_node.expr | |||||
| def fold_mul(self, expr: Expr): | |||||
| if self.scale_message[expr] is None: | |||||
| return expr | |||||
| graph = expr.top_graph | |||||
| graph.replace_node({expr.outputs[0]: expr.inputs[0]}) | |||||
| graph.compile() | |||||
| return expr | |||||
| def fold_conv_mul(self, expr: Expr): | |||||
| graph = expr.top_graph | |||||
| scale = self.scale_message[expr] | |||||
| if scale is None: | |||||
| return expr | |||||
| if is_call_function(expr, F.conv2d): | |||||
| named_args = expr.named_args | |||||
| weight = get_const_value(named_args["weight"], named_args["weight"]) * scale | |||||
| bias = get_const_value(named_args["bias"], named_args["bias"]) * scale | |||||
| named_args["weight"] = weight | |||||
| named_args["bias"] = bias | |||||
| with graph.insert_exprs(): | |||||
| out_node = F.conv2d(**named_args) | |||||
| graph.replace_node({expr.outputs[0]: out_node}) | |||||
| graph.compile() | |||||
| out_node.name = expr.outputs[0].name | |||||
| return out_node.expr | |||||
| else: | |||||
| mnode = expr.inputs[0] | |||||
| attr_name = expr.inputs[0].expr.name | |||||
| graph = expr.top_graph | |||||
| if len(mnode.users) > 1: | |||||
| self.used_names[mnode.qualname] += 1 | |||||
| attr_name = "{}_{}".format(attr_name, self.used_names[mnode.qualname]) | |||||
| logger.warning( | |||||
| "{} is used {} times and its name will be reset to {}.{}".format( | |||||
| mnode.qualname, len(mnode.users), graph.qualname, attr_name | |||||
| ) | |||||
| ) | |||||
| conv_module = mnode.owner | |||||
| if len(mnode.users) > 1: | |||||
| conv_module = deepcopy(conv_module) | |||||
| conv_module._name = None | |||||
| conv_module.weight = Parameter(conv_module.weight * scale) | |||||
| if conv_module.bias is not None: | |||||
| conv_module.bias = Parameter(conv_module.bias * scale) | |||||
| if len(mnode.users) > 1: | |||||
| self_node = mnode.expr.inputs[0] | |||||
| assign_attr(conv_module, self_node.owner, attr_name) | |||||
| with graph.insert_exprs(mnode.expr): | |||||
| new_conv_node = get_subattr(self_node, attr_name) | |||||
| expr.replace_inputs({mnode: new_conv_node}) | |||||
| return expr | |||||
| def reset_expr_message_to_none( | |||||
| self, expr: Expr, scale_message: Dict[Expr, Any], skip_exprs: Set[Expr], | |||||
| ): | |||||
| if expr in skip_exprs: | |||||
| return | |||||
| scale_message[expr] = None | |||||
| if is_call_function(expr, F.conv2d) or is_call_module(expr, M.Conv2d): | |||||
| return | |||||
| for out_node in expr.outputs: | |||||
| for user in out_node.users: | |||||
| if user in scale_message: | |||||
| self.reset_expr_message_to_none(user, scale_message, skip_exprs) | |||||
| def before_visit_graph(self, graph: InternalGraph): | |||||
| var = is_var().check_users(False) | |||||
| mul_const_pattern = var * is_const() | var * "*" | is_op(F.neg) | |||||
| relu_pattern = ( | |||||
| is_op(F.relu) | is_op(M.ReLU) | is_op(F.leaky_relu) | is_op(M.LeakyReLU) | |||||
| ) | |||||
| # The param of conv must be const, not support dynamic conv | |||||
| conv_pattern = ( | |||||
| is_op(F.conv2d)(var, is_const(), is_const()) | |||||
| | is_op(F.conv2d)(var, is_const()) | |||||
| | is_op(M.Conv2d) | |||||
| ) | |||||
| pattern = mul_const_pattern | relu_pattern | conv_pattern | |||||
| for op in [ | |||||
| "__add__", | |||||
| F.reshape, | |||||
| "reshape", | |||||
| F.transpose, | |||||
| "tranpose", | |||||
| F.min, | |||||
| "min", | |||||
| F.max, | |||||
| "max", | |||||
| F.max_pool2d, | |||||
| M.MaxPool2d, | |||||
| F.avg_pool2d, | |||||
| M.AvgPool2d, | |||||
| F.adaptive_avg_pool2d, | |||||
| M.AdaptiveAvgPool2d, | |||||
| F.adaptive_max_pool2d, | |||||
| M.AdaptiveMaxPool2d, | |||||
| F.expand_dims, | |||||
| F.concat, | |||||
| "__getitem__", | |||||
| ]: | |||||
| pattern |= is_op(op) | |||||
| matcher = PatternMatcher() | |||||
| scale_message = OrderedDict() | |||||
| mem_conv_scale_message = OrderedDict() | |||||
| skip_exprs = self.init_skip_exprs(graph) | |||||
| for expr in reversed(graph._exprs): | |||||
| if expr in skip_exprs: | |||||
| continue | |||||
| if len(expr.outputs) > 1 or not matcher.match(pattern, expr): | |||||
| self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||||
| if is_call_function(expr, F.conv2d): | |||||
| for user in expr.outputs[0].users: | |||||
| self.reset_expr_message_to_none(user, scale_message, skip_exprs) | |||||
| continue | |||||
| matched_exprs = matcher.matched_exprs | |||||
| const = None | |||||
| if mul_const_pattern in matched_exprs: | |||||
| if is_call_function(expr, F.neg): | |||||
| const = -1 | |||||
| elif len(expr.inputs) == 1: | |||||
| const = expr.const_val[0][-1] | |||||
| else: | |||||
| const = get_const_value(expr.inputs[1]) | |||||
| if isinstance(const, Tensor) and const._tuple_shape not in [(1,), tuple()]: | |||||
| self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||||
| continue | |||||
| users_const = [ | |||||
| scale_message[e] for e in expr.outputs[0].users if e not in skip_exprs | |||||
| ] | |||||
| if len(users_const) == 0: | |||||
| scale_message[expr] = const | |||||
| continue | |||||
| if any(c is None or c != users_const[0] for c in users_const): | |||||
| self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||||
| scale_message[expr] = const | |||||
| continue | |||||
| const = 1 if const is None else const | |||||
| const = const * users_const[0] | |||||
| if relu_pattern in matched_exprs and const < 0: | |||||
| self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||||
| continue | |||||
| if conv_pattern in matched_exprs: | |||||
| self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||||
| mem_conv_scale_message[expr] = const | |||||
| continue | |||||
| scale_message[expr] = const | |||||
| self.scale_message.update(scale_message) | |||||
| self.scale_message.update(mem_conv_scale_message) | |||||
| def init_skip_exprs(self, graph: InternalGraph): | |||||
| skip_exprs = set() | |||||
| for expr in graph._exprs: | |||||
| if is_apply_def(expr, GetVarShape): | |||||
| skip_exprs.add(expr) | |||||
| elif is_call_tensor_method(expr, "__getitem__") and expr in skip_exprs: | |||||
| skip_exprs.add(expr) | |||||
| elif is_getattr(expr): | |||||
| skip_exprs.add(expr) | |||||
| elif is_constant(expr): | |||||
| skip_exprs.add(expr) | |||||
| elif all(n.expr in skip_exprs for n in expr.inputs): | |||||
| skip_exprs.add(expr) | |||||
| return skip_exprs | |||||
| @@ -0,0 +1,248 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import operator | |||||
| from collections import defaultdict | |||||
| from typing import Any, Callable, List | |||||
| from ... import functional as F | |||||
| from ... import module as M | |||||
| from ...logger import get_logger | |||||
| from ...tensor import Parameter, Tensor | |||||
| from ...utils.bn_fusion import fold_weight_bias | |||||
| from ..expr import Expr, is_call_function | |||||
| from ..utils import assign_attr, get_subattr | |||||
| from .matcher import PatternMatcher | |||||
| from .pass_base import BackwardPass, register_pass | |||||
| from .pattern import ExprPattern, any_node, is_const, is_op, is_var | |||||
| from .utils import get_const_value, register_obj | |||||
| logger = get_logger(__name__) | |||||
| @register_pass("FuseAddMul") | |||||
| class FuseAddMul(BackwardPass): | |||||
| """Fold adjacent const add or mul binary operations. | |||||
| For example, the following code | |||||
| .. code-block:: | |||||
| x = x + 1 | |||||
| x = 2 + x | |||||
| x = x * 4 | |||||
| x = x * 0.25 | |||||
| will be changed to | |||||
| .. code-block:: | |||||
| x = x + 3 | |||||
| """ | |||||
| name = "FuseAddMul" | |||||
| required_pass = ["NormElemWise"] | |||||
| run_once = False | |||||
| def __init__(self,): | |||||
| super().__init__() | |||||
| def _make_pattern(op_0, op_1) -> ExprPattern: | |||||
| x = is_var().check_users(False) | |||||
| if op_0 not in [operator.add, operator.mul]: | |||||
| op_0 = is_op(op_0) | |||||
| if op_1 not in [operator.add, operator.mul]: | |||||
| op_1 = is_op(op_1) | |||||
| pattern = op_0(x, is_const()) | op_0(x, "*") | |||||
| pattern = op_1(pattern, is_const()) | op_1(pattern, "*") | |||||
| return pattern | |||||
| self.pattern_dict = {} | |||||
| for op, func in zip([operator.add, F.pow], [self.fold_add, self.fold_pow],): | |||||
| self.pattern_dict[_make_pattern(op, op)] = func | |||||
| for op_0 in [F.neg, operator.mul]: | |||||
| for op_1 in [F.neg, operator.mul]: | |||||
| self.pattern_dict[_make_pattern(op_0, op_1)] = self.fold_mul | |||||
| def run_transform(self, expr: Expr): | |||||
| matcher = PatternMatcher() | |||||
| for pattern, func in self.pattern_dict.items(): | |||||
| res = matcher.match(pattern, expr) | |||||
| if res: | |||||
| break | |||||
| if not res: | |||||
| return expr | |||||
| return func(expr) | |||||
| def _fold_helper(self, expr: Expr, op_c: Callable, op_t: Callable): | |||||
| const_0 = self.get_const_value(expr) | |||||
| # todo: support more shape | |||||
| if isinstance(const_0, Tensor) and const_0._tuple_shape not in [(1,), tuple()]: | |||||
| return expr | |||||
| const_1 = self.get_const_value(expr.inputs[0].expr) | |||||
| if isinstance(const_1, Tensor) and const_1._tuple_shape not in [(1,), tuple()]: | |||||
| return expr | |||||
| inp_node = expr.inputs[0].expr.inputs[0] | |||||
| const = op_c(const_0, const_1) | |||||
| graph = expr.top_graph | |||||
| if (const == 1 and op_t in [operator.pow, operator.mul]) or ( | |||||
| const == 0 and op_t in [operator.add] | |||||
| ): | |||||
| graph.replace_node({expr.outputs[0]: inp_node}) | |||||
| graph.compile() | |||||
| return expr | |||||
| with expr.top_graph.insert_exprs(): | |||||
| out_node = op_t(inp_node, const) | |||||
| graph.replace_node({expr.outputs[0]: out_node}) | |||||
| graph.compile() | |||||
| return out_node.expr | |||||
| def fold_add(self, expr: Expr): | |||||
| return self._fold_helper(expr, operator.add, operator.add) | |||||
| def fold_mul(self, expr): | |||||
| return self._fold_helper(expr, operator.mul, operator.mul) | |||||
| def fold_pow(self, expr): | |||||
| return self._fold_helper(expr, operator.mul, F.pow) | |||||
| def get_const_value(self, expr: Expr): | |||||
| if is_call_function(expr, F.neg): | |||||
| return -1 | |||||
| if len(expr.inputs) == 2: | |||||
| value = get_const_value(expr.inputs[1].expr, None) | |||||
| assert value is not None, " " | |||||
| return value | |||||
| value = expr.const_val[0][-1] | |||||
| return value | |||||
| @register_pass("FuseConvBn") | |||||
| class FuseConvBn(BackwardPass): | |||||
| r"""Fuse BN layers into conv2d.""" | |||||
| name = "FuseConvBn" | |||||
| required_pass = ["AttrToConstant"] | |||||
| run_once = True | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.used_name = defaultdict(int) | |||||
| def run_transform(self, expr: Expr): | |||||
| conv_pat_0 = is_op(M.Conv2d) | |||||
| conv_pat_1 = is_op(F.conv2d) | |||||
| bn_pat_0 = is_op(M.BatchNorm2d)(conv_pat_0 | conv_pat_1) | |||||
| bn_pat_1 = is_op(F.batch_norm) | |||||
| # inp, running_mean, running_var, weight, bias | |||||
| bn_inps = ( | |||||
| conv_pat_0 | conv_pat_1, | |||||
| is_const(), | |||||
| is_const(), | |||||
| is_const(), | |||||
| is_const(), | |||||
| ) | |||||
| bn_pat = ( | |||||
| (bn_pat_1(*bn_inps[:3])) | |||||
| | (bn_pat_1(*bn_inps[:4])) | |||||
| | (bn_pat_1(*bn_inps)) | |||||
| | bn_pat_0 | |||||
| ) | |||||
| matcher = PatternMatcher() | |||||
| if not matcher.match(bn_pat, expr): | |||||
| return expr | |||||
| matched_exprs = matcher.matched_exprs | |||||
| if conv_pat_0 in matched_exprs: | |||||
| return self.fold_convm_bn(matched_exprs[conv_pat_0], matched_exprs[bn_pat]) | |||||
| else: | |||||
| return self.fold_convf_bn(matched_exprs[conv_pat_1], matched_exprs[bn_pat]) | |||||
| def fold_convm_bn(self, conv: Expr, bn: Expr): | |||||
| mnode, inp_node = conv.inputs[:2] | |||||
| self_node = mnode.expr.inputs[0] | |||||
| attr_name = conv.inputs[0].expr.name | |||||
| graph = conv.top_graph | |||||
| if len(mnode.users) > 1: | |||||
| self.used_name[mnode.qualname] += 1 | |||||
| attr_name = "{}_{}".format(attr_name, self.used_name[mnode.qualname]) | |||||
| logger.warning( | |||||
| "{} is used {} times and its name will be reset to {}.{}".format( | |||||
| mnode.qualname, len(mnode.users), graph.qualname, attr_name | |||||
| ) | |||||
| ) | |||||
| conv_module = mnode.owner | |||||
| weight, bias = conv_module.weight, conv_module.bias | |||||
| mean, var, gamma, beta, eps = self.get_bn_params(bn) | |||||
| weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps) | |||||
| new_conv = M.Conv2d( | |||||
| in_channels=conv_module.in_channels, | |||||
| out_channels=conv_module.out_channels, | |||||
| kernel_size=conv_module.kernel_size, | |||||
| stride=conv_module.stride, | |||||
| padding=conv_module.padding, | |||||
| dilation=conv_module.dilation, | |||||
| groups=conv_module.groups, | |||||
| bias=conv_module.bias is not None, | |||||
| conv_mode=conv_module.conv_mode, | |||||
| compute_mode=conv_module.compute_mode, | |||||
| name=conv_module.name, | |||||
| ) | |||||
| new_conv.weight = Parameter(weight) | |||||
| new_conv.bias = Parameter(bias) | |||||
| new_conv.training = conv_module.training | |||||
| assign_attr(new_conv, self_node.owner, attr_name) | |||||
| with graph.insert_exprs(mnode.expr): | |||||
| out_node = get_subattr(self_node, attr_name)(inp_node) | |||||
| graph.replace_node({bn.outputs[0]: out_node}) | |||||
| graph.compile() | |||||
| out_node.name = conv.outputs[0].name | |||||
| return out_node.expr | |||||
| def fold_convf_bn(self, conv: Expr, bn: Expr): | |||||
| named_args = conv.named_args | |||||
| weight = get_const_value(named_args["weight"], named_args["weight"]) | |||||
| bias = get_const_value(named_args["bias"], named_args["bias"]) | |||||
| mean, var, gamma, beta, eps = self.get_bn_params(bn) | |||||
| weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps) | |||||
| named_args["weight"] = weight | |||||
| named_args["bias"] = bias | |||||
| graph = conv.top_graph | |||||
| with graph.insert_exprs(): | |||||
| out_node = F.conv2d(**named_args) | |||||
| graph.replace_node({bn.outputs[0]: out_node}) | |||||
| graph.compile() | |||||
| out_node.name = conv.outputs[0].name | |||||
| return out_node.expr | |||||
| def get_bn_params(self, bn: Expr): | |||||
| if is_call_function(bn): | |||||
| named_args = bn.named_args | |||||
| mean = get_const_value( | |||||
| named_args["running_mean"], named_args["running_mean"] | |||||
| ) | |||||
| var = get_const_value(named_args["running_var"], named_args["running_var"]) | |||||
| gamma = get_const_value(named_args["weight"], named_args["weight"]) | |||||
| beta = get_const_value(named_args["bias"], named_args["bias"]) | |||||
| eps = named_args["eps"] | |||||
| return mean, var, gamma, beta, eps | |||||
| else: | |||||
| bn_module = bn.inputs[0].owner | |||||
| mean = bn_module.running_mean | |||||
| var = bn_module.running_var | |||||
| gamma = bn_module.weight | |||||
| beta = bn_module.bias | |||||
| eps = bn_module.eps | |||||
| return mean, var, gamma, beta, eps | |||||
| @@ -0,0 +1,190 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import copy | |||||
| from abc import abstractmethod | |||||
| from collections import OrderedDict, namedtuple | |||||
| from functools import partial | |||||
| from re import T | |||||
| from typing import Any, Callable, Dict, Iterable, List, Union | |||||
| from ...logger import get_logger | |||||
| from ..expr import Expr | |||||
| from ..traced_module import InternalGraph, TracedModule | |||||
| from .utils import register_obj | |||||
| logger = get_logger(__name__) | |||||
| class PassContext: | |||||
| def __init__( | |||||
| self, disabled_pass: Iterable[str] = None, pass_config: Dict[str, Any] = None | |||||
| ): | |||||
| self._disabled_pass = set() | |||||
| self._config = pass_config | |||||
| self._handle = None | |||||
| if disabled_pass: | |||||
| self.add_diabled_pass(disabled_pass) | |||||
| def add_diabled_pass(self, passes: Iterable[str]): | |||||
| if isinstance(passes, str): | |||||
| passes = [passes] | |||||
| for pas in passes: | |||||
| self._disabled_pass.add(pas) | |||||
| def pass_enabled(self, pas: Union["BasePass", str]): | |||||
| pass_name = pas.name if isinstance(pas, BasePass) else pas | |||||
| return pass_name not in self._disabled_pass | |||||
| _default_context = PassContext() | |||||
| def get_default_pass_context(): | |||||
| return _default_context | |||||
| _pass_dict = OrderedDict() | |||||
| register_pass = partial(register_obj, _dict=_pass_dict) | |||||
| def get_registered_pass(pass_name: str): | |||||
| pas = _pass_dict.get(pass_name, None) | |||||
| assert ( | |||||
| pas is not None | |||||
| ), "{} is not found, please call `register_pass` to register it".format(pass_name) | |||||
| return pas | |||||
| class BasePass: | |||||
| run_once = True # bool | |||||
| required_pass = [] # Iterable[str] | |||||
| name = "" # str | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def __call__( | |||||
| self, mod: TracedModule, pass_ctx: PassContext = get_default_pass_context() | |||||
| ) -> TracedModule: | |||||
| assert isinstance(pass_ctx, PassContext) | |||||
| return self.apply_optimization(mod, pass_ctx) | |||||
| def apply_optimization( | |||||
| self, mod: TracedModule, pass_ctx: PassContext | |||||
| ) -> TracedModule: | |||||
| new_mod = mod | |||||
| for pass_name in self.required_pass + [self.name]: | |||||
| if not pass_ctx.pass_enabled(pass_name): | |||||
| logger.warning( | |||||
| "Since {} is disabled, {} will skipped".format(pass_name, self.name) | |||||
| ) | |||||
| return mod | |||||
| for pass_name in self.required_pass: | |||||
| pass_func = get_registered_pass(pass_name)() | |||||
| new_mod = pass_func(new_mod, pass_ctx) | |||||
| iter_num = 1 | |||||
| graph_changed = self.visit_graph(new_mod.graph) | |||||
| while not self.run_once and graph_changed: | |||||
| graph_changed = self.visit_graph(new_mod.graph) | |||||
| iter_num += 1 | |||||
| if iter_num == 100: | |||||
| break | |||||
| assert iter_num < 100, "{} was run 100 times, plase check for pass conflict." | |||||
| return new_mod | |||||
| @abstractmethod | |||||
| def visit_graph(self, graph: InternalGraph): | |||||
| raise NotImplementedError | |||||
| def before_visit_graph(self, graph: InternalGraph): | |||||
| pass | |||||
| def run_transform(self, expr: Expr) -> Expr: | |||||
| return expr | |||||
| def __repr__(self) -> str: | |||||
| return self.name | |||||
| class ForwardPass(BasePass): | |||||
| def visit_graph(self, graph: InternalGraph): | |||||
| class Item: | |||||
| def __init__(self, expr: Expr, child_expanded: bool = False): | |||||
| self.expr = expr | |||||
| self.child_expanded = child_expanded | |||||
| self.before_visit_graph(graph) | |||||
| graph_changed = False | |||||
| queue = [Item(n.expr) for n in graph.outputs] | |||||
| visited_expr, visited_graph = set(), set() | |||||
| while queue: | |||||
| item = queue[-1] | |||||
| if item.expr in visited_expr: | |||||
| queue.pop() | |||||
| elif item.child_expanded: | |||||
| if item.expr not in graph._exprs: | |||||
| queue.pop() | |||||
| continue | |||||
| new_expr = self.run_transform(item.expr) | |||||
| if new_expr is not item.expr: | |||||
| graph_changed = True | |||||
| assert new_expr not in visited_expr | |||||
| queue.append(Item(new_expr)) | |||||
| continue | |||||
| if ( | |||||
| hasattr(item.expr, "graph") | |||||
| and item.expr.graph is not None | |||||
| and item.expr.graph not in visited_graph | |||||
| ): | |||||
| graph_changed |= self.visit_graph(item.expr.graph) | |||||
| visited_graph.add(item.expr.graph) | |||||
| visited_expr.add(item.expr) | |||||
| else: | |||||
| item.child_expanded = True | |||||
| for i in item.expr.inputs: | |||||
| expr = i.expr | |||||
| if expr not in queue and expr not in visited_expr: | |||||
| queue.append(Item(expr)) | |||||
| return graph_changed | |||||
| class BackwardPass(BasePass): | |||||
| def visit_graph(self, graph: InternalGraph): | |||||
| self.before_visit_graph(graph) | |||||
| graph_changed = False | |||||
| queue = [n.expr for n in graph.outputs] | |||||
| visited_expr, visited_graph = set(), set() | |||||
| while queue: | |||||
| expr = queue.pop() | |||||
| if expr not in graph._exprs: | |||||
| continue | |||||
| new_expr = self.run_transform(expr) | |||||
| if new_expr is not expr: | |||||
| graph_changed = True | |||||
| queue.append(new_expr) | |||||
| continue | |||||
| else: | |||||
| visited_expr.add(expr) | |||||
| if ( | |||||
| hasattr(expr, "graph") | |||||
| and expr.graph is not None | |||||
| and expr.graph not in visited_graph | |||||
| ): | |||||
| graph_changed |= self.visit_graph(expr.graph) | |||||
| visited_graph.add(expr.graph) | |||||
| for i in expr.inputs: | |||||
| expr = i.expr | |||||
| if expr not in queue and expr not in visited_expr: | |||||
| queue.append(expr) | |||||
| return graph_changed | |||||
| @@ -13,7 +13,7 @@ import inspect | |||||
| import re | import re | ||||
| import weakref | import weakref | ||||
| from importlib import import_module | from importlib import import_module | ||||
| from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union | |||||
| from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union | |||||
| from ..core._imperative_rt import OpDef | from ..core._imperative_rt import OpDef | ||||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | from ..core._imperative_rt.core2 import Tensor as RawTensor | ||||
| @@ -50,20 +50,30 @@ def get_suffix_name(prefix: str, name: str): | |||||
| return matchd.group(1) | return matchd.group(1) | ||||
| def is_call_module(expr): | |||||
| def is_call_module(expr, module_cls: Module = None): | |||||
| return ( | return ( | ||||
| isinstance(expr, CallMethod) | isinstance(expr, CallMethod) | ||||
| and isinstance(expr.inputs[0], ModuleNode) | and isinstance(expr.inputs[0], ModuleNode) | ||||
| and expr.method == "__call__" | and expr.method == "__call__" | ||||
| ) | |||||
| ) and (module_cls is None or isinstance(expr.inputs[0].owner, module_cls)) | |||||
| def is_call_tensor_method(expr): | |||||
| return isinstance(expr, CallMethod) and not is_call_module(expr) | |||||
| def is_call_tensor_method(expr, method: Iterable[str] = None): | |||||
| if method and isinstance(method, str): | |||||
| method = (method,) | |||||
| return ( | |||||
| isinstance(expr, CallMethod) | |||||
| and not is_call_module(expr) | |||||
| and (method is None or any(expr.method == f for f in method)) | |||||
| ) | |||||
| def is_call_function(expr): | |||||
| return isinstance(expr, CallFunction) | |||||
| def is_call_function(expr, func: Iterable[Callable] = None): | |||||
| if func and not isinstance(func, Iterable): | |||||
| func = (func,) | |||||
| return isinstance(expr, CallFunction) and ( | |||||
| func is None or any(expr.func == f for f in func) | |||||
| ) | |||||
| def is_constant(expr): | def is_constant(expr): | ||||
| @@ -74,8 +84,8 @@ def is_getattr(expr): | |||||
| return isinstance(expr, GetAttr) | return isinstance(expr, GetAttr) | ||||
| def is_apply_def(expr): | |||||
| return isinstance(expr, Apply) | |||||
| def is_apply_def(expr, opdef=None): | |||||
| return isinstance(expr, Apply) and (opdef is None or isinstance(expr.opdef, opdef)) | |||||
| def is_input(expr): | def is_input(expr): | ||||
| @@ -78,6 +78,7 @@ class Node: | |||||
| "The name(%s) is already in use. Please try a different one again." | "The name(%s) is already in use. Please try a different one again." | ||||
| % (new_name) | % (new_name) | ||||
| ) | ) | ||||
| graph._namespace.unassociate_name_with_obj(self) | |||||
| self._name = graph._namespace.create_unique_name(new_name, self) | self._name = graph._namespace.create_unique_name(new_name, self) | ||||
| @property | @property | ||||
| @@ -14,6 +14,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Uni | |||||
| from .. import get_logger | from .. import get_logger | ||||
| from ..module import Module | from ..module import Module | ||||
| from ..tensor import Parameter, Tensor | |||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||
| @@ -301,3 +302,26 @@ class _ModuleDict(Module, MutableMapping): | |||||
| def forward(self): | def forward(self): | ||||
| raise RuntimeError("ModuleList is not callable") | raise RuntimeError("ModuleList is not callable") | ||||
| def assign_attr(obj: Union[Module, Tensor], module: Module, target: str): | |||||
| *prefix, name = target.split(".") | |||||
| for item in prefix: | |||||
| module = getattr(module, item) | |||||
| if not isinstance(module, Module): | |||||
| raise AttributeError("`{}` is not an Module".format(item)) | |||||
| setattr(module, name, obj) | |||||
| def get_subattr(module: Module, target: str): | |||||
| # todo : remove this import | |||||
| from .node import ModuleNode | |||||
| if target == "": | |||||
| return module | |||||
| *prefix, name = target.split(".") | |||||
| for item in prefix: | |||||
| module = getattr(module, item) | |||||
| if not isinstance(module, (Module, ModuleNode)): | |||||
| raise AttributeError("`{}` is not an Module".format(item)) | |||||
| return getattr(module, name) | |||||
| @@ -0,0 +1,86 @@ | |||||
| from copy import deepcopy | |||||
| from ..functional import ones, sqrt, zeros | |||||
| from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU | |||||
| from ..tensor import Parameter | |||||
| _MAP_TO_FUSED_MODULE = { | |||||
| (Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d, | |||||
| (Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d, | |||||
| (Conv2d, BatchNorm2d, False): Conv2d, | |||||
| (Conv2d, BatchNorm2d, True): ConvBn2d, | |||||
| (Conv2d, ReLU): ConvRelu2d, | |||||
| } | |||||
| def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5): | |||||
| # get fold bn conv param | |||||
| kernel_shape = weight.shape | |||||
| if len(kernel_shape) == 5: | |||||
| groups, num_features = kernel_shape[0], kernel_shape[1] | |||||
| else: | |||||
| groups, num_features = 1, kernel_shape[0] | |||||
| if gamma is None: | |||||
| gamma = ones((num_features), dtype="float32") | |||||
| gamma = gamma.reshape(1, -1, 1, 1) | |||||
| if beta is None: | |||||
| beta = zeros((num_features), dtype="float32") | |||||
| beta = beta.reshape(1, -1, 1, 1) | |||||
| if bn_mean is None: | |||||
| bn_mean = zeros((1, num_features, 1, 1), dtype="float32") | |||||
| if bn_var is None: | |||||
| bn_var = ones((1, num_features, 1, 1), dtype="float32") | |||||
| if bias is None: | |||||
| bias = zeros((1, num_features, 1, 1), dtype="float32") | |||||
| bn_istd = 1.0 / sqrt(bn_var + eps) | |||||
| scale_factor = gamma * bn_istd | |||||
| if groups == 1: | |||||
| w_fold = weight * scale_factor.reshape(-1, 1, 1, 1) | |||||
| else: | |||||
| w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1) | |||||
| b_fold = beta + gamma * (bias - bn_mean) * bn_istd | |||||
| return w_fold, b_fold | |||||
| def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||||
| module_key = tuple([type(m) for m in [conv, bn, relu] if m]) | |||||
| if bn: | |||||
| assert ( | |||||
| conv.training == bn.training | |||||
| ), "Conv and BN both must be in the same mode (train or eval)." | |||||
| assert ( | |||||
| bn.num_features == conv.out_channels | |||||
| ), "Output channel of Conv2d must match num_features of BatchNorm2d" | |||||
| module_key = module_key + (conv.training,) | |||||
| module = _MAP_TO_FUSED_MODULE[module_key]( | |||||
| in_channels=conv.in_channels, | |||||
| out_channels=conv.out_channels, | |||||
| kernel_size=conv.kernel_size, | |||||
| stride=conv.stride, | |||||
| padding=conv.padding, | |||||
| dilation=conv.dilation, | |||||
| groups=conv.groups, | |||||
| bias=conv.bias is not None, | |||||
| conv_mode=conv.conv_mode, | |||||
| compute_mode=conv.compute_mode, | |||||
| name=conv.name, | |||||
| ) | |||||
| new_conv = module if bn is None or not conv.training else module.conv | |||||
| weight, bias = conv.weight, conv.bias | |||||
| if not conv.training and bn is not None: | |||||
| weight, bias = fold_weight_bias( | |||||
| weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps, | |||||
| ) | |||||
| new_conv.weight = Parameter(weight) | |||||
| if bias is not None: | |||||
| new_conv.bias = Parameter(bias) | |||||
| if bn is not None and conv.training: | |||||
| module.bn = deepcopy(bn) | |||||
| new_conv.training = conv.training | |||||
| return module | |||||
| @@ -13,20 +13,7 @@ import megengine.quantization as Q | |||||
| from megengine import Tensor | from megengine import Tensor | ||||
| from megengine.module.qat.module import QATModule | from megengine.module.qat.module import QATModule | ||||
| from megengine.traced_module import TracedModule, trace_module | from megengine.traced_module import TracedModule, trace_module | ||||
| def get_subattr(self: M.Module, name: str): | |||||
| if name == "": | |||||
| return self | |||||
| module_path, _, name = name.rpartition(".") | |||||
| if module_path == "": | |||||
| return getattr(self, name) | |||||
| module_names = module_path.split(".") | |||||
| for item in module_names: | |||||
| self = getattr(self, item) | |||||
| if not isinstance(self, M.Module): | |||||
| raise AttributeError("`{}` is not an Module".format(item)) | |||||
| return getattr(self, name) | |||||
| from megengine.traced_module.utils import get_subattr | |||||
| class MyConvBnRelu2d(M.ConvBnRelu2d): | class MyConvBnRelu2d(M.ConvBnRelu2d): | ||||