Browse Source

!7118 [quant] quant evaluation export bugfix

Merge pull request !7118 from yuchaojie/quant
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
c5ce9fbebd
6 changed files with 42 additions and 32 deletions
  1. +2
    -5
      mindspore/nn/layer/quant.py
  2. +9
    -7
      mindspore/train/quant/quant.py
  3. +27
    -17
      mindspore/train/quant/quant_utils.py
  4. +1
    -1
      model_zoo/official/cv/lenet_quant/eval_quant.py
  5. +2
    -1
      model_zoo/official/cv/lenet_quant/export.py
  6. +1
    -1
      model_zoo/official/cv/lenet_quant/train_quant.py

+ 2
- 5
mindspore/nn/layer/quant.py View File

@@ -1364,10 +1364,6 @@ class QuantBlock(Cell):
self.has_bias = bias is not None self.has_bias = bias is not None
self.activation = activation self.activation = activation
self.has_act = activation is not None self.has_act = activation is not None
if isinstance(activation, ReLU):
self.activation = None
self.has_act = False
self.dequant.add_prim_attr("relu_flag", True)
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()


def construct(self, x): def construct(self, x):
@@ -1376,9 +1372,10 @@ class QuantBlock(Cell):
x = self.core_op(x, self.weight, self.bias) x = self.core_op(x, self.weight, self.bias)
else: else:
x = self.core_op(x, self.weight) x = self.core_op(x, self.weight)
x = self.dequant(x, self.dequant_scale)
x = F.cast(x, mstype.float32)
if self.has_act: if self.has_act:
x = self.activation(x) x = self.activation(x)
x = self.dequant(x, self.dequant_scale)
return x return x


def extend_repr(self): def extend_repr(self):


+ 9
- 7
mindspore/train/quant/quant.py View File

@@ -360,12 +360,12 @@ class ExportToQuantInferNetwork:


scale_w, zp_w, _, _ = \ scale_w, zp_w, _, _ = \
quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type) quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \
_, _, param_dict["output_maxq"], param_dict["output_minq"] = \
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type) quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)


info = self.quant_info_table.get(w_minq_name, None) info = self.quant_info_table.get(w_minq_name, None)
if info: if info:
fack_quant_a_in_op, minq_name = info
fake_quant_a_in_op, minq_name = info
if minq_name == 'input': if minq_name == 'input':
scale_a_in, zp_a_in = self.input_scale, self.input_zero_point scale_a_in, zp_a_in = self.input_scale, self.input_zero_point
else: else:
@@ -373,17 +373,17 @@ class ExportToQuantInferNetwork:
minq = self.all_parameters[minq_name] minq = self.all_parameters[minq_name]
if self.is_mindir: if self.is_mindir:
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
quant_utils.scale_zp_max_min_from_data(fack_quant_a_in_op, minq, maxq, np_type)
quant_utils.scale_zp_max_min_from_data(fake_quant_a_in_op, minq, maxq, np_type)
else: else:
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, minq, maxq, np_type)
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fake_quant_a_in_op, minq, maxq, np_type)
else: else:
logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}")
logger.warning(f"Can not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}")
return None return None


# Build the `Quant` `Dequant` op. # Build the `Quant` `Dequant` op.
# Quant only support perlayer version. Need check here. # Quant only support perlayer version. Need check here.
quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in)) quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in))
scale_deq = scale_a_out * scale_w
scale_deq = scale_a_in * scale_w
dequant_op = inner.Dequant() dequant_op = inner.Dequant()


if isinstance(activation, _AddFakeQuantAfterSubCell): if isinstance(activation, _AddFakeQuantAfterSubCell):
@@ -407,7 +407,9 @@ class ExportToQuantInferNetwork:
weight_b = weight weight_b = weight
bias_b = bias bias_b = bias
# apply the quant # apply the quant
weight = quant_utils.weight2int(weight, scale_w, zp_w)
fake_quant_weight_op = cell_core.fake_quant_weight.fake_quant_infer
weight = quant_utils.weight2int(weight, scale_w, zp_w, np_type, fake_quant_weight_op.num_bits,
fake_quant_weight_op.narrow_range)
if bias is not None: if bias is not None:
bias = Tensor(bias / scale_a_in / scale_w, mstype.int32) bias = Tensor(bias / scale_a_in / scale_w, mstype.int32)




+ 27
- 17
mindspore/train/quant/quant_utils.py View File

@@ -29,7 +29,7 @@ def cal_quantization_params(input_min,
Args: Args:
input_min (numpy.ndarray): The dimension of channel or 1. input_min (numpy.ndarray): The dimension of channel or 1.
input_max (numpy.ndarray): The dimension of channel or 1. input_max (numpy.ndarray): The dimension of channel or 1.
data_type (numpy type) : Can ben numpy int8, numpy uint8.
data_type (numpy type) : Can be numpy int8, numpy uint8.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
@@ -52,10 +52,12 @@ def cal_quantization_params(input_min,


if data_type == np.int8: if data_type == np.int8:
quant_min = 0 - 2 ** (num_bits - 1) quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1)
else:
quant_max = 2 ** (num_bits - 1) - 1
elif data_type == np.uint8:
quant_min = 0 quant_min = 0
quant_max = 2 ** num_bits - 1 quant_max = 2 ** num_bits - 1
else:
raise ValueError("Unsupported datatype({})".format(data_type))
if narrow_range: if narrow_range:
quant_min = quant_min + 1 quant_min = quant_min + 1


@@ -69,22 +71,13 @@ def cal_quantization_params(input_min,
if symmetric: if symmetric:
zp = np.zeros(input_min.shape) zp = np.zeros(input_min.shape)
else: else:
zp_from_min = quant_min - input_min / scale
zp_from_max = quant_max - input_max / scale
zp_from_min_error = np.abs(quant_min) + np.abs(input_min / scale)
zp_from_max_error = np.abs(quant_max) + np.abs(input_max / scale)
zp_double = zp_from_min if zp_from_min_error < zp_from_max_error else zp_from_max
if zp_double < quant_min:
zp = quant_min
elif zp_double > quant_max:
zp = quant_max
else:
zp = np.floor(zp_double + 0.5)
zp_double = quant_min - input_min / scale
zp = np.floor(zp_double + 0.5)


return scale, zp return scale, zp




def weight2int(data, scale, zero_point):
def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=False):
r""" r"""
Calculate int8/uint8 weight from fp32. the formula is defined as: Calculate int8/uint8 weight from fp32. the formula is defined as:


@@ -95,6 +88,9 @@ def weight2int(data, scale, zero_point):
data (numpy.ndarray): The dimension of channel or 1. Should be NCHW. data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
scale (numpy.ndarray): The dimension of channel or 1. scale (numpy.ndarray): The dimension of channel or 1.
zero_point (numpy.ndarray): The dimension of channel or 1. zero_point (numpy.ndarray): The dimension of channel or 1.
data_type (numpy type) : Can be numpy int8, numpy uint8.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.


Returns: Returns:
weight (numpy.ndarray): The dimension of channel or 1. weight (numpy.ndarray): The dimension of channel or 1.
@@ -118,7 +114,21 @@ def weight2int(data, scale, zero_point):
else: else:
raise ValueError("Unsupported weight shape({})".format(data.shape)) raise ValueError("Unsupported weight shape({})".format(data.shape))


return np.round((data / scale) + zero_point)
if data_type == np.int8:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
elif data_type == np.uint8:
quant_min = 0
quant_max = 2 ** num_bits - 1
else:
raise ValueError("Unsupported weight datatype({})".format(data_type))
if narrow_range:
quant_min = quant_min + 1

weight_int = np.round((data / scale) + zero_point)
weight_int[weight_int > quant_max] = quant_max
weight_int[weight_int < quant_min] = quant_min
return weight_int


def scale_zp_max_min_from_fake_quant_cell(cell, data_type): def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
"""Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`.""" """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`."""
@@ -145,7 +155,7 @@ def scale_zp_from_data(op, minq, maxq, data_type):
`mindspore.ops.operation.FakeQuantPerChannel` `mindspore.ops.operation.FakeQuantPerChannel`
minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax` minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax` maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
data_type (numpy type): Can be `numpy.int8` or `numpy.uint8`.


Returns: Returns:
scale (numpy.ndarray): quantization param. scale (numpy.ndarray): quantization param.


+ 1
- 1
model_zoo/official/cv/lenet_quant/eval_quant.py View File

@@ -48,7 +48,7 @@ if __name__ == "__main__":
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
per_channel=[True, False])
per_channel=[True, False], symmetric=[True, False])


# define loss # define loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")


+ 2
- 1
model_zoo/official/cv/lenet_quant/export.py View File

@@ -44,7 +44,8 @@ if __name__ == "__main__":
# define fusion network # define fusion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
per_channel=[True, False], symmetric=[True, False])
# load quantization aware network checkpoint # load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)


+ 1
- 1
model_zoo/official/cv/lenet_quant/train_quant.py View File

@@ -60,7 +60,7 @@ if __name__ == "__main__":


# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False], network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False],
symmetric=[False, False])
symmetric=[True, False])


# define network loss # define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")


Loading…
Cancel
Save