| @@ -83,9 +83,7 @@ def disable_receptive_field(): | |||||
| _receptive_field_enabled = False | _receptive_field_enabled = False | ||||
| @register_flops( | |||||
| M.Conv1d, M.Conv2d, M.Conv3d, M.LocalConv2d, M.DeformableConv2d | |||||
| ) | |||||
| @register_flops(M.Conv1d, M.Conv2d, M.Conv3d, M.LocalConv2d, M.DeformableConv2d) | |||||
| def flops_convNd(module: M.Conv2d, inputs, outputs): | def flops_convNd(module: M.Conv2d, inputs, outputs): | ||||
| bias = 1 if module.bias is not None else 0 | bias = 1 if module.bias is not None else 0 | ||||
| # N x Cout x H x W x (Cin x Kw x Kh + bias) | # N x Cout x H x W x (Cin x Kw x Kh + bias) | ||||
| @@ -93,13 +91,16 @@ def flops_convNd(module: M.Conv2d, inputs, outputs): | |||||
| float(module.in_channels // module.groups) * np.prod(module.kernel_size) + bias | float(module.in_channels // module.groups) * np.prod(module.kernel_size) + bias | ||||
| ) | ) | ||||
| @register_flops(M.ConvTranspose2d) | @register_flops(M.ConvTranspose2d) | ||||
| def flops_convNdTranspose(module: M.Conv2d, inputs, outputs): | def flops_convNdTranspose(module: M.Conv2d, inputs, outputs): | ||||
| bias = 1 if module.bias is not None else 0 | bias = 1 if module.bias is not None else 0 | ||||
| # N x Cout x H x W x (Cin x Kw x Kh + bias) | # N x Cout x H x W x (Cin x Kw x Kh + bias) | ||||
| return np.prod(inputs[0].shape) * ( | |||||
| module.out_channels // module.groups * np.prod(module.kernel_size) | |||||
| ) + np.prod(outputs[0].shape) * bias | |||||
| return ( | |||||
| np.prod(inputs[0].shape) | |||||
| * (module.out_channels // module.groups * np.prod(module.kernel_size)) | |||||
| + np.prod(outputs[0].shape) * bias | |||||
| ) | |||||
| @register_flops( | @register_flops( | ||||