GitOrigin-RevId: 0c568d3335
tags/v1.2.0
| @@ -57,7 +57,7 @@ def _is_module(obj): | |||||
| def _get_XNorm_typeclass(): | def _get_XNorm_typeclass(): | ||||
| from .batchnorm import _BatchNorm | from .batchnorm import _BatchNorm | ||||
| from .normalization import GroupNorm, LayerNorm, InstanceNorm | |||||
| from .normalization import GroupNorm, InstanceNorm, LayerNorm | |||||
| XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm) | XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm) | ||||
| return XNorm_types | return XNorm_types | ||||
| @@ -19,7 +19,10 @@ class Conv2d(Float.Conv2d, QATModule): | |||||
| def calc_conv_qat(self, inp): | def calc_conv_qat(self, inp): | ||||
| w_qat = self.apply_quant_weight(self.weight) | w_qat = self.apply_quant_weight(self.weight) | ||||
| b_qat = fake_quant_bias(self.bias, inp, w_qat) | |||||
| if self.weight_fake_quant and self.weight_fake_quant.enabled: | |||||
| b_qat = fake_quant_bias(self.bias, inp, w_qat) | |||||
| else: | |||||
| b_qat = self.bias | |||||
| conv = self.calc_conv(inp, w_qat, b_qat) | conv = self.calc_conv(inp, w_qat, b_qat) | ||||
| return conv | return conv | ||||
| @@ -122,7 +122,10 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||||
| b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | ||||
| w_qat = self.apply_quant_weight(w_fold) | w_qat = self.apply_quant_weight(w_fold) | ||||
| b_qat = fake_quant_bias(b_fold, inp, w_qat) | |||||
| if self.weight_fake_quant and self.weight_fake_quant.enabled: | |||||
| b_qat = fake_quant_bias(b_fold, inp, w_qat) | |||||
| else: | |||||
| b_qat = b_fold | |||||
| conv = self.conv.calc_conv(inp, w_qat, b_qat) | conv = self.conv.calc_conv(inp, w_qat, b_qat) | ||||
| if not (self.training and approx): | if not (self.training and approx): | ||||
| return conv | return conv | ||||
| @@ -24,7 +24,10 @@ class Linear(Float.Linear, QATModule): | |||||
| def forward(self, x): | def forward(self, x): | ||||
| w_qat = self.apply_quant_weight(self.weight) | w_qat = self.apply_quant_weight(self.weight) | ||||
| b_qat = fake_quant_bias(self.bias, x, w_qat) | |||||
| if self.weight_fake_quant and self.weight_fake_quant.enabled: | |||||
| b_qat = fake_quant_bias(self.bias, x, w_qat) | |||||
| else: | |||||
| b_qat = self.bias | |||||
| return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat)) | return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat)) | ||||
| @classmethod | @classmethod | ||||
| @@ -116,7 +116,7 @@ class TQT(_FakeQuantize): | |||||
| def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | ||||
| super().__init__(dtype, narrow_range, enable) | super().__init__(dtype, narrow_range, enable) | ||||
| self.scale = Parameter(0.0, dtype=np.float32) | |||||
| self.scale = Parameter([0.0], dtype=np.float32) | |||||
| def fake_quant_forward(self, inp, q_dict=None): | def fake_quant_forward(self, inp, q_dict=None): | ||||
| # when enable, TQT will do fakequant forward, finetune the scale | # when enable, TQT will do fakequant forward, finetune the scale | ||||
| @@ -219,8 +219,8 @@ class HistogramObserver(MinMaxObserver): | |||||
| By selecting new min/max, we filter out outliers in input distribution. | By selecting new min/max, we filter out outliers in input distribution. | ||||
| """ | """ | ||||
| np_min_val = self.min_val.numpy()[0] | |||||
| np_max_val = self.max_val.numpy()[0] | |||||
| np_min_val = self.min_val.numpy() | |||||
| np_max_val = self.max_val.numpy() | |||||
| np_histogram = self.histogram.numpy() | np_histogram = self.histogram.numpy() | ||||
| assert len(np_histogram) == self.bins, "bins mistmatch" | assert len(np_histogram) == self.bins, "bins mistmatch" | ||||
| bin_width = (np_max_val - np_min_val) / self.bins | bin_width = (np_max_val - np_min_val) / self.bins | ||||
| @@ -386,8 +386,8 @@ class HistogramObserver(MinMaxObserver): | |||||
| # This allows us to have a common grid of resolution s, where we can align | # This allows us to have a common grid of resolution s, where we can align | ||||
| # the input histogram | # the input histogram | ||||
| # start_idx maps min_val to the histogram bin index. | # start_idx maps min_val to the histogram bin index. | ||||
| np_min_val = self.min_val.numpy()[0] | |||||
| np_max_val = self.max_val.numpy()[0] | |||||
| np_min_val = self.min_val.numpy() | |||||
| np_max_val = self.max_val.numpy() | |||||
| hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate) | hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate) | ||||
| downsample_rate = int( | downsample_rate = int( | ||||
| @@ -404,8 +404,8 @@ class HistogramObserver(MinMaxObserver): | |||||
| def sideeffect_forward(self, x_orig): | def sideeffect_forward(self, x_orig): | ||||
| x = x_orig.numpy() | x = x_orig.numpy() | ||||
| min_val = self.min_val.numpy()[0] | |||||
| max_val = self.max_val.numpy()[0] | |||||
| min_val = self.min_val.numpy() | |||||
| max_val = self.max_val.numpy() | |||||
| histogram = self.histogram.numpy() | histogram = self.histogram.numpy() | ||||
| new_min = x.min() | new_min = x.min() | ||||
| new_max = x.max() | new_max = x.max() | ||||
| @@ -125,5 +125,6 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: | |||||
| qmax = _metadata_dict["qint32"].qmax | qmax = _metadata_dict["qint32"].qmax | ||||
| qmin = _metadata_dict["qint32"].qmin | qmin = _metadata_dict["qint32"].qmin | ||||
| b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict) | b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict) | ||||
| b_qat.q_dict.update(b_dict) | |||||
| return b_qat | return b_qat | ||||
| @@ -115,9 +115,10 @@ def _dump_compatible(entries: List[ProfileEntry], path: str): | |||||
| def _dump_graphviz(entries: List[ProfileEntry], path: str): | def _dump_graphviz(entries: List[ProfileEntry], path: str): | ||||
| import graphviz | |||||
| import json | import json | ||||
| import graphviz | |||||
| graph = graphviz.Digraph() | graph = graphviz.Digraph() | ||||
| graph.graph_attr["ordering"] = "out" | graph.graph_attr["ordering"] = "out" | ||||
| var_cache = {} | var_cache = {} | ||||
| @@ -14,8 +14,8 @@ from megengine.core.tensor.core import apply | |||||
| def elemwise(*args, mode): | def elemwise(*args, mode): | ||||
| from megengine.core.ops.builtin import Elemwise | |||||
| from megengine.core._imperative_rt.imperative import apply_op | from megengine.core._imperative_rt.imperative import apply_op | ||||
| from megengine.core.ops.builtin import Elemwise | |||||
| return apply_op(Elemwise(mode), args) | return apply_op(Elemwise(mode), args) | ||||
| @@ -61,8 +61,8 @@ def test_tensor_on_device(): | |||||
| def test_raw_tensor(): | def test_raw_tensor(): | ||||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | |||||
| from megengine.core.ops.builtin import Elemwise | from megengine.core.ops.builtin import Elemwise | ||||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | |||||
| x = np.random.rand(10).astype("float32") | x = np.random.rand(10).astype("float32") | ||||
| xx = as_raw_tensor(x) | xx = as_raw_tensor(x) | ||||
| @@ -5,9 +5,18 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import numpy as np | |||||
| import pytest | |||||
| from megengine import module as Float | from megengine import module as Float | ||||
| from megengine import tensor | |||||
| from megengine.module import qat as QAT | from megengine.module import qat as QAT | ||||
| from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat | |||||
| from megengine.quantization import min_max_fakequant_qconfig | |||||
| from megengine.quantization.quantize import ( | |||||
| _get_quantable_module_names, | |||||
| disable_fake_quant, | |||||
| quantize_qat, | |||||
| ) | |||||
| def test_get_quantable_module_names(): | def test_get_quantable_module_names(): | ||||
| @@ -78,3 +87,30 @@ def test_convert_with_custom_mapping(): | |||||
| net = Net() | net = Net() | ||||
| qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) | qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) | ||||
| assert isinstance(qat_net.example, QATExample) | assert isinstance(qat_net.example, QATExample) | ||||
| def test_disable_fake_quant(): | |||||
| class Net(Float.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.quant = Float.QuantStub() | |||||
| self.linear = Float.Linear(3, 3) | |||||
| self.dequant = Float.DequantStub() | |||||
| self.linear.bias.set_value(np.random.rand(3)) | |||||
| def forward(self, x): | |||||
| x = self.quant(x) | |||||
| x = self.linear(x) | |||||
| x = self.dequant(x) | |||||
| return x | |||||
| x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) | |||||
| net = Net() | |||||
| y1 = net(x).numpy() | |||||
| net = quantize_qat(net, min_max_fakequant_qconfig) | |||||
| y2 = net(x).numpy() | |||||
| disable_fake_quant(net) | |||||
| y3 = net(x).numpy() | |||||
| np.testing.assert_allclose(y1, y3) | |||||
| with pytest.raises(AssertionError): | |||||
| np.testing.assert_allclose(y2, y3) | |||||