|
|
|
@@ -34,6 +34,7 @@ from ...ops.operations import _quant_ops as Q |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
'FakeQuantWithMinMaxObserver', |
|
|
|
'Conv2dBnFoldQuantOneConv', |
|
|
|
'Conv2dBnFoldQuant', |
|
|
|
'Conv2dBnWithoutFoldQuant', |
|
|
|
'Conv2dQuant', |
|
|
|
@@ -330,6 +331,220 @@ QuantConfig = namedtuple("QuantConfig", ['weight', 'activation']) |
|
|
|
quant_config_default = QuantConfig(weight=FakeQuantWithMinMaxObserver, activation=FakeQuantWithMinMaxObserver) |
|
|
|
|
|
|
|
|
|
|
|
class Conv2dBnFoldQuantOneConv(Cell): |
|
|
|
r""" |
|
|
|
2D convolution with BatchNormal op folded construct. |
|
|
|
|
|
|
|
This part is a more detailed overview of Conv2d op. |
|
|
|
|
|
|
|
Args: |
|
|
|
in_channels (int): The number of input channel :math:`C_{in}`. |
|
|
|
out_channels (int): The number of output channel :math:`C_{out}`. |
|
|
|
kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window. |
|
|
|
stride (int): Specifies stride for all spatial dimensions with the same value. |
|
|
|
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". |
|
|
|
padding (int): Implicit paddings on both sides of the input. Default: 0. |
|
|
|
eps (float): Parameters for BatchNormal. Default: 1e-5. |
|
|
|
momentum (float): Parameters for BatchNormal op. Default: 0.997. |
|
|
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. |
|
|
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the |
|
|
|
convolution kernel. Default: 'normal'. |
|
|
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the |
|
|
|
bias vector. Default: 'zeros'. |
|
|
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the |
|
|
|
beta vector. Default: 'zeros'. |
|
|
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the |
|
|
|
gamma vector. Default: 'ones'. |
|
|
|
mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the |
|
|
|
mean vector. Default: 'zeros'. |
|
|
|
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the |
|
|
|
variance vector. Default: 'ones'. |
|
|
|
fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True. |
|
|
|
quant_config (QuantConfig): Configs the oberser types and quant configs of weight and activation. Default: |
|
|
|
both set to default FakeQuantWithMinMaxObserver. |
|
|
|
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> qconfig = compression.quant.create_quant_config() |
|
|
|
>>> conv2d_bnfold = nn.Conv2dBnFoldQuant(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid", |
|
|
|
>>> quant_config=qconfig) |
|
|
|
>>> input = Tensor(np.random.randint(-2, 2, (2, 1, 3, 3)), mindspore.float32) |
|
|
|
>>> result = conv2d_bnfold(input) |
|
|
|
>>> result.shape |
|
|
|
(2, 6, 2, 2) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
in_channels, |
|
|
|
out_channels, |
|
|
|
kernel_size, |
|
|
|
stride=1, |
|
|
|
pad_mode='same', |
|
|
|
padding=0, |
|
|
|
dilation=1, |
|
|
|
group=1, |
|
|
|
eps=1e-5, |
|
|
|
momentum=0.997, |
|
|
|
has_bias=False, |
|
|
|
weight_init='normal', |
|
|
|
bias_init='zeros', |
|
|
|
beta_init='zeros', |
|
|
|
gamma_init='ones', |
|
|
|
mean_init='zeros', |
|
|
|
var_init='ones', |
|
|
|
fake=True, |
|
|
|
quant_config=quant_config_default, |
|
|
|
quant_dtype=QuantDtype.INT8): |
|
|
|
"""Initialize Conv2dBnFoldQuant layer""" |
|
|
|
super(Conv2dBnFoldQuantOneConv, self).__init__() |
|
|
|
self.in_channels = in_channels |
|
|
|
self.out_channels = out_channels |
|
|
|
self.kernel_size = twice(kernel_size) |
|
|
|
self.stride = twice(stride) |
|
|
|
self.pad_mode = pad_mode |
|
|
|
self.padding = padding |
|
|
|
self.dilation = twice(dilation) |
|
|
|
self.group = group |
|
|
|
self.eps = eps |
|
|
|
self.momentum = momentum |
|
|
|
self.has_bias = has_bias |
|
|
|
self.fake = fake |
|
|
|
self.quant_config = quant_config |
|
|
|
self.quant_dtype = quant_dtype |
|
|
|
self.is_gpu = context.get_context('device_target') == "GPU" |
|
|
|
self.is_Ascend = context.get_context('device_target') == "Ascend" |
|
|
|
if context.get_context("enable_ge"): |
|
|
|
self.is_ge_backend = True |
|
|
|
else: |
|
|
|
self.is_ge_backend = False |
|
|
|
|
|
|
|
# initialize convolution op and Parameter |
|
|
|
if context.get_context('device_target') == "Ascend" and group > 1: |
|
|
|
Validator.check_equal_int(group, in_channels, 'group') |
|
|
|
Validator.check_equal_int(group, out_channels, 'group') |
|
|
|
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, |
|
|
|
kernel_size=self.kernel_size, |
|
|
|
pad_mode=pad_mode, |
|
|
|
pad=padding, |
|
|
|
stride=self.stride, |
|
|
|
dilation=self.dilation) |
|
|
|
weight_shape = [1, in_channels, *self.kernel_size] |
|
|
|
channel_axis = 1 |
|
|
|
else: |
|
|
|
self.conv = P.Conv2D(out_channel=out_channels, |
|
|
|
kernel_size=self.kernel_size, |
|
|
|
pad_mode=pad_mode, |
|
|
|
pad=padding, |
|
|
|
stride=self.stride, |
|
|
|
dilation=self.dilation, |
|
|
|
group=group) |
|
|
|
weight_shape = [out_channels, in_channels // group, *self.kernel_size] |
|
|
|
channel_axis = 0 |
|
|
|
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') |
|
|
|
self.bias_add = P.BiasAdd() |
|
|
|
if Validator.check_bool(has_bias): |
|
|
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') |
|
|
|
else: |
|
|
|
self.bias = None |
|
|
|
|
|
|
|
# initialize BatchNorm Parameter |
|
|
|
self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma') |
|
|
|
self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta') |
|
|
|
self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False) |
|
|
|
self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance', |
|
|
|
requires_grad=False) |
|
|
|
|
|
|
|
# initialize fake ops |
|
|
|
self.fake_quant_weight = quant_config.weight(min_init=-6, |
|
|
|
max_init=6, |
|
|
|
ema=False, |
|
|
|
channel_axis=channel_axis, |
|
|
|
num_channels=out_channels, |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend): |
|
|
|
self.bn_train = P.BatchNorm(is_training=True, |
|
|
|
epsilon=self.eps) |
|
|
|
elif self.is_gpu: |
|
|
|
self.bn_train = P.FusedBatchNormEx(mode=1, |
|
|
|
epsilon=self.eps, |
|
|
|
momentum=self.momentum, |
|
|
|
data_format=self.format) |
|
|
|
else: |
|
|
|
self.bn_train = P.FusedBatchNorm(mode=1, |
|
|
|
epsilon=self.eps, |
|
|
|
momentum=self.momentum) |
|
|
|
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) |
|
|
|
data_parallel_strategy = ((1,), (1,)) |
|
|
|
data_parallel_strategy_one = ((1,), ()) |
|
|
|
self.sub_mean = P.Sub().shard(data_parallel_strategy) |
|
|
|
self.sub_var = P.Sub().shard(data_parallel_strategy) |
|
|
|
self.mul_mean = P.Mul().shard(data_parallel_strategy_one) |
|
|
|
self.mul_var = P.Mul().shard(data_parallel_strategy_one) |
|
|
|
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy) |
|
|
|
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy) |
|
|
|
self.one = Tensor(1, mstype.int32) |
|
|
|
self.reshape = P.Reshape() |
|
|
|
|
|
|
|
def extend_repr(self): |
|
|
|
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ |
|
|
|
'pad_mode={}, padding={}, dilation={}, group={}, ' \ |
|
|
|
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels, |
|
|
|
self.kernel_size, self.stride, |
|
|
|
self.pad_mode, self.padding, self.dilation, |
|
|
|
self.group, |
|
|
|
self.fake, self.freeze_bn, self.momentum, |
|
|
|
self.fake_quant_weight.quant_delay) |
|
|
|
return s |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
running_std = P.Sqrt()(P.TensorAdd()(self.moving_variance, self.eps)) |
|
|
|
scale_factor = self.gamma / running_std |
|
|
|
weight = self.weight * scale_factor |
|
|
|
if self.channel_axis: |
|
|
|
scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) |
|
|
|
else: |
|
|
|
scale_factor = self.reshape(scale_factor, (-1, 1, 1, 1)) |
|
|
|
if self.fake: |
|
|
|
weight = self.fake_quant_weight(weight) |
|
|
|
conv = self.conv(x, weight) |
|
|
|
scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) |
|
|
|
conv_orig = conv / scale_factor |
|
|
|
if self.training: |
|
|
|
if not self.is_gpu: |
|
|
|
out, batch_mean, batch_var, _, _ = self.bn_train(conv_orig, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
None, |
|
|
|
None) |
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, batch_mean) |
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum) |
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, batch_var) |
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum) |
|
|
|
out = F.depend(out, self.assign_sub_mean(self.moving_mean, temp_mean)) |
|
|
|
out = F.depend(out, self.assign_sub_var(self.moving_variance, temp_variance)) |
|
|
|
else: |
|
|
|
out = self.bn_train(conv_orig, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
self.moving_mean, |
|
|
|
self.moving_variance)[0] |
|
|
|
else: |
|
|
|
out = self.bn_infer(conv_orig, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
self.moving_mean, |
|
|
|
self.moving_variance)[0] |
|
|
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class Conv2dBnFoldQuant(Cell): |
|
|
|
r""" |
|
|
|
2D convolution with BatchNormal op folded construct. |
|
|
|
@@ -627,7 +842,7 @@ class Conv2dBnWithoutFoldQuant(Cell): |
|
|
|
channel_axis=channel_axis, |
|
|
|
num_channels=out_channels, |
|
|
|
quant_dtype=quant_dtype) |
|
|
|
self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum) |
|
|
|
self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=1-momentum) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
weight = self.fake_quant_weight(self.weight) |
|
|
|
|