Browse Source

!1220 quantization aware training frontend operators bug fix.

Merge pull request !1220 from SanjayChan/per
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
f73f9fb2cf
2 changed files with 215 additions and 132 deletions
  1. +2
    -2
      mindspore/nn/cell.py
  2. +213
    -130
      mindspore/nn/layer/quant.py

+ 2
- 2
mindspore/nn/cell.py View File

@@ -97,9 +97,9 @@ class Cell:

After invoked, can get all the cell's children's name prefix by '_param_prefix'.
"""
cells = self.cells_and_names()
cells_name = self.cells_and_names()

for cell_name, cell in cells:
for cell_name, cell in cells_name:
cell._param_prefix = cell_name

@cell_init_args.setter


+ 213
- 130
mindspore/nn/layer/quant.py View File

@@ -15,7 +15,6 @@
"""Aware quantization."""

import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
@@ -24,7 +23,6 @@ from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive, check_bool, twice
from mindspore.nn.cell import Cell
from mindspore.nn.layer.conv import _Conv
from mindspore.nn.layer.activation import get_activation

__all__ = [
@@ -37,6 +35,7 @@ __all__ = [
'HSwishQuant',
'HSigmoidQuant',
'TensorAddQuant',
'MulQuant',
]


@@ -51,7 +50,7 @@ class FakeQuantWithMinMax(Cell):
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999.
per_channel (bool): Quantization by layer or channel. Default: False.
channel_size (int): declarate the min and max channel size, Default: 1.
out_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@@ -71,7 +70,7 @@ class FakeQuantWithMinMax(Cell):
ema=False,
ema_decay=0.999,
per_channel=False,
channel_size=1,
out_channels=1,
quant_delay=0,
symmetric=False,
narrow_range=False):
@@ -83,16 +82,16 @@ class FakeQuantWithMinMax(Cell):
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.channel_size = channel_size
self.out_channels = out_channels
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range

if per_channel:
min_array = np.array([self.min_init for i in range(
0, self.channel_size)]).astype(np.float32)
0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(
0, self.channel_size)]).astype(np.float32)
0, self.out_channels)]).astype(np.float32)
self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
@@ -102,8 +101,8 @@ class FakeQuantWithMinMax(Cell):
training=True)
self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
@@ -119,28 +118,27 @@ class FakeQuantWithMinMax(Cell):
training=True)
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
ema_decay=self.ema_decay,
quant_delay=self.quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)

self.min = Parameter(
self.minq = Parameter(
Tensor(min_array), name='quant_min', requires_grad=False)
self.max = Parameter(
self.maxq = Parameter(
Tensor(max_array), name='quant_max', requires_grad=False)

def extend_repr(self):
s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size,
self.quant_delay)
s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format(
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.quant_delay)
return s

def construct(self, x):
if self.training:
out = self.fake_quant_train(x, self.min, self.max)
out = self.fake_quant_train(x, self.minq, self.maxq)
else:
out = self.fake_quant_infer(x, self.min, self.max)
out = self.fake_quant_infer(x, self.minq, self.maxq)
return out


@@ -188,13 +186,13 @@ class Conv2dBatchNormQuant(Cell):
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
eps=1e-5,
momentum=0.9,
momentum=0.997,
weight_init=None,
beta_init=None,
gamma_init=None,
@@ -208,24 +206,25 @@ class Conv2dBatchNormQuant(Cell):
symmetric=False,
narrow_range=False):
super(Conv2dBatchNormQuant, self).__init__()
_ = dilation
self.stride = stride
self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=kernel_size,
mode=1,
pad_mode=pad_mode,
pad=padding,
stride=stride,
dilation=1,
group=group)
self.in_channels = in_channels
self.out_channels = out_channels
self.pad_mode = pad_mode
self.padding = padding
self.dilation = twice(dilation)
self.stride = twice(stride)
self.group = group
self.fake = fake
self.freeze_bn = freeze_bn
self.momentum = momentum
self.quant_delay = quant_delay
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self.kernel_size = (kernel_size, kernel_size)
else:
self.kernel_size = kernel_size

if weight_init is None:
weight_init = initializer(
'normal', [out_channels, in_channels // group, *kernel_size])
'normal', [out_channels, in_channels // group, *self.kernel_size])
self.weight = Parameter(weight_init, name='weight')
if gamma_init is None:
gamma_init = initializer('ones', [out_channels])
@@ -245,16 +244,23 @@ class Conv2dBatchNormQuant(Cell):
self.step = Parameter(initializer(
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False)

self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
channel_size=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)

self.conv = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)
self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps,
momentum=momentum,
is_training=True,
@@ -271,7 +277,12 @@ class Conv2dBatchNormQuant(Cell):
self.assignadd = P.AssignAdd()

def extend_repr(self):
s = 'fake={}, freeze_bn={}'.format(self.fake, self.freeze_bn)
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.fake, self.freeze_bn, self.momentum, self.quant_delay)
return s

def construct(self, x):
@@ -295,9 +306,8 @@ class Conv2dBatchNormQuant(Cell):
F.control_depend(out, self.assignadd(self.step, self.one))
else:
step = self.step
out_conv = self.conv(x, self.weight)
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer(
out_conv, self.moving_mean, self.moving_variance, step)
x, self.moving_mean, self.moving_variance, step)
weight = self.correct_mul(self.weight, self.gamma, running_std)
if self.fake:
weight = self.fake_quant_weight(weight)
@@ -307,7 +317,7 @@ class Conv2dBatchNormQuant(Cell):
return out


class Conv2dQuant(_Conv):
class Conv2dQuant(Cell):
r"""
2D convolution with fake quant op layer.

@@ -325,8 +335,8 @@ class Conv2dQuant(_Conv):
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
Default: None.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: None.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
@@ -351,40 +361,72 @@ class Conv2dQuant(_Conv):
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros',
weight_init=None,
bias_init=None,
quant_delay=0,
num_bits=8,
per_channel=False,
symmetric=False,
narrow_range=False):
kernel_size = twice(kernel_size)
super(Conv2dQuant, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation,
group, has_bias, weight_init, bias_init)
self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, mode=1,
pad_mode=self.pad_mode, pad=self.padding, stride=self.stride, dilation=self.dilation,
group=self.group)
self.bias_add = P.BiasAdd()
if pad_mode not in ('valid', 'same', 'pad'):
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
channel_size=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)
super(Conv2dQuant, self).__init__()
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
else:
self.kernel_size = kernel_size
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = has_bias
self.stride = twice(stride)
self.dilation = twice(dilation)
self.pad_mode = pad_mode
self.padding = padding
self.group = group
self.quant_delay = quant_delay

if weight_init is None:
weight_init = initializer(
'normal', [out_channels, in_channels // group, *self.kernel_size])
self.weight = Parameter(weight_init, name='weight')
if bias_init is None:
bias_init = initializer('zeros', [out_channels])
if has_bias:
self.bias = Parameter(bias_init, name='bias')
self.bias_add = P.BiasAdd()

self.conv = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)

def construct(self, x):
weight_q = self.fake_quant_weight(self.weight)
out = self.conv2d(x, weight_q)
weight = self.fake_quant_weight(self.weight)
out = self.conv(x, weight)
if self.has_bias:
return self.bias_add(out, self.bias)
return out

def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \
'has_bias={}, quant_delay={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.has_bias, self.quant_delay)
return s


class DenseQuant(Cell):
r"""
@@ -453,15 +495,15 @@ class DenseQuant(Cell):

self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
channel_size=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
symmetric=symmetric,
narrow_range=narrow_range)

def construct(self, x):
"""Use operators to construct to Dense layer."""
@@ -511,13 +553,13 @@ class ReLUQuant(Cell):
symmetric=False,
narrow_range=False):
super(ReLUQuant, self).__init__()
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.relu = P.ReLU()

def construct(self, x):
@@ -551,13 +593,13 @@ class ReLU6Quant(Cell):
def __init__(self, num_bits=8, quant_delay=0, symmetric=False,
narrow_range=False):
super(ReLU6Quant, self).__init__()
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.relu6 = P.ReLU6()

def construct(self, x):
@@ -592,20 +634,20 @@ class HSwishQuant(Cell):
symmetric=False,
narrow_range=False):
super(HSwishQuant, self).__init__()
self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.act = P.HSwish()

def construct(self, x):
@@ -641,20 +683,20 @@ class HSigmoidQuant(Cell):
symmetric=False,
narrow_range=False):
super(HSigmoidQuant, self).__init__()
self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.act = P.HSigmoid()

def construct(self, x):
@@ -690,16 +732,57 @@ class TensorAddQuant(Cell):
symmetric=False,
narrow_range=False):
super(TensorAddQuant, self).__init__()
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.add = P.TensorAdd()

def construct(self, x1, x2):
x = self.add(x1, x2)
x = self.fake_quant_act(x)
return x


class MulQuant(Cell):
r"""
Add Fake Quant OP after Mul OP.

For a more Detailed overview of Mul op.

Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.

Inputs:
- **x** (Tensor) - The input of MulQuant.

Outputs:
Tensor, with the same type and shape as the `x`.

"""

def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False):
super(MulQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
symmetric=symmetric,
narrow_range=narrow_range)
self.mul = P.Mul()

def construct(self, x1, x2):
x = self.mul(x1, x2)
x = self.fake_quant_act(x)
return x

Loading…
Cancel
Save