Merge pull request !7118 from yuchaojie/quanttags/v1.1.0
| @@ -1364,10 +1364,6 @@ class QuantBlock(Cell): | |||
| self.has_bias = bias is not None | |||
| self.activation = activation | |||
| self.has_act = activation is not None | |||
| if isinstance(activation, ReLU): | |||
| self.activation = None | |||
| self.has_act = False | |||
| self.dequant.add_prim_attr("relu_flag", True) | |||
| self.bias_add = P.BiasAdd() | |||
| def construct(self, x): | |||
| @@ -1376,9 +1372,10 @@ class QuantBlock(Cell): | |||
| x = self.core_op(x, self.weight, self.bias) | |||
| else: | |||
| x = self.core_op(x, self.weight) | |||
| x = self.dequant(x, self.dequant_scale) | |||
| x = F.cast(x, mstype.float32) | |||
| if self.has_act: | |||
| x = self.activation(x) | |||
| x = self.dequant(x, self.dequant_scale) | |||
| return x | |||
| def extend_repr(self): | |||
| @@ -360,12 +360,12 @@ class ExportToQuantInferNetwork: | |||
| scale_w, zp_w, _, _ = \ | |||
| quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type) | |||
| scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \ | |||
| _, _, param_dict["output_maxq"], param_dict["output_minq"] = \ | |||
| quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type) | |||
| info = self.quant_info_table.get(w_minq_name, None) | |||
| if info: | |||
| fack_quant_a_in_op, minq_name = info | |||
| fake_quant_a_in_op, minq_name = info | |||
| if minq_name == 'input': | |||
| scale_a_in, zp_a_in = self.input_scale, self.input_zero_point | |||
| else: | |||
| @@ -373,17 +373,17 @@ class ExportToQuantInferNetwork: | |||
| minq = self.all_parameters[minq_name] | |||
| if self.is_mindir: | |||
| scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ | |||
| quant_utils.scale_zp_max_min_from_data(fack_quant_a_in_op, minq, maxq, np_type) | |||
| quant_utils.scale_zp_max_min_from_data(fake_quant_a_in_op, minq, maxq, np_type) | |||
| else: | |||
| scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, minq, maxq, np_type) | |||
| scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fake_quant_a_in_op, minq, maxq, np_type) | |||
| else: | |||
| logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") | |||
| logger.warning(f"Can not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") | |||
| return None | |||
| # Build the `Quant` `Dequant` op. | |||
| # Quant only support perlayer version. Need check here. | |||
| quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in)) | |||
| scale_deq = scale_a_out * scale_w | |||
| scale_deq = scale_a_in * scale_w | |||
| dequant_op = inner.Dequant() | |||
| if isinstance(activation, _AddFakeQuantAfterSubCell): | |||
| @@ -407,7 +407,9 @@ class ExportToQuantInferNetwork: | |||
| weight_b = weight | |||
| bias_b = bias | |||
| # apply the quant | |||
| weight = quant_utils.weight2int(weight, scale_w, zp_w) | |||
| fake_quant_weight_op = cell_core.fake_quant_weight.fake_quant_infer | |||
| weight = quant_utils.weight2int(weight, scale_w, zp_w, np_type, fake_quant_weight_op.num_bits, | |||
| fake_quant_weight_op.narrow_range) | |||
| if bias is not None: | |||
| bias = Tensor(bias / scale_a_in / scale_w, mstype.int32) | |||
| @@ -29,7 +29,7 @@ def cal_quantization_params(input_min, | |||
| Args: | |||
| input_min (numpy.ndarray): The dimension of channel or 1. | |||
| input_max (numpy.ndarray): The dimension of channel or 1. | |||
| data_type (numpy type) : Can ben numpy int8, numpy uint8. | |||
| data_type (numpy type) : Can be numpy int8, numpy uint8. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. | |||
| narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. | |||
| @@ -52,10 +52,12 @@ def cal_quantization_params(input_min, | |||
| if data_type == np.int8: | |||
| quant_min = 0 - 2 ** (num_bits - 1) | |||
| quant_max = 2 ** (num_bits - 1) | |||
| else: | |||
| quant_max = 2 ** (num_bits - 1) - 1 | |||
| elif data_type == np.uint8: | |||
| quant_min = 0 | |||
| quant_max = 2 ** num_bits - 1 | |||
| else: | |||
| raise ValueError("Unsupported datatype({})".format(data_type)) | |||
| if narrow_range: | |||
| quant_min = quant_min + 1 | |||
| @@ -69,22 +71,13 @@ def cal_quantization_params(input_min, | |||
| if symmetric: | |||
| zp = np.zeros(input_min.shape) | |||
| else: | |||
| zp_from_min = quant_min - input_min / scale | |||
| zp_from_max = quant_max - input_max / scale | |||
| zp_from_min_error = np.abs(quant_min) + np.abs(input_min / scale) | |||
| zp_from_max_error = np.abs(quant_max) + np.abs(input_max / scale) | |||
| zp_double = zp_from_min if zp_from_min_error < zp_from_max_error else zp_from_max | |||
| if zp_double < quant_min: | |||
| zp = quant_min | |||
| elif zp_double > quant_max: | |||
| zp = quant_max | |||
| else: | |||
| zp = np.floor(zp_double + 0.5) | |||
| zp_double = quant_min - input_min / scale | |||
| zp = np.floor(zp_double + 0.5) | |||
| return scale, zp | |||
| def weight2int(data, scale, zero_point): | |||
| def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=False): | |||
| r""" | |||
| Calculate int8/uint8 weight from fp32. the formula is defined as: | |||
| @@ -95,6 +88,9 @@ def weight2int(data, scale, zero_point): | |||
| data (numpy.ndarray): The dimension of channel or 1. Should be NCHW. | |||
| scale (numpy.ndarray): The dimension of channel or 1. | |||
| zero_point (numpy.ndarray): The dimension of channel or 1. | |||
| data_type (numpy type) : Can be numpy int8, numpy uint8. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. | |||
| Returns: | |||
| weight (numpy.ndarray): The dimension of channel or 1. | |||
| @@ -118,7 +114,21 @@ def weight2int(data, scale, zero_point): | |||
| else: | |||
| raise ValueError("Unsupported weight shape({})".format(data.shape)) | |||
| return np.round((data / scale) + zero_point) | |||
| if data_type == np.int8: | |||
| quant_min = 0 - 2 ** (num_bits - 1) | |||
| quant_max = 2 ** (num_bits - 1) - 1 | |||
| elif data_type == np.uint8: | |||
| quant_min = 0 | |||
| quant_max = 2 ** num_bits - 1 | |||
| else: | |||
| raise ValueError("Unsupported weight datatype({})".format(data_type)) | |||
| if narrow_range: | |||
| quant_min = quant_min + 1 | |||
| weight_int = np.round((data / scale) + zero_point) | |||
| weight_int[weight_int > quant_max] = quant_max | |||
| weight_int[weight_int < quant_min] = quant_min | |||
| return weight_int | |||
| def scale_zp_max_min_from_fake_quant_cell(cell, data_type): | |||
| """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`.""" | |||
| @@ -145,7 +155,7 @@ def scale_zp_from_data(op, minq, maxq, data_type): | |||
| `mindspore.ops.operation.FakeQuantPerChannel` | |||
| minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax` | |||
| maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax` | |||
| data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`. | |||
| data_type (numpy type): Can be `numpy.int8` or `numpy.uint8`. | |||
| Returns: | |||
| scale (numpy.ndarray): quantization param. | |||
| @@ -48,7 +48,7 @@ if __name__ == "__main__": | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # convert fusion network to quantization aware network | |||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, | |||
| per_channel=[True, False]) | |||
| per_channel=[True, False], symmetric=[True, False]) | |||
| # define loss | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||
| @@ -44,7 +44,8 @@ if __name__ == "__main__": | |||
| # define fusion network | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # convert fusion network to quantization aware network | |||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | |||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, | |||
| per_channel=[True, False], symmetric=[True, False]) | |||
| # load quantization aware network checkpoint | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(network, param_dict) | |||
| @@ -60,7 +60,7 @@ if __name__ == "__main__": | |||
| # convert fusion network to quantization aware network | |||
| network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False], | |||
| symmetric=[False, False]) | |||
| symmetric=[True, False]) | |||
| # define network loss | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||