Browse Source

!1220 quantization aware training frontend operators bug fix.

Merge pull request !1220 from SanjayChan/per
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
f73f9fb2cf
2 changed files with 215 additions and 132 deletions
  1. +2
    -2
      mindspore/nn/cell.py
  2. +213
    -130
      mindspore/nn/layer/quant.py

+ 2
- 2
mindspore/nn/cell.py View File

@@ -97,9 +97,9 @@ class Cell:


After invoked, can get all the cell's children's name prefix by '_param_prefix'. After invoked, can get all the cell's children's name prefix by '_param_prefix'.
""" """
cells = self.cells_and_names()
cells_name = self.cells_and_names()


for cell_name, cell in cells:
for cell_name, cell in cells_name:
cell._param_prefix = cell_name cell._param_prefix = cell_name


@cell_init_args.setter @cell_init_args.setter


+ 213
- 130
mindspore/nn/layer/quant.py View File

@@ -15,7 +15,6 @@
"""Aware quantization.""" """Aware quantization."""


import numpy as np import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
@@ -24,7 +23,6 @@ from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive, check_bool, twice from mindspore._checkparam import check_int_positive, check_bool, twice
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.nn.layer.conv import _Conv
from mindspore.nn.layer.activation import get_activation from mindspore.nn.layer.activation import get_activation


__all__ = [ __all__ = [
@@ -37,6 +35,7 @@ __all__ = [
'HSwishQuant', 'HSwishQuant',
'HSigmoidQuant', 'HSigmoidQuant',
'TensorAddQuant', 'TensorAddQuant',
'MulQuant',
] ]




@@ -51,7 +50,7 @@ class FakeQuantWithMinMax(Cell):
ema (bool): Exponential Moving Average algorithm update min and max. Default: False. ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999.
per_channel (bool): Quantization by layer or channel. Default: False. per_channel (bool): Quantization by layer or channel. Default: False.
channel_size (int): declarate the min and max channel size, Default: 1.
out_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@@ -71,7 +70,7 @@ class FakeQuantWithMinMax(Cell):
ema=False, ema=False,
ema_decay=0.999, ema_decay=0.999,
per_channel=False, per_channel=False,
channel_size=1,
out_channels=1,
quant_delay=0, quant_delay=0,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False):
@@ -83,16 +82,16 @@ class FakeQuantWithMinMax(Cell):
self.ema = ema self.ema = ema
self.ema_decay = ema_decay self.ema_decay = ema_decay
self.per_channel = per_channel self.per_channel = per_channel
self.channel_size = channel_size
self.out_channels = out_channels
self.quant_delay = quant_delay self.quant_delay = quant_delay
self.symmetric = symmetric self.symmetric = symmetric
self.narrow_range = narrow_range self.narrow_range = narrow_range


if per_channel: if per_channel:
min_array = np.array([self.min_init for i in range( min_array = np.array([self.min_init for i in range(
0, self.channel_size)]).astype(np.float32)
0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range( max_array = np.array([self.max_init for i in range(
0, self.channel_size)]).astype(np.float32)
0, self.out_channels)]).astype(np.float32)
self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema, ema=self.ema,
ema_decay=self.ema_decay, ema_decay=self.ema_decay,
@@ -102,8 +101,8 @@ class FakeQuantWithMinMax(Cell):
training=True) training=True)
self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema, ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric, symmetric=self.symmetric,
narrow_range=self.narrow_range, narrow_range=self.narrow_range,
training=False) training=False)
@@ -119,28 +118,27 @@ class FakeQuantWithMinMax(Cell):
training=True) training=True)
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits, self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema, ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric, symmetric=self.symmetric,
narrow_range=self.narrow_range, narrow_range=self.narrow_range,
training=False) training=False)


self.min = Parameter(
self.minq = Parameter(
Tensor(min_array), name='quant_min', requires_grad=False) Tensor(min_array), name='quant_min', requires_grad=False)
self.max = Parameter(
self.maxq = Parameter(
Tensor(max_array), name='quant_max', requires_grad=False) Tensor(max_array), name='quant_max', requires_grad=False)


def extend_repr(self): def extend_repr(self):
s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size,
self.quant_delay)
s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.quant_delay)
return s return s


def construct(self, x): def construct(self, x):
if self.training: if self.training:
out = self.fake_quant_train(x, self.min, self.max)
out = self.fake_quant_train(x, self.minq, self.maxq)
else: else:
out = self.fake_quant_infer(x, self.min, self.max)
out = self.fake_quant_infer(x, self.minq, self.maxq)
return out return out




@@ -188,13 +186,13 @@ class Conv2dBatchNormQuant(Cell):
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
stride,
pad_mode,
stride=1,
pad_mode='same',
padding=0, padding=0,
dilation=1, dilation=1,
group=1, group=1,
eps=1e-5, eps=1e-5,
momentum=0.9,
momentum=0.997,
weight_init=None, weight_init=None,
beta_init=None, beta_init=None,
gamma_init=None, gamma_init=None,
@@ -208,24 +206,25 @@ class Conv2dBatchNormQuant(Cell):
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False):
super(Conv2dBatchNormQuant, self).__init__() super(Conv2dBatchNormQuant, self).__init__()
_ = dilation
self.stride = stride
self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=kernel_size,
mode=1,
pad_mode=pad_mode,
pad=padding,
stride=stride,
dilation=1,
group=group)
self.in_channels = in_channels
self.out_channels = out_channels
self.pad_mode = pad_mode
self.padding = padding
self.dilation = twice(dilation)
self.stride = twice(stride)
self.group = group
self.fake = fake self.fake = fake
self.freeze_bn = freeze_bn self.freeze_bn = freeze_bn
self.momentum = momentum
self.quant_delay = quant_delay
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self.kernel_size = (kernel_size, kernel_size)
else:
self.kernel_size = kernel_size


if weight_init is None: if weight_init is None:
weight_init = initializer( weight_init = initializer(
'normal', [out_channels, in_channels // group, *kernel_size])
'normal', [out_channels, in_channels // group, *self.kernel_size])
self.weight = Parameter(weight_init, name='weight') self.weight = Parameter(weight_init, name='weight')
if gamma_init is None: if gamma_init is None:
gamma_init = initializer('ones', [out_channels]) gamma_init = initializer('ones', [out_channels])
@@ -245,16 +244,23 @@ class Conv2dBatchNormQuant(Cell):
self.step = Parameter(initializer( self.step = Parameter(initializer(
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False) 'normal', [1], dtype=mstype.int32), name='step', requires_grad=False)


self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
channel_size=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)

self.conv = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)
self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps, self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps,
momentum=momentum, momentum=momentum,
is_training=True, is_training=True,
@@ -271,7 +277,12 @@ class Conv2dBatchNormQuant(Cell):
self.assignadd = P.AssignAdd() self.assignadd = P.AssignAdd()


def extend_repr(self): def extend_repr(self):
s = 'fake={}, freeze_bn={}'.format(self.fake, self.freeze_bn)
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.fake, self.freeze_bn, self.momentum, self.quant_delay)
return s return s


def construct(self, x): def construct(self, x):
@@ -295,9 +306,8 @@ class Conv2dBatchNormQuant(Cell):
F.control_depend(out, self.assignadd(self.step, self.one)) F.control_depend(out, self.assignadd(self.step, self.one))
else: else:
step = self.step step = self.step
out_conv = self.conv(x, self.weight)
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer( batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer(
out_conv, self.moving_mean, self.moving_variance, step)
x, self.moving_mean, self.moving_variance, step)
weight = self.correct_mul(self.weight, self.gamma, running_std) weight = self.correct_mul(self.weight, self.gamma, running_std)
if self.fake: if self.fake:
weight = self.fake_quant_weight(weight) weight = self.fake_quant_weight(weight)
@@ -307,7 +317,7 @@ class Conv2dBatchNormQuant(Cell):
return out return out




class Conv2dQuant(_Conv):
class Conv2dQuant(Cell):
r""" r"""
2D convolution with fake quant op layer. 2D convolution with fake quant op layer.


@@ -325,8 +335,8 @@ class Conv2dQuant(_Conv):
divisible by the number of groups. Default: 1. divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
Default: None.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: None.
quant_delay (int): Quantization delay parameters according by global step. Default: 0. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
@@ -351,40 +361,72 @@ class Conv2dQuant(_Conv):
dilation=1, dilation=1,
group=1, group=1,
has_bias=False, has_bias=False,
weight_init='normal',
bias_init='zeros',
weight_init=None,
bias_init=None,
quant_delay=0, quant_delay=0,
num_bits=8, num_bits=8,
per_channel=False, per_channel=False,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False):
kernel_size = twice(kernel_size)
super(Conv2dQuant, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation,
group, has_bias, weight_init, bias_init)
self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, mode=1,
pad_mode=self.pad_mode, pad=self.padding, stride=self.stride, dilation=self.dilation,
group=self.group)
self.bias_add = P.BiasAdd()
if pad_mode not in ('valid', 'same', 'pad'):
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
channel_size=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)
super(Conv2dQuant, self).__init__()
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
else:
self.kernel_size = kernel_size
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = has_bias
self.stride = twice(stride)
self.dilation = twice(dilation)
self.pad_mode = pad_mode
self.padding = padding
self.group = group
self.quant_delay = quant_delay

if weight_init is None:
weight_init = initializer(
'normal', [out_channels, in_channels // group, *self.kernel_size])
self.weight = Parameter(weight_init, name='weight')
if bias_init is None:
bias_init = initializer('zeros', [out_channels])
if has_bias:
self.bias = Parameter(bias_init, name='bias')
self.bias_add = P.BiasAdd()

self.conv = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)


def construct(self, x): def construct(self, x):
weight_q = self.fake_quant_weight(self.weight)
out = self.conv2d(x, weight_q)
weight = self.fake_quant_weight(self.weight)
out = self.conv(x, weight)
if self.has_bias: if self.has_bias:
return self.bias_add(out, self.bias) return self.bias_add(out, self.bias)
return out return out


def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \
'has_bias={}, quant_delay={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.has_bias, self.quant_delay)
return s



class DenseQuant(Cell): class DenseQuant(Cell):
r""" r"""
@@ -453,15 +495,15 @@ class DenseQuant(Cell):


self.activation = get_activation(activation) self.activation = get_activation(activation)
self.activation_flag = self.activation is not None self.activation_flag = self.activation is not None
self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
channel_size=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)


def construct(self, x): def construct(self, x):
"""Use operators to construct to Dense layer.""" """Use operators to construct to Dense layer."""
@@ -511,13 +553,13 @@ class ReLUQuant(Cell):
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False):
super(ReLUQuant, self).__init__() super(ReLUQuant, self).__init__()
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.relu = P.ReLU() self.relu = P.ReLU()


def construct(self, x): def construct(self, x):
@@ -551,13 +593,13 @@ class ReLU6Quant(Cell):
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, def __init__(self, num_bits=8, quant_delay=0, symmetric=False,
narrow_range=False): narrow_range=False):
super(ReLU6Quant, self).__init__() super(ReLU6Quant, self).__init__()
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.relu6 = P.ReLU6() self.relu6 = P.ReLU6()


def construct(self, x): def construct(self, x):
@@ -592,20 +634,20 @@ class HSwishQuant(Cell):
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False):
super(HSwishQuant, self).__init__() super(HSwishQuant, self).__init__()
self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.act = P.HSwish() self.act = P.HSwish()


def construct(self, x): def construct(self, x):
@@ -641,20 +683,20 @@ class HSigmoidQuant(Cell):
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False):
super(HSigmoidQuant, self).__init__() super(HSigmoidQuant, self).__init__()
self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.act = P.HSigmoid() self.act = P.HSigmoid()


def construct(self, x): def construct(self, x):
@@ -690,16 +732,57 @@ class TensorAddQuant(Cell):
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False):
super(TensorAddQuant, self).__init__() super(TensorAddQuant, self).__init__()
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.add = P.TensorAdd() self.add = P.TensorAdd()


def construct(self, x1, x2): def construct(self, x1, x2):
x = self.add(x1, x2) x = self.add(x1, x2)
x = self.fake_quant_act(x) x = self.fake_quant_act(x)
return x return x


class MulQuant(Cell):
r"""
Add Fake Quant OP after Mul OP.

For a more Detailed overview of Mul op.

Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.

Inputs:
- **x** (Tensor) - The input of MulQuant.

Outputs:
Tensor, with the same type and shape as the `x`.

"""

def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False):
super(MulQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.mul = P.Mul()

def construct(self, x1, x2):
x = self.mul(x1, x2)
x = self.fake_quant_act(x)
return x

Loading…
Cancel
Save