| @@ -49,6 +49,8 @@ class QATModule(Module): | |||||
| def _apply_fakequant_with_observer( | def _apply_fakequant_with_observer( | ||||
| self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | ||||
| ): | ): | ||||
| if observer is None: | |||||
| return target | |||||
| oup = observer(target) | oup = observer(target) | ||||
| if fake_quant is None: | if fake_quant is None: | ||||
| return oup | return oup | ||||
| @@ -76,7 +78,7 @@ class QATModule(Module): | |||||
| r""" | r""" | ||||
| Get weight's quantization dtype as the method from ``qconfig``. | Get weight's quantization dtype as the method from ``qconfig``. | ||||
| """ | """ | ||||
| if hasattr(self.act_fake_quant, "get_dtype"): | |||||
| if hasattr(self.weight_fake_quant, "get_dtype"): | |||||
| return self.weight_fake_quant.get_dtype() | return self.weight_fake_quant.get_dtype() | ||||
| else: | else: | ||||
| return self.weight_observer.get_dtype() | return self.weight_observer.get_dtype() | ||||
| @@ -5,7 +5,9 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from .fake_quant import FakeQuantize | from .fake_quant import FakeQuantize | ||||
| from .internal_fake_quant import * | |||||
| from .observer import HistogramObserver, Observer, ObserverMode | from .observer import HistogramObserver, Observer, ObserverMode | ||||
| from .qconfig import ( | from .qconfig import ( | ||||
| QConfig, | QConfig, | ||||
| @@ -19,6 +19,15 @@ from .observer import ObserverMode, Round | |||||
| class _FakeQuantize(Module): | class _FakeQuantize(Module): | ||||
| r""" | |||||
| A Basic Fake Quant module. | |||||
| :param dtype: A string indicating the target quantization type of input. | |||||
| :param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, | |||||
| instead of 1 greater. Usually True for weight and False for activation. | |||||
| :param enable: Whether do ``normal_forward`` or ``fake_quant_forward``. | |||||
| """ | |||||
| def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | ||||
| super().__init__() | super().__init__() | ||||
| if not dtype in _metadata_dict.keys(): | if not dtype in _metadata_dict.keys(): | ||||
| @@ -92,9 +101,9 @@ class TQT_Function(Function): | |||||
| class TQT(_FakeQuantize): | class TQT(_FakeQuantize): | ||||
| """ | |||||
| r""" | |||||
| TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | ||||
| for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks | |||||
| for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. | |||||
| """ | """ | ||||
| def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | ||||
| @@ -119,11 +128,6 @@ class TQT(_FakeQuantize): | |||||
| class FakeQuantize(_FakeQuantize): | class FakeQuantize(_FakeQuantize): | ||||
| r""" | r""" | ||||
| A module to do quant and dequant according to observer's scale and zero_point. | A module to do quant and dequant according to observer's scale and zero_point. | ||||
| :param dtype: A string indicating the target quantization type of input. | |||||
| :param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, | |||||
| instead of 1 greater. Usually True for weight and False for activation. | |||||
| :param enable: Whether do ``normal_forward`` or ``fake_quant_forward``. | |||||
| """ | """ | ||||
| def fake_quant_forward(self, inp, q_dict): | def fake_quant_forward(self, inp, q_dict): | ||||
| @@ -0,0 +1,19 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import copy | |||||
| import math | |||||
| from functools import partial | |||||
| import numpy as np | |||||
| from .. import functional as F | |||||
| from ..core import Function | |||||
| from .fake_quant import _FakeQuantize | |||||
| from .observer import MinMaxObserver | |||||
| from .qconfig import QConfig | |||||
| @@ -13,6 +13,7 @@ import megengine as mge | |||||
| import megengine._internal as mgb | import megengine._internal as mgb | ||||
| from megengine.core import tensor | from megengine.core import tensor | ||||
| from megengine.quantization.fake_quant import TQT_Function | from megengine.quantization.fake_quant import TQT_Function | ||||
| from megengine.quantization.internal_fake_quant import * | |||||
| from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
| @@ -75,3 +76,5 @@ def test_TQT(): | |||||
| a.set_value(a_np) | a.set_value(a_np) | ||||
| b.set_value(b_np) | b.set_value(b_np) | ||||
| check_inp(a, b, b, a_np, b_np, b_np) | check_inp(a, b, b, a_np, b_np, b_np) | ||||