|
|
|
@@ -27,10 +27,8 @@ from mindspore.nn.cell import Cell |
|
|
|
from mindspore.nn.layer.activation import get_activation |
|
|
|
import mindspore.context as context |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
'FakeQuantWithMinMax', |
|
|
|
'DepthwiseConv2dBatchNormQuant', |
|
|
|
'Conv2dBatchNormQuant', |
|
|
|
'Conv2dQuant', |
|
|
|
'DenseQuant', |
|
|
|
@@ -113,7 +111,7 @@ class FakeQuantWithMinMaxD(Cell): |
|
|
|
max_init (int, list): The dimension of channel or 1(layer). Default: 6. |
|
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. |
|
|
|
ema (bool): Exponential Moving Average algorithm update min and max. Default: False. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
per_channel (bool): Quantization by layer or channel. Default: False. |
|
|
|
out_channels (int): declarate the min and max channel size, Default: 1. |
|
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
|
@@ -131,6 +129,7 @@ class FakeQuantWithMinMaxD(Cell): |
|
|
|
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) |
|
|
|
>>> result = fake_quant(input_x) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
min_init=-6, |
|
|
|
max_init=6, |
|
|
|
@@ -215,7 +214,7 @@ class FakeQuantWithMinMax(Cell): |
|
|
|
max_init (int, list): The dimension of channel or 1(layer). Default: 6. |
|
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. |
|
|
|
ema (bool): Exponential Moving Average algorithm update min and max. Default: False. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
per_channel (bool): Quantization by layer or channel. Default: False. |
|
|
|
out_channels (int): declarate the min and max channel size, Default: 1. |
|
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
|
@@ -335,191 +334,6 @@ class FakeQuantWithMinMax(Cell): |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class DepthwiseConv2dBatchNormQuant(Cell): |
|
|
|
r""" |
|
|
|
2D depthwise convolution with BatchNormal op folded layer. |
|
|
|
|
|
|
|
For a more Detailed overview of Conv2d op. |
|
|
|
|
|
|
|
Args: |
|
|
|
in_channels (int): The number of input channel :math:`C_{in}`. |
|
|
|
out_channels (int): The number of output channel :math:`C_{out}`. |
|
|
|
kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window. |
|
|
|
stride (int): Specifies stride for all spatial dimensions with the same value. |
|
|
|
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". |
|
|
|
padding: (int): Implicit paddings on both sides of the input. Default: 0. |
|
|
|
eps (int): Parameters for BatchNormal. Default: 1e-5. |
|
|
|
momentum (int): Parameters for BatchNormal op. Default: 0.9. |
|
|
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the |
|
|
|
convolution kernel. Default: 'None'. |
|
|
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the |
|
|
|
beta vector. Default: 'None'. |
|
|
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the |
|
|
|
gamma vector. Default: 'None'. |
|
|
|
mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the |
|
|
|
mean vector. Default: 'None'. |
|
|
|
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the |
|
|
|
variance vector. Default: 'None'. |
|
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
|
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. |
|
|
|
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. |
|
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. |
|
|
|
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. |
|
|
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. |
|
|
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> quant = nn.DepthwiseConv2dBatchNormQuant(1, 6, |
|
|
|
kernel_size= (2, 2), |
|
|
|
stride=(1, 1), |
|
|
|
pad_mode="valid", |
|
|
|
>>> dilation=(1, 1)) |
|
|
|
>>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32) |
|
|
|
>>> result = quant(input_x) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
in_channels, |
|
|
|
out_channels, |
|
|
|
kernel_size, |
|
|
|
stride=1, |
|
|
|
pad_mode='same', |
|
|
|
padding=0, |
|
|
|
dilation=1, |
|
|
|
group=1, |
|
|
|
eps=1e-5, |
|
|
|
momentum=0.997, |
|
|
|
weight_init=None, |
|
|
|
beta_init=None, |
|
|
|
gamma_init=None, |
|
|
|
mean_init=None, |
|
|
|
var_init=None, |
|
|
|
quant_delay=0, |
|
|
|
freeze_bn=100000, |
|
|
|
fake=True, |
|
|
|
num_bits=8, |
|
|
|
per_channel=False, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False): |
|
|
|
"""init DepthwiseConv2dBatchNormQuant layer""" |
|
|
|
super(DepthwiseConv2dBatchNormQuant, self).__init__() |
|
|
|
self.in_channels = in_channels |
|
|
|
self.out_channels = out_channels |
|
|
|
self.pad_mode = pad_mode |
|
|
|
self.padding = padding |
|
|
|
self.dilation = twice(dilation) |
|
|
|
self.stride = twice(stride) |
|
|
|
self.group = group |
|
|
|
self.fake = fake |
|
|
|
self.freeze_bn = freeze_bn |
|
|
|
self.momentum = momentum |
|
|
|
self.quant_delay = quant_delay |
|
|
|
if isinstance(kernel_size, int): |
|
|
|
self.kernel_size = (kernel_size, kernel_size) |
|
|
|
else: |
|
|
|
self.kernel_size = kernel_size |
|
|
|
if group > 1: |
|
|
|
validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant') |
|
|
|
validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant') |
|
|
|
self.is_depthwise = group > 1 |
|
|
|
|
|
|
|
channel_multiplier = out_channels // in_channels |
|
|
|
self.conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, |
|
|
|
kernel_size=kernel_size, |
|
|
|
stride=stride, |
|
|
|
pad_mode=pad_mode, |
|
|
|
pad=padding) |
|
|
|
|
|
|
|
if weight_init is None: |
|
|
|
weight_init = initializer('normal', [channel_multiplier, in_channels, *kernel_size]) |
|
|
|
self.weight = Parameter(weight_init, name='weight') |
|
|
|
if gamma_init is None: |
|
|
|
gamma_init = initializer('ones', [out_channels]) |
|
|
|
self.gamma = Parameter(gamma_init, name='gamma') |
|
|
|
if beta_init is None: |
|
|
|
beta_init = initializer('zeros', [out_channels]) |
|
|
|
self.beta = Parameter(beta_init, name='beta') |
|
|
|
if mean_init is None: |
|
|
|
mean_init = initializer('zeros', [out_channels]) |
|
|
|
self.moving_mean = Parameter( |
|
|
|
mean_init, name='moving_mean', requires_grad=False) |
|
|
|
if var_init is None: |
|
|
|
var_init = initializer('ones', [out_channels]) |
|
|
|
self.moving_variance = Parameter( |
|
|
|
var_init, name='moving_variance', requires_grad=False) |
|
|
|
|
|
|
|
self.step = Parameter(initializer( |
|
|
|
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False) |
|
|
|
|
|
|
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=False, |
|
|
|
num_bits=num_bits, |
|
|
|
quant_delay=quant_delay, |
|
|
|
per_channel=per_channel, |
|
|
|
out_channels=out_channels, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range) |
|
|
|
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) |
|
|
|
|
|
|
|
self.correct_mul = P.CorrectionMul(self.is_depthwise) |
|
|
|
if context.get_context('device_target') == "Ascend": |
|
|
|
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn) |
|
|
|
self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0) |
|
|
|
elif context.get_context('device_target') == "GPU": |
|
|
|
self.batchnorm_fold2_train = P.BatchNormFold2(freeze_bn=freeze_bn) |
|
|
|
self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0) |
|
|
|
else: |
|
|
|
raise ValueError("Not support platform.") |
|
|
|
self.one = Tensor(1, mstype.int32) |
|
|
|
self.assignadd = P.AssignAdd() |
|
|
|
self.is_gpu = context.get_context('device_target') == "GPU" |
|
|
|
|
|
|
|
def extend_repr(self): |
|
|
|
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ |
|
|
|
'pad_mode={}, padding={}, dilation={}, group={}, ' \ |
|
|
|
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format( |
|
|
|
self.in_channels, self.out_channels, self.kernel_size, self.stride, |
|
|
|
self.pad_mode, self.padding, self.dilation, self.group, |
|
|
|
self.fake, self.freeze_bn, self.momentum, self.quant_delay) |
|
|
|
return s |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
out_conv = self.conv(x, self.weight) |
|
|
|
# BN fold1 |
|
|
|
batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv, |
|
|
|
self.moving_mean, |
|
|
|
self.moving_variance, |
|
|
|
self.step) |
|
|
|
# fake weight |
|
|
|
weight = self.correct_mul(self.weight, self.gamma, running_std) |
|
|
|
if self.fake: |
|
|
|
weight = self.fake_quant_weight(weight) |
|
|
|
out = self.conv(x, weight) |
|
|
|
# BN fold2 |
|
|
|
if self.is_gpu: |
|
|
|
if self.training: |
|
|
|
out = self.batchnorm_fold2_train(out, self.beta, self.gamma, |
|
|
|
batch_std, batch_mean, running_std, running_mean, self.step) |
|
|
|
F.control_depend(out, self.assignadd(self.step, self.one)) |
|
|
|
else: |
|
|
|
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, |
|
|
|
batch_std, batch_mean, running_std, running_mean, self.step) |
|
|
|
else: |
|
|
|
if self.training: |
|
|
|
out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) |
|
|
|
F.control_depend(out, self.assignadd(self.step, self.one)) |
|
|
|
else: |
|
|
|
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class Conv2dBatchNormQuant(Cell): |
|
|
|
r""" |
|
|
|
2D convolution with BatchNormal op folded layer. |
|
|
|
@@ -593,23 +407,47 @@ class Conv2dBatchNormQuant(Cell): |
|
|
|
super(Conv2dBatchNormQuant, self).__init__() |
|
|
|
self.in_channels = in_channels |
|
|
|
self.out_channels = out_channels |
|
|
|
self.kernel_size = twice(kernel_size) |
|
|
|
self.stride = twice(stride) |
|
|
|
self.pad_mode = pad_mode |
|
|
|
self.padding = padding |
|
|
|
self.dilation = twice(dilation) |
|
|
|
self.stride = twice(stride) |
|
|
|
self.group = group |
|
|
|
self.fake = fake |
|
|
|
self.freeze_bn = freeze_bn |
|
|
|
self.eps = eps |
|
|
|
self.momentum = momentum |
|
|
|
self.quant_delay = quant_delay |
|
|
|
if isinstance(kernel_size, int): |
|
|
|
self.kernel_size = (kernel_size, kernel_size) |
|
|
|
self.freeze_bn = freeze_bn |
|
|
|
self.fake = fake |
|
|
|
self.num_bits = num_bits |
|
|
|
self.per_channel = per_channel |
|
|
|
self.symmetric = symmetric |
|
|
|
self.narrow_range = narrow_range |
|
|
|
|
|
|
|
# initialize convolution op and Parameter |
|
|
|
if context.get_context('device_target') == "Ascend" and group > 1: |
|
|
|
validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant') |
|
|
|
validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant') |
|
|
|
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, |
|
|
|
kernel_size=self.kernel_size, |
|
|
|
pad_mode=pad_mode, |
|
|
|
pad=padding, |
|
|
|
stride=self.stride, |
|
|
|
dilation=self.dilation) |
|
|
|
if weight_init is None: |
|
|
|
weight_init = initializer('normal', [1, in_channels, *self.kernel_size]) |
|
|
|
else: |
|
|
|
self.kernel_size = kernel_size |
|
|
|
if weight_init is None: |
|
|
|
weight_init = initializer( |
|
|
|
'normal', [out_channels, in_channels // group, *self.kernel_size]) |
|
|
|
self.conv = P.Conv2D(out_channel=out_channels, |
|
|
|
kernel_size=self.kernel_size, |
|
|
|
pad_mode=pad_mode, |
|
|
|
pad=padding, |
|
|
|
stride=self.stride, |
|
|
|
dilation=self.dilation, |
|
|
|
group=group) |
|
|
|
if weight_init is None: |
|
|
|
weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size]) |
|
|
|
self.weight = Parameter(weight_init, name='weight') |
|
|
|
|
|
|
|
# initialize batchnorm Parameter |
|
|
|
if gamma_init is None: |
|
|
|
gamma_init = initializer('ones', [out_channels]) |
|
|
|
self.gamma = Parameter(gamma_init, name='gamma') |
|
|
|
@@ -618,16 +456,12 @@ class Conv2dBatchNormQuant(Cell): |
|
|
|
self.beta = Parameter(beta_init, name='beta') |
|
|
|
if mean_init is None: |
|
|
|
mean_init = initializer('zeros', [out_channels]) |
|
|
|
self.moving_mean = Parameter( |
|
|
|
mean_init, name='moving_mean', requires_grad=False) |
|
|
|
self.moving_mean = Parameter(mean_init, name='moving_mean', requires_grad=False) |
|
|
|
if var_init is None: |
|
|
|
var_init = initializer('ones', [out_channels]) |
|
|
|
self.moving_variance = Parameter( |
|
|
|
var_init, name='moving_variance', requires_grad=False) |
|
|
|
|
|
|
|
self.step = Parameter(initializer( |
|
|
|
'normal', [1], dtype=mstype.int32), name='step', requires_grad=False) |
|
|
|
self.moving_variance = Parameter(var_init, name='moving_variance', requires_grad=False) |
|
|
|
|
|
|
|
# initialize fake ops |
|
|
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=False, |
|
|
|
@@ -638,14 +472,6 @@ class Conv2dBatchNormQuant(Cell): |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range) |
|
|
|
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) |
|
|
|
self.conv = P.Conv2D(out_channel=out_channels, |
|
|
|
kernel_size=kernel_size, |
|
|
|
mode=1, |
|
|
|
pad_mode=pad_mode, |
|
|
|
pad=padding, |
|
|
|
stride=stride, |
|
|
|
dilation=1, |
|
|
|
group=group) |
|
|
|
self.correct_mul = P.CorrectionMul() |
|
|
|
if context.get_context('device_target') == "Ascend": |
|
|
|
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn) |
|
|
|
@@ -654,7 +480,8 @@ class Conv2dBatchNormQuant(Cell): |
|
|
|
self.batchnorm_fold2_train = P.BatchNormFold2(freeze_bn=freeze_bn) |
|
|
|
self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0) |
|
|
|
else: |
|
|
|
raise ValueError("Not support platform.") |
|
|
|
raise ValueError("Unsupported platform: {}".format(context.get_context('device_target'))) |
|
|
|
self.step = Parameter(initializer('normal', [1], dtype=mstype.int32), name='step', requires_grad=False) |
|
|
|
self.one = Tensor(1, mstype.int32) |
|
|
|
self.assignadd = P.AssignAdd() |
|
|
|
|
|
|
|
@@ -926,6 +753,7 @@ class ReLUQuant(Cell): |
|
|
|
Args: |
|
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. |
|
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. |
|
|
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. |
|
|
|
|
|
|
|
@@ -944,6 +772,7 @@ class ReLUQuant(Cell): |
|
|
|
def __init__(self, |
|
|
|
num_bits=8, |
|
|
|
quant_delay=0, |
|
|
|
ema_decay=0.999, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False): |
|
|
|
super(ReLUQuant, self).__init__() |
|
|
|
@@ -952,6 +781,7 @@ class ReLUQuant(Cell): |
|
|
|
num_bits=num_bits, |
|
|
|
quant_delay=quant_delay, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range) |
|
|
|
self.relu = P.ReLU() |
|
|
|
@@ -973,6 +803,7 @@ class ReLU6Quant(Cell): |
|
|
|
Args: |
|
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. |
|
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. |
|
|
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. |
|
|
|
|
|
|
|
@@ -988,7 +819,11 @@ class ReLU6Quant(Cell): |
|
|
|
>>> result = relu6_quant(input_x) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, |
|
|
|
def __init__(self, |
|
|
|
num_bits=8, |
|
|
|
quant_delay=0, |
|
|
|
ema_decay=0.999, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False): |
|
|
|
super(ReLU6Quant, self).__init__() |
|
|
|
self.fake_quant_act = FakeQuantWithMinMax(min_init=0, |
|
|
|
@@ -996,6 +831,7 @@ class ReLU6Quant(Cell): |
|
|
|
num_bits=num_bits, |
|
|
|
quant_delay=quant_delay, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range) |
|
|
|
self.relu6 = P.ReLU6() |
|
|
|
@@ -1015,6 +851,7 @@ class HSwishQuant(Cell): |
|
|
|
Args: |
|
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. |
|
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. |
|
|
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. |
|
|
|
|
|
|
|
@@ -1033,6 +870,7 @@ class HSwishQuant(Cell): |
|
|
|
def __init__(self, |
|
|
|
num_bits=8, |
|
|
|
quant_delay=0, |
|
|
|
ema_decay=0.999, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False): |
|
|
|
super(HSwishQuant, self).__init__() |
|
|
|
@@ -1041,6 +879,7 @@ class HSwishQuant(Cell): |
|
|
|
num_bits=num_bits, |
|
|
|
quant_delay=quant_delay, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range) |
|
|
|
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, |
|
|
|
@@ -1048,6 +887,7 @@ class HSwishQuant(Cell): |
|
|
|
num_bits=num_bits, |
|
|
|
quant_delay=quant_delay, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range) |
|
|
|
self.act = P.HSwish() |
|
|
|
@@ -1068,6 +908,7 @@ class HSigmoidQuant(Cell): |
|
|
|
Args: |
|
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. |
|
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. |
|
|
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. |
|
|
|
|
|
|
|
@@ -1086,6 +927,7 @@ class HSigmoidQuant(Cell): |
|
|
|
def __init__(self, |
|
|
|
num_bits=8, |
|
|
|
quant_delay=0, |
|
|
|
ema_decay=0.999, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False): |
|
|
|
super(HSigmoidQuant, self).__init__() |
|
|
|
@@ -1101,6 +943,7 @@ class HSigmoidQuant(Cell): |
|
|
|
num_bits=num_bits, |
|
|
|
quant_delay=quant_delay, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range) |
|
|
|
self.act = P.HSigmoid() |
|
|
|
@@ -1121,6 +964,7 @@ class TensorAddQuant(Cell): |
|
|
|
Args: |
|
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. |
|
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. |
|
|
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. |
|
|
|
|
|
|
|
@@ -1140,6 +984,7 @@ class TensorAddQuant(Cell): |
|
|
|
def __init__(self, |
|
|
|
num_bits=8, |
|
|
|
quant_delay=0, |
|
|
|
ema_decay=0.999, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False): |
|
|
|
super(TensorAddQuant, self).__init__() |
|
|
|
@@ -1148,6 +993,7 @@ class TensorAddQuant(Cell): |
|
|
|
num_bits=num_bits, |
|
|
|
quant_delay=quant_delay, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range) |
|
|
|
self.add = P.TensorAdd() |
|
|
|
@@ -1167,6 +1013,7 @@ class MulQuant(Cell): |
|
|
|
Args: |
|
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. |
|
|
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0. |
|
|
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. |
|
|
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. |
|
|
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. |
|
|
|
|
|
|
|
@@ -1181,6 +1028,7 @@ class MulQuant(Cell): |
|
|
|
def __init__(self, |
|
|
|
num_bits=8, |
|
|
|
quant_delay=0, |
|
|
|
ema_decay=0.999, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False): |
|
|
|
super(MulQuant, self).__init__() |
|
|
|
@@ -1189,6 +1037,7 @@ class MulQuant(Cell): |
|
|
|
num_bits=num_bits, |
|
|
|
quant_delay=quant_delay, |
|
|
|
ema=True, |
|
|
|
ema_decay=ema_decay, |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range) |
|
|
|
self.mul = P.Mul() |
|
|
|
|