Browse Source

fix some quant op bug

tags/v0.5.0-beta
wangdongxu6 王东旭 5 years ago
parent
commit
9eee157c58
4 changed files with 69 additions and 216 deletions
  1. +64
    -215
      mindspore/nn/layer/quant.py
  2. +1
    -1
      mindspore/ops/_grad/grad_quant_ops.py
  3. +1
    -0
      mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py
  4. +3
    -0
      mindspore/ops/operations/_quant_ops.py

+ 64
- 215
mindspore/nn/layer/quant.py View File

@@ -27,10 +27,8 @@ from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
import mindspore.context as context


__all__ = [
'FakeQuantWithMinMax',
'DepthwiseConv2dBatchNormQuant',
'Conv2dBatchNormQuant',
'Conv2dQuant',
'DenseQuant',
@@ -113,7 +111,7 @@ class FakeQuantWithMinMaxD(Cell):
max_init (int, list): The dimension of channel or 1(layer). Default: 6.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization by layer or channel. Default: False.
out_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
@@ -131,6 +129,7 @@ class FakeQuantWithMinMaxD(Cell):
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
"""

def __init__(self,
min_init=-6,
max_init=6,
@@ -215,7 +214,7 @@ class FakeQuantWithMinMax(Cell):
max_init (int, list): The dimension of channel or 1(layer). Default: 6.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization by layer or channel. Default: False.
out_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
@@ -335,191 +334,6 @@ class FakeQuantWithMinMax(Cell):
return out


class DepthwiseConv2dBatchNormQuant(Cell):
r"""
2D depthwise convolution with BatchNormal op folded layer.

For a more Detailed overview of Conv2d op.

Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
stride (int): Specifies stride for all spatial dimensions with the same value.
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding: (int): Implicit paddings on both sides of the input. Default: 0.
eps (int): Parameters for BatchNormal. Default: 1e-5.
momentum (int): Parameters for BatchNormal op. Default: 0.9.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
convolution kernel. Default: 'None'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
beta vector. Default: 'None'.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
gamma vector. Default: 'None'.
mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
mean vector. Default: 'None'.
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
variance vector. Default: 'None'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.

Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.

Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

Examples:
>>> quant = nn.DepthwiseConv2dBatchNormQuant(1, 6,
kernel_size= (2, 2),
stride=(1, 1),
pad_mode="valid",
>>> dilation=(1, 1))
>>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32)
>>> result = quant(input_x)
"""

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
eps=1e-5,
momentum=0.997,
weight_init=None,
beta_init=None,
gamma_init=None,
mean_init=None,
var_init=None,
quant_delay=0,
freeze_bn=100000,
fake=True,
num_bits=8,
per_channel=False,
symmetric=False,
narrow_range=False):
"""init DepthwiseConv2dBatchNormQuant layer"""
super(DepthwiseConv2dBatchNormQuant, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.pad_mode = pad_mode
self.padding = padding
self.dilation = twice(dilation)
self.stride = twice(stride)
self.group = group
self.fake = fake
self.freeze_bn = freeze_bn
self.momentum = momentum
self.quant_delay = quant_delay
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
else:
self.kernel_size = kernel_size
if group > 1:
validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant')
validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant')
self.is_depthwise = group > 1

channel_multiplier = out_channels // in_channels
self.conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier,
kernel_size=kernel_size,
stride=stride,
pad_mode=pad_mode,
pad=padding)

if weight_init is None:
weight_init = initializer('normal', [channel_multiplier, in_channels, *kernel_size])
self.weight = Parameter(weight_init, name='weight')
if gamma_init is None:
gamma_init = initializer('ones', [out_channels])
self.gamma = Parameter(gamma_init, name='gamma')
if beta_init is None:
beta_init = initializer('zeros', [out_channels])
self.beta = Parameter(beta_init, name='beta')
if mean_init is None:
mean_init = initializer('zeros', [out_channels])
self.moving_mean = Parameter(
mean_init, name='moving_mean', requires_grad=False)
if var_init is None:
var_init = initializer('ones', [out_channels])
self.moving_variance = Parameter(
var_init, name='moving_variance', requires_grad=False)

self.step = Parameter(initializer(
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False)

self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)

self.correct_mul = P.CorrectionMul(self.is_depthwise)
if context.get_context('device_target') == "Ascend":
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0)
elif context.get_context('device_target') == "GPU":
self.batchnorm_fold2_train = P.BatchNormFold2(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0)
else:
raise ValueError("Not support platform.")
self.one = Tensor(1, mstype.int32)
self.assignadd = P.AssignAdd()
self.is_gpu = context.get_context('device_target') == "GPU"

def extend_repr(self):
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.fake, self.freeze_bn, self.momentum, self.quant_delay)
return s

def construct(self, x):
out_conv = self.conv(x, self.weight)
# BN fold1
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv,
self.moving_mean,
self.moving_variance,
self.step)
# fake weight
weight = self.correct_mul(self.weight, self.gamma, running_std)
if self.fake:
weight = self.fake_quant_weight(weight)
out = self.conv(x, weight)
# BN fold2
if self.is_gpu:
if self.training:
out = self.batchnorm_fold2_train(out, self.beta, self.gamma,
batch_std, batch_mean, running_std, running_mean, self.step)
F.control_depend(out, self.assignadd(self.step, self.one))
else:
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
batch_std, batch_mean, running_std, running_mean, self.step)
else:
if self.training:
out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
F.control_depend(out, self.assignadd(self.step, self.one))
else:
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
return out


class Conv2dBatchNormQuant(Cell):
r"""
2D convolution with BatchNormal op folded layer.
@@ -593,23 +407,47 @@ class Conv2dBatchNormQuant(Cell):
super(Conv2dBatchNormQuant, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = twice(kernel_size)
self.stride = twice(stride)
self.pad_mode = pad_mode
self.padding = padding
self.dilation = twice(dilation)
self.stride = twice(stride)
self.group = group
self.fake = fake
self.freeze_bn = freeze_bn
self.eps = eps
self.momentum = momentum
self.quant_delay = quant_delay
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
self.freeze_bn = freeze_bn
self.fake = fake
self.num_bits = num_bits
self.per_channel = per_channel
self.symmetric = symmetric
self.narrow_range = narrow_range

# initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1:
validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant')
validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant')
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
pad=padding,
stride=self.stride,
dilation=self.dilation)
if weight_init is None:
weight_init = initializer('normal', [1, in_channels, *self.kernel_size])
else:
self.kernel_size = kernel_size
if weight_init is None:
weight_init = initializer(
'normal', [out_channels, in_channels // group, *self.kernel_size])
self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
pad=padding,
stride=self.stride,
dilation=self.dilation,
group=group)
if weight_init is None:
weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size])
self.weight = Parameter(weight_init, name='weight')

# initialize batchnorm Parameter
if gamma_init is None:
gamma_init = initializer('ones', [out_channels])
self.gamma = Parameter(gamma_init, name='gamma')
@@ -618,16 +456,12 @@ class Conv2dBatchNormQuant(Cell):
self.beta = Parameter(beta_init, name='beta')
if mean_init is None:
mean_init = initializer('zeros', [out_channels])
self.moving_mean = Parameter(
mean_init, name='moving_mean', requires_grad=False)
self.moving_mean = Parameter(mean_init, name='moving_mean', requires_grad=False)
if var_init is None:
var_init = initializer('ones', [out_channels])
self.moving_variance = Parameter(
var_init, name='moving_variance', requires_grad=False)

self.step = Parameter(initializer(
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
self.moving_variance = Parameter(var_init, name='moving_variance', requires_grad=False)

# initialize fake ops
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
@@ -638,14 +472,6 @@ class Conv2dBatchNormQuant(Cell):
symmetric=symmetric,
narrow_range=narrow_range)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=kernel_size,
mode=1,
pad_mode=pad_mode,
pad=padding,
stride=stride,
dilation=1,
group=group)
self.correct_mul = P.CorrectionMul()
if context.get_context('device_target') == "Ascend":
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn)
@@ -654,7 +480,8 @@ class Conv2dBatchNormQuant(Cell):
self.batchnorm_fold2_train = P.BatchNormFold2(freeze_bn=freeze_bn)
self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0)
else:
raise ValueError("Not support platform.")
raise ValueError("Unsupported platform: {}".format(context.get_context('device_target')))
self.step = Parameter(initializer('normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
self.one = Tensor(1, mstype.int32)
self.assignadd = P.AssignAdd()

@@ -926,6 +753,7 @@ class ReLUQuant(Cell):
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.

@@ -944,6 +772,7 @@ class ReLUQuant(Cell):
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
symmetric=False,
narrow_range=False):
super(ReLUQuant, self).__init__()
@@ -952,6 +781,7 @@ class ReLUQuant(Cell):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)
self.relu = P.ReLU()
@@ -973,6 +803,7 @@ class ReLU6Quant(Cell):
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.

@@ -988,7 +819,11 @@ class ReLU6Quant(Cell):
>>> result = relu6_quant(input_x)
"""

def __init__(self, num_bits=8, quant_delay=0, symmetric=False,
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
symmetric=False,
narrow_range=False):
super(ReLU6Quant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
@@ -996,6 +831,7 @@ class ReLU6Quant(Cell):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)
self.relu6 = P.ReLU6()
@@ -1015,6 +851,7 @@ class HSwishQuant(Cell):
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.

@@ -1033,6 +870,7 @@ class HSwishQuant(Cell):
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
symmetric=False,
narrow_range=False):
super(HSwishQuant, self).__init__()
@@ -1041,6 +879,7 @@ class HSwishQuant(Cell):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
@@ -1048,6 +887,7 @@ class HSwishQuant(Cell):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)
self.act = P.HSwish()
@@ -1068,6 +908,7 @@ class HSigmoidQuant(Cell):
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.

@@ -1086,6 +927,7 @@ class HSigmoidQuant(Cell):
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
symmetric=False,
narrow_range=False):
super(HSigmoidQuant, self).__init__()
@@ -1101,6 +943,7 @@ class HSigmoidQuant(Cell):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)
self.act = P.HSigmoid()
@@ -1121,6 +964,7 @@ class TensorAddQuant(Cell):
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.

@@ -1140,6 +984,7 @@ class TensorAddQuant(Cell):
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
symmetric=False,
narrow_range=False):
super(TensorAddQuant, self).__init__()
@@ -1148,6 +993,7 @@ class TensorAddQuant(Cell):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)
self.add = P.TensorAdd()
@@ -1167,6 +1013,7 @@ class MulQuant(Cell):
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.

@@ -1181,6 +1028,7 @@ class MulQuant(Cell):
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
symmetric=False,
narrow_range=False):
super(MulQuant, self).__init__()
@@ -1189,6 +1037,7 @@ class MulQuant(Cell):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)
self.mul = P.Mul()


+ 1
- 1
mindspore/ops/_grad/grad_quant_ops.py View File

@@ -85,7 +85,7 @@ def get_bprop_batchnorm_fold2(self):
@bprop_getters.register(P.BatchNormFoldD)
def get_bprop_BatchNormFold(self):
"""Generate bprop for BatchNormFold for Ascend"""
op = P.BatchNormFoldGrad_(self.epsilon, self.is_training, self.freeze_bn)
op = P.BatchNormFoldGradD(self.epsilon, self.is_training, self.freeze_bn)

def bprop(x, x_sum, x_square_sum, mean, variance, out, dout):
dx = op(dout[1], dout[2], x, out[1], out[2])


+ 1
- 0
mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py View File

@@ -16,6 +16,7 @@
"""_BatchNormFold op"""

from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
import te
from te import tvm
from topi import generic
from topi.cce import util


+ 3
- 0
mindspore/ops/operations/_quant_ops.py View File

@@ -31,8 +31,11 @@ __all__ = ["FakeQuantWithMinMax",
"BatchNormFold2",
"BatchNormFold2Grad",
"BatchNormFoldD",
"BatchNormFoldGradD",
"BNTrainingReduce",
"BatchNormFold2_D",
"BatchNormFold2GradD",
"BatchNormFold2GradReduce",
"FakeQuantWithMinMaxUpdate",
]



Loading…
Cancel
Save