| @@ -29,7 +29,7 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor | |||||
| // To-DO the format may read from ME tensor | // To-DO the format may read from ME tensor | ||||
| MS_EXCEPTION_IF_NULL(value); | MS_EXCEPTION_IF_NULL(value); | ||||
| auto me_tensor = value->cast<MeTensorPtr>(); | auto me_tensor = value->cast<MeTensorPtr>(); | ||||
| auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_NCHW); | |||||
| auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_ND); | |||||
| return ge_tensor == nullptr ? GeTensor() : *ge_tensor; | return ge_tensor == nullptr ? GeTensor() : *ge_tensor; | ||||
| } | } | ||||
| @@ -388,7 +388,7 @@ class _Executor: | |||||
| dic = dict(zip(args_names, args_list)) | dic = dict(zip(args_names, args_list)) | ||||
| key = generate_key(phase, dic) | key = generate_key(phase, dic) | ||||
| self.phase_prefix = str(key[1]) | self.phase_prefix = str(key[1]) | ||||
| if phase == 'export': | |||||
| if 'export' in phase: | |||||
| phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time) | phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time) | ||||
| else: | else: | ||||
| phase = self.phase_prefix + phase + '.' + str(obj.create_time) | phase = self.phase_prefix + phase + '.' + str(obj.create_time) | ||||
| @@ -332,6 +332,7 @@ class Quant(PrimitiveWithInfer): | |||||
| self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | ||||
| self.round_mode = validator.check_string("round_mode", round_mode, | self.round_mode = validator.check_string("round_mode", round_mode, | ||||
| ["Round", "Floor", "Ceil", "Trunc"], self.name) | ["Round", "Floor", "Ceil", "Trunc"], self.name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| return x_shape | return x_shape | ||||
| @@ -382,6 +383,7 @@ class Dequant(PrimitiveWithInfer): | |||||
| self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | ||||
| self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) | self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) | ||||
| self.add_prim_attr("dtype", mstype.float16) | self.add_prim_attr("dtype", mstype.float16) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, x_shape, deq_scale_shape): | def infer_shape(self, x_shape, deq_scale_shape): | ||||
| return x_shape | return x_shape | ||||
| @@ -258,6 +258,7 @@ class _Reduce(PrimitiveWithInfer): | |||||
| """init Reduce""" | """init Reduce""" | ||||
| validator.check_value_type('keep_dims', keep_dims, [bool], self.name) | validator.check_value_type('keep_dims', keep_dims, [bool], self.name) | ||||
| self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) | self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def __call__(self, x, axis=()): | def __call__(self, x, axis=()): | ||||
| args = [x, axis] | args = [x, axis] | ||||
| @@ -626,6 +627,7 @@ class MatMul(PrimitiveWithInfer): | |||||
| cls_name = self.name | cls_name = self.name | ||||
| validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | ||||
| validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def check_shape_size(self, x, y): | def check_shape_size(self, x, y): | ||||
| if len(x) != 2 or len(y) != 2: | if len(x) != 2 or len(y) != 2: | ||||
| @@ -314,8 +314,8 @@ class ExportToQuantInferNetwork: | |||||
| network = validator.check_isinstance('network', network, (nn.Cell,)) | network = validator.check_isinstance('network', network, (nn.Cell,)) | ||||
| # quantize for inputs: q = f / scale + zero_point | # quantize for inputs: q = f / scale + zero_point | ||||
| # dequantize for outputs: f = (q - zero_point) * scale | # dequantize for outputs: f = (q - zero_point) * scale | ||||
| self.input_scale = round(mean) | |||||
| self.input_zero_point = 1 / std_dev | |||||
| self.input_scale = 1 / std_dev | |||||
| self.input_zero_point = round(mean) | |||||
| self.data_type = mstype.int8 | self.data_type = mstype.int8 | ||||
| self.network = copy.deepcopy(network) | self.network = copy.deepcopy(network) | ||||
| self.all_parameters = {p.name: p for p in self.network.get_parameters()} | self.all_parameters = {p.name: p for p in self.network.get_parameters()} | ||||
| @@ -351,20 +351,16 @@ class ExportToQuantInferNetwork: | |||||
| else: | else: | ||||
| maxq = self.all_parameters[minq_name[:-4] + "maxq"] | maxq = self.all_parameters[minq_name[:-4] + "maxq"] | ||||
| minq = self.all_parameters[minq_name] | minq = self.all_parameters[minq_name] | ||||
| scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) | |||||
| scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, minq, maxq, np_type) | |||||
| else: | else: | ||||
| logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") | logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") | ||||
| return None | return None | ||||
| # Build the `Quant` `Dequant` op. | # Build the `Quant` `Dequant` op. | ||||
| # Quant only support perlayer version. Need check here. | # Quant only support perlayer version. Need check here. | ||||
| quant_op = inner.Quant(float(scale_a_in), float(zp_a_in)) | |||||
| sqrt_mode = False | |||||
| quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in)) | |||||
| scale_deq = scale_a_out * scale_w | scale_deq = scale_a_out * scale_w | ||||
| if (scale_deq < 2 ** -14).all(): | |||||
| scale_deq = np.sqrt(scale_deq) | |||||
| sqrt_mode = True | |||||
| dequant_op = inner.Dequant(sqrt_mode) | |||||
| dequant_op = inner.Dequant() | |||||
| if isinstance(activation, _AddFakeQuantAfterSubCell): | if isinstance(activation, _AddFakeQuantAfterSubCell): | ||||
| activation = activation.subcell | activation = activation.subcell | ||||
| @@ -385,8 +381,19 @@ class ExportToQuantInferNetwork: | |||||
| # apply the quant | # apply the quant | ||||
| weight = quant_utils.weight2int(weight, scale_w, zp_w) | weight = quant_utils.weight2int(weight, scale_w, zp_w) | ||||
| if bias is not None: | if bias is not None: | ||||
| bias = Tensor(scale_a_in * scale_w * bias, mstype.int32) | |||||
| scale_deq = Tensor(scale_deq, mstype.float16) | |||||
| bias = Tensor(bias / scale_a_in / scale_w, mstype.int32) | |||||
| # fuse parameter | |||||
| # |--------|47:40|--------|39:32|--------|31:0| | |||||
| # offset_w [8] shift_N [8] deq_scale [32] | |||||
| float32_deq_scale = scale_deq.astype(np.float32) | |||||
| uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32) | |||||
| scale_length = scale_deq.size # channel | |||||
| dequant_param = np.zeros(scale_length, dtype=np.uint64) | |||||
| for index in range(scale_length): | |||||
| dequant_param[index] += uint32_deq_scale[index] | |||||
| scale_deq = Tensor(dequant_param, mstype.uint64) | |||||
| # get op | # get op | ||||
| if isinstance(cell_core, quant.DenseQuant): | if isinstance(cell_core, quant.DenseQuant): | ||||
| op_core = P.MatMul() | op_core = P.MatMul() | ||||