Browse Source

fix(mge/quantization): `disable_fake_quant` does not work correctly

GitOrigin-RevId: 0c568d3335
tags/v1.2.0
Megvii Engine Team 5 years ago
parent
commit
6c4841e807
10 changed files with 62 additions and 15 deletions
  1. +1
    -1
      imperative/python/megengine/module/module.py
  2. +4
    -1
      imperative/python/megengine/module/qat/conv.py
  3. +4
    -1
      imperative/python/megengine/module/qat/conv_bn.py
  4. +4
    -1
      imperative/python/megengine/module/qat/linear.py
  5. +1
    -1
      imperative/python/megengine/quantization/fake_quant.py
  6. +6
    -6
      imperative/python/megengine/quantization/observer.py
  7. +1
    -0
      imperative/python/megengine/quantization/utils.py
  8. +2
    -1
      imperative/python/megengine/utils/profiler.py
  9. +2
    -2
      imperative/python/test/unit/core/test_imperative_rt.py
  10. +37
    -1
      imperative/python/test/unit/quantization/quantize.py

+ 1
- 1
imperative/python/megengine/module/module.py View File

@@ -57,7 +57,7 @@ def _is_module(obj):

def _get_XNorm_typeclass():
from .batchnorm import _BatchNorm
from .normalization import GroupNorm, LayerNorm, InstanceNorm
from .normalization import GroupNorm, InstanceNorm, LayerNorm

XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm)
return XNorm_types


+ 4
- 1
imperative/python/megengine/module/qat/conv.py View File

@@ -19,7 +19,10 @@ class Conv2d(Float.Conv2d, QATModule):

def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, inp, w_qat)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, inp, w_qat)
else:
b_qat = self.bias
conv = self.calc_conv(inp, w_qat, b_qat)
return conv



+ 4
- 1
imperative/python/megengine/module/qat/conv_bn.py View File

@@ -122,7 +122,10 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd

w_qat = self.apply_quant_weight(w_fold)
b_qat = fake_quant_bias(b_fold, inp, w_qat)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(b_fold, inp, w_qat)
else:
b_qat = b_fold
conv = self.conv.calc_conv(inp, w_qat, b_qat)
if not (self.training and approx):
return conv


+ 4
- 1
imperative/python/megengine/module/qat/linear.py View File

@@ -24,7 +24,10 @@ class Linear(Float.Linear, QATModule):

def forward(self, x):
w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, x, w_qat)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, x, w_qat)
else:
b_qat = self.bias
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat))

@classmethod


+ 1
- 1
imperative/python/megengine/quantization/fake_quant.py View File

@@ -116,7 +116,7 @@ class TQT(_FakeQuantize):

def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
super().__init__(dtype, narrow_range, enable)
self.scale = Parameter(0.0, dtype=np.float32)
self.scale = Parameter([0.0], dtype=np.float32)

def fake_quant_forward(self, inp, q_dict=None):
# when enable, TQT will do fakequant forward, finetune the scale


+ 6
- 6
imperative/python/megengine/quantization/observer.py View File

@@ -219,8 +219,8 @@ class HistogramObserver(MinMaxObserver):
By selecting new min/max, we filter out outliers in input distribution.
"""

np_min_val = self.min_val.numpy()[0]
np_max_val = self.max_val.numpy()[0]
np_min_val = self.min_val.numpy()
np_max_val = self.max_val.numpy()
np_histogram = self.histogram.numpy()
assert len(np_histogram) == self.bins, "bins mistmatch"
bin_width = (np_max_val - np_min_val) / self.bins
@@ -386,8 +386,8 @@ class HistogramObserver(MinMaxObserver):
# This allows us to have a common grid of resolution s, where we can align
# the input histogram
# start_idx maps min_val to the histogram bin index.
np_min_val = self.min_val.numpy()[0]
np_max_val = self.max_val.numpy()[0]
np_min_val = self.min_val.numpy()
np_max_val = self.max_val.numpy()

hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
downsample_rate = int(
@@ -404,8 +404,8 @@ class HistogramObserver(MinMaxObserver):

def sideeffect_forward(self, x_orig):
x = x_orig.numpy()
min_val = self.min_val.numpy()[0]
max_val = self.max_val.numpy()[0]
min_val = self.min_val.numpy()
max_val = self.max_val.numpy()
histogram = self.histogram.numpy()
new_min = x.min()
new_max = x.max()


+ 1
- 0
imperative/python/megengine/quantization/utils.py View File

@@ -125,5 +125,6 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
qmax = _metadata_dict["qint32"].qmax
qmin = _metadata_dict["qint32"].qmin
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict)
b_qat.q_dict.update(b_dict)

return b_qat

+ 2
- 1
imperative/python/megengine/utils/profiler.py View File

@@ -115,9 +115,10 @@ def _dump_compatible(entries: List[ProfileEntry], path: str):


def _dump_graphviz(entries: List[ProfileEntry], path: str):
import graphviz
import json

import graphviz

graph = graphviz.Digraph()
graph.graph_attr["ordering"] = "out"
var_cache = {}


+ 2
- 2
imperative/python/test/unit/core/test_imperative_rt.py View File

@@ -14,8 +14,8 @@ from megengine.core.tensor.core import apply


def elemwise(*args, mode):
from megengine.core.ops.builtin import Elemwise
from megengine.core._imperative_rt.imperative import apply_op
from megengine.core.ops.builtin import Elemwise

return apply_op(Elemwise(mode), args)

@@ -61,8 +61,8 @@ def test_tensor_on_device():


def test_raw_tensor():
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.raw_tensor import as_raw_tensor

x = np.random.rand(10).astype("float32")
xx = as_raw_tensor(x)


+ 37
- 1
imperative/python/test/unit/quantization/quantize.py View File

@@ -5,9 +5,18 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest

from megengine import module as Float
from megengine import tensor
from megengine.module import qat as QAT
from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat
from megengine.quantization import min_max_fakequant_qconfig
from megengine.quantization.quantize import (
_get_quantable_module_names,
disable_fake_quant,
quantize_qat,
)


def test_get_quantable_module_names():
@@ -78,3 +87,30 @@ def test_convert_with_custom_mapping():
net = Net()
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample})
assert isinstance(qat_net.example, QATExample)


def test_disable_fake_quant():
class Net(Float.Module):
def __init__(self):
super().__init__()
self.quant = Float.QuantStub()
self.linear = Float.Linear(3, 3)
self.dequant = Float.DequantStub()
self.linear.bias.set_value(np.random.rand(3))

def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x

x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
net = Net()
y1 = net(x).numpy()
net = quantize_qat(net, min_max_fakequant_qconfig)
y2 = net(x).numpy()
disable_fake_quant(net)
y3 = net(x).numpy()
np.testing.assert_allclose(y1, y3)
with pytest.raises(AssertionError):
np.testing.assert_allclose(y2, y3)

Loading…
Cancel
Save