|
|
@@ -15,7 +15,6 @@ |
|
|
"""Aware quantization.""" |
|
|
"""Aware quantization.""" |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import numpy as np |
|
|
import mindspore.nn as nn |
|
|
|
|
|
import mindspore.common.dtype as mstype |
|
|
import mindspore.common.dtype as mstype |
|
|
from mindspore.ops import operations as P |
|
|
from mindspore.ops import operations as P |
|
|
from mindspore.ops import functional as F |
|
|
from mindspore.ops import functional as F |
|
|
@@ -24,7 +23,6 @@ from mindspore.common.initializer import initializer |
|
|
from mindspore.common.tensor import Tensor |
|
|
from mindspore.common.tensor import Tensor |
|
|
from mindspore._checkparam import check_int_positive, check_bool, twice |
|
|
from mindspore._checkparam import check_int_positive, check_bool, twice |
|
|
from mindspore.nn.cell import Cell |
|
|
from mindspore.nn.cell import Cell |
|
|
from mindspore.nn.layer.conv import _Conv |
|
|
|
|
|
from mindspore.nn.layer.activation import get_activation |
|
|
from mindspore.nn.layer.activation import get_activation |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
__all__ = [ |
|
|
@@ -37,6 +35,7 @@ __all__ = [ |
|
|
'HSwishQuant', |
|
|
'HSwishQuant', |
|
|
'HSigmoidQuant', |
|
|
'HSigmoidQuant', |
|
|
'TensorAddQuant', |
|
|
'TensorAddQuant', |
|
|
|
|
|
'MulQuant', |
|
|
] |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -51,7 +50,7 @@ class FakeQuantWithMinMax(Cell): |
|
|
ema (bool): Exponential Moving Average algorithm update min and max. Default: False. |
|
|
ema (bool): Exponential Moving Average algorithm update min and max. Default: False. |
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999. |
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999. |
|
|
per_channel (bool): Quantization by layer or channel. Default: False. |
|
|
per_channel (bool): Quantization by layer or channel. Default: False. |
|
|
channel_size (int): declarate the min and max channel size, Default: 1. |
|
|
|
|
|
|
|
|
out_channels (int): declarate the min and max channel size, Default: 1. |
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. |
|
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. |
|
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. |
|
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. |
|
|
@@ -71,7 +70,7 @@ class FakeQuantWithMinMax(Cell): |
|
|
ema=False, |
|
|
ema=False, |
|
|
ema_decay=0.999, |
|
|
ema_decay=0.999, |
|
|
per_channel=False, |
|
|
per_channel=False, |
|
|
channel_size=1, |
|
|
|
|
|
|
|
|
out_channels=1, |
|
|
quant_delay=0, |
|
|
quant_delay=0, |
|
|
symmetric=False, |
|
|
symmetric=False, |
|
|
narrow_range=False): |
|
|
narrow_range=False): |
|
|
@@ -83,16 +82,16 @@ class FakeQuantWithMinMax(Cell): |
|
|
self.ema = ema |
|
|
self.ema = ema |
|
|
self.ema_decay = ema_decay |
|
|
self.ema_decay = ema_decay |
|
|
self.per_channel = per_channel |
|
|
self.per_channel = per_channel |
|
|
self.channel_size = channel_size |
|
|
|
|
|
|
|
|
self.out_channels = out_channels |
|
|
self.quant_delay = quant_delay |
|
|
self.quant_delay = quant_delay |
|
|
self.symmetric = symmetric |
|
|
self.symmetric = symmetric |
|
|
self.narrow_range = narrow_range |
|
|
self.narrow_range = narrow_range |
|
|
|
|
|
|
|
|
if per_channel: |
|
|
if per_channel: |
|
|
min_array = np.array([self.min_init for i in range( |
|
|
min_array = np.array([self.min_init for i in range( |
|
|
0, self.channel_size)]).astype(np.float32) |
|
|
|
|
|
|
|
|
0, self.out_channels)]).astype(np.float32) |
|
|
max_array = np.array([self.max_init for i in range( |
|
|
max_array = np.array([self.max_init for i in range( |
|
|
0, self.channel_size)]).astype(np.float32) |
|
|
|
|
|
|
|
|
0, self.out_channels)]).astype(np.float32) |
|
|
self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, |
|
|
self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, |
|
|
ema=self.ema, |
|
|
ema=self.ema, |
|
|
ema_decay=self.ema_decay, |
|
|
ema_decay=self.ema_decay, |
|
|
@@ -102,8 +101,8 @@ class FakeQuantWithMinMax(Cell): |
|
|
training=True) |
|
|
training=True) |
|
|
self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, |
|
|
self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, |
|
|
ema=self.ema, |
|
|
ema=self.ema, |
|
|
ema_decay=ema_decay, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
|
|
|
ema_decay=self.ema_decay, |
|
|
|
|
|
quant_delay=self.quant_delay, |
|
|
symmetric=self.symmetric, |
|
|
symmetric=self.symmetric, |
|
|
narrow_range=self.narrow_range, |
|
|
narrow_range=self.narrow_range, |
|
|
training=False) |
|
|
training=False) |
|
|
@@ -119,28 +118,27 @@ class FakeQuantWithMinMax(Cell): |
|
|
training=True) |
|
|
training=True) |
|
|
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits, |
|
|
self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits, |
|
|
ema=self.ema, |
|
|
ema=self.ema, |
|
|
ema_decay=ema_decay, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
|
|
|
ema_decay=self.ema_decay, |
|
|
|
|
|
quant_delay=self.quant_delay, |
|
|
symmetric=self.symmetric, |
|
|
symmetric=self.symmetric, |
|
|
narrow_range=self.narrow_range, |
|
|
narrow_range=self.narrow_range, |
|
|
training=False) |
|
|
training=False) |
|
|
|
|
|
|
|
|
self.min = Parameter( |
|
|
|
|
|
|
|
|
self.minq = Parameter( |
|
|
Tensor(min_array), name='quant_min', requires_grad=False) |
|
|
Tensor(min_array), name='quant_min', requires_grad=False) |
|
|
self.max = Parameter( |
|
|
|
|
|
|
|
|
self.maxq = Parameter( |
|
|
Tensor(max_array), name='quant_max', requires_grad=False) |
|
|
Tensor(max_array), name='quant_max', requires_grad=False) |
|
|
|
|
|
|
|
|
def extend_repr(self): |
|
|
def extend_repr(self): |
|
|
s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format( |
|
|
|
|
|
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size, |
|
|
|
|
|
self.quant_delay) |
|
|
|
|
|
|
|
|
s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format( |
|
|
|
|
|
self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.quant_delay) |
|
|
return s |
|
|
return s |
|
|
|
|
|
|
|
|
def construct(self, x): |
|
|
def construct(self, x): |
|
|
if self.training: |
|
|
if self.training: |
|
|
out = self.fake_quant_train(x, self.min, self.max) |
|
|
|
|
|
|
|
|
out = self.fake_quant_train(x, self.minq, self.maxq) |
|
|
else: |
|
|
else: |
|
|
out = self.fake_quant_infer(x, self.min, self.max) |
|
|
|
|
|
|
|
|
out = self.fake_quant_infer(x, self.minq, self.maxq) |
|
|
return out |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -188,13 +186,13 @@ class Conv2dBatchNormQuant(Cell): |
|
|
in_channels, |
|
|
in_channels, |
|
|
out_channels, |
|
|
out_channels, |
|
|
kernel_size, |
|
|
kernel_size, |
|
|
stride, |
|
|
|
|
|
pad_mode, |
|
|
|
|
|
|
|
|
stride=1, |
|
|
|
|
|
pad_mode='same', |
|
|
padding=0, |
|
|
padding=0, |
|
|
dilation=1, |
|
|
dilation=1, |
|
|
group=1, |
|
|
group=1, |
|
|
eps=1e-5, |
|
|
eps=1e-5, |
|
|
momentum=0.9, |
|
|
|
|
|
|
|
|
momentum=0.997, |
|
|
weight_init=None, |
|
|
weight_init=None, |
|
|
beta_init=None, |
|
|
beta_init=None, |
|
|
gamma_init=None, |
|
|
gamma_init=None, |
|
|
@@ -208,24 +206,25 @@ class Conv2dBatchNormQuant(Cell): |
|
|
symmetric=False, |
|
|
symmetric=False, |
|
|
narrow_range=False): |
|
|
narrow_range=False): |
|
|
super(Conv2dBatchNormQuant, self).__init__() |
|
|
super(Conv2dBatchNormQuant, self).__init__() |
|
|
_ = dilation |
|
|
|
|
|
self.stride = stride |
|
|
|
|
|
self.conv = P.Conv2D(out_channel=out_channels, |
|
|
|
|
|
kernel_size=kernel_size, |
|
|
|
|
|
mode=1, |
|
|
|
|
|
pad_mode=pad_mode, |
|
|
|
|
|
pad=padding, |
|
|
|
|
|
stride=stride, |
|
|
|
|
|
dilation=1, |
|
|
|
|
|
group=group) |
|
|
|
|
|
|
|
|
self.in_channels = in_channels |
|
|
|
|
|
self.out_channels = out_channels |
|
|
|
|
|
self.pad_mode = pad_mode |
|
|
|
|
|
self.padding = padding |
|
|
|
|
|
self.dilation = twice(dilation) |
|
|
|
|
|
self.stride = twice(stride) |
|
|
|
|
|
self.group = group |
|
|
self.fake = fake |
|
|
self.fake = fake |
|
|
self.freeze_bn = freeze_bn |
|
|
self.freeze_bn = freeze_bn |
|
|
|
|
|
self.momentum = momentum |
|
|
|
|
|
self.quant_delay = quant_delay |
|
|
if isinstance(kernel_size, int): |
|
|
if isinstance(kernel_size, int): |
|
|
kernel_size = (kernel_size, kernel_size) |
|
|
|
|
|
|
|
|
self.kernel_size = (kernel_size, kernel_size) |
|
|
|
|
|
else: |
|
|
|
|
|
self.kernel_size = kernel_size |
|
|
|
|
|
|
|
|
if weight_init is None: |
|
|
if weight_init is None: |
|
|
weight_init = initializer( |
|
|
weight_init = initializer( |
|
|
'normal', [out_channels, in_channels // group, *kernel_size]) |
|
|
|
|
|
|
|
|
'normal', [out_channels, in_channels // group, *self.kernel_size]) |
|
|
self.weight = Parameter(weight_init, name='weight') |
|
|
self.weight = Parameter(weight_init, name='weight') |
|
|
if gamma_init is None: |
|
|
if gamma_init is None: |
|
|
gamma_init = initializer('ones', [out_channels]) |
|
|
gamma_init = initializer('ones', [out_channels]) |
|
|
@@ -245,16 +244,23 @@ class Conv2dBatchNormQuant(Cell): |
|
|
self.step = Parameter(initializer( |
|
|
self.step = Parameter(initializer( |
|
|
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False) |
|
|
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False) |
|
|
|
|
|
|
|
|
self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
ema=False, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
per_channel=per_channel, |
|
|
|
|
|
channel_size=out_channels, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.conv = P.Conv2D(out_channel=self.out_channels, |
|
|
|
|
|
kernel_size=self.kernel_size, |
|
|
|
|
|
mode=1, |
|
|
|
|
|
pad_mode=self.pad_mode, |
|
|
|
|
|
pad=self.padding, |
|
|
|
|
|
stride=self.stride, |
|
|
|
|
|
dilation=self.dilation, |
|
|
|
|
|
group=self.group) |
|
|
|
|
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
ema=False, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
per_channel=per_channel, |
|
|
|
|
|
out_channels=out_channels, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps, |
|
|
self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps, |
|
|
momentum=momentum, |
|
|
momentum=momentum, |
|
|
is_training=True, |
|
|
is_training=True, |
|
|
@@ -271,7 +277,12 @@ class Conv2dBatchNormQuant(Cell): |
|
|
self.assignadd = P.AssignAdd() |
|
|
self.assignadd = P.AssignAdd() |
|
|
|
|
|
|
|
|
def extend_repr(self): |
|
|
def extend_repr(self): |
|
|
s = 'fake={}, freeze_bn={}'.format(self.fake, self.freeze_bn) |
|
|
|
|
|
|
|
|
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \ |
|
|
|
|
|
'pad_mode={}, padding={}, dilation={}, group={}, ' \ |
|
|
|
|
|
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format( |
|
|
|
|
|
self.in_channels, self.out_channels, self.kernel_size, self.stride, |
|
|
|
|
|
self.pad_mode, self.padding, self.dilation, self.group, |
|
|
|
|
|
self.fake, self.freeze_bn, self.momentum, self.quant_delay) |
|
|
return s |
|
|
return s |
|
|
|
|
|
|
|
|
def construct(self, x): |
|
|
def construct(self, x): |
|
|
@@ -295,9 +306,8 @@ class Conv2dBatchNormQuant(Cell): |
|
|
F.control_depend(out, self.assignadd(self.step, self.one)) |
|
|
F.control_depend(out, self.assignadd(self.step, self.one)) |
|
|
else: |
|
|
else: |
|
|
step = self.step |
|
|
step = self.step |
|
|
out_conv = self.conv(x, self.weight) |
|
|
|
|
|
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer( |
|
|
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer( |
|
|
out_conv, self.moving_mean, self.moving_variance, step) |
|
|
|
|
|
|
|
|
x, self.moving_mean, self.moving_variance, step) |
|
|
weight = self.correct_mul(self.weight, self.gamma, running_std) |
|
|
weight = self.correct_mul(self.weight, self.gamma, running_std) |
|
|
if self.fake: |
|
|
if self.fake: |
|
|
weight = self.fake_quant_weight(weight) |
|
|
weight = self.fake_quant_weight(weight) |
|
|
@@ -307,7 +317,7 @@ class Conv2dBatchNormQuant(Cell): |
|
|
return out |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Conv2dQuant(_Conv): |
|
|
|
|
|
|
|
|
class Conv2dQuant(Cell): |
|
|
r""" |
|
|
r""" |
|
|
2D convolution with fake quant op layer. |
|
|
2D convolution with fake quant op layer. |
|
|
|
|
|
|
|
|
@@ -325,8 +335,8 @@ class Conv2dQuant(_Conv): |
|
|
divisible by the number of groups. Default: 1. |
|
|
divisible by the number of groups. Default: 1. |
|
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. |
|
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. |
|
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. |
|
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. |
|
|
Default: 'normal'. |
|
|
|
|
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. |
|
|
|
|
|
|
|
|
Default: None. |
|
|
|
|
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: None. |
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. |
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. |
|
|
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. |
|
|
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. |
|
|
@@ -351,40 +361,72 @@ class Conv2dQuant(_Conv): |
|
|
dilation=1, |
|
|
dilation=1, |
|
|
group=1, |
|
|
group=1, |
|
|
has_bias=False, |
|
|
has_bias=False, |
|
|
weight_init='normal', |
|
|
|
|
|
bias_init='zeros', |
|
|
|
|
|
|
|
|
weight_init=None, |
|
|
|
|
|
bias_init=None, |
|
|
quant_delay=0, |
|
|
quant_delay=0, |
|
|
num_bits=8, |
|
|
num_bits=8, |
|
|
per_channel=False, |
|
|
per_channel=False, |
|
|
symmetric=False, |
|
|
symmetric=False, |
|
|
narrow_range=False): |
|
|
narrow_range=False): |
|
|
kernel_size = twice(kernel_size) |
|
|
|
|
|
super(Conv2dQuant, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation, |
|
|
|
|
|
group, has_bias, weight_init, bias_init) |
|
|
|
|
|
self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, mode=1, |
|
|
|
|
|
pad_mode=self.pad_mode, pad=self.padding, stride=self.stride, dilation=self.dilation, |
|
|
|
|
|
group=self.group) |
|
|
|
|
|
self.bias_add = P.BiasAdd() |
|
|
|
|
|
if pad_mode not in ('valid', 'same', 'pad'): |
|
|
|
|
|
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' |
|
|
|
|
|
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') |
|
|
|
|
|
self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
ema=False, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
per_channel=per_channel, |
|
|
|
|
|
channel_size=out_channels, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
|
|
|
super(Conv2dQuant, self).__init__() |
|
|
|
|
|
if isinstance(kernel_size, int): |
|
|
|
|
|
self.kernel_size = (kernel_size, kernel_size) |
|
|
|
|
|
else: |
|
|
|
|
|
self.kernel_size = kernel_size |
|
|
|
|
|
self.in_channels = check_int_positive(in_channels) |
|
|
|
|
|
self.out_channels = check_int_positive(out_channels) |
|
|
|
|
|
self.has_bias = has_bias |
|
|
|
|
|
self.stride = twice(stride) |
|
|
|
|
|
self.dilation = twice(dilation) |
|
|
|
|
|
self.pad_mode = pad_mode |
|
|
|
|
|
self.padding = padding |
|
|
|
|
|
self.group = group |
|
|
|
|
|
self.quant_delay = quant_delay |
|
|
|
|
|
|
|
|
|
|
|
if weight_init is None: |
|
|
|
|
|
weight_init = initializer( |
|
|
|
|
|
'normal', [out_channels, in_channels // group, *self.kernel_size]) |
|
|
|
|
|
self.weight = Parameter(weight_init, name='weight') |
|
|
|
|
|
if bias_init is None: |
|
|
|
|
|
bias_init = initializer('zeros', [out_channels]) |
|
|
|
|
|
if has_bias: |
|
|
|
|
|
self.bias = Parameter(bias_init, name='bias') |
|
|
|
|
|
self.bias_add = P.BiasAdd() |
|
|
|
|
|
|
|
|
|
|
|
self.conv = P.Conv2D(out_channel=self.out_channels, |
|
|
|
|
|
kernel_size=self.kernel_size, |
|
|
|
|
|
mode=1, |
|
|
|
|
|
pad_mode=self.pad_mode, |
|
|
|
|
|
pad=self.padding, |
|
|
|
|
|
stride=self.stride, |
|
|
|
|
|
dilation=self.dilation, |
|
|
|
|
|
group=self.group) |
|
|
|
|
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
ema=False, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
per_channel=per_channel, |
|
|
|
|
|
out_channels=out_channels, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
|
|
|
def construct(self, x): |
|
|
def construct(self, x): |
|
|
weight_q = self.fake_quant_weight(self.weight) |
|
|
|
|
|
out = self.conv2d(x, weight_q) |
|
|
|
|
|
|
|
|
weight = self.fake_quant_weight(self.weight) |
|
|
|
|
|
out = self.conv(x, weight) |
|
|
if self.has_bias: |
|
|
if self.has_bias: |
|
|
return self.bias_add(out, self.bias) |
|
|
return self.bias_add(out, self.bias) |
|
|
return out |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def extend_repr(self): |
|
|
|
|
|
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \ |
|
|
|
|
|
'pad_mode={}, padding={}, dilation={}, group={}, ' \ |
|
|
|
|
|
'has_bias={}, quant_delay={}'.format( |
|
|
|
|
|
self.in_channels, self.out_channels, self.kernel_size, self.stride, |
|
|
|
|
|
self.pad_mode, self.padding, self.dilation, self.group, |
|
|
|
|
|
self.has_bias, self.quant_delay) |
|
|
|
|
|
return s |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DenseQuant(Cell): |
|
|
class DenseQuant(Cell): |
|
|
r""" |
|
|
r""" |
|
|
@@ -453,15 +495,15 @@ class DenseQuant(Cell): |
|
|
|
|
|
|
|
|
self.activation = get_activation(activation) |
|
|
self.activation = get_activation(activation) |
|
|
self.activation_flag = self.activation is not None |
|
|
self.activation_flag = self.activation is not None |
|
|
self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
ema=False, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
per_channel=per_channel, |
|
|
|
|
|
channel_size=out_channels, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
|
|
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
ema=False, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
per_channel=per_channel, |
|
|
|
|
|
out_channels=out_channels, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
|
|
|
def construct(self, x): |
|
|
def construct(self, x): |
|
|
"""Use operators to construct to Dense layer.""" |
|
|
"""Use operators to construct to Dense layer.""" |
|
|
@@ -511,13 +553,13 @@ class ReLUQuant(Cell): |
|
|
symmetric=False, |
|
|
symmetric=False, |
|
|
narrow_range=False): |
|
|
narrow_range=False): |
|
|
super(ReLUQuant, self).__init__() |
|
|
super(ReLUQuant, self).__init__() |
|
|
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
|
|
|
self.fake_quant_act = FakeQuantWithMinMax(min_init=0, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
self.relu = P.ReLU() |
|
|
self.relu = P.ReLU() |
|
|
|
|
|
|
|
|
def construct(self, x): |
|
|
def construct(self, x): |
|
|
@@ -551,13 +593,13 @@ class ReLU6Quant(Cell): |
|
|
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, |
|
|
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, |
|
|
narrow_range=False): |
|
|
narrow_range=False): |
|
|
super(ReLU6Quant, self).__init__() |
|
|
super(ReLU6Quant, self).__init__() |
|
|
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
|
|
|
self.fake_quant_act = FakeQuantWithMinMax(min_init=0, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
self.relu6 = P.ReLU6() |
|
|
self.relu6 = P.ReLU6() |
|
|
|
|
|
|
|
|
def construct(self, x): |
|
|
def construct(self, x): |
|
|
@@ -592,20 +634,20 @@ class HSwishQuant(Cell): |
|
|
symmetric=False, |
|
|
symmetric=False, |
|
|
narrow_range=False): |
|
|
narrow_range=False): |
|
|
super(HSwishQuant, self).__init__() |
|
|
super(HSwishQuant, self).__init__() |
|
|
self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
|
|
|
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
self.act = P.HSwish() |
|
|
self.act = P.HSwish() |
|
|
|
|
|
|
|
|
def construct(self, x): |
|
|
def construct(self, x): |
|
|
@@ -641,20 +683,20 @@ class HSigmoidQuant(Cell): |
|
|
symmetric=False, |
|
|
symmetric=False, |
|
|
narrow_range=False): |
|
|
narrow_range=False): |
|
|
super(HSigmoidQuant, self).__init__() |
|
|
super(HSigmoidQuant, self).__init__() |
|
|
self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
|
|
|
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
self.act = P.HSigmoid() |
|
|
self.act = P.HSigmoid() |
|
|
|
|
|
|
|
|
def construct(self, x): |
|
|
def construct(self, x): |
|
|
@@ -690,16 +732,57 @@ class TensorAddQuant(Cell): |
|
|
symmetric=False, |
|
|
symmetric=False, |
|
|
narrow_range=False): |
|
|
narrow_range=False): |
|
|
super(TensorAddQuant, self).__init__() |
|
|
super(TensorAddQuant, self).__init__() |
|
|
self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=-6, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
|
|
|
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
self.add = P.TensorAdd() |
|
|
self.add = P.TensorAdd() |
|
|
|
|
|
|
|
|
def construct(self, x1, x2): |
|
|
def construct(self, x1, x2): |
|
|
x = self.add(x1, x2) |
|
|
x = self.add(x1, x2) |
|
|
x = self.fake_quant_act(x) |
|
|
x = self.fake_quant_act(x) |
|
|
return x |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MulQuant(Cell): |
|
|
|
|
|
r""" |
|
|
|
|
|
Add Fake Quant OP after Mul OP. |
|
|
|
|
|
|
|
|
|
|
|
For a more Detailed overview of Mul op. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. |
|
|
|
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
|
|
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. |
|
|
|
|
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. |
|
|
|
|
|
|
|
|
|
|
|
Inputs: |
|
|
|
|
|
- **x** (Tensor) - The input of MulQuant. |
|
|
|
|
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
Tensor, with the same type and shape as the `x`. |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
|
num_bits=8, |
|
|
|
|
|
quant_delay=0, |
|
|
|
|
|
symmetric=False, |
|
|
|
|
|
narrow_range=False): |
|
|
|
|
|
super(MulQuant, self).__init__() |
|
|
|
|
|
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, |
|
|
|
|
|
max_init=6, |
|
|
|
|
|
num_bits=num_bits, |
|
|
|
|
|
quant_delay=quant_delay, |
|
|
|
|
|
ema=True, |
|
|
|
|
|
symmetric=symmetric, |
|
|
|
|
|
narrow_range=narrow_range) |
|
|
|
|
|
self.mul = P.Mul() |
|
|
|
|
|
|
|
|
|
|
|
def construct(self, x1, x2): |
|
|
|
|
|
x = self.mul(x1, x2) |
|
|
|
|
|
x = self.fake_quant_act(x) |
|
|
|
|
|
return x |