Browse Source

fix(mge/quantization): `disable_fake_quant` does not work correctly

GitOrigin-RevId: 0c568d3335
tags/v1.2.0
Megvii Engine Team 5 years ago
parent
commit
6c4841e807
10 changed files with 62 additions and 15 deletions
  1. +1
    -1
      imperative/python/megengine/module/module.py
  2. +4
    -1
      imperative/python/megengine/module/qat/conv.py
  3. +4
    -1
      imperative/python/megengine/module/qat/conv_bn.py
  4. +4
    -1
      imperative/python/megengine/module/qat/linear.py
  5. +1
    -1
      imperative/python/megengine/quantization/fake_quant.py
  6. +6
    -6
      imperative/python/megengine/quantization/observer.py
  7. +1
    -0
      imperative/python/megengine/quantization/utils.py
  8. +2
    -1
      imperative/python/megengine/utils/profiler.py
  9. +2
    -2
      imperative/python/test/unit/core/test_imperative_rt.py
  10. +37
    -1
      imperative/python/test/unit/quantization/quantize.py

+ 1
- 1
imperative/python/megengine/module/module.py View File

@@ -57,7 +57,7 @@ def _is_module(obj):


def _get_XNorm_typeclass(): def _get_XNorm_typeclass():
from .batchnorm import _BatchNorm from .batchnorm import _BatchNorm
from .normalization import GroupNorm, LayerNorm, InstanceNorm
from .normalization import GroupNorm, InstanceNorm, LayerNorm


XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm) XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm)
return XNorm_types return XNorm_types


+ 4
- 1
imperative/python/megengine/module/qat/conv.py View File

@@ -19,7 +19,10 @@ class Conv2d(Float.Conv2d, QATModule):


def calc_conv_qat(self, inp): def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, inp, w_qat)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, inp, w_qat)
else:
b_qat = self.bias
conv = self.calc_conv(inp, w_qat, b_qat) conv = self.calc_conv(inp, w_qat, b_qat)
return conv return conv




+ 4
- 1
imperative/python/megengine/module/qat/conv_bn.py View File

@@ -122,7 +122,10 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd


w_qat = self.apply_quant_weight(w_fold) w_qat = self.apply_quant_weight(w_fold)
b_qat = fake_quant_bias(b_fold, inp, w_qat)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(b_fold, inp, w_qat)
else:
b_qat = b_fold
conv = self.conv.calc_conv(inp, w_qat, b_qat) conv = self.conv.calc_conv(inp, w_qat, b_qat)
if not (self.training and approx): if not (self.training and approx):
return conv return conv


+ 4
- 1
imperative/python/megengine/module/qat/linear.py View File

@@ -24,7 +24,10 @@ class Linear(Float.Linear, QATModule):


def forward(self, x): def forward(self, x):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
b_qat = fake_quant_bias(self.bias, x, w_qat)
if self.weight_fake_quant and self.weight_fake_quant.enabled:
b_qat = fake_quant_bias(self.bias, x, w_qat)
else:
b_qat = self.bias
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat)) return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat))


@classmethod @classmethod


+ 1
- 1
imperative/python/megengine/quantization/fake_quant.py View File

@@ -116,7 +116,7 @@ class TQT(_FakeQuantize):


def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
super().__init__(dtype, narrow_range, enable) super().__init__(dtype, narrow_range, enable)
self.scale = Parameter(0.0, dtype=np.float32)
self.scale = Parameter([0.0], dtype=np.float32)


def fake_quant_forward(self, inp, q_dict=None): def fake_quant_forward(self, inp, q_dict=None):
# when enable, TQT will do fakequant forward, finetune the scale # when enable, TQT will do fakequant forward, finetune the scale


+ 6
- 6
imperative/python/megengine/quantization/observer.py View File

@@ -219,8 +219,8 @@ class HistogramObserver(MinMaxObserver):
By selecting new min/max, we filter out outliers in input distribution. By selecting new min/max, we filter out outliers in input distribution.
""" """


np_min_val = self.min_val.numpy()[0]
np_max_val = self.max_val.numpy()[0]
np_min_val = self.min_val.numpy()
np_max_val = self.max_val.numpy()
np_histogram = self.histogram.numpy() np_histogram = self.histogram.numpy()
assert len(np_histogram) == self.bins, "bins mistmatch" assert len(np_histogram) == self.bins, "bins mistmatch"
bin_width = (np_max_val - np_min_val) / self.bins bin_width = (np_max_val - np_min_val) / self.bins
@@ -386,8 +386,8 @@ class HistogramObserver(MinMaxObserver):
# This allows us to have a common grid of resolution s, where we can align # This allows us to have a common grid of resolution s, where we can align
# the input histogram # the input histogram
# start_idx maps min_val to the histogram bin index. # start_idx maps min_val to the histogram bin index.
np_min_val = self.min_val.numpy()[0]
np_max_val = self.max_val.numpy()[0]
np_min_val = self.min_val.numpy()
np_max_val = self.max_val.numpy()


hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate) hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
downsample_rate = int( downsample_rate = int(
@@ -404,8 +404,8 @@ class HistogramObserver(MinMaxObserver):


def sideeffect_forward(self, x_orig): def sideeffect_forward(self, x_orig):
x = x_orig.numpy() x = x_orig.numpy()
min_val = self.min_val.numpy()[0]
max_val = self.max_val.numpy()[0]
min_val = self.min_val.numpy()
max_val = self.max_val.numpy()
histogram = self.histogram.numpy() histogram = self.histogram.numpy()
new_min = x.min() new_min = x.min()
new_max = x.max() new_max = x.max()


+ 1
- 0
imperative/python/megengine/quantization/utils.py View File

@@ -125,5 +125,6 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
qmax = _metadata_dict["qint32"].qmax qmax = _metadata_dict["qint32"].qmax
qmin = _metadata_dict["qint32"].qmin qmin = _metadata_dict["qint32"].qmin
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict) b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict)
b_qat.q_dict.update(b_dict)


return b_qat return b_qat

+ 2
- 1
imperative/python/megengine/utils/profiler.py View File

@@ -115,9 +115,10 @@ def _dump_compatible(entries: List[ProfileEntry], path: str):




def _dump_graphviz(entries: List[ProfileEntry], path: str): def _dump_graphviz(entries: List[ProfileEntry], path: str):
import graphviz
import json import json


import graphviz

graph = graphviz.Digraph() graph = graphviz.Digraph()
graph.graph_attr["ordering"] = "out" graph.graph_attr["ordering"] = "out"
var_cache = {} var_cache = {}


+ 2
- 2
imperative/python/test/unit/core/test_imperative_rt.py View File

@@ -14,8 +14,8 @@ from megengine.core.tensor.core import apply




def elemwise(*args, mode): def elemwise(*args, mode):
from megengine.core.ops.builtin import Elemwise
from megengine.core._imperative_rt.imperative import apply_op from megengine.core._imperative_rt.imperative import apply_op
from megengine.core.ops.builtin import Elemwise


return apply_op(Elemwise(mode), args) return apply_op(Elemwise(mode), args)


@@ -61,8 +61,8 @@ def test_tensor_on_device():




def test_raw_tensor(): def test_raw_tensor():
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.raw_tensor import as_raw_tensor


x = np.random.rand(10).astype("float32") x = np.random.rand(10).astype("float32")
xx = as_raw_tensor(x) xx = as_raw_tensor(x)


+ 37
- 1
imperative/python/test/unit/quantization/quantize.py View File

@@ -5,9 +5,18 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest

from megengine import module as Float from megengine import module as Float
from megengine import tensor
from megengine.module import qat as QAT from megengine.module import qat as QAT
from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat
from megengine.quantization import min_max_fakequant_qconfig
from megengine.quantization.quantize import (
_get_quantable_module_names,
disable_fake_quant,
quantize_qat,
)




def test_get_quantable_module_names(): def test_get_quantable_module_names():
@@ -78,3 +87,30 @@ def test_convert_with_custom_mapping():
net = Net() net = Net()
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample})
assert isinstance(qat_net.example, QATExample) assert isinstance(qat_net.example, QATExample)


def test_disable_fake_quant():
class Net(Float.Module):
def __init__(self):
super().__init__()
self.quant = Float.QuantStub()
self.linear = Float.Linear(3, 3)
self.dequant = Float.DequantStub()
self.linear.bias.set_value(np.random.rand(3))

def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x

x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
net = Net()
y1 = net(x).numpy()
net = quantize_qat(net, min_max_fakequant_qconfig)
y2 = net(x).numpy()
disable_fake_quant(net)
y3 = net(x).numpy()
np.testing.assert_allclose(y1, y3)
with pytest.raises(AssertionError):
np.testing.assert_allclose(y2, y3)

Loading…
Cancel
Save