GitOrigin-RevId: 383458acbf
tags/v1.1.0
| @@ -62,6 +62,21 @@ def get_zero_point(dtype): | |||||
| return metadata["zero_point"] | return metadata["zero_point"] | ||||
| def is_equal(dt0, dt1): | |||||
| def _get_zero_point(dtype): | |||||
| assert is_quantize(dtype) | |||||
| metadata = dtype.metadata["mgb_dtype"] | |||||
| return metadata.get("zero_point") | |||||
| if is_quantize(dt0) and is_quantize(dt1): | |||||
| return get_scale(dt0) == get_scale(dt1) and _get_zero_point( | |||||
| dt0 | |||||
| ) == _get_zero_point(dt1) | |||||
| if not (is_quantize(dt0) or is_quantize(dt1)): | |||||
| return dt0 == dt1 | |||||
| return False | |||||
| def _check_zero_point(zp: int, dtype_str: str): | def _check_zero_point(zp: int, dtype_str: str): | ||||
| qmin = _metadata_dict[dtype_str].qmin | qmin = _metadata_dict[dtype_str].qmin | ||||
| qmax = _metadata_dict[dtype_str].qmax | qmax = _metadata_dict[dtype_str].qmax | ||||
| @@ -14,6 +14,7 @@ import numpy as np | |||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.special import Const | from ..ops.special import Const | ||||
| from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | ||||
| from .dtype import is_equal | |||||
| def dtype_promotion(inputs): | def dtype_promotion(inputs): | ||||
| @@ -112,7 +113,7 @@ def concatenate(inputs, axis=0, *, device=None): | |||||
| def astype(x, dtype): | def astype(x, dtype): | ||||
| dtype = np.dtype(dtype) | dtype = np.dtype(dtype) | ||||
| if x.dtype != dtype: | |||||
| if not is_equal(x.dtype, dtype): | |||||
| (x,) = apply(builtin.TypeCvt(param=dtype), x) | (x,) = apply(builtin.TypeCvt(param=dtype), x) | ||||
| return x | return x | ||||
| @@ -8,6 +8,7 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import numpy as np | import numpy as np | ||||
| from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | |||||
| from megengine.core.tensor.tensor_wrapper import TensorWrapper | from megengine.core.tensor.tensor_wrapper import TensorWrapper | ||||
| @@ -71,3 +72,17 @@ def test_transpose(): | |||||
| x = np.random.rand(2, 5).astype("float32") | x = np.random.rand(2, 5).astype("float32") | ||||
| xx = TensorWrapper(x) | xx = TensorWrapper(x) | ||||
| np.testing.assert_almost_equal(xx.T.numpy(), x.T) | np.testing.assert_almost_equal(xx.T.numpy(), x.T) | ||||
| def test_as_type(): | |||||
| x = TensorWrapper([1, 2, 3], dtype=np.float32) | |||||
| y = x.astype(qint8(0.1)) | |||||
| np.testing.assert_almost_equal(get_scale(y.dtype), 0.1) | |||||
| z = y.astype(qint8(0.2)) | |||||
| np.testing.assert_almost_equal(get_scale(z.dtype), 0.2) | |||||
| a = z.astype(quint8(0.3, 127)) | |||||
| np.testing.assert_almost_equal(get_scale(a.dtype), 0.3) | |||||
| np.testing.assert_equal(get_zero_point(a.dtype), 127) | |||||
| b = a.astype(quint8(0.3, 128)) | |||||
| np.testing.assert_almost_equal(get_scale(b.dtype), 0.3) | |||||
| np.testing.assert_equal(get_zero_point(b.dtype), 128) | |||||