Browse Source

bug fix in fake quant

tags/v0.5.0-beta
chenzomi 5 years ago
parent
commit
4da1e21f45
5 changed files with 58 additions and 146 deletions
  1. +39
    -128
      mindspore/nn/layer/quant.py
  2. +3
    -2
      mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py
  3. +3
    -2
      mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py
  4. +1
    -1
      mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py
  5. +12
    -13
      mindspore/ops/operations/_quant_ops.py

+ 39
- 128
mindspore/nn/layer/quant.py View File

@@ -102,13 +102,13 @@ class BatchNormFoldCell(Cell):
return batch_mean, batch_std, running_mean, running_std


class FakeQuantWithMinMaxAscend(Cell):
class FakeQuantWithMinMax(Cell):
r"""
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.

Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
max_init (int, list): The dimension of channel or 1(layer). Default: 6.
min_init (int, float): The dimension of channel or 1(layer). Default: -6.
max_init (int, float): The dimension of channel or 1(layer). Default: 6.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
@@ -143,32 +143,33 @@ class FakeQuantWithMinMaxAscend(Cell):
symmetric=False,
narrow_range=False,
training=True):
"""init FakeQuantWithMinMaxAscend layer"""
super(FakeQuantWithMinMaxAscend, self).__init__()
"""init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__()
self.min_init = min_init
self.max_init = max_init
self.num_bits = num_bits
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.out_channels = out_channels
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training
self.is_ascend = context.get_context('device_target') == "Ascend"

# init tensor min and max for fake quant op
if isinstance(min_init, int):
min_array = np.array([min_init]).reshape(1).astype(np.float32)
max_array = np.array([max_init]).reshape(1).astype(np.float32)
elif isinstance(min_init, list):
min_array = np.array([self.min_init for i in range(
0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(
0, self.out_channels)]).astype(np.float32)
if self.per_channel:
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(0, self.out_channels)]).astype(np.float32)
else:
min_array = np.array([self.min_init]).reshape(1).astype(np.float32)
max_array = np.array([self.max_init]).reshape(1).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)

# init fake quant relative op
if per_channel:
quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
ema_fun = partial(P.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis)
@@ -176,28 +177,36 @@ class FakeQuantWithMinMaxAscend(Cell):
quant_fun = P.FakeQuantPerLayer
ema_fun = P.FakeQuantMinMaxPerLayerUpdate

self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
if self.is_ascend:
self.fake_quant = quant_fun(num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
else:
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
if self.ema:
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)

def extend_repr(self):
s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay,
self.per_channel, self.quant_delay, self.channel_axis)
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
'quant_delay={}, min_init={}, max_init={}'.format(
self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel,
self.channel_axis, self.out_channels, self.quant_delay, self.min_init, self.max_init)
return s

def construct(self, x):
if self.update:
if self.ema and self.is_ascend:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up)
@@ -207,104 +216,6 @@ class FakeQuantWithMinMaxAscend(Cell):
return out


class FakeQuantWithMinMaxGPU(Cell):
r"""
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.

Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
max_init (int, list): The dimension of channel or 1(layer). Default: 6.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization by layer or channel. Default: False.
out_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.

Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMax.

Outputs:
Tensor, with the same type and shape as the `x`.

Examples:
>>> fake_quant = FakeQuantWithMinMax()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
"""

def __init__(self,
min_init=-6,
max_init=6,
num_bits=8,
ema=False,
ema_decay=0.999,
per_channel=False,
channel_axis=1,
out_channels=1,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
super(FakeQuantWithMinMaxGPU, self).__init__()
self.min_init = min_init
self.max_init = max_init
self.num_bits = num_bits
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training

# init tensor min and max for fake quant op
if isinstance(min_init, int):
min_array = np.array([min_init]).reshape(1).astype(np.float32)
max_array = np.array([max_init]).reshape(1).astype(np.float32)
elif isinstance(min_init, list):
min_array = np.array([self.min_init for i in range(
0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(
0, self.out_channels)]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)

if per_channel:
quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
else:
quant_fun = P.FakeQuantPerLayer
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)

def extend_repr(self):
s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay,
self.per_channel, self.quant_delay, self.channel_axis)
return s

def construct(self, x):
out = self.fake_quant(x, self.minq, self.maxq)
return out


def FakeQuantWithMinMax(**kwargs):
if context.get_context('device_target') == "Ascend":
out = FakeQuantWithMinMaxAscend(**kwargs)
if context.get_context('device_target') == "GPU":
out = FakeQuantWithMinMaxGPU(**kwargs)
else:
raise ValueError("Not support platform or channel mode.")
return out

class Conv2dBatchNormQuant(Cell):
r"""
2D convolution with BatchNormal op folded layer.


+ 3
- 2
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py View File

@@ -92,6 +92,7 @@ def fake_quant_perchannel(x, min_val, max_val, y,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
x_shape = x.get("shape")
x_shape_ = x.get("ori_shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
@@ -101,8 +102,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,

util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)


+ 3
- 2
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py View File

@@ -117,6 +117,7 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
kernel_name="fake_quant_perchannel_grad"):
"""FakeQuantPerChannelGrad"""
x_shape = x.get("shape")
x_shape_ = x.get("ori_shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
@@ -126,8 +127,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,

util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape)


+ 1
- 1
mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py View File

@@ -43,7 +43,7 @@ fake_quant_per_layer_op_info = TBERegOp("FakeQuantPerLayer") \
.get_op_info()


@op_info_register(fake_quant_op_info)
@op_info_register(fake_quant_per_layer_op_info)
def _fake_quant_per_layer_tbe():
"""FakeQuantPerLayer TBE register"""
return


+ 12
- 13
mindspore/ops/operations/_quant_ops.py View File

@@ -110,10 +110,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):

def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
min_shape), 1, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name)
return x_shape

def infer_dtype(self, x_type, min_type, max_type):
@@ -168,7 +166,7 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
x_shape, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return dout_shape

@@ -255,10 +253,11 @@ class FakeQuantPerChannel(PrimitiveWithInfer):

def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer(
"min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer(
"max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
"max shape", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
return x_shape

def infer_dtype(self, x_type, min_type, max_type):
@@ -379,7 +378,7 @@ class BatchNormFold(PrimitiveWithInfer):
def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
return mean_shape, mean_shape, mean_shape, mean_shape

def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
@@ -426,7 +425,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
"batch_std shape", batch_std_shape, Rel.EQ, self.name)
validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
"input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
return x_shape

def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
@@ -568,7 +567,7 @@ class BatchNormFold2(PrimitiveWithInfer):
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
Rel.EQ, self.name)
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
return x_shape

def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
@@ -616,7 +615,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
Rel.EQ, self.name)
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape

def infer_dtype(self, dout_type, x_type, gamma_type,
@@ -887,7 +886,7 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape

@@ -963,7 +962,7 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min rank", len(
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape



Loading…
Cancel
Save