You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_transpose_bn.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from ...functional import ones, relu, sqrt, sum, zeros
  2. from .. import conv_transpose_bn as Float
  3. from .module import QATModule
  4. class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule):
  5. def get_batch_mean_var(self, inp):
  6. def _sum_channel(inp, axis=0, keepdims=True):
  7. if isinstance(axis, int):
  8. out = sum(inp, axis=axis, keepdims=keepdims)
  9. elif isinstance(axis, tuple):
  10. for idx, elem in enumerate(axis):
  11. out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
  12. return out
  13. sum1 = _sum_channel(inp, (0, 2, 3))
  14. sum2 = _sum_channel(inp ** 2, (0, 2, 3))
  15. reduce_size = inp.size / inp.shape[1]
  16. batch_mean = sum1 / reduce_size
  17. batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
  18. return batch_mean, batch_var
  19. def fold_weight_bias(self, bn_mean, bn_var):
  20. # get fold bn conv_transpose2d param
  21. gamma = self.bn.weight
  22. if gamma is None:
  23. gamma = ones((self.bn.num_features), dtype="float32")
  24. gamma = gamma.reshape(1, -1, 1, 1)
  25. beta = self.bn.bias
  26. if beta is None:
  27. beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
  28. if bn_mean is None:
  29. bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
  30. if bn_var is None:
  31. bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
  32. conv_transpose2d_bias = self.conv_transpose2d.bias
  33. if conv_transpose2d_bias is None:
  34. conv_transpose2d_bias = zeros(
  35. self.conv_transpose2d._infer_bias_shape(), dtype="float32"
  36. )
  37. bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
  38. scale_factor = gamma * bn_istd
  39. if self.conv_transpose2d.groups == 1:
  40. w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1)
  41. else:
  42. w_fold = self.conv_transpose2d.weight * scale_factor.reshape(
  43. self.conv_transpose2d.groups, -1, 1, 1, 1
  44. )
  45. w_fold = self.apply_quant_weight(w_fold)
  46. b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd
  47. return w_fold, b_fold
  48. def update_running_mean_and_running_var(
  49. self, bn_mean, bn_var, num_elements_per_channel
  50. ):
  51. # update running mean and running var. no grad, use unbiased bn var
  52. bn_mean = bn_mean.detach()
  53. bn_var = (
  54. bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1)
  55. )
  56. exponential_average_factor = 1 - self.bn.momentum
  57. self.bn.running_mean *= self.bn.momentum
  58. self.bn.running_mean += exponential_average_factor * bn_mean
  59. self.bn.running_var *= self.bn.momentum
  60. self.bn.running_var += exponential_average_factor * bn_var
  61. def calc_conv_transpose2d_bn_qat(self, inp, approx=True):
  62. if self.training and not approx:
  63. conv_transpose2d = self.conv_transpose2d(inp)
  64. bn_mean, bn_var = self.get_batch_mean_var(conv_transpose2d)
  65. num_elements_per_channel = conv_transpose2d.size / conv_transpose2d.shape[1]
  66. self.update_running_mean_and_running_var(
  67. bn_mean, bn_var, num_elements_per_channel
  68. )
  69. else:
  70. bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
  71. # get gamma and beta in BatchNorm
  72. gamma = self.bn.weight
  73. if gamma is None:
  74. gamma = ones((self.bn.num_features), dtype="float32")
  75. gamma = gamma.reshape(1, -1, 1, 1)
  76. beta = self.bn.bias
  77. if beta is None:
  78. beta = zeros((self.bn.num_features), dtype="float32")
  79. beta = beta.reshape(1, -1, 1, 1)
  80. # conv_transpose2d_bias
  81. conv_transpose2d_bias = self.conv_transpose2d.bias
  82. if conv_transpose2d_bias is None:
  83. conv_transpose2d_bias = zeros(
  84. self.conv_transpose2d._infer_bias_shape(), dtype="float32"
  85. )
  86. bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
  87. scale_factor = gamma * bn_istd
  88. if self.conv_transpose2d.groups == 1:
  89. w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1)
  90. else:
  91. w_fold = self.conv_transpose2d.weight * scale_factor.reshape(
  92. self.conv_transpose2d.groups, 1, -1, 1, 1
  93. )
  94. b_fold = None
  95. if not (self.training and approx):
  96. b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd
  97. w_qat = self.apply_quant_weight(w_fold)
  98. b_qat = self.apply_quant_bias(b_fold, inp, w_qat)
  99. conv_transpose2d = self.conv_transpose2d.calc_conv_transpose2d(
  100. inp, w_qat, b_qat
  101. )
  102. if not (self.training and approx):
  103. return conv_transpose2d
  104. # rescale conv_transpose2d to get original conv_transpose2d output
  105. orig_conv_transpose2d = conv_transpose2d / scale_factor.reshape(1, -1, 1, 1)
  106. if self.conv_transpose2d.bias is not None:
  107. orig_conv_transpose2d = orig_conv_transpose2d + self.conv_transpose2d.bias
  108. # calculate batch norm
  109. conv_transpose2d = self.bn(orig_conv_transpose2d)
  110. return conv_transpose2d
  111. @classmethod
  112. def from_float_module(cls, float_module: Float._ConvTransposeBnActivation2d):
  113. qat_module = cls(
  114. float_module.conv_transpose2d.in_channels,
  115. float_module.conv_transpose2d.out_channels,
  116. float_module.conv_transpose2d.kernel_size,
  117. float_module.conv_transpose2d.stride,
  118. float_module.conv_transpose2d.padding,
  119. float_module.conv_transpose2d.output_padding,
  120. float_module.conv_transpose2d.dilation,
  121. float_module.conv_transpose2d.groups,
  122. float_module.conv_transpose2d.bias is not None,
  123. float_module.conv_transpose2d.conv_mode,
  124. float_module.conv_transpose2d.compute_mode,
  125. name=float_module.name,
  126. )
  127. qat_module.conv_transpose2d.weight = float_module.conv_transpose2d.weight
  128. qat_module.conv_transpose2d.bias = float_module.conv_transpose2d.bias
  129. qat_module.bn = float_module.bn
  130. return qat_module
  131. class ConvTransposeBn2d(_ConvTransposeBnActivation2d):
  132. r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d` with QAT support.
  133. Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
  134. """
  135. def forward(self, inp):
  136. return self.apply_quant_activation(self.calc_conv_transpose2d_bn_qat(inp))
  137. class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d):
  138. r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu` with QAT support.
  139. Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`.
  140. """
  141. def forward(self, inp):
  142. return self.apply_quant_activation(relu(self.calc_conv_transpose2d_bn_qat(inp)))