| @@ -0,0 +1,183 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ... import functional as F | |||
| from ... import module as M | |||
| from ...core.ops.builtin import GetVarShape | |||
| from ...logger import get_logger | |||
| from ...tensor import Tensor | |||
| from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr | |||
| from ..node import Node, TensorNode | |||
| from .matcher import PatternMatcher | |||
| from .pass_base import BackwardPass, ForwardPass, register_pass | |||
| from .pattern import is_op | |||
| from .utils import get_const_value | |||
| logger = get_logger(__name__) | |||
| @register_pass("AttrToConstant") | |||
| class AttrToConstant(BackwardPass): | |||
| r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr.""" | |||
| name = "AttrToConstant" | |||
| run_once = True | |||
| def run_transform(self, expr: Expr): | |||
| if not (is_getattr(expr) and isinstance(expr.outputs[0], TensorNode)): | |||
| return expr | |||
| graph = expr.top_graph | |||
| value = get_const_value(expr) | |||
| orig_node = expr.outputs[0] | |||
| name = orig_node.name | |||
| with graph.insert_exprs(expr): | |||
| const_node = Constant.make(value, name=name) | |||
| graph.replace_node({orig_node: const_node}) | |||
| graph.compile() | |||
| name = orig_node.name | |||
| return const_node.expr | |||
| @register_pass("FixInputShape") | |||
| class FixInputShape(BackwardPass): | |||
| name = "FixInputShape" | |||
| run_once = True | |||
| def run_transform(self, expr: Expr): | |||
| if not is_apply_def(expr, GetVarShape): | |||
| return expr | |||
| shape = Tensor(expr.inputs[0].shape, dtype="int32") | |||
| graph = expr.top_graph | |||
| with graph.insert_exprs(expr): | |||
| const_shape = Constant.make(shape) | |||
| graph.replace_node({expr.outputs[0]: const_shape}) | |||
| graph.compile() | |||
| const_shape.name = expr.outputs[0].name | |||
| return const_shape.expr | |||
| @register_pass("FlodConstant") | |||
| class FlodConstant(ForwardPass): | |||
| r"""Constant folding.""" | |||
| name = "FlodConstant" | |||
| required_pass = ["AttrToConstant"] | |||
| run_once = False | |||
| def run_transform(self, expr: Expr): | |||
| if len(expr.inputs) == 0 or any(not is_constant(n.expr) for n in expr.inputs): | |||
| return expr | |||
| const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0] | |||
| graph = expr.top_graph | |||
| with graph.insert_exprs(expr): | |||
| const_node = Constant.make(const_var) | |||
| graph.replace_node({expr.outputs[0]: const_node}) | |||
| graph.compile() | |||
| const_node.name = expr.outputs[0].name | |||
| return const_node.expr | |||
| @register_pass("NormElemWise") | |||
| class NormElemWise(BackwardPass): | |||
| r"""Transform add/sub or mul/div expr to add-only or mul-only chains. | |||
| For example, the following code | |||
| .. code-block:: | |||
| b = 1 - a | |||
| c = 2 * b | |||
| d = 1 / c | |||
| will be changed to | |||
| .. code-block:: | |||
| a1 = F.neg(a) | |||
| b = a1 + 1 | |||
| c = b * 2 | |||
| d = F.pow(d, -1) | |||
| """ | |||
| name = "NormElemWise" | |||
| required_pass = ["FlodConstant"] | |||
| run_once = False | |||
| def __init__(self,): | |||
| super().__init__() | |||
| self.pattern = is_op(F.add) | |||
| for op in [F.sub, F.mul, F.div]: | |||
| self.pattern |= is_op(op) | |||
| for op in ["__add__", "__iadd__", "__radd__"]: | |||
| self.pattern |= is_op(op) | |||
| for op in ["__sub__", "__isub__", "__rsub__"]: | |||
| self.pattern |= is_op(op) | |||
| for op in ["__mul__", "__imul__", "__rmul__"]: | |||
| self.pattern |= is_op(op) | |||
| for op in ["__truediv__", "__itruediv__", "__rtruediv__"]: | |||
| self.pattern |= is_op(op) | |||
| def run_transform(self, expr: Expr): | |||
| matcher = PatternMatcher() | |||
| if not matcher.match(self.pattern, expr): | |||
| return expr | |||
| pattern = matcher.matched_patterns[0] | |||
| target = pattern.target | |||
| cofee, left_node, right_node = 1, None, None | |||
| if len(expr.inputs) == 1 and target not in ["__add__", "__mul__"]: | |||
| left_node = expr.inputs[0] | |||
| right_node = expr.const_val[0][-1] | |||
| if target in ["__rsub__", "__rtruediv__"]: | |||
| cofee = -1 | |||
| if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]: | |||
| cofee = -1 | |||
| elif len(expr.inputs) == 2 and ( | |||
| target not in ["__add__", "__mul__"] or is_constant(expr.inputs[0].expr) | |||
| ): | |||
| left_node, right_node = expr.inputs | |||
| if target in ["__rsub__", "__rtruediv__"]: | |||
| left_node, right_node = right_node, left_node | |||
| if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]: | |||
| left_node, right_node = right_node, left_node | |||
| if is_constant(left_node.expr): | |||
| left_node, right_node = right_node, left_node | |||
| cofee = -1 | |||
| if left_node is None: | |||
| return expr | |||
| if isinstance(right_node, TensorNode): | |||
| right_node = get_const_value(right_node.expr, right_node) | |||
| graph = expr.top_graph | |||
| with graph.insert_exprs(): | |||
| if target in ["__mul__", "__imul__", "__rmul__", F.mul]: | |||
| out_node = left_node * right_node | |||
| elif target in ["__add__", "__iadd__", "__radd__", F.add]: | |||
| out_node = left_node + right_node | |||
| elif target in ["__sub__", "__isub__", "__rsub__", F.sub]: | |||
| if cofee == -1: | |||
| left_node = F.neg(left_node) | |||
| else: | |||
| if isinstance(right_node, TensorNode): | |||
| right_node = F.neg(right_node) | |||
| else: | |||
| right_node = -1 * right_node | |||
| out_node = left_node + right_node | |||
| elif target in ["__truediv__", "__itruediv__", "__rtruediv__", F.div]: | |||
| if cofee == -1: | |||
| left_node = F.pow(left_node, -1) | |||
| else: | |||
| if isinstance(right_node, TensorNode): | |||
| right_node = F.pow(right_node, -1) | |||
| else: | |||
| right_node = 1 / right_node | |||
| out_node = left_node * right_node | |||
| graph.replace_node({expr.outputs[0]: out_node}) | |||
| graph.compile() | |||
| return out_node.expr | |||
| @@ -0,0 +1,298 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from collections import OrderedDict, defaultdict | |||
| from copy import deepcopy | |||
| from typing import Any, Dict, List, Set | |||
| from ... import functional as F | |||
| from ... import module as M | |||
| from ...core.ops.builtin import GetVarShape | |||
| from ...logger import get_logger | |||
| from ...tensor import Parameter, Tensor | |||
| from ..expr import ( | |||
| Expr, | |||
| is_apply_def, | |||
| is_call_function, | |||
| is_call_module, | |||
| is_call_tensor_method, | |||
| is_constant, | |||
| is_getattr, | |||
| ) | |||
| from ..traced_module import InternalGraph | |||
| from ..utils import assign_attr, get_subattr | |||
| from .matcher import PatternMatcher | |||
| from .pass_base import BackwardPass, register_pass | |||
| from .pattern import is_const, is_op, is_var | |||
| from .utils import get_const_value | |||
| logger = get_logger(__name__) | |||
| @register_pass("BackwardFoldScale") | |||
| class BackwardFoldScale(BackwardPass): | |||
| r"""Backward fold const scaling into weights of conv2d. | |||
| For example, the following code | |||
| .. code-block:: | |||
| x = conv(x, w, b) | |||
| x = relu(x) | |||
| x1 = x + 3 | |||
| x2 = x + 4 | |||
| y = (x1 + x2) * 3 | |||
| will be changed to | |||
| .. code-block:: | |||
| x = conv(x, w * 3, b * 3) | |||
| x = relu(x) | |||
| x1 = x + 9 | |||
| x2 = x + 12 | |||
| y = x1 + x2 | |||
| """ | |||
| name = "BackwardFoldScale" | |||
| required_pass = ["AttrToConstant", "NormElemWise"] | |||
| run_once = True | |||
| def __init__(self): | |||
| super().__init__() | |||
| # todo : supoort more axis | |||
| self.scale_message = OrderedDict() | |||
| self.used_names = defaultdict(int) | |||
| def run_transform(self, expr: Expr) -> Expr: | |||
| if expr not in self.scale_message: | |||
| return expr | |||
| var = is_var().check_users(False) | |||
| mul_const_pattern = var * is_const() | var * "*" | is_op(F.neg) | |||
| add_const_pattern = var + is_const() | var + "*" | |||
| conv_pattern = is_op(F.conv2d) | is_op(M.Conv2d) | |||
| pattern = conv_pattern | add_const_pattern | mul_const_pattern | |||
| macther = PatternMatcher() | |||
| if not macther.match(pattern, expr): | |||
| return expr | |||
| macther_exprs = macther.matched_exprs | |||
| if conv_pattern in macther_exprs: | |||
| return self.fold_conv_mul(expr) | |||
| if mul_const_pattern in macther_exprs: | |||
| return self.fold_mul(expr) | |||
| if add_const_pattern in macther_exprs: | |||
| return self.fold_add_mul(expr) | |||
| return expr | |||
| def fold_add_mul(self, expr: Expr): | |||
| if self.scale_message[expr] is None: | |||
| return expr | |||
| scale = self.scale_message[expr] | |||
| if len(expr.inputs) == 1: | |||
| const = expr.const_val[0][-1] | |||
| else: | |||
| const = get_const_value(expr.inputs[1]) | |||
| const = const * scale | |||
| inp_node = expr.inputs[0] | |||
| graph = expr.top_graph | |||
| with graph.insert_exprs(): | |||
| add_node = inp_node + const | |||
| graph.replace_node({expr.outputs[0]: add_node}) | |||
| graph.compile() | |||
| add_node.name = expr.outputs[0].name | |||
| return add_node.expr | |||
| def fold_mul(self, expr: Expr): | |||
| if self.scale_message[expr] is None: | |||
| return expr | |||
| graph = expr.top_graph | |||
| graph.replace_node({expr.outputs[0]: expr.inputs[0]}) | |||
| graph.compile() | |||
| return expr | |||
| def fold_conv_mul(self, expr: Expr): | |||
| graph = expr.top_graph | |||
| scale = self.scale_message[expr] | |||
| if scale is None: | |||
| return expr | |||
| if is_call_function(expr, F.conv2d): | |||
| named_args = expr.named_args | |||
| weight = get_const_value(named_args["weight"], named_args["weight"]) * scale | |||
| bias = get_const_value(named_args["bias"], named_args["bias"]) * scale | |||
| named_args["weight"] = weight | |||
| named_args["bias"] = bias | |||
| with graph.insert_exprs(): | |||
| out_node = F.conv2d(**named_args) | |||
| graph.replace_node({expr.outputs[0]: out_node}) | |||
| graph.compile() | |||
| out_node.name = expr.outputs[0].name | |||
| return out_node.expr | |||
| else: | |||
| mnode = expr.inputs[0] | |||
| attr_name = expr.inputs[0].expr.name | |||
| graph = expr.top_graph | |||
| if len(mnode.users) > 1: | |||
| self.used_names[mnode.qualname] += 1 | |||
| attr_name = "{}_{}".format(attr_name, self.used_names[mnode.qualname]) | |||
| logger.warning( | |||
| "{} is used {} times and its name will be reset to {}.{}".format( | |||
| mnode.qualname, len(mnode.users), graph.qualname, attr_name | |||
| ) | |||
| ) | |||
| conv_module = mnode.owner | |||
| if len(mnode.users) > 1: | |||
| conv_module = deepcopy(conv_module) | |||
| conv_module._name = None | |||
| conv_module.weight = Parameter(conv_module.weight * scale) | |||
| if conv_module.bias is not None: | |||
| conv_module.bias = Parameter(conv_module.bias * scale) | |||
| if len(mnode.users) > 1: | |||
| self_node = mnode.expr.inputs[0] | |||
| assign_attr(conv_module, self_node.owner, attr_name) | |||
| with graph.insert_exprs(mnode.expr): | |||
| new_conv_node = get_subattr(self_node, attr_name) | |||
| expr.replace_inputs({mnode: new_conv_node}) | |||
| return expr | |||
| def reset_expr_message_to_none( | |||
| self, expr: Expr, scale_message: Dict[Expr, Any], skip_exprs: Set[Expr], | |||
| ): | |||
| if expr in skip_exprs: | |||
| return | |||
| scale_message[expr] = None | |||
| if is_call_function(expr, F.conv2d) or is_call_module(expr, M.Conv2d): | |||
| return | |||
| for out_node in expr.outputs: | |||
| for user in out_node.users: | |||
| if user in scale_message: | |||
| self.reset_expr_message_to_none(user, scale_message, skip_exprs) | |||
| def before_visit_graph(self, graph: InternalGraph): | |||
| var = is_var().check_users(False) | |||
| mul_const_pattern = var * is_const() | var * "*" | is_op(F.neg) | |||
| relu_pattern = ( | |||
| is_op(F.relu) | is_op(M.ReLU) | is_op(F.leaky_relu) | is_op(M.LeakyReLU) | |||
| ) | |||
| # The param of conv must be const, not support dynamic conv | |||
| conv_pattern = ( | |||
| is_op(F.conv2d)(var, is_const(), is_const()) | |||
| | is_op(F.conv2d)(var, is_const()) | |||
| | is_op(M.Conv2d) | |||
| ) | |||
| pattern = mul_const_pattern | relu_pattern | conv_pattern | |||
| for op in [ | |||
| "__add__", | |||
| F.reshape, | |||
| "reshape", | |||
| F.transpose, | |||
| "tranpose", | |||
| F.min, | |||
| "min", | |||
| F.max, | |||
| "max", | |||
| F.max_pool2d, | |||
| M.MaxPool2d, | |||
| F.avg_pool2d, | |||
| M.AvgPool2d, | |||
| F.adaptive_avg_pool2d, | |||
| M.AdaptiveAvgPool2d, | |||
| F.adaptive_max_pool2d, | |||
| M.AdaptiveMaxPool2d, | |||
| F.expand_dims, | |||
| F.concat, | |||
| "__getitem__", | |||
| ]: | |||
| pattern |= is_op(op) | |||
| matcher = PatternMatcher() | |||
| scale_message = OrderedDict() | |||
| mem_conv_scale_message = OrderedDict() | |||
| skip_exprs = self.init_skip_exprs(graph) | |||
| for expr in reversed(graph._exprs): | |||
| if expr in skip_exprs: | |||
| continue | |||
| if len(expr.outputs) > 1 or not matcher.match(pattern, expr): | |||
| self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||
| if is_call_function(expr, F.conv2d): | |||
| for user in expr.outputs[0].users: | |||
| self.reset_expr_message_to_none(user, scale_message, skip_exprs) | |||
| continue | |||
| matched_exprs = matcher.matched_exprs | |||
| const = None | |||
| if mul_const_pattern in matched_exprs: | |||
| if is_call_function(expr, F.neg): | |||
| const = -1 | |||
| elif len(expr.inputs) == 1: | |||
| const = expr.const_val[0][-1] | |||
| else: | |||
| const = get_const_value(expr.inputs[1]) | |||
| if isinstance(const, Tensor) and const._tuple_shape not in [(1,), tuple()]: | |||
| self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||
| continue | |||
| users_const = [ | |||
| scale_message[e] for e in expr.outputs[0].users if e not in skip_exprs | |||
| ] | |||
| if len(users_const) == 0: | |||
| scale_message[expr] = const | |||
| continue | |||
| if any(c is None or c != users_const[0] for c in users_const): | |||
| self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||
| scale_message[expr] = const | |||
| continue | |||
| const = 1 if const is None else const | |||
| const = const * users_const[0] | |||
| if relu_pattern in matched_exprs and const < 0: | |||
| self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||
| continue | |||
| if conv_pattern in matched_exprs: | |||
| self.reset_expr_message_to_none(expr, scale_message, skip_exprs) | |||
| mem_conv_scale_message[expr] = const | |||
| continue | |||
| scale_message[expr] = const | |||
| self.scale_message.update(scale_message) | |||
| self.scale_message.update(mem_conv_scale_message) | |||
| def init_skip_exprs(self, graph: InternalGraph): | |||
| skip_exprs = set() | |||
| for expr in graph._exprs: | |||
| if is_apply_def(expr, GetVarShape): | |||
| skip_exprs.add(expr) | |||
| elif is_call_tensor_method(expr, "__getitem__") and expr in skip_exprs: | |||
| skip_exprs.add(expr) | |||
| elif is_getattr(expr): | |||
| skip_exprs.add(expr) | |||
| elif is_constant(expr): | |||
| skip_exprs.add(expr) | |||
| elif all(n.expr in skip_exprs for n in expr.inputs): | |||
| skip_exprs.add(expr) | |||
| return skip_exprs | |||
| @@ -0,0 +1,248 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import operator | |||
| from collections import defaultdict | |||
| from typing import Any, Callable, List | |||
| from ... import functional as F | |||
| from ... import module as M | |||
| from ...logger import get_logger | |||
| from ...tensor import Parameter, Tensor | |||
| from ...utils.bn_fusion import fold_weight_bias | |||
| from ..expr import Expr, is_call_function | |||
| from ..utils import assign_attr, get_subattr | |||
| from .matcher import PatternMatcher | |||
| from .pass_base import BackwardPass, register_pass | |||
| from .pattern import ExprPattern, any_node, is_const, is_op, is_var | |||
| from .utils import get_const_value, register_obj | |||
| logger = get_logger(__name__) | |||
| @register_pass("FuseAddMul") | |||
| class FuseAddMul(BackwardPass): | |||
| """Fold adjacent const add or mul binary operations. | |||
| For example, the following code | |||
| .. code-block:: | |||
| x = x + 1 | |||
| x = 2 + x | |||
| x = x * 4 | |||
| x = x * 0.25 | |||
| will be changed to | |||
| .. code-block:: | |||
| x = x + 3 | |||
| """ | |||
| name = "FuseAddMul" | |||
| required_pass = ["NormElemWise"] | |||
| run_once = False | |||
| def __init__(self,): | |||
| super().__init__() | |||
| def _make_pattern(op_0, op_1) -> ExprPattern: | |||
| x = is_var().check_users(False) | |||
| if op_0 not in [operator.add, operator.mul]: | |||
| op_0 = is_op(op_0) | |||
| if op_1 not in [operator.add, operator.mul]: | |||
| op_1 = is_op(op_1) | |||
| pattern = op_0(x, is_const()) | op_0(x, "*") | |||
| pattern = op_1(pattern, is_const()) | op_1(pattern, "*") | |||
| return pattern | |||
| self.pattern_dict = {} | |||
| for op, func in zip([operator.add, F.pow], [self.fold_add, self.fold_pow],): | |||
| self.pattern_dict[_make_pattern(op, op)] = func | |||
| for op_0 in [F.neg, operator.mul]: | |||
| for op_1 in [F.neg, operator.mul]: | |||
| self.pattern_dict[_make_pattern(op_0, op_1)] = self.fold_mul | |||
| def run_transform(self, expr: Expr): | |||
| matcher = PatternMatcher() | |||
| for pattern, func in self.pattern_dict.items(): | |||
| res = matcher.match(pattern, expr) | |||
| if res: | |||
| break | |||
| if not res: | |||
| return expr | |||
| return func(expr) | |||
| def _fold_helper(self, expr: Expr, op_c: Callable, op_t: Callable): | |||
| const_0 = self.get_const_value(expr) | |||
| # todo: support more shape | |||
| if isinstance(const_0, Tensor) and const_0._tuple_shape not in [(1,), tuple()]: | |||
| return expr | |||
| const_1 = self.get_const_value(expr.inputs[0].expr) | |||
| if isinstance(const_1, Tensor) and const_1._tuple_shape not in [(1,), tuple()]: | |||
| return expr | |||
| inp_node = expr.inputs[0].expr.inputs[0] | |||
| const = op_c(const_0, const_1) | |||
| graph = expr.top_graph | |||
| if (const == 1 and op_t in [operator.pow, operator.mul]) or ( | |||
| const == 0 and op_t in [operator.add] | |||
| ): | |||
| graph.replace_node({expr.outputs[0]: inp_node}) | |||
| graph.compile() | |||
| return expr | |||
| with expr.top_graph.insert_exprs(): | |||
| out_node = op_t(inp_node, const) | |||
| graph.replace_node({expr.outputs[0]: out_node}) | |||
| graph.compile() | |||
| return out_node.expr | |||
| def fold_add(self, expr: Expr): | |||
| return self._fold_helper(expr, operator.add, operator.add) | |||
| def fold_mul(self, expr): | |||
| return self._fold_helper(expr, operator.mul, operator.mul) | |||
| def fold_pow(self, expr): | |||
| return self._fold_helper(expr, operator.mul, F.pow) | |||
| def get_const_value(self, expr: Expr): | |||
| if is_call_function(expr, F.neg): | |||
| return -1 | |||
| if len(expr.inputs) == 2: | |||
| value = get_const_value(expr.inputs[1].expr, None) | |||
| assert value is not None, " " | |||
| return value | |||
| value = expr.const_val[0][-1] | |||
| return value | |||
| @register_pass("FuseConvBn") | |||
| class FuseConvBn(BackwardPass): | |||
| r"""Fuse BN layers into conv2d.""" | |||
| name = "FuseConvBn" | |||
| required_pass = ["AttrToConstant"] | |||
| run_once = True | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.used_name = defaultdict(int) | |||
| def run_transform(self, expr: Expr): | |||
| conv_pat_0 = is_op(M.Conv2d) | |||
| conv_pat_1 = is_op(F.conv2d) | |||
| bn_pat_0 = is_op(M.BatchNorm2d)(conv_pat_0 | conv_pat_1) | |||
| bn_pat_1 = is_op(F.batch_norm) | |||
| # inp, running_mean, running_var, weight, bias | |||
| bn_inps = ( | |||
| conv_pat_0 | conv_pat_1, | |||
| is_const(), | |||
| is_const(), | |||
| is_const(), | |||
| is_const(), | |||
| ) | |||
| bn_pat = ( | |||
| (bn_pat_1(*bn_inps[:3])) | |||
| | (bn_pat_1(*bn_inps[:4])) | |||
| | (bn_pat_1(*bn_inps)) | |||
| | bn_pat_0 | |||
| ) | |||
| matcher = PatternMatcher() | |||
| if not matcher.match(bn_pat, expr): | |||
| return expr | |||
| matched_exprs = matcher.matched_exprs | |||
| if conv_pat_0 in matched_exprs: | |||
| return self.fold_convm_bn(matched_exprs[conv_pat_0], matched_exprs[bn_pat]) | |||
| else: | |||
| return self.fold_convf_bn(matched_exprs[conv_pat_1], matched_exprs[bn_pat]) | |||
| def fold_convm_bn(self, conv: Expr, bn: Expr): | |||
| mnode, inp_node = conv.inputs[:2] | |||
| self_node = mnode.expr.inputs[0] | |||
| attr_name = conv.inputs[0].expr.name | |||
| graph = conv.top_graph | |||
| if len(mnode.users) > 1: | |||
| self.used_name[mnode.qualname] += 1 | |||
| attr_name = "{}_{}".format(attr_name, self.used_name[mnode.qualname]) | |||
| logger.warning( | |||
| "{} is used {} times and its name will be reset to {}.{}".format( | |||
| mnode.qualname, len(mnode.users), graph.qualname, attr_name | |||
| ) | |||
| ) | |||
| conv_module = mnode.owner | |||
| weight, bias = conv_module.weight, conv_module.bias | |||
| mean, var, gamma, beta, eps = self.get_bn_params(bn) | |||
| weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps) | |||
| new_conv = M.Conv2d( | |||
| in_channels=conv_module.in_channels, | |||
| out_channels=conv_module.out_channels, | |||
| kernel_size=conv_module.kernel_size, | |||
| stride=conv_module.stride, | |||
| padding=conv_module.padding, | |||
| dilation=conv_module.dilation, | |||
| groups=conv_module.groups, | |||
| bias=conv_module.bias is not None, | |||
| conv_mode=conv_module.conv_mode, | |||
| compute_mode=conv_module.compute_mode, | |||
| name=conv_module.name, | |||
| ) | |||
| new_conv.weight = Parameter(weight) | |||
| new_conv.bias = Parameter(bias) | |||
| new_conv.training = conv_module.training | |||
| assign_attr(new_conv, self_node.owner, attr_name) | |||
| with graph.insert_exprs(mnode.expr): | |||
| out_node = get_subattr(self_node, attr_name)(inp_node) | |||
| graph.replace_node({bn.outputs[0]: out_node}) | |||
| graph.compile() | |||
| out_node.name = conv.outputs[0].name | |||
| return out_node.expr | |||
| def fold_convf_bn(self, conv: Expr, bn: Expr): | |||
| named_args = conv.named_args | |||
| weight = get_const_value(named_args["weight"], named_args["weight"]) | |||
| bias = get_const_value(named_args["bias"], named_args["bias"]) | |||
| mean, var, gamma, beta, eps = self.get_bn_params(bn) | |||
| weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps) | |||
| named_args["weight"] = weight | |||
| named_args["bias"] = bias | |||
| graph = conv.top_graph | |||
| with graph.insert_exprs(): | |||
| out_node = F.conv2d(**named_args) | |||
| graph.replace_node({bn.outputs[0]: out_node}) | |||
| graph.compile() | |||
| out_node.name = conv.outputs[0].name | |||
| return out_node.expr | |||
| def get_bn_params(self, bn: Expr): | |||
| if is_call_function(bn): | |||
| named_args = bn.named_args | |||
| mean = get_const_value( | |||
| named_args["running_mean"], named_args["running_mean"] | |||
| ) | |||
| var = get_const_value(named_args["running_var"], named_args["running_var"]) | |||
| gamma = get_const_value(named_args["weight"], named_args["weight"]) | |||
| beta = get_const_value(named_args["bias"], named_args["bias"]) | |||
| eps = named_args["eps"] | |||
| return mean, var, gamma, beta, eps | |||
| else: | |||
| bn_module = bn.inputs[0].owner | |||
| mean = bn_module.running_mean | |||
| var = bn_module.running_var | |||
| gamma = bn_module.weight | |||
| beta = bn_module.bias | |||
| eps = bn_module.eps | |||
| return mean, var, gamma, beta, eps | |||
| @@ -0,0 +1,190 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import copy | |||
| from abc import abstractmethod | |||
| from collections import OrderedDict, namedtuple | |||
| from functools import partial | |||
| from re import T | |||
| from typing import Any, Callable, Dict, Iterable, List, Union | |||
| from ...logger import get_logger | |||
| from ..expr import Expr | |||
| from ..traced_module import InternalGraph, TracedModule | |||
| from .utils import register_obj | |||
| logger = get_logger(__name__) | |||
| class PassContext: | |||
| def __init__( | |||
| self, disabled_pass: Iterable[str] = None, pass_config: Dict[str, Any] = None | |||
| ): | |||
| self._disabled_pass = set() | |||
| self._config = pass_config | |||
| self._handle = None | |||
| if disabled_pass: | |||
| self.add_diabled_pass(disabled_pass) | |||
| def add_diabled_pass(self, passes: Iterable[str]): | |||
| if isinstance(passes, str): | |||
| passes = [passes] | |||
| for pas in passes: | |||
| self._disabled_pass.add(pas) | |||
| def pass_enabled(self, pas: Union["BasePass", str]): | |||
| pass_name = pas.name if isinstance(pas, BasePass) else pas | |||
| return pass_name not in self._disabled_pass | |||
| _default_context = PassContext() | |||
| def get_default_pass_context(): | |||
| return _default_context | |||
| _pass_dict = OrderedDict() | |||
| register_pass = partial(register_obj, _dict=_pass_dict) | |||
| def get_registered_pass(pass_name: str): | |||
| pas = _pass_dict.get(pass_name, None) | |||
| assert ( | |||
| pas is not None | |||
| ), "{} is not found, please call `register_pass` to register it".format(pass_name) | |||
| return pas | |||
| class BasePass: | |||
| run_once = True # bool | |||
| required_pass = [] # Iterable[str] | |||
| name = "" # str | |||
| def __init__(self): | |||
| super().__init__() | |||
| def __call__( | |||
| self, mod: TracedModule, pass_ctx: PassContext = get_default_pass_context() | |||
| ) -> TracedModule: | |||
| assert isinstance(pass_ctx, PassContext) | |||
| return self.apply_optimization(mod, pass_ctx) | |||
| def apply_optimization( | |||
| self, mod: TracedModule, pass_ctx: PassContext | |||
| ) -> TracedModule: | |||
| new_mod = mod | |||
| for pass_name in self.required_pass + [self.name]: | |||
| if not pass_ctx.pass_enabled(pass_name): | |||
| logger.warning( | |||
| "Since {} is disabled, {} will skipped".format(pass_name, self.name) | |||
| ) | |||
| return mod | |||
| for pass_name in self.required_pass: | |||
| pass_func = get_registered_pass(pass_name)() | |||
| new_mod = pass_func(new_mod, pass_ctx) | |||
| iter_num = 1 | |||
| graph_changed = self.visit_graph(new_mod.graph) | |||
| while not self.run_once and graph_changed: | |||
| graph_changed = self.visit_graph(new_mod.graph) | |||
| iter_num += 1 | |||
| if iter_num == 100: | |||
| break | |||
| assert iter_num < 100, "{} was run 100 times, plase check for pass conflict." | |||
| return new_mod | |||
| @abstractmethod | |||
| def visit_graph(self, graph: InternalGraph): | |||
| raise NotImplementedError | |||
| def before_visit_graph(self, graph: InternalGraph): | |||
| pass | |||
| def run_transform(self, expr: Expr) -> Expr: | |||
| return expr | |||
| def __repr__(self) -> str: | |||
| return self.name | |||
| class ForwardPass(BasePass): | |||
| def visit_graph(self, graph: InternalGraph): | |||
| class Item: | |||
| def __init__(self, expr: Expr, child_expanded: bool = False): | |||
| self.expr = expr | |||
| self.child_expanded = child_expanded | |||
| self.before_visit_graph(graph) | |||
| graph_changed = False | |||
| queue = [Item(n.expr) for n in graph.outputs] | |||
| visited_expr, visited_graph = set(), set() | |||
| while queue: | |||
| item = queue[-1] | |||
| if item.expr in visited_expr: | |||
| queue.pop() | |||
| elif item.child_expanded: | |||
| if item.expr not in graph._exprs: | |||
| queue.pop() | |||
| continue | |||
| new_expr = self.run_transform(item.expr) | |||
| if new_expr is not item.expr: | |||
| graph_changed = True | |||
| assert new_expr not in visited_expr | |||
| queue.append(Item(new_expr)) | |||
| continue | |||
| if ( | |||
| hasattr(item.expr, "graph") | |||
| and item.expr.graph is not None | |||
| and item.expr.graph not in visited_graph | |||
| ): | |||
| graph_changed |= self.visit_graph(item.expr.graph) | |||
| visited_graph.add(item.expr.graph) | |||
| visited_expr.add(item.expr) | |||
| else: | |||
| item.child_expanded = True | |||
| for i in item.expr.inputs: | |||
| expr = i.expr | |||
| if expr not in queue and expr not in visited_expr: | |||
| queue.append(Item(expr)) | |||
| return graph_changed | |||
| class BackwardPass(BasePass): | |||
| def visit_graph(self, graph: InternalGraph): | |||
| self.before_visit_graph(graph) | |||
| graph_changed = False | |||
| queue = [n.expr for n in graph.outputs] | |||
| visited_expr, visited_graph = set(), set() | |||
| while queue: | |||
| expr = queue.pop() | |||
| if expr not in graph._exprs: | |||
| continue | |||
| new_expr = self.run_transform(expr) | |||
| if new_expr is not expr: | |||
| graph_changed = True | |||
| queue.append(new_expr) | |||
| continue | |||
| else: | |||
| visited_expr.add(expr) | |||
| if ( | |||
| hasattr(expr, "graph") | |||
| and expr.graph is not None | |||
| and expr.graph not in visited_graph | |||
| ): | |||
| graph_changed |= self.visit_graph(expr.graph) | |||
| visited_graph.add(expr.graph) | |||
| for i in expr.inputs: | |||
| expr = i.expr | |||
| if expr not in queue and expr not in visited_expr: | |||
| queue.append(expr) | |||
| return graph_changed | |||
| @@ -13,7 +13,7 @@ import inspect | |||
| import re | |||
| import weakref | |||
| from importlib import import_module | |||
| from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union | |||
| from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union | |||
| from ..core._imperative_rt import OpDef | |||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
| @@ -50,20 +50,30 @@ def get_suffix_name(prefix: str, name: str): | |||
| return matchd.group(1) | |||
| def is_call_module(expr): | |||
| def is_call_module(expr, module_cls: Module = None): | |||
| return ( | |||
| isinstance(expr, CallMethod) | |||
| and isinstance(expr.inputs[0], ModuleNode) | |||
| and expr.method == "__call__" | |||
| ) | |||
| ) and (module_cls is None or isinstance(expr.inputs[0].owner, module_cls)) | |||
| def is_call_tensor_method(expr): | |||
| return isinstance(expr, CallMethod) and not is_call_module(expr) | |||
| def is_call_tensor_method(expr, method: Iterable[str] = None): | |||
| if method and isinstance(method, str): | |||
| method = (method,) | |||
| return ( | |||
| isinstance(expr, CallMethod) | |||
| and not is_call_module(expr) | |||
| and (method is None or any(expr.method == f for f in method)) | |||
| ) | |||
| def is_call_function(expr): | |||
| return isinstance(expr, CallFunction) | |||
| def is_call_function(expr, func: Iterable[Callable] = None): | |||
| if func and not isinstance(func, Iterable): | |||
| func = (func,) | |||
| return isinstance(expr, CallFunction) and ( | |||
| func is None or any(expr.func == f for f in func) | |||
| ) | |||
| def is_constant(expr): | |||
| @@ -74,8 +84,8 @@ def is_getattr(expr): | |||
| return isinstance(expr, GetAttr) | |||
| def is_apply_def(expr): | |||
| return isinstance(expr, Apply) | |||
| def is_apply_def(expr, opdef=None): | |||
| return isinstance(expr, Apply) and (opdef is None or isinstance(expr.opdef, opdef)) | |||
| def is_input(expr): | |||
| @@ -78,6 +78,7 @@ class Node: | |||
| "The name(%s) is already in use. Please try a different one again." | |||
| % (new_name) | |||
| ) | |||
| graph._namespace.unassociate_name_with_obj(self) | |||
| self._name = graph._namespace.create_unique_name(new_name, self) | |||
| @property | |||
| @@ -14,6 +14,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Uni | |||
| from .. import get_logger | |||
| from ..module import Module | |||
| from ..tensor import Parameter, Tensor | |||
| logger = get_logger(__name__) | |||
| @@ -301,3 +302,26 @@ class _ModuleDict(Module, MutableMapping): | |||
| def forward(self): | |||
| raise RuntimeError("ModuleList is not callable") | |||
| def assign_attr(obj: Union[Module, Tensor], module: Module, target: str): | |||
| *prefix, name = target.split(".") | |||
| for item in prefix: | |||
| module = getattr(module, item) | |||
| if not isinstance(module, Module): | |||
| raise AttributeError("`{}` is not an Module".format(item)) | |||
| setattr(module, name, obj) | |||
| def get_subattr(module: Module, target: str): | |||
| # todo : remove this import | |||
| from .node import ModuleNode | |||
| if target == "": | |||
| return module | |||
| *prefix, name = target.split(".") | |||
| for item in prefix: | |||
| module = getattr(module, item) | |||
| if not isinstance(module, (Module, ModuleNode)): | |||
| raise AttributeError("`{}` is not an Module".format(item)) | |||
| return getattr(module, name) | |||
| @@ -0,0 +1,86 @@ | |||
| from copy import deepcopy | |||
| from ..functional import ones, sqrt, zeros | |||
| from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU | |||
| from ..tensor import Parameter | |||
| _MAP_TO_FUSED_MODULE = { | |||
| (Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d, | |||
| (Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d, | |||
| (Conv2d, BatchNorm2d, False): Conv2d, | |||
| (Conv2d, BatchNorm2d, True): ConvBn2d, | |||
| (Conv2d, ReLU): ConvRelu2d, | |||
| } | |||
| def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5): | |||
| # get fold bn conv param | |||
| kernel_shape = weight.shape | |||
| if len(kernel_shape) == 5: | |||
| groups, num_features = kernel_shape[0], kernel_shape[1] | |||
| else: | |||
| groups, num_features = 1, kernel_shape[0] | |||
| if gamma is None: | |||
| gamma = ones((num_features), dtype="float32") | |||
| gamma = gamma.reshape(1, -1, 1, 1) | |||
| if beta is None: | |||
| beta = zeros((num_features), dtype="float32") | |||
| beta = beta.reshape(1, -1, 1, 1) | |||
| if bn_mean is None: | |||
| bn_mean = zeros((1, num_features, 1, 1), dtype="float32") | |||
| if bn_var is None: | |||
| bn_var = ones((1, num_features, 1, 1), dtype="float32") | |||
| if bias is None: | |||
| bias = zeros((1, num_features, 1, 1), dtype="float32") | |||
| bn_istd = 1.0 / sqrt(bn_var + eps) | |||
| scale_factor = gamma * bn_istd | |||
| if groups == 1: | |||
| w_fold = weight * scale_factor.reshape(-1, 1, 1, 1) | |||
| else: | |||
| w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1) | |||
| b_fold = beta + gamma * (bias - bn_mean) * bn_istd | |||
| return w_fold, b_fold | |||
| def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||
| module_key = tuple([type(m) for m in [conv, bn, relu] if m]) | |||
| if bn: | |||
| assert ( | |||
| conv.training == bn.training | |||
| ), "Conv and BN both must be in the same mode (train or eval)." | |||
| assert ( | |||
| bn.num_features == conv.out_channels | |||
| ), "Output channel of Conv2d must match num_features of BatchNorm2d" | |||
| module_key = module_key + (conv.training,) | |||
| module = _MAP_TO_FUSED_MODULE[module_key]( | |||
| in_channels=conv.in_channels, | |||
| out_channels=conv.out_channels, | |||
| kernel_size=conv.kernel_size, | |||
| stride=conv.stride, | |||
| padding=conv.padding, | |||
| dilation=conv.dilation, | |||
| groups=conv.groups, | |||
| bias=conv.bias is not None, | |||
| conv_mode=conv.conv_mode, | |||
| compute_mode=conv.compute_mode, | |||
| name=conv.name, | |||
| ) | |||
| new_conv = module if bn is None or not conv.training else module.conv | |||
| weight, bias = conv.weight, conv.bias | |||
| if not conv.training and bn is not None: | |||
| weight, bias = fold_weight_bias( | |||
| weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps, | |||
| ) | |||
| new_conv.weight = Parameter(weight) | |||
| if bias is not None: | |||
| new_conv.bias = Parameter(bias) | |||
| if bn is not None and conv.training: | |||
| module.bn = deepcopy(bn) | |||
| new_conv.training = conv.training | |||
| return module | |||
| @@ -13,20 +13,7 @@ import megengine.quantization as Q | |||
| from megengine import Tensor | |||
| from megengine.module.qat.module import QATModule | |||
| from megengine.traced_module import TracedModule, trace_module | |||
| def get_subattr(self: M.Module, name: str): | |||
| if name == "": | |||
| return self | |||
| module_path, _, name = name.rpartition(".") | |||
| if module_path == "": | |||
| return getattr(self, name) | |||
| module_names = module_path.split(".") | |||
| for item in module_names: | |||
| self = getattr(self, item) | |||
| if not isinstance(self, M.Module): | |||
| raise AttributeError("`{}` is not an Module".format(item)) | |||
| return getattr(self, name) | |||
| from megengine.traced_module.utils import get_subattr | |||
| class MyConvBnRelu2d(M.ConvBnRelu2d): | |||