Browse Source

fix lsq quant bugs

tags/v1.3.0
Erpim 5 years ago
parent
commit
2e9c9a6d4f
3 changed files with 25 additions and 17 deletions
  1. +7
    -0
      mindspore/compression/quant/qat.py
  2. +7
    -1
      mindspore/compression/quant/quant_utils.py
  3. +11
    -16
      mindspore/nn/layer/quant.py

+ 7
- 0
mindspore/compression/quant/qat.py View File

@@ -493,8 +493,15 @@ class QuantizationAwareTraining(Quantizer):
"""
act_class = activation.__class__
act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid]
neg_trunc_act_list = [nn.ReLU, nn.ReLU6]
act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish]

if act_class in neg_trunc_act_list and OptimizeOption.LEARNED_SCALE in self.optimize_option:
self.quant_config = self.quant_config._replace(
activation=self.quant_config.activation.partial_init(neg_trunc=True, narrow_range=False))
return quant.ActQuant(activation=activation,
quant_config=self.quant_config,
quant_dtype=self.act_dtype)
if act_class in act_list:
return quant.ActQuant(activation=activation,
quant_config=self.quant_config,


+ 7
- 1
mindspore/compression/quant/quant_utils.py View File

@@ -278,12 +278,16 @@ def compute_KL_threshold(data, bitwidth):
Tensor with Shape 1. Threshold to calculate the data.
"""
bitwidth = bitwidth.num_bits

data_min = 0
data_max = np.abs(data).max()
if data_max < 1e-5:
return 1e-5
hist, bin_edges = np.histogram(np.abs(data), bins='sqrt', range=(data_min, data_max), density=True)
# For the sake of high efficiency, we limit the maximum number of bins to 1024 in `sqrt` mode, If it exceeds the
# largest size, turn to use the default bins config.
largest_bin_size = 1024
if hist.shape[0] > largest_bin_size:
hist, bin_edges = np.histogram(np.abs(data), range=(data_min, data_max), density=True)
hist = hist / np.sum(hist)
cumsum = np.cumsum(hist)
bit_pow_range = pow(2, int(bitwidth) - 1)
@@ -353,6 +357,8 @@ def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_param
Returns:
None
"""
if quant_new_params is not None and not isinstance(quant_new_params, list):
raise TypeError("quant_new_params must be list or None.")
iterable_dict = {
'weight': iter(list(filter(lambda item: item[0].endswith('weight'), params_dict.items()))),
'bias': iter(list(filter(lambda item: item[0].endswith('bias'), params_dict.items()))),


+ 11
- 16
mindspore/nn/layer/quant.py View File

@@ -24,13 +24,12 @@ from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator, Rel, twice
from mindspore._checkparam import Validator, twice
from mindspore.compression.common import QuantDtype
import mindspore.context as context
from .normalization import BatchNorm2d
from .activation import get_activation, ReLU
from ..cell import Cell
from ... import nn
from ...ops.operations import _quant_ops as Q

__all__ = [
@@ -381,10 +380,6 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
num_channels=num_channels)
Validator.check_value_type("min_init", min_init, [int, float, list], type(self).__name__)
Validator.check_value_type("max_init", max_init, [int, float, list], type(self).__name__)
if isinstance(max_init, (int, float)) and isinstance(min_init, (int, float)):
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
elif not np.greater(max_init, min_init).all():
raise ValueError("`min_init` is not less than `max_init`, please reset the initial value.")
Validator.check_non_negative_int(quant_delay, 'quant_delay')
self.min_init = min_init
self.max_init = max_init
@@ -405,7 +400,8 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):

min_array = self._get_init_array(self.min_init)
max_array = self._get_init_array(self.max_init)

if not np.greater(max_array, min_array).all():
raise ValueError("`min_init` is not less than `max_init`, please reset the initial value.")
if self.mode == "DEFAULT":
# init tensor min and max for fake quantized operation
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
@@ -441,7 +437,9 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
raise ValueError("The 'LEARNED_SCALE' mode only support symmetric quant, please set symmetric to True.")
if self.neg_trunc:
min_array = self._get_init_array(0)
self.narrow_range = False
if self.narrow_range:
raise ValueError("The 'LEARNED_SCALE' mode only support the combination of "
"neg_trunc=True and narrow_range=False config scenario.")
elif not self.narrow_range:
raise ValueError("The 'LEARNED_SCALE' mode only support narrow_range=True config, "
"except for neg_trunc=True scenario.")
@@ -483,6 +481,8 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
self.max_init = max_init
min_array = self._get_init_array(self.min_init)
max_array = self._get_init_array(self.max_init)
if not np.greater(max_array, min_array).all():
raise ValueError("`min_init` is not less than `max_init`, please reset the initial value.")
self.minq.set_data(Tensor(min_array))
self.maxq.set_data(Tensor(max_array))
self.quant_max.set_data(Tensor(np.array([self._quant_max]).astype(np.float32)))
@@ -494,10 +494,10 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
Convert the initial value to array.
"""
if isinstance(init_date, list) and self.per_channel and len(init_date) != self.num_channels:
raise ValueError("The length of the min_init/max_init list shuold be equal to num_channels for "
raise ValueError("The length of the min_init/max_init list should be equal to num_channels for "
"perchannel quant scenario, but get {}".format(len(init_date)))
if isinstance(init_date, list) and not self.per_channel and len(init_date) != 1:
raise ValueError("The length of the min_init/max_init list shuold be 1 for perlayer quant "
raise ValueError("The length of the min_init/max_init list should be 1 for perlayer quant "
"scenario, but get {}".format(len(init_date)))

if isinstance(init_date, list):
@@ -1343,8 +1343,6 @@ class ActQuant(_QuantActivation):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(ActQuant, self).__init__()
act_class = activation.__class__
act_list = [nn.ReLU, nn.ReLU6]
self.act = Validator.check_isinstance("activation", activation, Cell)
self.fake_before = Validator.check_bool(fake_before, "fake_before")
if self.fake_before:
@@ -1353,14 +1351,11 @@ class ActQuant(_QuantActivation):
ema=ema,
ema_decay=ema_decay,
quant_dtype=quant_dtype)

neg_trunc = bool(act_class in act_list)
self.fake_quant_act = quant_config.activation(min_init=-6,
max_init=6,
ema=ema,
ema_decay=ema_decay,
quant_dtype=quant_dtype,
neg_trunc=neg_trunc)
quant_dtype=quant_dtype)

def construct(self, x):
if self.fake_before:


Loading…
Cancel
Save