GitOrigin-RevId: 0a94cb6b17
tags/v0.6.0
| @@ -235,6 +235,14 @@ class Tensor: | |||
| return self.__val.dtype | |||
| return self._symvar.dtype | |||
| def set_dtype(self, dtype: str = None): | |||
| r"""Set the data type of the tensor. | |||
| """ | |||
| if self.__val is not None: | |||
| self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) | |||
| elif self.__sym is not None: | |||
| self.__sym = self.__sym.astype(dtype) | |||
| @property | |||
| def _comp_node(self): | |||
| if self.__val is not None: | |||
| @@ -11,6 +11,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
| import numpy as np | |||
| from .._internal.dtype import is_quantize | |||
| from ..core import Buffer, Parameter, Tensor | |||
| from ..logger import get_logger | |||
| @@ -460,6 +461,10 @@ class Module(metaclass=ABCMeta): | |||
| ), "param `{}` shape mismatch, should be {}, get {}".format( | |||
| k, var.shape, to_be_load.shape | |||
| ) | |||
| # For quantized dtype, the initialized dtype | |||
| # scale/zero_points maybe invalid, use pretrained dtype instead. | |||
| if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||
| var.set_dtype(to_be_load.dtype) | |||
| var.set_value(to_be_load) | |||
| loaded.append(k) | |||
| @@ -10,6 +10,7 @@ import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine._internal as mgb | |||
| def test_wrong_dtype(): | |||
| @@ -26,3 +27,48 @@ def test_tensor_routine(): | |||
| mge.tensor([1]) | |||
| mge.tensor(1.5) | |||
| def test_tensor_set_dtype(): | |||
| def check_dtype_value(tensor, dtype_scale, value): | |||
| if mgb.dtype.is_quantize(tensor.dtype): | |||
| if np.abs(mgb.dtype.get_scale(tensor.dtype) - dtype_scale) > 1e-5: | |||
| raise AssertionError( | |||
| "compare scale failed expect {} got {}".format( | |||
| dtype_scale, mgb.dtype.get_scale(tensor.dtype) | |||
| ) | |||
| ) | |||
| if np.abs(tensor.numpy()[0][0] - value) > 1e-5: | |||
| raise AssertionError( | |||
| "compare value failed expect {} got {}".format( | |||
| tensor.numpy()[0][0], value | |||
| ) | |||
| ) | |||
| t = mge.Parameter(np.ones((3, 4), dtype="float32")) | |||
| t.set_dtype(mgb.dtype.qint8(0.1)) | |||
| check_dtype_value(t, 0.1, 10) | |||
| t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||
| t.set_dtype(mgb.dtype.qint8(0.3)) | |||
| check_dtype_value(t, 0.3, 3) | |||
| t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||
| t.set_dtype(mgb.dtype.qint8(0.1)) | |||
| check_dtype_value(t, 0.1, 10) | |||
| t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||
| t.set_dtype(mgb.dtype.qint8(0.3)) | |||
| check_dtype_value(t, 0.3, 3) | |||
| t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||
| s = t + 1 | |||
| s.set_dtype(mgb.dtype.qint8(0.2)) | |||
| check_dtype_value(s, 0.2, 10) | |||
| t.set_dtype(mgb.dtype.qint8(0.3)) | |||
| s = t + 1 | |||
| s.set_dtype(mgb.dtype.qint8(0.1)) | |||
| check_dtype_value(s, 0.1, 18) | |||
| s.set_dtype("float32") | |||
| check_dtype_value(s, 0, 1.8) | |||
| @@ -14,8 +14,10 @@ import pytest | |||
| from helpers import MLP | |||
| import megengine as mge | |||
| import megengine._internal as mgb | |||
| from megengine.core import Buffer, Parameter, Tensor, tensor | |||
| from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential | |||
| from megengine.quantization.quantize import quantize, quantize_qat | |||
| from megengine.test import assertTensorClose | |||
| @@ -347,3 +349,38 @@ def test_dump_model(): | |||
| pred = mlp(data) | |||
| with tempfile.NamedTemporaryFile() as f: | |||
| mge.dump(pred, f.name) | |||
| def test_load_quantized(): | |||
| data_shape = (2, 28) | |||
| data = tensor(np.random.random(data_shape), dtype="float32") | |||
| data = data.astype(mgb.dtype.qint8(0.1)) | |||
| mlp = MLP() | |||
| quantize_qat(mlp) | |||
| quantize(mlp) | |||
| mlp.dense0.weight = Parameter( | |||
| mlp.dense0.weight.astype(mgb.dtype.qint8(0.001)).numpy() | |||
| ) | |||
| mlp.dense1.weight = Parameter( | |||
| mlp.dense1.weight.astype(mgb.dtype.qint8(0.0002)).numpy() | |||
| ) | |||
| mlp.eval() | |||
| pred0 = mlp(data) | |||
| with BytesIO() as fout: | |||
| mge.save(mlp.state_dict(), fout) | |||
| fout.seek(0) | |||
| checkpoint = mge.load(fout) | |||
| # change mlp weight. | |||
| mlp.dense0.weight = Parameter( | |||
| mlp.dense0.weight.astype(mgb.dtype.qint8(0.00001)).numpy() | |||
| ) | |||
| mlp.dense1.weight = Parameter( | |||
| mlp.dense1.weight.astype(mgb.dtype.qint8(0.2)).numpy() | |||
| ) | |||
| mlp.load_state_dict(checkpoint) | |||
| pred1 = mlp(data) | |||
| assertTensorClose( | |||
| pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6 | |||
| ) | |||