| @@ -8,6 +8,7 @@ | |||
| from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | |||
| from . import compat | |||
| from ._passes import optimize | |||
| from .traced_module import ( | |||
| TracedModule, | |||
| _register_all_builtin_module, | |||
| @@ -19,3 +20,11 @@ from .traced_module import ( | |||
| _register_all_builtin_module() | |||
| set_cpp_apply_module_trace(cpp_apply_module_trace) | |||
| __all__ = { | |||
| "register_as_builtin", | |||
| "trace_module", | |||
| "wrap", | |||
| "TracedModule", | |||
| "optimize", | |||
| } | |||
| @@ -0,0 +1,12 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from . import const_pass, fold_scale_pass, fuse_pass | |||
| from .optimization import optimize | |||
| __all__ = ["optimize"] | |||
| @@ -0,0 +1,70 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from copy import deepcopy | |||
| from typing import List, Set | |||
| from ...logger import get_logger | |||
| from ..traced_module import TracedModule | |||
| from .pass_base import get_default_pass_context, get_registered_pass | |||
| logger = get_logger(__name__) | |||
| def optimize( | |||
| module: TracedModule, enabled_pass: List[str] = ["FuseConvBn"], | |||
| ) -> TracedModule: | |||
| r"""Performs a set of optimization passes to optimize a `TracedModule` for inference. | |||
| The following passes are currently supported: | |||
| * FuseConvBn: fuse BN layers into to conv2d | |||
| * FuseAddMul: fold adjacent const add or mul binary operations | |||
| * BackwardFoldScale: backward fold const scaling into weights of conv2d | |||
| Args: | |||
| module: the :class:`TracedModule` to be optimized. | |||
| enabled_pass: optimization passes to be enabled during optimization. | |||
| Default: ["FuseConvBn"] | |||
| Returns: | |||
| the optimized :class:`TracedModule`. | |||
| """ | |||
| defalut_passes_list = [ | |||
| "FuseConvBn", | |||
| "FuseAddMul", | |||
| ] | |||
| if isinstance(enabled_pass, str): | |||
| enabled_pass = [enabled_pass] | |||
| if "BackwardFoldScale" in enabled_pass: | |||
| if "FuseConvBn" not in enabled_pass: | |||
| logger.warning( | |||
| "Since BackwardFoldScale requires FuseConvBn" | |||
| ", FuseConvBn will be enabled." | |||
| ) | |||
| enabled_pass.append("FuseConvBn") | |||
| defalut_passes_list.extend( | |||
| ["BackwardFoldScale", "FuseAddMul",] | |||
| ) | |||
| pass_ctx = get_default_pass_context() | |||
| def run_pass(mod: TracedModule): | |||
| for pass_name in defalut_passes_list: | |||
| if pass_name in enabled_pass: | |||
| pass_func = get_registered_pass(pass_name)() | |||
| mod = pass_func(mod, pass_ctx) | |||
| return mod | |||
| module = deepcopy(module) | |||
| module = run_pass(module) | |||
| return module | |||
| @@ -0,0 +1,106 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import types | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.traced_module as tm | |||
| class myconv(M.Conv2d): | |||
| pass | |||
| class mybn(M.BatchNorm2d): | |||
| pass | |||
| class MyBlock(M.Module): | |||
| def __init__(self, conv_cls, bn_cls): | |||
| super().__init__() | |||
| self.conv = conv_cls(3, 3, 1, 1, 0) | |||
| self.bn = bn_cls(3) | |||
| self.conv2 = conv_cls(3, 3, 1, 1, 0) | |||
| self.bn2 = bn_cls(3) | |||
| self.scale = mge.Tensor([3, 4]) | |||
| def forward(self, x): | |||
| x1 = self.conv(x) | |||
| x1 = self.bn(x1) | |||
| x1 = F.relu(x1) | |||
| x1 = x1 * self.scale[0] | |||
| x2 = self.conv2(x) | |||
| x2 = self.bn2(x2) | |||
| x2 = F.relu(x2) | |||
| x2 = x2 * self.scale[1] | |||
| y = x1 + x2 | |||
| y = y + 4 | |||
| y = self.scale[0] + y | |||
| y = F.relu(y) * 3 | |||
| return y | |||
| class MyModule(M.Module): | |||
| def __init__(self, conv_cls, bn_cls): | |||
| super().__init__() | |||
| self.block_0 = MyBlock(conv_cls, bn_cls) | |||
| self.block_1 = MyBlock(conv_cls, bn_cls) | |||
| def forward(self, x): | |||
| x1 = self.block_0(x) | |||
| x2 = self.block_1(x) | |||
| y = x1 + x2 | |||
| y = F.reshape(y, (-1)) | |||
| y = y * 3 | |||
| return y | |||
| @pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv]) | |||
| @pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn]) | |||
| def test_backward_fold_scale(conv_cls, bn_cls): | |||
| module = MyModule(conv_cls, bn_cls) | |||
| module.eval() | |||
| inp = mge.Tensor(np.random.random((1, 3, 32, 32))) | |||
| desired = module(inp) | |||
| traced_net = tm.trace_module(module, inp) | |||
| traced_net = traced_net.flatten() | |||
| optimized_net = tm.optimize(traced_net, "BackwardFoldScale") | |||
| actual = optimized_net(inp) | |||
| np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4) | |||
| # fuse all mul to conv | |||
| mul_list = optimized_net.graph.get_method_by_type("__mul__").as_list() | |||
| assert len(mul_list) == 0 | |||
| @pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv]) | |||
| @pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn]) | |||
| def test_fuse_bn(conv_cls, bn_cls): | |||
| module = MyModule(conv_cls, bn_cls) | |||
| module.eval() | |||
| inp = mge.Tensor(np.random.random((1, 3, 32, 32))) | |||
| desired = module(inp) | |||
| traced_net = tm.trace_module(module, inp) | |||
| traced_net = traced_net.flatten() | |||
| optimized_net = tm.optimize(traced_net, "FuseConvBn") | |||
| actual = optimized_net(inp) | |||
| np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4) | |||
| # fuse all mul to conv | |||
| bn_list = optimized_net.graph.get_function_by_type(F.batch_norm).as_list() | |||
| assert len(bn_list) == 0 | |||
| bn_list = optimized_net.graph.get_module_by_type(M.BatchNorm2d).as_list() | |||
| assert len(bn_list) == 0 | |||