| @@ -102,13 +102,13 @@ class BatchNormFoldCell(Cell): | |||||
| return batch_mean, batch_std, running_mean, running_std | return batch_mean, batch_std, running_mean, running_std | ||||
| class FakeQuantWithMinMaxAscend(Cell): | |||||
| class FakeQuantWithMinMax(Cell): | |||||
| r""" | r""" | ||||
| Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. | Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. | ||||
| Args: | Args: | ||||
| min_init (int, list): The dimension of channel or 1(layer). Default: -6. | |||||
| max_init (int, list): The dimension of channel or 1(layer). Default: 6. | |||||
| min_init (int, float): The dimension of channel or 1(layer). Default: -6. | |||||
| max_init (int, float): The dimension of channel or 1(layer). Default: 6. | |||||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | ||||
| ema (bool): Exponential Moving Average algorithm update min and max. Default: False. | ema (bool): Exponential Moving Average algorithm update min and max. Default: False. | ||||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | ||||
| @@ -143,32 +143,33 @@ class FakeQuantWithMinMaxAscend(Cell): | |||||
| symmetric=False, | symmetric=False, | ||||
| narrow_range=False, | narrow_range=False, | ||||
| training=True): | training=True): | ||||
| """init FakeQuantWithMinMaxAscend layer""" | |||||
| super(FakeQuantWithMinMaxAscend, self).__init__() | |||||
| """init FakeQuantWithMinMax layer""" | |||||
| super(FakeQuantWithMinMax, self).__init__() | |||||
| self.min_init = min_init | self.min_init = min_init | ||||
| self.max_init = max_init | self.max_init = max_init | ||||
| self.num_bits = num_bits | self.num_bits = num_bits | ||||
| self.ema = ema | self.ema = ema | ||||
| self.ema_decay = ema_decay | self.ema_decay = ema_decay | ||||
| self.per_channel = per_channel | self.per_channel = per_channel | ||||
| self.out_channels = out_channels | |||||
| self.channel_axis = channel_axis | self.channel_axis = channel_axis | ||||
| self.quant_delay = quant_delay | self.quant_delay = quant_delay | ||||
| self.symmetric = symmetric | self.symmetric = symmetric | ||||
| self.narrow_range = narrow_range | self.narrow_range = narrow_range | ||||
| self.training = training | self.training = training | ||||
| self.is_ascend = context.get_context('device_target') == "Ascend" | |||||
| # init tensor min and max for fake quant op | # init tensor min and max for fake quant op | ||||
| if isinstance(min_init, int): | |||||
| min_array = np.array([min_init]).reshape(1).astype(np.float32) | |||||
| max_array = np.array([max_init]).reshape(1).astype(np.float32) | |||||
| elif isinstance(min_init, list): | |||||
| min_array = np.array([self.min_init for i in range( | |||||
| 0, self.out_channels)]).astype(np.float32) | |||||
| max_array = np.array([self.max_init for i in range( | |||||
| 0, self.out_channels)]).astype(np.float32) | |||||
| if self.per_channel: | |||||
| min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) | |||||
| max_array = np.array([self.max_init for i in range(0, self.out_channels)]).astype(np.float32) | |||||
| else: | |||||
| min_array = np.array([self.min_init]).reshape(1).astype(np.float32) | |||||
| max_array = np.array([self.max_init]).reshape(1).astype(np.float32) | |||||
| self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) | self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) | ||||
| self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) | self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) | ||||
| # init fake quant relative op | |||||
| if per_channel: | if per_channel: | ||||
| quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) | quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) | ||||
| ema_fun = partial(P.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis) | ema_fun = partial(P.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis) | ||||
| @@ -176,28 +177,36 @@ class FakeQuantWithMinMaxAscend(Cell): | |||||
| quant_fun = P.FakeQuantPerLayer | quant_fun = P.FakeQuantPerLayer | ||||
| ema_fun = P.FakeQuantMinMaxPerLayerUpdate | ema_fun = P.FakeQuantMinMaxPerLayerUpdate | ||||
| self.fake_quant = quant_fun(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=self.ema_decay, | |||||
| quant_delay=self.quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=self.training) | |||||
| self.ema_update = ema_fun(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=self.ema_decay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=self.training) | |||||
| if self.is_ascend: | |||||
| self.fake_quant = quant_fun(num_bits=self.num_bits, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=self.training) | |||||
| else: | |||||
| self.fake_quant = quant_fun(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=ema_decay, | |||||
| quant_delay=quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=self.training) | |||||
| if self.ema: | |||||
| self.ema_update = ema_fun(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=self.ema_decay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=self.training) | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format( | |||||
| self.min_init, self.max_init, self.ema, self.ema_decay, | |||||
| self.per_channel, self.quant_delay, self.channel_axis) | |||||
| s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \ | |||||
| 'quant_delay={}, min_init={}, max_init={}'.format( | |||||
| self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel, | |||||
| self.channel_axis, self.out_channels, self.quant_delay, self.min_init, self.max_init) | |||||
| return s | return s | ||||
| def construct(self, x): | def construct(self, x): | ||||
| if self.update: | |||||
| if self.ema and self.is_ascend: | |||||
| min_up, max_up = self.ema_update(x, self.minq, self.maxq) | min_up, max_up = self.ema_update(x, self.minq, self.maxq) | ||||
| out = self.fake_quant(x, min_up, max_up) | out = self.fake_quant(x, min_up, max_up) | ||||
| P.Assign()(self.minq, min_up) | P.Assign()(self.minq, min_up) | ||||
| @@ -207,104 +216,6 @@ class FakeQuantWithMinMaxAscend(Cell): | |||||
| return out | return out | ||||
| class FakeQuantWithMinMaxGPU(Cell): | |||||
| r""" | |||||
| Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. | |||||
| Args: | |||||
| min_init (int, list): The dimension of channel or 1(layer). Default: -6. | |||||
| max_init (int, list): The dimension of channel or 1(layer). Default: 6. | |||||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||||
| ema (bool): Exponential Moving Average algorithm update min and max. Default: False. | |||||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | |||||
| per_channel (bool): Quantization by layer or channel. Default: False. | |||||
| out_channels (int): declarate the min and max channel size, Default: 1. | |||||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||||
| Inputs: | |||||
| - **x** (Tensor) - The input of FakeQuantWithMinMax. | |||||
| Outputs: | |||||
| Tensor, with the same type and shape as the `x`. | |||||
| Examples: | |||||
| >>> fake_quant = FakeQuantWithMinMax() | |||||
| >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) | |||||
| >>> result = fake_quant(input_x) | |||||
| """ | |||||
| def __init__(self, | |||||
| min_init=-6, | |||||
| max_init=6, | |||||
| num_bits=8, | |||||
| ema=False, | |||||
| ema_decay=0.999, | |||||
| per_channel=False, | |||||
| channel_axis=1, | |||||
| out_channels=1, | |||||
| quant_delay=0, | |||||
| symmetric=False, | |||||
| narrow_range=False, | |||||
| training=True): | |||||
| super(FakeQuantWithMinMaxGPU, self).__init__() | |||||
| self.min_init = min_init | |||||
| self.max_init = max_init | |||||
| self.num_bits = num_bits | |||||
| self.ema = ema | |||||
| self.ema_decay = ema_decay | |||||
| self.per_channel = per_channel | |||||
| self.channel_axis = channel_axis | |||||
| self.quant_delay = quant_delay | |||||
| self.symmetric = symmetric | |||||
| self.narrow_range = narrow_range | |||||
| self.training = training | |||||
| # init tensor min and max for fake quant op | |||||
| if isinstance(min_init, int): | |||||
| min_array = np.array([min_init]).reshape(1).astype(np.float32) | |||||
| max_array = np.array([max_init]).reshape(1).astype(np.float32) | |||||
| elif isinstance(min_init, list): | |||||
| min_array = np.array([self.min_init for i in range( | |||||
| 0, self.out_channels)]).astype(np.float32) | |||||
| max_array = np.array([self.max_init for i in range( | |||||
| 0, self.out_channels)]).astype(np.float32) | |||||
| self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) | |||||
| self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) | |||||
| if per_channel: | |||||
| quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) | |||||
| else: | |||||
| quant_fun = P.FakeQuantPerLayer | |||||
| self.fake_quant = quant_fun(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=ema_decay, | |||||
| quant_delay=quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=self.training) | |||||
| def extend_repr(self): | |||||
| s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format( | |||||
| self.min_init, self.max_init, self.ema, self.ema_decay, | |||||
| self.per_channel, self.quant_delay, self.channel_axis) | |||||
| return s | |||||
| def construct(self, x): | |||||
| out = self.fake_quant(x, self.minq, self.maxq) | |||||
| return out | |||||
| def FakeQuantWithMinMax(**kwargs): | |||||
| if context.get_context('device_target') == "Ascend": | |||||
| out = FakeQuantWithMinMaxAscend(**kwargs) | |||||
| if context.get_context('device_target') == "GPU": | |||||
| out = FakeQuantWithMinMaxGPU(**kwargs) | |||||
| else: | |||||
| raise ValueError("Not support platform or channel mode.") | |||||
| return out | |||||
| class Conv2dBatchNormQuant(Cell): | class Conv2dBatchNormQuant(Cell): | ||||
| r""" | r""" | ||||
| 2D convolution with BatchNormal op folded layer. | 2D convolution with BatchNormal op folded layer. | ||||
| @@ -92,6 +92,7 @@ def fake_quant_perchannel(x, min_val, max_val, y, | |||||
| kernel_name="fake_quant_perchannel"): | kernel_name="fake_quant_perchannel"): | ||||
| """FakeQuantPerChannel""" | """FakeQuantPerChannel""" | ||||
| x_shape = x.get("shape") | x_shape = x.get("shape") | ||||
| x_shape_ = x.get("ori_shape") | |||||
| x_format = x.get("format") | x_format = x.get("format") | ||||
| x_dtype = x.get("dtype") | x_dtype = x.get("dtype") | ||||
| min_shape = min_val.get("ori_shape") | min_shape = min_val.get("ori_shape") | ||||
| @@ -101,8 +102,8 @@ def fake_quant_perchannel(x, min_val, max_val, y, | |||||
| util.check_kernel_name(kernel_name) | util.check_kernel_name(kernel_name) | ||||
| util.check_shape_rule(x_shape) | util.check_shape_rule(x_shape) | ||||
| util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) | |||||
| util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) | |||||
| util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) | |||||
| util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) | |||||
| util.check_tensor_shape_size(x_shape) | util.check_tensor_shape_size(x_shape) | ||||
| util.check_tensor_shape_size(min_shape) | util.check_tensor_shape_size(min_shape) | ||||
| util.check_tensor_shape_size(max_shape) | util.check_tensor_shape_size(max_shape) | ||||
| @@ -117,6 +117,7 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx, | |||||
| kernel_name="fake_quant_perchannel_grad"): | kernel_name="fake_quant_perchannel_grad"): | ||||
| """FakeQuantPerChannelGrad""" | """FakeQuantPerChannelGrad""" | ||||
| x_shape = x.get("shape") | x_shape = x.get("shape") | ||||
| x_shape_ = x.get("ori_shape") | |||||
| x_format = x.get("format") | x_format = x.get("format") | ||||
| x_dtype = x.get("dtype") | x_dtype = x.get("dtype") | ||||
| min_shape = min_val.get("ori_shape") | min_shape = min_val.get("ori_shape") | ||||
| @@ -126,8 +127,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx, | |||||
| util.check_kernel_name(kernel_name) | util.check_kernel_name(kernel_name) | ||||
| util.check_shape_rule(x_shape) | util.check_shape_rule(x_shape) | ||||
| util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) | |||||
| util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) | |||||
| util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) | |||||
| util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) | |||||
| util.check_tensor_shape_size(x_shape) | util.check_tensor_shape_size(x_shape) | ||||
| util.check_tensor_shape_size(min_shape) | util.check_tensor_shape_size(min_shape) | ||||
| util.check_tensor_shape_size(max_shape) | util.check_tensor_shape_size(max_shape) | ||||
| @@ -43,7 +43,7 @@ fake_quant_per_layer_op_info = TBERegOp("FakeQuantPerLayer") \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @op_info_register(fake_quant_op_info) | |||||
| @op_info_register(fake_quant_per_layer_op_info) | |||||
| def _fake_quant_per_layer_tbe(): | def _fake_quant_per_layer_tbe(): | ||||
| """FakeQuantPerLayer TBE register""" | """FakeQuantPerLayer TBE register""" | ||||
| return | return | ||||
| @@ -110,10 +110,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape, min_shape, max_shape): | def infer_shape(self, x_shape, min_shape, max_shape): | ||||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | ||||
| validator.check("min shape", min_shape, "max shape", | |||||
| max_shape, Rel.EQ, self.name) | |||||
| validator.check_integer("min rank", len( | |||||
| min_shape), 1, Rel.EQ, self.name) | |||||
| validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) | |||||
| validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name) | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type, min_type, max_type): | def infer_dtype(self, x_type, min_type, max_type): | ||||
| @@ -168,7 +166,7 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer): | |||||
| x_shape, Rel.EQ, self.name) | x_shape, Rel.EQ, self.name) | ||||
| validator.check("min shape", min_shape, "max shape", | validator.check("min shape", min_shape, "max shape", | ||||
| max_shape, Rel.EQ, self.name) | max_shape, Rel.EQ, self.name) | ||||
| validator.check_integer("min rank", len( | |||||
| validator.check_integer("min shape", len( | |||||
| min_shape), 1, Rel.EQ, self.name) | min_shape), 1, Rel.EQ, self.name) | ||||
| return dout_shape | return dout_shape | ||||
| @@ -255,10 +253,11 @@ class FakeQuantPerChannel(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape, min_shape, max_shape): | def infer_shape(self, x_shape, min_shape, max_shape): | ||||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | ||||
| validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) | |||||
| validator.check_integer( | validator.check_integer( | ||||
| "min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) | |||||
| "min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) | |||||
| validator.check_integer( | validator.check_integer( | ||||
| "max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) | |||||
| "max shape", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type, min_type, max_type): | def infer_dtype(self, x_type, min_type, max_type): | ||||
| @@ -379,7 +378,7 @@ class BatchNormFold(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): | def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): | ||||
| validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) | validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) | ||||
| validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) | validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) | ||||
| validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) | |||||
| validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) | |||||
| return mean_shape, mean_shape, mean_shape, mean_shape | return mean_shape, mean_shape, mean_shape, mean_shape | ||||
| def infer_dtype(self, x_type, mean_type, variance_type, global_step_type): | def infer_dtype(self, x_type, mean_type, variance_type, global_step_type): | ||||
| @@ -426,7 +425,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer): | |||||
| "batch_std shape", batch_std_shape, Rel.EQ, self.name) | "batch_std shape", batch_std_shape, Rel.EQ, self.name) | ||||
| validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], | validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], | ||||
| "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) | "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) | ||||
| validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) | |||||
| validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type, | def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type, | ||||
| @@ -568,7 +567,7 @@ class BatchNormFold2(PrimitiveWithInfer): | |||||
| validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) | validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) | ||||
| validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis], | validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis], | ||||
| Rel.EQ, self.name) | Rel.EQ, self.name) | ||||
| validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) | |||||
| validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type, | def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type, | ||||
| @@ -616,7 +615,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer): | |||||
| validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) | validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) | ||||
| validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis], | validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis], | ||||
| Rel.EQ, self.name) | Rel.EQ, self.name) | ||||
| validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) | |||||
| validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) | |||||
| return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape | return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape | ||||
| def infer_dtype(self, dout_type, x_type, gamma_type, | def infer_dtype(self, dout_type, x_type, gamma_type, | ||||
| @@ -887,7 +886,7 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): | |||||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | ||||
| validator.check("min shape", min_shape, "max shape", | validator.check("min shape", min_shape, "max shape", | ||||
| max_shape, Rel.EQ, self.name) | max_shape, Rel.EQ, self.name) | ||||
| validator.check_integer("min rank", len( | |||||
| validator.check_integer("min shape", len( | |||||
| min_shape), 1, Rel.EQ, self.name) | min_shape), 1, Rel.EQ, self.name) | ||||
| return min_shape, max_shape | return min_shape, max_shape | ||||
| @@ -963,7 +962,7 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer): | |||||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) | validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) | ||||
| validator.check("min shape", min_shape, "max shape", | validator.check("min shape", min_shape, "max shape", | ||||
| max_shape, Rel.EQ, self.name) | max_shape, Rel.EQ, self.name) | ||||
| validator.check_integer("min rank", len( | |||||
| validator.check_integer("min shape", len( | |||||
| min_shape), 1, Rel.EQ, self.name) | min_shape), 1, Rel.EQ, self.name) | ||||
| return min_shape, max_shape | return min_shape, max_shape | ||||