diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 5122813fa6..9ec1a2d578 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -1364,10 +1364,6 @@ class QuantBlock(Cell): self.has_bias = bias is not None self.activation = activation self.has_act = activation is not None - if isinstance(activation, ReLU): - self.activation = None - self.has_act = False - self.dequant.add_prim_attr("relu_flag", True) self.bias_add = P.BiasAdd() def construct(self, x): @@ -1376,9 +1372,10 @@ class QuantBlock(Cell): x = self.core_op(x, self.weight, self.bias) else: x = self.core_op(x, self.weight) + x = self.dequant(x, self.dequant_scale) + x = F.cast(x, mstype.float32) if self.has_act: x = self.activation(x) - x = self.dequant(x, self.dequant_scale) return x def extend_repr(self): diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index 78ffc00f94..c3ffb16e32 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -360,12 +360,12 @@ class ExportToQuantInferNetwork: scale_w, zp_w, _, _ = \ quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type) - scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \ + _, _, param_dict["output_maxq"], param_dict["output_minq"] = \ quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type) info = self.quant_info_table.get(w_minq_name, None) if info: - fack_quant_a_in_op, minq_name = info + fake_quant_a_in_op, minq_name = info if minq_name == 'input': scale_a_in, zp_a_in = self.input_scale, self.input_zero_point else: @@ -373,17 +373,17 @@ class ExportToQuantInferNetwork: minq = self.all_parameters[minq_name] if self.is_mindir: scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ - quant_utils.scale_zp_max_min_from_data(fack_quant_a_in_op, minq, maxq, np_type) + quant_utils.scale_zp_max_min_from_data(fake_quant_a_in_op, minq, maxq, np_type) else: - scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, minq, maxq, np_type) + scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fake_quant_a_in_op, minq, maxq, np_type) else: - logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") + logger.warning(f"Can not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") return None # Build the `Quant` `Dequant` op. # Quant only support perlayer version. Need check here. quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in)) - scale_deq = scale_a_out * scale_w + scale_deq = scale_a_in * scale_w dequant_op = inner.Dequant() if isinstance(activation, _AddFakeQuantAfterSubCell): @@ -407,7 +407,9 @@ class ExportToQuantInferNetwork: weight_b = weight bias_b = bias # apply the quant - weight = quant_utils.weight2int(weight, scale_w, zp_w) + fake_quant_weight_op = cell_core.fake_quant_weight.fake_quant_infer + weight = quant_utils.weight2int(weight, scale_w, zp_w, np_type, fake_quant_weight_op.num_bits, + fake_quant_weight_op.narrow_range) if bias is not None: bias = Tensor(bias / scale_a_in / scale_w, mstype.int32) diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index 1ca8ae0bd3..21cb231fde 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -29,7 +29,7 @@ def cal_quantization_params(input_min, Args: input_min (numpy.ndarray): The dimension of channel or 1. input_max (numpy.ndarray): The dimension of channel or 1. - data_type (numpy type) : Can ben numpy int8, numpy uint8. + data_type (numpy type) : Can be numpy int8, numpy uint8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. @@ -52,10 +52,12 @@ def cal_quantization_params(input_min, if data_type == np.int8: quant_min = 0 - 2 ** (num_bits - 1) - quant_max = 2 ** (num_bits - 1) - else: + quant_max = 2 ** (num_bits - 1) - 1 + elif data_type == np.uint8: quant_min = 0 quant_max = 2 ** num_bits - 1 + else: + raise ValueError("Unsupported datatype({})".format(data_type)) if narrow_range: quant_min = quant_min + 1 @@ -69,22 +71,13 @@ def cal_quantization_params(input_min, if symmetric: zp = np.zeros(input_min.shape) else: - zp_from_min = quant_min - input_min / scale - zp_from_max = quant_max - input_max / scale - zp_from_min_error = np.abs(quant_min) + np.abs(input_min / scale) - zp_from_max_error = np.abs(quant_max) + np.abs(input_max / scale) - zp_double = zp_from_min if zp_from_min_error < zp_from_max_error else zp_from_max - if zp_double < quant_min: - zp = quant_min - elif zp_double > quant_max: - zp = quant_max - else: - zp = np.floor(zp_double + 0.5) + zp_double = quant_min - input_min / scale + zp = np.floor(zp_double + 0.5) return scale, zp -def weight2int(data, scale, zero_point): +def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=False): r""" Calculate int8/uint8 weight from fp32. the formula is defined as: @@ -95,6 +88,9 @@ def weight2int(data, scale, zero_point): data (numpy.ndarray): The dimension of channel or 1. Should be NCHW. scale (numpy.ndarray): The dimension of channel or 1. zero_point (numpy.ndarray): The dimension of channel or 1. + data_type (numpy type) : Can be numpy int8, numpy uint8. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. Returns: weight (numpy.ndarray): The dimension of channel or 1. @@ -118,7 +114,21 @@ def weight2int(data, scale, zero_point): else: raise ValueError("Unsupported weight shape({})".format(data.shape)) - return np.round((data / scale) + zero_point) + if data_type == np.int8: + quant_min = 0 - 2 ** (num_bits - 1) + quant_max = 2 ** (num_bits - 1) - 1 + elif data_type == np.uint8: + quant_min = 0 + quant_max = 2 ** num_bits - 1 + else: + raise ValueError("Unsupported weight datatype({})".format(data_type)) + if narrow_range: + quant_min = quant_min + 1 + + weight_int = np.round((data / scale) + zero_point) + weight_int[weight_int > quant_max] = quant_max + weight_int[weight_int < quant_min] = quant_min + return weight_int def scale_zp_max_min_from_fake_quant_cell(cell, data_type): """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`.""" @@ -145,7 +155,7 @@ def scale_zp_from_data(op, minq, maxq, data_type): `mindspore.ops.operation.FakeQuantPerChannel` minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax` maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax` - data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`. + data_type (numpy type): Can be `numpy.int8` or `numpy.uint8`. Returns: scale (numpy.ndarray): quantization param. diff --git a/model_zoo/official/cv/lenet_quant/eval_quant.py b/model_zoo/official/cv/lenet_quant/eval_quant.py index e1ac7b501b..fb44c01a91 100644 --- a/model_zoo/official/cv/lenet_quant/eval_quant.py +++ b/model_zoo/official/cv/lenet_quant/eval_quant.py @@ -48,7 +48,7 @@ if __name__ == "__main__": network = LeNet5Fusion(cfg.num_classes) # convert fusion network to quantization aware network network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, - per_channel=[True, False]) + per_channel=[True, False], symmetric=[True, False]) # define loss net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") diff --git a/model_zoo/official/cv/lenet_quant/export.py b/model_zoo/official/cv/lenet_quant/export.py index 4fad84c9eb..380541587c 100644 --- a/model_zoo/official/cv/lenet_quant/export.py +++ b/model_zoo/official/cv/lenet_quant/export.py @@ -44,7 +44,8 @@ if __name__ == "__main__": # define fusion network network = LeNet5Fusion(cfg.num_classes) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) + network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, + per_channel=[True, False], symmetric=[True, False]) # load quantization aware network checkpoint param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) diff --git a/model_zoo/official/cv/lenet_quant/train_quant.py b/model_zoo/official/cv/lenet_quant/train_quant.py index e24ca5dec4..9e43cbe58c 100644 --- a/model_zoo/official/cv/lenet_quant/train_quant.py +++ b/model_zoo/official/cv/lenet_quant/train_quant.py @@ -60,7 +60,7 @@ if __name__ == "__main__": # convert fusion network to quantization aware network network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False], - symmetric=[False, False]) + symmetric=[True, False]) # define network loss net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")