Merge pull request !4658 from chengxb7532/mastertags/v0.7.0-beta
| @@ -607,7 +607,6 @@ class Conv2dBnWithoutFoldQuant(Cell): | |||
| group (int): Split filter into groups, `in_ channels` and `out_channels` should be | |||
| divisible by the number of groups. Default: 1. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. | |||
| has_bn (bool): Specifies to used batchnorm or not. Default: False. | |||
| eps (float): Parameters for BatchNormal. Default: 1e-5. | |||
| momentum (float): Parameters for BatchNormal op. Default: 0.997. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. | |||
| @@ -641,7 +640,6 @@ class Conv2dBnWithoutFoldQuant(Cell): | |||
| dilation=1, | |||
| group=1, | |||
| has_bias=False, | |||
| has_bn=True, | |||
| eps=1e-5, | |||
| momentum=0.997, | |||
| weight_init='normal', | |||
| @@ -693,17 +691,14 @@ class Conv2dBnWithoutFoldQuant(Cell): | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range, | |||
| quant_delay=quant_delay) | |||
| self.has_bn = validator.check_bool("has_bn", has_bn) | |||
| if has_bn: | |||
| self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum) | |||
| self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum) | |||
| def construct(self, x): | |||
| weight = self.fake_quant_weight(self.weight) | |||
| out = self.conv(x, weight) | |||
| if self.has_bias: | |||
| out = self.bias_add(out, self.bias) | |||
| if self.has_bn: | |||
| out = self.batchnorm(out) | |||
| out = self.batchnorm(out) | |||
| return out | |||
| def extend_repr(self): | |||
| @@ -208,7 +208,6 @@ class ConvertToQuantNetwork: | |||
| group=conv_inner.group, | |||
| eps=bn_inner.eps, | |||
| momentum=bn_inner.momentum, | |||
| has_bn=True, | |||
| quant_delay=self.weight_qdelay, | |||
| per_channel=self.weight_channel, | |||
| num_bits=self.weight_bits, | |||
| @@ -378,8 +377,10 @@ class ExportToQuantInferNetwork: | |||
| if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)): | |||
| if cell_core.has_bias: | |||
| bias = cell_core.bias.data.asnumpy() | |||
| elif isinstance(cell_core, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant)): | |||
| elif isinstance(cell_core, quant.Conv2dBnFoldQuant): | |||
| weight, bias = quant_utils.fold_batchnorm(weight, cell_core) | |||
| elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant): | |||
| weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core) | |||
| # apply the quant | |||
| weight = quant_utils.weight2int(weight, scale_w, zp_w) | |||
| @@ -211,3 +211,42 @@ def fold_batchnorm(weight, cell_quant): | |||
| weight = weight * _gamma / _sigma | |||
| bias = beta - gamma * mean / sigma | |||
| return weight, bias | |||
| def without_fold_batchnorm(weight, cell_quant): | |||
| r""" | |||
| Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight. | |||
| Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. | |||
| Args: | |||
| weight (numpy.ndarray): Weight of `cell_quant`. | |||
| cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`. | |||
| Returns: | |||
| weight (numpy.ndarray): whihout folded weight. | |||
| bias (numpy.ndarray): without folded bias. | |||
| """ | |||
| variance = cell_quant.batchnorm.moving_variance.data.asnumpy() | |||
| mean = cell_quant.batchnorm.moving_mean.data.asnumpy() | |||
| gamma = cell_quant.batchnorm.gamma.data.asnumpy() | |||
| beta = cell_quant.batchnorm.beta.data.asnumpy() | |||
| epsilon = cell_quant.batchnorm.eps | |||
| sigma = np.sqrt(variance + epsilon) | |||
| if gamma.shape[0] == weight.shape[0]: | |||
| # `Conv2d` or `Dense` op weight | |||
| shape_list = [-1] + [1] * len(weight.shape[1:]) | |||
| _gamma = gamma.reshape(shape_list) | |||
| _sigma = sigma.reshape(shape_list) | |||
| elif gamma.shape[0] == weight.shape[1]: | |||
| # `DepthwiseConv2d` op weight | |||
| shape_list = [1, -1] + [1] * len(weight.shape[2:]) | |||
| _gamma = gamma.reshape(shape_list) | |||
| _sigma = sigma.reshape(shape_list) | |||
| else: | |||
| raise ValueError("Unsupported weight shape({})".format(weight.shape)) | |||
| weight = weight * _gamma / _sigma | |||
| bias = beta - gamma * mean / sigma | |||
| return weight, bias | |||