|
|
|
@@ -5,9 +5,18 @@ |
|
|
|
# Unless required by applicable law or agreed to in writing, |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import numpy as np |
|
|
|
import pytest |
|
|
|
|
|
|
|
from megengine import module as Float |
|
|
|
from megengine import tensor |
|
|
|
from megengine.module import qat as QAT |
|
|
|
from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat |
|
|
|
from megengine.quantization import min_max_fakequant_qconfig |
|
|
|
from megengine.quantization.quantize import ( |
|
|
|
_get_quantable_module_names, |
|
|
|
disable_fake_quant, |
|
|
|
quantize_qat, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_get_quantable_module_names(): |
|
|
|
@@ -78,3 +87,30 @@ def test_convert_with_custom_mapping(): |
|
|
|
net = Net() |
|
|
|
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) |
|
|
|
assert isinstance(qat_net.example, QATExample) |
|
|
|
|
|
|
|
|
|
|
|
def test_disable_fake_quant(): |
|
|
|
class Net(Float.Module): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.quant = Float.QuantStub() |
|
|
|
self.linear = Float.Linear(3, 3) |
|
|
|
self.dequant = Float.DequantStub() |
|
|
|
self.linear.bias.set_value(np.random.rand(3)) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
x = self.quant(x) |
|
|
|
x = self.linear(x) |
|
|
|
x = self.dequant(x) |
|
|
|
return x |
|
|
|
|
|
|
|
x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) |
|
|
|
net = Net() |
|
|
|
y1 = net(x).numpy() |
|
|
|
net = quantize_qat(net, min_max_fakequant_qconfig) |
|
|
|
y2 = net(x).numpy() |
|
|
|
disable_fake_quant(net) |
|
|
|
y3 = net(x).numpy() |
|
|
|
np.testing.assert_allclose(y1, y3) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
np.testing.assert_allclose(y2, y3) |