Browse Source

bug fix in fake quant

tags/v0.5.0-beta
chenzomi 5 years ago
parent
commit
4da1e21f45
5 changed files with 58 additions and 146 deletions
  1. +39
    -128
      mindspore/nn/layer/quant.py
  2. +3
    -2
      mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py
  3. +3
    -2
      mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py
  4. +1
    -1
      mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py
  5. +12
    -13
      mindspore/ops/operations/_quant_ops.py

+ 39
- 128
mindspore/nn/layer/quant.py View File

@@ -102,13 +102,13 @@ class BatchNormFoldCell(Cell):
return batch_mean, batch_std, running_mean, running_std return batch_mean, batch_std, running_mean, running_std




class FakeQuantWithMinMaxAscend(Cell):
class FakeQuantWithMinMax(Cell):
r""" r"""
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.


Args: Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
max_init (int, list): The dimension of channel or 1(layer). Default: 6.
min_init (int, float): The dimension of channel or 1(layer). Default: -6.
max_init (int, float): The dimension of channel or 1(layer). Default: 6.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False. ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
@@ -143,32 +143,33 @@ class FakeQuantWithMinMaxAscend(Cell):
symmetric=False, symmetric=False,
narrow_range=False, narrow_range=False,
training=True): training=True):
"""init FakeQuantWithMinMaxAscend layer"""
super(FakeQuantWithMinMaxAscend, self).__init__()
"""init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__()
self.min_init = min_init self.min_init = min_init
self.max_init = max_init self.max_init = max_init
self.num_bits = num_bits self.num_bits = num_bits
self.ema = ema self.ema = ema
self.ema_decay = ema_decay self.ema_decay = ema_decay
self.per_channel = per_channel self.per_channel = per_channel
self.out_channels = out_channels
self.channel_axis = channel_axis self.channel_axis = channel_axis
self.quant_delay = quant_delay self.quant_delay = quant_delay
self.symmetric = symmetric self.symmetric = symmetric
self.narrow_range = narrow_range self.narrow_range = narrow_range
self.training = training self.training = training
self.is_ascend = context.get_context('device_target') == "Ascend"


# init tensor min and max for fake quant op # init tensor min and max for fake quant op
if isinstance(min_init, int):
min_array = np.array([min_init]).reshape(1).astype(np.float32)
max_array = np.array([max_init]).reshape(1).astype(np.float32)
elif isinstance(min_init, list):
min_array = np.array([self.min_init for i in range(
0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(
0, self.out_channels)]).astype(np.float32)
if self.per_channel:
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(0, self.out_channels)]).astype(np.float32)
else:
min_array = np.array([self.min_init]).reshape(1).astype(np.float32)
max_array = np.array([self.max_init]).reshape(1).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)


# init fake quant relative op
if per_channel: if per_channel:
quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
ema_fun = partial(P.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis) ema_fun = partial(P.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis)
@@ -176,28 +177,36 @@ class FakeQuantWithMinMaxAscend(Cell):
quant_fun = P.FakeQuantPerLayer quant_fun = P.FakeQuantPerLayer
ema_fun = P.FakeQuantMinMaxPerLayerUpdate ema_fun = P.FakeQuantMinMaxPerLayerUpdate


self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
if self.is_ascend:
self.fake_quant = quant_fun(num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
else:
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
if self.ema:
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)


def extend_repr(self): def extend_repr(self):
s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay,
self.per_channel, self.quant_delay, self.channel_axis)
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
'quant_delay={}, min_init={}, max_init={}'.format(
self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel,
self.channel_axis, self.out_channels, self.quant_delay, self.min_init, self.max_init)
return s return s


def construct(self, x): def construct(self, x):
if self.update:
if self.ema and self.is_ascend:
min_up, max_up = self.ema_update(x, self.minq, self.maxq) min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up) out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up) P.Assign()(self.minq, min_up)
@@ -207,104 +216,6 @@ class FakeQuantWithMinMaxAscend(Cell):
return out return out




class FakeQuantWithMinMaxGPU(Cell):
r"""
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.

Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
max_init (int, list): The dimension of channel or 1(layer). Default: 6.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization by layer or channel. Default: False.
out_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.

Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMax.

Outputs:
Tensor, with the same type and shape as the `x`.

Examples:
>>> fake_quant = FakeQuantWithMinMax()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
"""

def __init__(self,
min_init=-6,
max_init=6,
num_bits=8,
ema=False,
ema_decay=0.999,
per_channel=False,
channel_axis=1,
out_channels=1,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
super(FakeQuantWithMinMaxGPU, self).__init__()
self.min_init = min_init
self.max_init = max_init
self.num_bits = num_bits
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training

# init tensor min and max for fake quant op
if isinstance(min_init, int):
min_array = np.array([min_init]).reshape(1).astype(np.float32)
max_array = np.array([max_init]).reshape(1).astype(np.float32)
elif isinstance(min_init, list):
min_array = np.array([self.min_init for i in range(
0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(
0, self.out_channels)]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)

if per_channel:
quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
else:
quant_fun = P.FakeQuantPerLayer
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)

def extend_repr(self):
s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay,
self.per_channel, self.quant_delay, self.channel_axis)
return s

def construct(self, x):
out = self.fake_quant(x, self.minq, self.maxq)
return out


def FakeQuantWithMinMax(**kwargs):
if context.get_context('device_target') == "Ascend":
out = FakeQuantWithMinMaxAscend(**kwargs)
if context.get_context('device_target') == "GPU":
out = FakeQuantWithMinMaxGPU(**kwargs)
else:
raise ValueError("Not support platform or channel mode.")
return out

class Conv2dBatchNormQuant(Cell): class Conv2dBatchNormQuant(Cell):
r""" r"""
2D convolution with BatchNormal op folded layer. 2D convolution with BatchNormal op folded layer.


+ 3
- 2
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py View File

@@ -92,6 +92,7 @@ def fake_quant_perchannel(x, min_val, max_val, y,
kernel_name="fake_quant_perchannel"): kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel""" """FakeQuantPerChannel"""
x_shape = x.get("shape") x_shape = x.get("shape")
x_shape_ = x.get("ori_shape")
x_format = x.get("format") x_format = x.get("format")
x_dtype = x.get("dtype") x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape") min_shape = min_val.get("ori_shape")
@@ -101,8 +102,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,


util.check_kernel_name(kernel_name) util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape) util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape) util.check_tensor_shape_size(max_shape)


+ 3
- 2
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py View File

@@ -117,6 +117,7 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
kernel_name="fake_quant_perchannel_grad"): kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad""" """FakeQuantPerChannelGrad"""
x_shape = x.get("shape") x_shape = x.get("shape")
x_shape_ = x.get("ori_shape")
x_format = x.get("format") x_format = x.get("format")
x_dtype = x.get("dtype") x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape") min_shape = min_val.get("ori_shape")
@@ -126,8 +127,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,


util.check_kernel_name(kernel_name) util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape) util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape) util.check_tensor_shape_size(max_shape)


+ 1
- 1
mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py View File

@@ -43,7 +43,7 @@ fake_quant_per_layer_op_info = TBERegOp("FakeQuantPerLayer") \
.get_op_info() .get_op_info()




@op_info_register(fake_quant_op_info)
@op_info_register(fake_quant_per_layer_op_info)
def _fake_quant_per_layer_tbe(): def _fake_quant_per_layer_tbe():
"""FakeQuantPerLayer TBE register""" """FakeQuantPerLayer TBE register"""
return return


+ 12
- 13
mindspore/ops/operations/_quant_ops.py View File

@@ -110,10 +110,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):


def infer_shape(self, x_shape, min_shape, max_shape): def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name)
return x_shape return x_shape


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
@@ -168,7 +166,7 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
x_shape, Rel.EQ, self.name) x_shape, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape", validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name) max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name) min_shape), 1, Rel.EQ, self.name)
return dout_shape return dout_shape


@@ -255,10 +253,11 @@ class FakeQuantPerChannel(PrimitiveWithInfer):


def infer_shape(self, x_shape, min_shape, max_shape): def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer( validator.check_integer(
"min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer( validator.check_integer(
"max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
"max shape", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
return x_shape return x_shape


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
@@ -379,7 +378,7 @@ class BatchNormFold(PrimitiveWithInfer):
def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
return mean_shape, mean_shape, mean_shape, mean_shape return mean_shape, mean_shape, mean_shape, mean_shape


def infer_dtype(self, x_type, mean_type, variance_type, global_step_type): def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
@@ -426,7 +425,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
"batch_std shape", batch_std_shape, Rel.EQ, self.name) "batch_std shape", batch_std_shape, Rel.EQ, self.name)
validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
"input channel", x_shape[self.channel_axis], Rel.EQ, self.name) "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
return x_shape return x_shape


def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type, def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
@@ -568,7 +567,7 @@ class BatchNormFold2(PrimitiveWithInfer):
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis], validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
Rel.EQ, self.name) Rel.EQ, self.name)
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
return x_shape return x_shape


def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type, def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
@@ -616,7 +615,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis], validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
Rel.EQ, self.name) Rel.EQ, self.name)
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape


def infer_dtype(self, dout_type, x_type, gamma_type, def infer_dtype(self, dout_type, x_type, gamma_type,
@@ -887,7 +886,7 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape", validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name) max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name) min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape return min_shape, max_shape


@@ -963,7 +962,7 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape", validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name) max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name) min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape return min_shape, max_shape




Loading…
Cancel
Save