|
|
|
@@ -0,0 +1,100 @@ |
|
|
|
import types |
|
|
|
from functools import partial |
|
|
|
|
|
|
|
import megengine.functional as F |
|
|
|
import megengine.module as M |
|
|
|
from megengine.functional.tensor import zeros |
|
|
|
from megengine.utils.module_utils import set_module_mode_safe |
|
|
|
|
|
|
|
|
|
|
|
def get_norm_mod_value(weight, norm_value): |
|
|
|
weight = weight.reshape(-1) |
|
|
|
norm = F.norm(weight) |
|
|
|
scale = norm_value / norm |
|
|
|
round_log = F.floor(F.log(scale) / F.log(2)) |
|
|
|
rounded_scale = 2 ** round_log |
|
|
|
return rounded_scale.detach() |
|
|
|
|
|
|
|
|
|
|
|
def get_scaled_model(model, scale_submodel, input_shape=None): |
|
|
|
submodule_list = None |
|
|
|
scale_value = None |
|
|
|
accumulated_scale = 1.0 |
|
|
|
|
|
|
|
def scale_calc(mod_calc_func): |
|
|
|
def calcfun(self, inp, weight, bias): |
|
|
|
scaled_weight = weight |
|
|
|
scaled_bias = bias |
|
|
|
if self.training: |
|
|
|
scaled_weight = ( |
|
|
|
weight * self.weight_scale if weight is not None else None |
|
|
|
) |
|
|
|
scaled_bias = bias * self.bias_scale if bias is not None else None |
|
|
|
return mod_calc_func(inp, scaled_weight, scaled_bias) |
|
|
|
|
|
|
|
return calcfun |
|
|
|
|
|
|
|
def scale_module_structure( |
|
|
|
scale_list: list = None, scale_value: tuple = None, |
|
|
|
): |
|
|
|
nonlocal accumulated_scale |
|
|
|
for i in range(len(scale_list)): |
|
|
|
key, mod = scale_list[i] |
|
|
|
w_scale_value = scale_value[1] |
|
|
|
if scale_value[0] is not "CONST": |
|
|
|
w_scale_value = get_norm_mod_value(mod.weight, scale_value[1]) |
|
|
|
|
|
|
|
accumulated_scale *= w_scale_value |
|
|
|
|
|
|
|
mod.weight_scale = w_scale_value |
|
|
|
mod.bias_scale = accumulated_scale |
|
|
|
|
|
|
|
if isinstance(mod, M.conv.Conv2d): |
|
|
|
mod.calc_conv = types.MethodType(scale_calc(mod.calc_conv), mod) |
|
|
|
else: |
|
|
|
mod._calc_linear = types.MethodType(scale_calc(mod._calc_linear), mod) |
|
|
|
|
|
|
|
def forward_hook(submodel, inputs, outpus, modelname=""): |
|
|
|
nonlocal submodule_list |
|
|
|
nonlocal scale_value |
|
|
|
nonlocal accumulated_scale |
|
|
|
if modelname in scale_submodel: |
|
|
|
scale_value = scale_submodel[modelname] |
|
|
|
if isinstance(submodel, (M.conv.Conv2d, M.linear.Linear)): |
|
|
|
scale_module_structure([(modelname, submodel)], scale_value) |
|
|
|
else: |
|
|
|
submodule_list = [] |
|
|
|
|
|
|
|
if isinstance(submodel, (M.conv.Conv2d, M.linear.Linear)) and ( |
|
|
|
submodule_list is not None |
|
|
|
): |
|
|
|
submodule_list.append((modelname, submodel)) |
|
|
|
|
|
|
|
if isinstance(submodel, M.batchnorm.BatchNorm2d) and ( |
|
|
|
submodule_list is not None |
|
|
|
): |
|
|
|
scale_module_structure(submodule_list, scale_value) |
|
|
|
submodule_list = None |
|
|
|
scale_value = None |
|
|
|
accumulated_scale = 1.0 |
|
|
|
|
|
|
|
if input_shape is None: |
|
|
|
raise ValueError("input_shape is required for calculating scale value") |
|
|
|
|
|
|
|
input = zeros(input_shape) |
|
|
|
|
|
|
|
hooks = [] |
|
|
|
for modelname, submodel in model.named_modules(): |
|
|
|
hooks.append( |
|
|
|
submodel.register_forward_pre_hook( |
|
|
|
partial(forward_hook, modelname=modelname, outpus=None) |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
with set_module_mode_safe(model, training=False) as model: |
|
|
|
model(input) |
|
|
|
|
|
|
|
for hook in hooks: |
|
|
|
hook.remove() |
|
|
|
|
|
|
|
return model |