diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 0bb69b4476..31a25a65b7 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -1379,3 +1379,66 @@ class QuantBlock(Cell): str_info = str_info + f', activation={self.activation}' str_info = str_info + f', dequant={self.dequant}' return str_info + + +class QuantMindirBlock(Cell): + """A quant binary block of Conv/Dense, activation layer for export MINDIR model. + + Args: + core_op (Cell): The operation cell. + weight (Tensor): The weigth of the cell. + bias (Tensor): The bias of the cell. Default: None. + activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None. + param_dict (dict): The information of the cell. + """ + + def __init__(self, + core_op, + weight, + bias=None, + activation=None, + param_dict=None): + + super(QuantMindirBlock, self).__init__() + self.core_op = core_op + if activation is not None: + self.core_op.add_prim_attr("activation_name", activation.__class__.__name__) + self.core_op.add_prim_attr("filter_maxq", Tensor(param_dict["filter_maxq"])) + self.core_op.add_prim_attr("filter_minq", Tensor(param_dict["filter_minq"])) + self.core_op.add_prim_attr("output_maxq", Tensor(param_dict["output_maxq"])) + self.core_op.add_prim_attr("output_minq", Tensor(param_dict["output_minq"])) + self.core_op.add_prim_attr("symmetric", Tensor(param_dict["symmetric"])) + if hasattr(core_op, 'pad_mode'): + self.core_op.add_prim_attr("pad_mode", core_op.pad_mode) + self.core_op.add_prim_attr("num_bits", Tensor(8)) + self.core_op.add_prim_attr("narrow_range", Tensor(False)) + if param_dict["input_maxq"] is not None: + self.core_op.add_prim_attr("input_maxq", Tensor(param_dict["input_maxq"])) + self.core_op.add_prim_attr("input_minq", Tensor(param_dict["input_minq"])) + else: + self.core_op.add_prim_attr("mean", Tensor(param_dict["mean"])) + self.core_op.add_prim_attr("std_dev", Tensor(param_dict["std_dev"])) + self.weight = weight + self.bias = bias + self.has_bias = bias is not None + self.activation = activation + self.has_act = activation is not None + if isinstance(activation, ReLU): + self.activation = None + self.has_act = False + self.bias_add = P.BiasAdd() + + def construct(self, x): + if self.has_bias: + x = self.core_op(x, self.weight, self.bias) + else: + x = self.core_op(x, self.weight) + return x + + def extend_repr(self): + str_info = f'core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]' + if self.has_bias: + str_info = str_info + f', bias=shape[{self.bias.shape}]' + if self.has_act: + str_info = str_info + f', activation={self.activation}' + return str_info diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index c112f1e711..4499c59bdf 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -304,13 +304,14 @@ class ExportToQuantInferNetwork: inputs (Tensor): Input tensors of the `quantization aware training network`. mean (int): Input data mean. Default: 127.5. std_dev (int, float): Input data variance. Default: 127.5. + is_mindir (bool): Whether is MINDIR format. Default: False. Returns: Cell, Infer network. """ __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] - def __init__(self, network, mean, std_dev, *inputs): + def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): network = validator.check_isinstance('network', network, (nn.Cell,)) # quantize for inputs: q = f / scale + zero_point # dequantize for outputs: f = (q - zero_point) * scale @@ -320,6 +321,9 @@ class ExportToQuantInferNetwork: self.network = copy.deepcopy(network) self.all_parameters = {p.name: p for p in self.network.get_parameters()} self.get_inputs_table(inputs) + self.mean = mean + self.std_dev = std_dev + self.is_mindir = is_mindir def get_inputs_table(self, inputs): """Get the support info for quant export.""" @@ -341,8 +345,24 @@ class ExportToQuantInferNetwork: # Calculate the scale and zero point w_minq_name = cell_core.fake_quant_weight.minq.name np_type = mstype.dtype_to_nptype(self.data_type) - scale_w, zp_w = quant_utils.scale_zp_from_fack_quant_cell(cell_core.fake_quant_weight, np_type) - scale_a_out, _ = quant_utils.scale_zp_from_fack_quant_cell(fake_quant_a_out, np_type) + param_dict = dict() + param_dict["filter_maxq"] = None + param_dict["filter_minq"] = None + param_dict["output_maxq"] = None + param_dict["output_minq"] = None + param_dict["input_maxq"] = None + param_dict["input_minq"] = None + param_dict["mean"] = self.mean + param_dict["std_dev"] = self.std_dev + param_dict["symmetric"] = fake_quant_a_out.symmetric + if self.is_mindir: + scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \ + quant_utils.scale_zp_max_min_from_fack_quant_cell(cell_core.fake_quant_weight, np_type) + scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \ + quant_utils.scale_zp_max_min_from_fack_quant_cell(fake_quant_a_out, np_type) + else: + scale_w, zp_w = quant_utils.scale_zp_from_fack_quant_cell(cell_core.fake_quant_weight, np_type) + scale_a_out, _ = quant_utils.scale_zp_from_fack_quant_cell(fake_quant_a_out, np_type) info = self.quant_info_table.get(w_minq_name, None) if info: fack_quant_a_in_op, minq_name = info @@ -351,7 +371,11 @@ class ExportToQuantInferNetwork: else: maxq = self.all_parameters[minq_name[:-4] + "maxq"] minq = self.all_parameters[minq_name] - scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, minq, maxq, np_type) + if self.is_mindir: + scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ + quant_utils.scale_zp_max_min_from_data(fack_quant_a_in_op, minq, maxq, np_type) + else: + scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, minq, maxq, np_type) else: logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") return None @@ -377,7 +401,8 @@ class ExportToQuantInferNetwork: weight, bias = quant_utils.fold_batchnorm(weight, cell_core) elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant): weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core) - + weight_b = weight + bias_b = bias # apply the quant weight = quant_utils.weight2int(weight, scale_w, zp_w) if bias is not None: @@ -398,10 +423,16 @@ class ExportToQuantInferNetwork: if isinstance(cell_core, quant.DenseQuant): op_core = P.MatMul() weight = np.transpose(weight) + weight_b = np.transpose(weight_b) else: op_core = cell_core.conv weight = Tensor(weight, self.data_type) - block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) + weight_b = Tensor(weight_b) + bias_b = Tensor(bias_b, mstype.float32) + if self.is_mindir: + block = quant.QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict) + else: + block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) return block def _convert_quant2deploy(self, network): @@ -475,8 +506,10 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format=' raise ValueError('Illegal file format {}.'.format(file_format)) network.set_train(False) - - exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs) + if file_format == "MINDIR": + exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True) + else: + exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs) deploy_net = exporter.run() serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index e7120d35be..4146d8195e 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -146,6 +146,20 @@ def scale_zp_from_fack_quant_cell(cell, data_type): return scale, zp +def scale_zp_max_min_from_fack_quant_cell(cell, data_type): + """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`.""" + minq = cell.minq.data.asnumpy() + maxq = cell.maxq.data.asnumpy() + op = cell.fake_quant_infer + + scale, zp = cal_quantization_params( + minq, maxq, data_type, + num_bits=op.num_bits, + symmetric=op.symmetric, + narrow_range=op.narrow_range) + return scale, zp, maxq, minq + + def scale_zp_from_data(op, minq, maxq, data_type): r""" Get calculate quantization params for scale and zero point. @@ -174,6 +188,19 @@ def scale_zp_from_data(op, minq, maxq, data_type): return scale, zp +def scale_zp_max_min_from_data(op, minq, maxq, data_type): + """Get calculate quantization params for scale, zero point, max and min.""" + minq = minq.data.asnumpy() + maxq = maxq.data.asnumpy() + + scale, zp = cal_quantization_params( + minq, maxq, data_type, + num_bits=op.num_bits, + symmetric=op.symmetric, + narrow_range=op.narrow_range) + return scale, zp, maxq, minq + + def fold_batchnorm(weight, cell_quant): r""" Fold the batchnorm in `Conv2dBnFoldQuant` to weight.