You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quant.py 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Quantization function."""
  16. from mindspore.common import dtype as mstype
  17. from mindspore.common.parameter import Parameter
  18. from mindspore.ops import operations as P
  19. from mindspore.ops import functional as F
  20. from mindspore.ops import composite as C
  21. from mindspore import nn
  22. class QuantizeWeightCell(nn.Cell):
  23. """
  24. The ternary fake quant op for weight.
  25. Args:
  26. num_bits (int): The bit number of quantization, supporting 2 to 8 bits. Default: 2.
  27. compute_type (:class:`mindspore.dtype`): Compute type in QuantizeWeightCell. Default: mstype.float32.
  28. clip_value (float): Clips weight to be in [-clip_value, clip_value].
  29. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
  30. Inputs:
  31. - **weight** (Parameter) - Parameter of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  32. Outputs:
  33. Parameter of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  34. """
  35. def __init__(self, num_bits=8, compute_type=mstype.float32, clip_value=1.0, per_channel=False):
  36. super(QuantizeWeightCell, self).__init__()
  37. self.num_bits = num_bits
  38. self.compute_type = compute_type
  39. self.clip_value = clip_value
  40. self.per_channel = per_channel
  41. self.clamp = C.clip_by_value
  42. self.abs = P.Abs()
  43. self.sum = P.ReduceSum()
  44. self.nelement = F.size
  45. self.div = P.Div()
  46. self.cast = P.Cast()
  47. self.max = P.ReduceMax()
  48. self.min = P.ReduceMin()
  49. self.round = P.Round()
  50. def construct(self, weight):
  51. """quantize weight cell"""
  52. tensor = self.clamp(weight, -self.clip_value, self.clip_value)
  53. if self.num_bits == 2:
  54. if self.per_channel:
  55. n = self.nelement(tensor[0])
  56. m = self.div(self.sum(self.abs(tensor), 1), n)
  57. thres = 0.7 * m
  58. pos = self.cast(tensor[:] > thres[0], self.compute_type)
  59. neg = self.cast(tensor[:] < -thres[0], self.compute_type)
  60. mask = self.cast(self.abs(tensor)[:] > thres[0], self.compute_type)
  61. alpha = self.reshape(self.sum(self.abs(mask * tensor), 1) / self.sum(mask, 1), (-1, 1))
  62. output = alpha * pos - alpha * neg
  63. else:
  64. n = self.nelement(tensor)
  65. m = self.div(self.sum(self.abs(tensor)), n)
  66. thres = 0.7 * m
  67. pos = self.cast(tensor > thres, self.compute_type)
  68. neg = self.cast(tensor < -thres, self.compute_type)
  69. mask = self.cast(self.abs(tensor) > thres, self.compute_type)
  70. alpha = self.sum(self.abs(mask * self.cast(tensor, self.compute_type))) / self.sum(mask)
  71. output = alpha * pos - alpha * neg
  72. else:
  73. tensor_max = self.cast(self.max(tensor), self.compute_type)
  74. tensor_min = self.cast(self.min(tensor), self.compute_type)
  75. s = (tensor_max - tensor_min) / (2 ** self.cast(self.num_bits, self.compute_type) - 1)
  76. output = self.round(self.div(tensor - tensor_min, s)) * s + tensor_min
  77. return output
  78. class QuantizeWeight:
  79. """
  80. Quantize weight into specified bit.
  81. Args:
  82. num_bits (int): The bit number of quantization, supporting 2 to 8 bits. Default: 2.
  83. compute_type (:class:`mindspore.dtype`): Compute type in QuantizeWeightCell. Default: mstype.float32.
  84. clip_value (float): Clips weight to be in [-clip_value, clip_value].
  85. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
  86. Inputs:
  87. - **weight** (Parameter) - Parameter of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  88. Outputs:
  89. Parameter of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  90. """
  91. def __init__(self, num_bits=2, compute_type=mstype.float32, clip_value=1.0, per_channel=False):
  92. self.num_bits = num_bits
  93. self.compute_type = compute_type
  94. self.clip_value = clip_value
  95. self.per_channel = per_channel
  96. self.clamp = C.clip_by_value
  97. self.abs = P.Abs()
  98. self.sum = P.ReduceSum()
  99. self.nelement = F.size
  100. self.div = P.Div()
  101. self.cast = P.Cast()
  102. self.max = P.ReduceMax()
  103. self.min = P.ReduceMin()
  104. self.floor = P.Floor()
  105. def construct(self, weight):
  106. """quantize weight"""
  107. tensor = self.clamp(weight, -self.clip_value, self.clip_value)
  108. if self.num_bits == 2:
  109. if self.per_channel:
  110. n = self.nelement(tensor[0])
  111. m = self.div(self.sum(self.abs(tensor), 1), n)
  112. thres = 0.7 * m
  113. pos = self.cast(tensor[:] > thres[0], self.compute_type)
  114. neg = self.cast(tensor[:] < -thres[0], self.compute_type)
  115. mask = self.cast(self.abs(tensor)[:] > thres[0], self.compute_type)
  116. alpha = self.reshape(self.sum(self.abs(mask * tensor), 1) / self.sum(mask, 1), (-1, 1))
  117. output = alpha * pos - alpha * neg
  118. else:
  119. n = self.nelement(tensor)
  120. m = self.div(self.sum(self.abs(tensor)), n)
  121. thres = 0.7 * m
  122. pos = self.cast(tensor > thres, self.compute_type)
  123. neg = self.cast(tensor < -thres, self.compute_type)
  124. mask = self.cast(self.abs(tensor) > thres, self.compute_type)
  125. alpha = self.sum(self.abs(mask * tensor)) / self.sum(mask)
  126. output = alpha * pos - alpha * neg
  127. else:
  128. tensor_max = self.max(tensor)
  129. tensor_min = self.min(tensor)
  130. s = (tensor_max - tensor_min) / (2 ** self.num_bits - 1)
  131. output = self.floor(self.div((tensor - tensor_min), s) + 0.5) * s + tensor_min
  132. return output
  133. def convert_network(network, embedding_bits=2, weight_bits=2, clip_value=1.0):
  134. quantize_embedding = QuantizeWeight(num_bits=embedding_bits, clip_value=clip_value)
  135. quantize_weight = QuantizeWeight(num_bits=weight_bits, clip_value=clip_value)
  136. for name, param in network.parameters_and_names():
  137. if 'bert_embedding_lookup' in name and 'min' not in name and 'max' not in name:
  138. quantized_param = quantize_embedding.construct(param)
  139. param.set_data(quantized_param)
  140. elif 'weight' in name and 'dense_1' not in name:
  141. quantized_param = quantize_weight.construct(param)
  142. param.set_data(quantized_param)
  143. def save_params(network):
  144. return {name: Parameter(param, 'saved_params') for name, param in network.parameters_and_names()}
  145. def restore_params(network, params_dict):
  146. for name, param in network.parameters_and_names():
  147. param.set_data(params_dict[name])