GitOrigin-RevId: 9d09c8fa6f
tags/v1.1.0
| @@ -14,7 +14,7 @@ import numpy as np | |||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.special import Const | from ..ops.special import Const | ||||
| from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | ||||
| from .dtype import is_equal | |||||
| from .dtype import is_equal, is_quantize | |||||
| def dtype_promotion(inputs): | def dtype_promotion(inputs): | ||||
| @@ -122,7 +122,7 @@ def convert_single_value(v, inputs, *, dtype=None, device=None): | |||||
| tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))] | tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))] | ||||
| assert len(tensors) > 0 | assert len(tensors) > 0 | ||||
| if isinstance(v, (TensorWrapperBase, TensorBase)): | if isinstance(v, (TensorWrapperBase, TensorBase)): | ||||
| v = astype(v, dtype) | |||||
| v = astype(v, v.dtype if is_quantize(v.dtype) else dtype) | |||||
| else: | else: | ||||
| (v,) = Const(v, dtype=dtype, device=device)(*tensors) | (v,) = Const(v, dtype=dtype, device=device)(*tensors) | ||||
| return v | return v | ||||
| @@ -12,7 +12,6 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||||
| import numpy as np | import numpy as np | ||||
| from ..core.tensor.dtype import is_quantize | |||||
| from ..core.tensor.utils import make_shape_tuple | from ..core.tensor.utils import make_shape_tuple | ||||
| from ..logger import get_logger | from ..logger import get_logger | ||||
| from ..tensor import Parameter, Tensor | from ..tensor import Parameter, Tensor | ||||
| @@ -529,11 +528,7 @@ class Module(metaclass=ABCMeta): | |||||
| ), "param `{}` shape mismatch, should be {}, get {}".format( | ), "param `{}` shape mismatch, should be {}, get {}".format( | ||||
| k, var.shape, to_be_load.shape | k, var.shape, to_be_load.shape | ||||
| ) | ) | ||||
| # For quantized dtype, the initialized dtype | |||||
| # scale/zero_points maybe invalid, use pretrained dtype instead. | |||||
| if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||||
| var = var.astype(to_be_load.dtype) | |||||
| var._reset(to_be_load) | |||||
| var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device)) | |||||
| loaded.append(k) | loaded.append(k) | ||||
| return set(loaded), set(skipped) | return set(loaded), set(skipped) | ||||
| @@ -10,6 +10,7 @@ import numpy as np | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| from megengine import tensor | from megengine import tensor | ||||
| from megengine.core.tensor import dtype | |||||
| from megengine.functional.elemwise import _elwise | from megengine.functional.elemwise import _elwise | ||||
| @@ -150,3 +151,18 @@ def test_logical_oprs(): | |||||
| np.testing.assert_equal(x & y, F.logical_and(xx, yy).numpy()) | np.testing.assert_equal(x & y, F.logical_and(xx, yy).numpy()) | ||||
| np.testing.assert_equal(x | y, F.logical_or(xx, yy).numpy()) | np.testing.assert_equal(x | y, F.logical_or(xx, yy).numpy()) | ||||
| np.testing.assert_equal(x ^ y, F.logical_xor(xx, yy).numpy()) | np.testing.assert_equal(x ^ y, F.logical_xor(xx, yy).numpy()) | ||||
| def test_qadd(): | |||||
| inp_scale = 0.5 | |||||
| outp_scale = 0.2 | |||||
| x = np.arange(6).reshape(2, 3).astype("float32") | |||||
| y = np.arange(6).reshape(2, 3).astype("float32") | |||||
| x = tensor(x, dtype=dtype.qint8(inp_scale)) | |||||
| y = tensor(y, dtype=dtype.qint8(inp_scale)) | |||||
| result_mge = F.elemwise._elemwise_multi_type( | |||||
| x, y, mode="QADD", dtype=dtype.qint8(outp_scale) | |||||
| ) | |||||
| result_mge = result_mge.astype("float32").numpy() | |||||
| result_expect = x.astype("float32").numpy() + y.astype("float32").numpy() | |||||
| np.testing.assert_almost_equal(result_mge, result_expect, decimal=6) | |||||