diff --git a/imperative/python/megengine/experimental/__init__.py b/imperative/python/megengine/experimental/__init__.py index 0253e7dc..19b1fc6b 100644 --- a/imperative/python/megengine/experimental/__init__.py +++ b/imperative/python/megengine/experimental/__init__.py @@ -6,3 +6,4 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .weight_scaler import get_scaled_model diff --git a/imperative/python/megengine/experimental/weight_scaler.py b/imperative/python/megengine/experimental/weight_scaler.py new file mode 100644 index 00000000..0908257b --- /dev/null +++ b/imperative/python/megengine/experimental/weight_scaler.py @@ -0,0 +1,100 @@ +import types +from functools import partial + +import megengine.functional as F +import megengine.module as M +from megengine.functional.tensor import zeros +from megengine.utils.module_utils import set_module_mode_safe + + +def get_norm_mod_value(weight, norm_value): + weight = weight.reshape(-1) + norm = F.norm(weight) + scale = norm_value / norm + round_log = F.floor(F.log(scale) / F.log(2)) + rounded_scale = 2 ** round_log + return rounded_scale.detach() + + +def get_scaled_model(model, scale_submodel, input_shape=None): + submodule_list = None + scale_value = None + accumulated_scale = 1.0 + + def scale_calc(mod_calc_func): + def calcfun(self, inp, weight, bias): + scaled_weight = weight + scaled_bias = bias + if self.training: + scaled_weight = ( + weight * self.weight_scale if weight is not None else None + ) + scaled_bias = bias * self.bias_scale if bias is not None else None + return mod_calc_func(inp, scaled_weight, scaled_bias) + + return calcfun + + def scale_module_structure( + scale_list: list = None, scale_value: tuple = None, + ): + nonlocal accumulated_scale + for i in range(len(scale_list)): + key, mod = scale_list[i] + w_scale_value = scale_value[1] + if scale_value[0] is not "CONST": + w_scale_value = get_norm_mod_value(mod.weight, scale_value[1]) + + accumulated_scale *= w_scale_value + + mod.weight_scale = w_scale_value + mod.bias_scale = accumulated_scale + + if isinstance(mod, M.conv.Conv2d): + mod.calc_conv = types.MethodType(scale_calc(mod.calc_conv), mod) + else: + mod._calc_linear = types.MethodType(scale_calc(mod._calc_linear), mod) + + def forward_hook(submodel, inputs, outpus, modelname=""): + nonlocal submodule_list + nonlocal scale_value + nonlocal accumulated_scale + if modelname in scale_submodel: + scale_value = scale_submodel[modelname] + if isinstance(submodel, (M.conv.Conv2d, M.linear.Linear)): + scale_module_structure([(modelname, submodel)], scale_value) + else: + submodule_list = [] + + if isinstance(submodel, (M.conv.Conv2d, M.linear.Linear)) and ( + submodule_list is not None + ): + submodule_list.append((modelname, submodel)) + + if isinstance(submodel, M.batchnorm.BatchNorm2d) and ( + submodule_list is not None + ): + scale_module_structure(submodule_list, scale_value) + submodule_list = None + scale_value = None + accumulated_scale = 1.0 + + if input_shape is None: + raise ValueError("input_shape is required for calculating scale value") + + input = zeros(input_shape) + + hooks = [] + for modelname, submodel in model.named_modules(): + hooks.append( + submodel.register_forward_pre_hook( + partial(forward_hook, modelname=modelname, outpus=None) + ) + ) + + with set_module_mode_safe(model, training=False) as model: + model(input) + + for hook in hooks: + hook.remove() + + return model