Browse Source

fix(mge/quantization): set ``q_dict`` as an instance property

GitOrigin-RevId: 2f32008aad
tags/v1.3.0
Megvii Engine Team 5 years ago
parent
commit
02df634da2
7 changed files with 93 additions and 56 deletions
  1. +1
    -1
      imperative/python/megengine/quantization/observer.py
  2. +8
    -2
      imperative/python/megengine/tensor.py
  3. +9
    -5
      imperative/python/test/unit/core/test_serialization.py
  4. +16
    -0
      imperative/python/test/unit/functional/test_tensor.py
  5. +1
    -0
      imperative/python/test/unit/module/test_module_tensor.py
  6. +47
    -36
      imperative/python/test/unit/quantization/test_module.py
  7. +11
    -12
      imperative/python/test/unit/quantization/test_quantize.py

+ 1
- 1
imperative/python/megengine/quantization/observer.py View File

@@ -467,7 +467,7 @@ class PassiveObserver(Observer):
@scale.setter @scale.setter
def scale(self, value): def scale(self, value):
assert value > 0 assert value > 0
self.q_dict["scale"].set_value(value)
self.q_dict["scale"][...] = Tensor(value)


def get_qparams(self): def get_qparams(self):
return self.q_dict return self.q_dict


+ 8
- 2
imperative/python/megengine/tensor.py View File

@@ -25,7 +25,7 @@ from .utils.deprecation import deprecated
class Tensor(_Tensor, ArrayMethodMixin): class Tensor(_Tensor, ArrayMethodMixin):
grad = None grad = None
dmap_callback = None dmap_callback = None
q_dict = {"mode": None, "scale": None, "zero_point": None}
_q_dict = None


def __new__(cls, data, dtype=None, device=None, is_const=False, no_cache=False): def __new__(cls, data, dtype=None, device=None, is_const=False, no_cache=False):
if device is None: if device is None:
@@ -70,6 +70,12 @@ class Tensor(_Tensor, ArrayMethodMixin):
def dtype(self) -> np.dtype: def dtype(self) -> np.dtype:
return super().dtype return super().dtype


@property
def q_dict(self):
if self._q_dict is None:
self._q_dict = {"mode": None, "scale": None, "zero_point": None}
return self._q_dict

def numpy(self) -> np.ndarray: def numpy(self) -> np.ndarray:
return super().numpy() return super().numpy()


@@ -135,7 +141,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
return state return state


def __setstate__(self, state): def __setstate__(self, state):
self.q_dict = state.pop("qdict")
self._q_dict = state.pop("qdict")




tensor = Tensor tensor = Tensor


+ 9
- 5
imperative/python/test/unit/core/test_serialization.py View File

@@ -16,11 +16,6 @@ from megengine import Parameter, Tensor




def test_tensor_serialization(): def test_tensor_serialization():
def tensor_eq(a, b):
assert a.dtype == b.dtype
assert a.device == b.device
np.testing.assert_equal(a.numpy(), b.numpy())

with TemporaryFile() as f: with TemporaryFile() as f:
data = np.random.randint(low=0, high=7, size=[233]) data = np.random.randint(low=0, high=7, size=[233])
a = Tensor(data, device="xpux", dtype=np.int32) a = Tensor(data, device="xpux", dtype=np.int32)
@@ -67,3 +62,12 @@ def test_tensor_serialization():
assert "cpu0" in str(b.device) assert "cpu0" in str(b.device)
np.testing.assert_equal(a.numpy(), b.numpy()) np.testing.assert_equal(a.numpy(), b.numpy())
mge.set_default_device(device_org) mge.set_default_device(device_org)

with TemporaryFile() as f:
a = Tensor(0)
a.q_dict["scale"] = Tensor(1.0)
pickle.dump(a, f)
f.seek(0)
b = pickle.load(f)
assert isinstance(b.q_dict["scale"], Tensor)
np.testing.assert_equal(b.q_dict["scale"].numpy(), 1.0)

+ 16
- 0
imperative/python/test/unit/functional/test_tensor.py View File

@@ -379,3 +379,19 @@ def test_copy_d2h():
def test_copy_d2d(): def test_copy_d2d():
copy_test("gpu0", "gpu1") copy_test("gpu0", "gpu1")
copy_test("gpu0:0", "gpu0:1") copy_test("gpu0:0", "gpu0:1")


def test_q_dict():
x = tensor(1)
assert x.q_dict["scale"] is None
x.q_dict["scale"] = tensor(1.0)

y = tensor(1)
assert y.q_dict["scale"] is None
y.q_dict["scale"] = tensor(2.0)

assert x.q_dict["scale"].numpy() == 1.0
assert y.q_dict["scale"].numpy() == 2.0

z = x + y
assert z.q_dict["scale"] is None

+ 1
- 0
imperative/python/test/unit/module/test_module_tensor.py View File

@@ -17,6 +17,7 @@ from megengine import Parameter, Tensor
from megengine.module import Conv2d from megengine.module import Conv2d




# TODO: delete this test after deleting set_value
def test_set_value(): def test_set_value():
v0 = np.random.random((2, 3)).astype(np.float32) v0 = np.random.random((2, 3)).astype(np.float32)
param = Parameter(v0) param = Parameter(v0)


+ 47
- 36
imperative/python/test/unit/quantization/test_module.py View File

@@ -1,3 +1,5 @@
from functools import partial

import numpy as np import numpy as np
import pytest import pytest


@@ -6,17 +8,21 @@ import megengine.functional as F
import megengine.module as Float import megengine.module as Float
import megengine.module.qat as QAT import megengine.module.qat as QAT
import megengine.module.quantized as Q import megengine.module.quantized as Q
from megengine import Parameter, Tensor
from megengine.core.tensor import dtype from megengine.core.tensor import dtype
from megengine.quantization import min_max_fakequant_qconfig
from megengine.quantization import FakeQuantize, MinMaxObserver, QConfig
from megengine.quantization.quantize import ( from megengine.quantization.quantize import (
disable_fake_quant, disable_fake_quant,
disable_observer, disable_observer,
propagate_qconfig, propagate_qconfig,
) )


"""
Calculate testing scales based on ``min_max_fakequant_qconfig``
"""
min_max_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
)


inp_scale = np.float32(np.random.rand() + 1) inp_scale = np.float32(np.random.rand() + 1)


@@ -31,21 +37,26 @@ def quant(x, scale):
return x.astype(inp_dtype) return x.astype(inp_dtype)




def fake_quant(x, scale):
def fake_quant(x, scale, qmin, qmax):
x = x / scale x = x / scale
x = F.round(x) x = F.round(x)
x = F.clip(x, -128, 127)
x = F.clip(x, qmin, qmax)
x = x * scale x = x * scale
return x return x




fake_quant_act = partial(fake_quant, qmin=-128, qmax=127)
fake_quant_weight = partial(fake_quant, qmin=-127, qmax=127)
fake_quant_bias = partial(fake_quant, qmin=-(2 ** 31), qmax=2 ** 31 - 1)


def init_qat_net(net): def init_qat_net(net):
if net.with_weight: if net.with_weight:
net.weight_observer.min_val.set_value(min_val[0])
net.weight_observer.max_val.set_value(max_val[0])
net.weight_observer.min_val[...] = Tensor(min_val[0])
net.weight_observer.max_val[...] = Tensor(max_val[0])
if net.with_act: if net.with_act:
net.act_observer.min_val.set_value(min_val[1])
net.act_observer.max_val.set_value(max_val[1])
net.act_observer.min_val[...] = Tensor(min_val[1])
net.act_observer.max_val[...] = Tensor(max_val[1])




def test_quant_stub(): def test_quant_stub():
@@ -71,7 +82,7 @@ def test_quant_stub():


normal = normal_net(x) normal = normal_net(x)
qat_without_fakequant = qat_from_float(x) qat_without_fakequant = qat_from_float(x)
fake_quant_normal = fake_quant(normal_net(x), act_scale)
fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
qat = qat_net(x) qat = qat_net(x)
q = q_net(x).numpy() * act_scale q = q_net(x).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal) np.testing.assert_allclose(qat_without_fakequant, normal)
@@ -99,7 +110,7 @@ def test_dequant_stub():
q_net.eval() q_net.eval()


x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x = fake_quant(x, inp_scale)
x = fake_quant_act(x, inp_scale)
x.q_dict["scale"] = inp_scale x.q_dict["scale"] = inp_scale


normal = normal_net(x) normal = normal_net(x)
@@ -134,12 +145,12 @@ def test_elemwise(kind):


x1_scale = np.float32(np.random.rand() + 1) x1_scale = np.float32(np.random.rand() + 1)
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x1 = fake_quant(x1, x1_scale)
x1 = fake_quant_act(x1, x1_scale)
x1.q_dict["scale"] = x1_scale x1.q_dict["scale"] = x1_scale


x2_scale = np.float32(np.random.rand() + 1) x2_scale = np.float32(np.random.rand() + 1)
x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x2 = fake_quant(x2, x2_scale)
x2 = fake_quant_act(x2, x2_scale)
x2.q_dict["scale"] = x2_scale x2.q_dict["scale"] = x2_scale


x1_int8 = quant(x1, x1_scale) x1_int8 = quant(x1, x1_scale)
@@ -149,13 +160,13 @@ def test_elemwise(kind):
if kind in ("ADD", "MUL", "FUSE_ADD_RELU"): if kind in ("ADD", "MUL", "FUSE_ADD_RELU"):
normal = normal_net(x1, x2) normal = normal_net(x1, x2)
qat_without_fakequant = qat_from_float(x1, x2) qat_without_fakequant = qat_from_float(x1, x2)
fake_quant_normal = fake_quant(normal_net(x1, x2), act_scale)
fake_quant_normal = fake_quant_act(normal_net(x1, x2), act_scale)
qat = qat_net(x1, x2) qat = qat_net(x1, x2)
q = q_net(x1_int8, x2_int8).numpy() * act_scale q = q_net(x1_int8, x2_int8).numpy() * act_scale
else: else:
normal = normal_net(x1) normal = normal_net(x1)
qat_without_fakequant = qat_from_float(x1) qat_without_fakequant = qat_from_float(x1)
fake_quant_normal = fake_quant(normal_net(x1), act_scale)
fake_quant_normal = fake_quant_act(normal_net(x1), act_scale)
qat = qat_net(x1) qat = qat_net(x1)
q = q_net(x1_int8).numpy() * act_scale q = q_net(x1_int8).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal) np.testing.assert_allclose(qat_without_fakequant, normal)
@@ -175,17 +186,17 @@ def test_linear():
init_qat_net(qat_net) init_qat_net(qat_net)


x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x = fake_quant(x, inp_scale)
x = fake_quant_act(x, inp_scale)
x.q_dict["scale"] = inp_scale x.q_dict["scale"] = inp_scale


x_int8 = quant(x, inp_scale) x_int8 = quant(x, inp_scale)


weight = np.random.normal(size=(3, 3)).astype("float32") weight = np.random.normal(size=(3, 3)).astype("float32")
bias = np.random.normal(size=(3,)).astype("float32") bias = np.random.normal(size=(3,)).astype("float32")
normal_net.weight.set_value(fake_quant(weight, weight_scale))
normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale))
qat_net.weight.set_value(weight)
qat_net.bias.set_value(bias)
normal_net.weight[...] = fake_quant_weight(weight, weight_scale)
normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
qat_net.weight[...] = Parameter(weight)
qat_net.bias[...] = Parameter(bias)


qat_from_float = QAT.Linear.from_float_module(normal_net) qat_from_float = QAT.Linear.from_float_module(normal_net)
qat_from_float.eval() qat_from_float.eval()
@@ -197,11 +208,11 @@ def test_linear():


normal = normal_net(x) normal = normal_net(x)
qat_without_fakequant = qat_from_float(x) qat_without_fakequant = qat_from_float(x)
fake_quant_normal = fake_quant(normal_net(x), act_scale)
fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
qat = qat_net(x) qat = qat_net(x)
q = q_net(x_int8).numpy() * act_scale q = q_net(x_int8).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal) np.testing.assert_allclose(qat_without_fakequant, normal)
np.testing.assert_allclose(qat, fake_quant_normal)
np.testing.assert_allclose(qat, fake_quant_normal.numpy())
np.testing.assert_allclose(q, fake_quant_normal.numpy()) np.testing.assert_allclose(q, fake_quant_normal.numpy())




@@ -218,7 +229,7 @@ def test_conv(module):
init_qat_net(qat_net) init_qat_net(qat_net)


x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32"))
x = fake_quant(x, inp_scale)
x = fake_quant_act(x, inp_scale)
x.q_dict["scale"] = inp_scale x.q_dict["scale"] = inp_scale


x_int8 = quant(x, inp_scale) x_int8 = quant(x, inp_scale)
@@ -226,15 +237,15 @@ def test_conv(module):
weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32") weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32")
bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32") bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32")
if module in ("ConvBn2d", "ConvBnRelu2d"): if module in ("ConvBn2d", "ConvBnRelu2d"):
normal_net.conv.weight.set_value(fake_quant(weight, weight_scale))
normal_net.conv.bias.set_value(fake_quant(bias, inp_scale * weight_scale))
qat_net.conv.weight.set_value(weight)
qat_net.conv.bias.set_value(bias)
normal_net.conv.weight[...] = fake_quant_weight(weight, weight_scale)
normal_net.conv.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
qat_net.conv.weight[...] = Parameter(weight)
qat_net.conv.bias[...] = Parameter(bias)
else: else:
normal_net.weight.set_value(fake_quant(weight, weight_scale))
normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale))
qat_net.weight.set_value(weight)
qat_net.bias.set_value(bias)
normal_net.weight[...] = fake_quant_weight(weight, weight_scale)
normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale)
qat_net.weight[...] = Parameter(weight)
qat_net.bias[...] = Parameter(bias)


qat_from_float = getattr(QAT, module).from_float_module(normal_net) qat_from_float = getattr(QAT, module).from_float_module(normal_net)
qat_from_float.eval() qat_from_float.eval()
@@ -246,9 +257,9 @@ def test_conv(module):


normal = normal_net(x) normal = normal_net(x)
qat_without_fakequant = qat_from_float(x) qat_without_fakequant = qat_from_float(x)
fake_quant_normal = fake_quant(normal_net(x), act_scale)
fake_quant_normal = fake_quant_act(normal_net(x), act_scale)
qat = qat_net(x) qat = qat_net(x)
q = q_net(x_int8).numpy() * act_scale q = q_net(x_int8).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-6)
np.testing.assert_allclose(qat, fake_quant_normal)
np.testing.assert_allclose(q, fake_quant_normal.numpy())
np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5)
np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale)
np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale)

+ 11
- 12
imperative/python/test/unit/quantization/test_quantize.py View File

@@ -8,9 +8,8 @@
import numpy as np import numpy as np
import pytest import pytest


from megengine import functional
from megengine import Parameter, Tensor
from megengine import module as Float from megengine import module as Float
from megengine import tensor
from megengine.module import qat as QAT from megengine.module import qat as QAT
from megengine.module import quantized as Q from megengine.module import quantized as Q
from megengine.quantization import ( from megengine.quantization import (
@@ -40,7 +39,7 @@ class Net(Float.Module):
self.quant = Float.QuantStub() self.quant = Float.QuantStub()
self.linear = Float.Linear(3, 3) self.linear = Float.Linear(3, 3)
self.dequant = Float.DequantStub() self.dequant = Float.DequantStub()
self.linear.bias.set_value(np.random.rand(3))
self.linear.bias[...] = Parameter(np.random.rand(3))


def forward(self, x): def forward(self, x):
x = self.quant(x) x = self.quant(x)
@@ -55,7 +54,7 @@ class QATNet(Float.Module):
self.quant = QAT.QuantStub() self.quant = QAT.QuantStub()
self.linear = QAT.Linear(3, 3) self.linear = QAT.Linear(3, 3)
self.dequant = QAT.DequantStub() self.dequant = QAT.DequantStub()
self.linear.bias.set_value(np.random.rand(3))
self.linear.bias[...] = Parameter(np.random.rand(3))


def forward(self, x): def forward(self, x):
x = self.quant(x) x = self.quant(x)
@@ -90,12 +89,12 @@ def init_qat_net():
propagate_qconfig(net, min_max_fakequant_qconfig) propagate_qconfig(net, min_max_fakequant_qconfig)
min_val = np.random.randint(-127, 0, size=(3,)) min_val = np.random.randint(-127, 0, size=(3,))
max_val = np.random.randint(1, 127, size=(3,)) max_val = np.random.randint(1, 127, size=(3,))
net.quant.act_observer.min_val.set_value(min_val[0])
net.quant.act_observer.max_val.set_value(max_val[0])
net.linear.weight_observer.min_val.set_value(min_val[1])
net.linear.weight_observer.max_val.set_value(max_val[1])
net.linear.act_observer.min_val.set_value(min_val[2])
net.linear.act_observer.max_val.set_value(max_val[2])
net.quant.act_observer.min_val[...] = Parameter(min_val[0])
net.quant.act_observer.max_val[...] = Parameter(max_val[0])
net.linear.weight_observer.min_val[...] = Parameter(min_val[1])
net.linear.weight_observer.max_val[...] = Parameter(max_val[1])
net.linear.act_observer.min_val[...] = Parameter(min_val[2])
net.linear.act_observer.max_val[...] = Parameter(max_val[2])
return net return net




@@ -144,7 +143,7 @@ def init_observer(module, data):




def test_enable_and_disable_all(): def test_enable_and_disable_all():
x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
x = Tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
net = Net() net = Net()
y1 = net(x).numpy() y1 = net(x).numpy()
net = quantize_qat(net, min_max_fakequant_qconfig) net = quantize_qat(net, min_max_fakequant_qconfig)
@@ -180,7 +179,7 @@ def test_quantize():


def test_apply_easy_quant(): def test_apply_easy_quant():
qat_net = init_qat_net() qat_net = init_qat_net()
data = tensor(np.random.rand(2, 3, 3, 3), dtype=np.float32)
data = Tensor(np.random.rand(2, 3, 3, 3), dtype=np.float32)
eq_net = reset_qconfig(qat_net, passive_qconfig, inplace=False) eq_net = reset_qconfig(qat_net, passive_qconfig, inplace=False)
apply_easy_quant(eq_net, data, 0.9, 1.1, 10) apply_easy_quant(eq_net, data, 0.9, 1.1, 10)
assert isinstance(eq_net.quant.act_observer, PassiveObserver) assert isinstance(eq_net.quant.act_observer, PassiveObserver)


Loading…
Cancel
Save