GitOrigin-RevId: a15f17d616
tags/v1.11.0
| @@ -457,6 +457,7 @@ def module_stats( | |||
| log_activations = False | |||
| disable_receptive_field() | |||
| recorded_parameters = set() | |||
| def module_stats_hook(module, inputs, outputs, name=""): | |||
| class_name = str(module.__class__).split(".")[-1].split("'")[0] | |||
| @@ -468,17 +469,27 @@ def module_stats( | |||
| flops.append(flops_stats) | |||
| if cal_params: | |||
| if hasattr(module, "weight") and module.weight is not None: | |||
| if ( | |||
| hasattr(module, "weight") | |||
| and (module.weight is not None) | |||
| and module.weight not in recorded_parameters | |||
| ): | |||
| w = module.weight | |||
| param_stats = get_param_stats(w) | |||
| param_stats["name"] = name + "-w" | |||
| params.append(param_stats) | |||
| recorded_parameters.add(w) | |||
| if hasattr(module, "bias") and module.bias is not None: | |||
| if ( | |||
| hasattr(module, "bias") | |||
| and module.bias is not None | |||
| and module.bias not in recorded_parameters | |||
| ): | |||
| b = module.bias | |||
| param_stats = get_param_stats(b) | |||
| param_stats["name"] = name + "-b" | |||
| params.append(param_stats) | |||
| recorded_parameters.add(b) | |||
| if cal_activations: | |||
| if not isinstance(outputs, (tuple, list)): | |||
| @@ -504,7 +515,6 @@ def module_stats( | |||
| hooks.append( | |||
| module.register_forward_hook(partial(module_stats_hook, name=name)) | |||
| ) | |||
| with set_module_mode_safe(model, training=False) as model: | |||
| model(*inputs) | |||
| @@ -42,6 +42,65 @@ def test_other_input_module_state(): | |||
| net(_nt) | |||
| @pytest.mark.skipif( | |||
| use_symbolic_shape(), reason="This test do not support symbolic shape.", | |||
| ) | |||
| def test_duplicated_module(): | |||
| input_shape = (1, 3, 224, 224) | |||
| net0 = TestNet0() | |||
| net0_stats, _ = module_stats(net0, input_shapes=input_shape) | |||
| net1 = TestNet1() | |||
| net1_stats, _ = module_stats(net1, input_shapes=input_shape) | |||
| net2 = TestNet2() | |||
| net2_stats, _ = module_stats(net2, input_shapes=input_shape) | |||
| assert net0_stats.param_dims == net1_stats.param_dims | |||
| assert net0_stats.param_size == net1_stats.param_size | |||
| assert net0_stats.param_dims == net2_stats.param_dims | |||
| assert net0_stats.param_size == net2_stats.param_size | |||
| class TestNet0(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.conv = M.Conv2d(3, 3, 3, padding=(1, 1)) | |||
| self.conv.bias = mge.Parameter( | |||
| np.random.random(self.conv.bias.shape).astype(np.float32) | |||
| ) | |||
| def forward(self, x): | |||
| x = self.conv(x) | |||
| return x | |||
| class TestNet1(TestNet0): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.conv1 = self.conv | |||
| def forward(self, x): | |||
| x = self.conv(x) | |||
| x = self.conv1(x) | |||
| return x | |||
| class TestNet2(TestNet0): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.conv1 = M.Conv2d(3, 3, 3, padding=(1, 1)) | |||
| self.conv1.weight = self.conv.weight | |||
| self.conv1.bias = self.conv.bias | |||
| def forward(self, x): | |||
| x = self.conv(x) | |||
| x = self.conv1(x) | |||
| return x | |||
| class FakeNet(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||