Browse Source

!16362 Quant end to end accuracy testing

From: @zhang__sss
Reviewed-by: @zh_qh,@zlq2020,@liangchenghui
Signed-off-by: @zh_qh
tags/v1.3.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f8f1f0f84d
9 changed files with 51 additions and 55 deletions
  1. +19
    -14
      mindspore/compression/export/quant_export.py
  2. +1
    -1
      mindspore/compression/quant/qat.py
  3. +3
    -4
      mindspore/nn/layer/quant.py
  4. +2
    -0
      mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py
  5. +2
    -0
      mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py
  6. +2
    -0
      mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py
  7. +2
    -0
      mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py
  8. +17
    -35
      model_zoo/official/cv/resnet50_quant/models/resnet_quant.py
  9. +3
    -1
      model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py

+ 19
- 14
mindspore/compression/export/quant_export.py View File

@@ -27,7 +27,7 @@ from ...nn.layer import quant
from ...ops import operations as P
from ...ops.operations import _inner_ops as inner
from ..quant import quant_utils
from ..quant.qat import QuantizationAwareTraining, _AddFakeQuantInput, _AddFakeQuantAfterSubCell
from ..quant.qat import _AddFakeQuantInput, _AddFakeQuantAfterSubCell


__all__ = ["ExportToQuantInferNetwork"]
@@ -184,10 +184,11 @@ class ExportToQuantInferNetwork:
def _add_output_min_max_for_op(self, origin_op, fake_quant_cell):
"""add output quant info for quant op for export mindir."""
if self.is_mindir:
np_type = mstype.dtype_to_nptype(self.data_type)
_, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_cell, np_type)
origin_op.add_prim_attr('output_maxq', Tensor(maxq))
origin_op.add_prim_attr('output_minq', Tensor(minq))
if isinstance(origin_op, ops.Primitive) and not hasattr(origin_op, 'output_minq'):
np_type = mstype.dtype_to_nptype(self.data_type)
_, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_cell, np_type)
origin_op.add_prim_attr('output_maxq', Tensor(maxq))
origin_op.add_prim_attr('output_minq', Tensor(minq))

def _convert_quant2deploy(self, network):
"""Convert network's all quant subcell to deploy subcell."""
@@ -205,9 +206,13 @@ class ExportToQuantInferNetwork:
quant.Conv2dBnWithoutFoldQuant, quant.Conv2dQuant, quant.DenseQuant)):
network, change = self._convert_subcell(network, change, name, subcell, core=False)
elif isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"):
activation = subcell.get_origin()
if isinstance(activation, nn.ReLU):
self._add_output_min_max_for_op(activation.relu, subcell.fake_quant_act)
elif isinstance(activation, nn.ReLU6):
self._add_output_min_max_for_op(activation.relu6, subcell.fake_quant_act)
if self.upcell:
self._add_output_min_max_for_op(self.upcell.core_op, subcell.fake_quant_act)
activation = subcell.get_origin()
network.insert_child_to_cell(name, activation)
change = True
elif isinstance(subcell, nn.TensorAddQuant):
@@ -216,8 +221,7 @@ class ExportToQuantInferNetwork:
subcell.__delattr__("add")
subcell.__setattr__("add", add_op)
add_op = subcell.add
if add_op:
self._add_output_min_max_for_op(add_op, subcell.fake_quant_act)
self._add_output_min_max_for_op(add_op, subcell.fake_quant_act)
subcell.__delattr__("fake_quant_act")
subcell.__setattr__("fake_quant_act", P.identity())
elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver):
@@ -227,11 +231,10 @@ class ExportToQuantInferNetwork:
network.__setattr__(name, P.identity())
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
op = subcell.subcell
if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive):
self._add_output_min_max_for_op(op, subcell.fake_quant_act)
network.__delattr__(name)
network.__setattr__(name, op)
change = True
self._add_output_min_max_for_op(op, subcell.fake_quant_act)
network.__delattr__(name)
network.__setattr__(name, op)
change = True
else:
self.upcell, self.upname = None, None
self._convert_quant2deploy(subcell)
@@ -246,7 +249,9 @@ class ExportToQuantInferNetwork:
if core:
cell_core = subcell.conv if conv else subcell.dense
activation = subcell.activation
if hasattr(activation, 'fake_quant_act'):
if hasattr(activation, 'fake_quant_act_before'):
fake_quant_act = activation.fake_quant_act_before
elif hasattr(activation, 'fake_quant_act'):
fake_quant_act = activation.fake_quant_act
else:
cell_core = subcell


+ 1
- 1
mindspore/compression/quant/qat.py View File

@@ -193,7 +193,7 @@ class QuantizationAwareTraining(Quantizer):
>>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False])
>>> net_qat = quantizer.quantize(net)
"""
__quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
__quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv", "ReduceMean"]

def __init__(self,
bn_fold=True,


+ 3
- 4
mindspore/nn/layer/quant.py View File

@@ -28,7 +28,7 @@ from mindspore._checkparam import Validator, twice
from mindspore.compression.common import QuantDtype
import mindspore.context as context
from .normalization import BatchNorm2d
from .activation import get_activation, ReLU
from .activation import get_activation
from ..cell import Cell
from ...ops.operations import _quant_ops as Q

@@ -1601,9 +1601,6 @@ class QuantMindirBlock(Cell):
self.activation = activation
self.has_act = activation is not None
self.bias_add = P.BiasAdd()
if isinstance(activation, ReLU):
self.activation = None
self.has_act = False

def construct(self, x):
if self.has_bias:
@@ -1611,6 +1608,8 @@ class QuantMindirBlock(Cell):
x = self.bias_add(x, self.bias)
else:
x = self.core_op(x, self.weight)
if self.has_act:
x = self.activation(x)
return x

def extend_repr(self):


+ 2
- 0
mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py View File

@@ -52,6 +52,8 @@ def fake_learned_scale_quant_perchannel_compute(input_data, alpha_data, quant_ma
kernel_name="fake_learned_scale_quant_perchannel"):
"""FakeLearnedScaleQuantPerChannel"""
input_shape = te.lang.cce.util.shape_to_list(input_data.shape)
eps = tvm.const(1e-6, input_data.dtype)
alpha_data = te.lang.cce.vcmpsel(te.lang.cce.vabs(alpha_data), eps, 'ge', alpha_data, eps)
alpha_data = te.lang.cce.broadcast(alpha_data, input_shape, input_data.dtype)
quant_max_data = te.lang.cce.broadcast(quant_max_data, input_shape, input_data.dtype)



+ 2
- 0
mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py View File

@@ -63,6 +63,8 @@ def fake_learned_scale_quant_perchannel_grad_d_compute(dout, input_data, alpha_d
kernel_name="fake_learned_scale_quant_perchannel_grad_d"):
"""FakeLearnedScaleQuantPerChannelGradD"""
input_shape = te.lang.cce.util.shape_to_list(input_data.shape)
eps = tvm.const(1e-6, input_data.dtype)
alpha_data = te.lang.cce.vcmpsel(te.lang.cce.vabs(alpha_data), eps, 'ge', alpha_data, eps)
alpha_data = te.lang.cce.broadcast(alpha_data, input_shape, input_data.dtype)
quant_max_data = te.lang.cce.broadcast(quant_max_data, input_shape, input_data.dtype)



+ 2
- 0
mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py View File

@@ -52,6 +52,8 @@ def fake_learned_scale_quant_perlayer_compute(input_data, alpha_data, quant_max_
kernel_name="fake_learned_scale_quant_perlayer"):
"""FakeLearnedScaleQuantPerLayer"""
input_shape = te.lang.cce.util.shape_to_list(input_data.shape)
eps = tvm.const(1e-6, input_data.dtype)
alpha_data = te.lang.cce.vcmpsel(te.lang.cce.vabs(alpha_data), eps, 'ge', alpha_data, eps)
alpha_data = te.lang.cce.broadcast(alpha_data, input_shape, input_data.dtype)
quant_max_data = te.lang.cce.broadcast(quant_max_data, input_shape, input_data.dtype)



+ 2
- 0
mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py View File

@@ -64,6 +64,8 @@ def fake_learned_scale_quant_perlayer_grad_d_compute(dout, input_data, alpha_dat
kernel_name="fake_learned_scale_quant_perlayer_grad_d"):
"""FakeLearnedScaleQuantPerLayerGradD"""
input_shape = te.lang.cce.util.shape_to_list(input_data.shape)
eps = tvm.const(1e-6, input_data.dtype)
alpha_data = te.lang.cce.vcmpsel(te.lang.cce.vabs(alpha_data), eps, 'ge', alpha_data, eps)
alpha_data = te.lang.cce.broadcast(alpha_data, input_shape, input_data.dtype)
quant_max_data = te.lang.cce.broadcast(quant_max_data, input_shape, input_data.dtype)



+ 17
- 35
model_zoo/official/cv/resnet50_quant/models/resnet_quant.py View File

@@ -13,39 +13,24 @@
# limitations under the License.
# ============================================================================
"""ResNet."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor

def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)

class ConvBNReLU(nn.Cell):
def ConvBNReLU(in_channel, out_channel, kernel_size, stride=1):
"""
Convolution/Depthwise fused with Batchnorm and ReLU block definition.

Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
kernel_size (int): Input kernel size.
stride (int): Stride size for the first convolutional layer. Default: 1.
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.

Returns:
Tensor, output tensor.

Examples:
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
"""

def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
group=groups, has_bn=True, activation='relu')
self.features = conv

def construct(self, x):
output = self.features(x)
return output

weight_shape = (out_channel, in_channel, kernel_size, kernel_size)
weight = _weight_variable(weight_shape)
padding = (kernel_size - 1) // 2
return nn.Conv2dBnAct(in_channel, out_channel, kernel_size, stride, weight_init=weight,
pad_mode='pad', padding=padding, has_bn=True, activation='relu')

class ResidualBlock(nn.Cell):
"""
@@ -73,8 +58,7 @@ class ResidualBlock(nn.Cell):
channel = out_channel // self.expansion
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
self.conv3 = nn.Conv2dBnAct(channel, out_channel, kernel_size=1, stride=1, pad_mode='same', padding=0,
has_bn=True, activation='relu')
self.conv3 = nn.Conv2dBnAct(channel, out_channel, kernel_size=1, stride=1, pad_mode='same', has_bn=True)

self.down_sample = False
if stride != 1 or in_channel != out_channel:
@@ -82,9 +66,7 @@ class ResidualBlock(nn.Cell):
self.down_sample_layer = None

if self.down_sample:
self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel,
kernel_size=1, stride=stride,
pad_mode='same', padding=0, has_bn=True, activation='relu')
self.down_sample_layer = nn.Conv2dBnAct(in_channel, out_channel, 1, stride, has_bn=True)
self.add = P.Add()
self.relu = P.ReLU()

@@ -164,7 +146,7 @@ class ResNet(nn.Cell):

self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = nn.DenseBnAct(out_channels[3], num_classes, has_bias=True, has_bn=False)
self.end_point = nn.DenseBnAct(out_channels[3], num_classes, has_bn=False)

def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
"""
@@ -211,7 +193,7 @@ class ResNet(nn.Cell):

def resnet50_quant(class_num=10):
"""
Get ResNet50 neural network.
Get ResNet50_quant neural network.

Args:
class_num (int): Class number.
@@ -232,7 +214,7 @@ def resnet50_quant(class_num=10):

def resnet101_quant(class_num=1001):
"""
Get ResNet101 neural network.
Get ResNet101_quant neural network.

Args:
class_num (int): Class number.
@@ -241,7 +223,7 @@ def resnet101_quant(class_num=1001):
Cell, cell instance of ResNet101 neural network.

Examples:
>>> net = resnet101(1001)
>>> net = resnet101_quant(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 23, 3],


+ 3
- 1
model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py View File

@@ -156,7 +156,7 @@ class ResidualBlock(nn.Cell):
pad_mode='same',
padding=0)
self.add = nn.TensorAddQuant()
self.relu = P.ReLU()
self.relu = nn.ActQuant(nn.ReLU())

def construct(self, x):
identity = x
@@ -233,6 +233,7 @@ class ResNet(nn.Cell):
stride=strides[3])

self.mean = P.ReduceMean(keep_dims=True)
self.reduce_fake = nn.FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay)
self.flatten = nn.Flatten()
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, quant_config=_quant_config)
self.output_fake = nn.FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay)
@@ -275,6 +276,7 @@ class ResNet(nn.Cell):
c5 = self.layer4(c4)

out = self.mean(c5, (2, 3))
out = self.reduce_fake(out)
out = self.flatten(out)
out = self.end_point(out)
out = self.output_fake(out)


Loading…
Cancel
Save