GitOrigin-RevId: f0aaea99b9
tags/v1.9.0
| @@ -6,11 +6,11 @@ | |||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import weakref | import weakref | ||||
| from collections import OrderedDict | |||||
| from typing import Callable, Iterable, List, Union | from typing import Callable, Iterable, List, Union | ||||
| from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | ||||
| from ..core.autodiff.grad import Grad | from ..core.autodiff.grad import Grad | ||||
| from ..core.tensor.dtype import is_differentible_dtype | |||||
| from ..logger import get_logger | from ..logger import get_logger | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from ..utils.future import Future | from ..utils.future import Future | ||||
| @@ -208,6 +208,10 @@ class GradManager: | |||||
| for x in tensors: | for x in tensors: | ||||
| assert isinstance(x, Tensor), "Object to be attached should be Tensor" | assert isinstance(x, Tensor), "Object to be attached should be Tensor" | ||||
| assert is_differentible_dtype(x.dtype), ( | |||||
| "Only tensors of floating point dtype can be attached to get gradients, " | |||||
| "get tensor dtype: {} and shape: {}".format(x.dtype, x.shape) | |||||
| ) | |||||
| spec = self._attach_specs.get(id(x)) | spec = self._attach_specs.get(id(x)) | ||||
| new_attach = spec is None | new_attach = spec is None | ||||
| if spec is None: | if spec is None: | ||||
| @@ -38,6 +38,10 @@ def is_bfloat16(dtype): | |||||
| return dtype is bfloat16 | return dtype is bfloat16 | ||||
| def is_differentible_dtype(dtype): | |||||
| return dtype == np.float32 or dtype == np.float16 or is_bfloat16(dtype) | |||||
| # quantization dtype related | # quantization dtype related | ||||
| # use namedtuple to make class immutable, comparable and easy to print | # use namedtuple to make class immutable, comparable and easy to print | ||||
| @@ -114,7 +118,7 @@ def create_quantized_dtype( | |||||
| dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None] | dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None] | ||||
| ): | ): | ||||
| r"""Get quantized dtype with metadata attribute according to _metadata_dict. | r"""Get quantized dtype with metadata attribute according to _metadata_dict. | ||||
| Note that unsigned dtype must have ``zero_point`` and signed dtype must | Note that unsigned dtype must have ``zero_point`` and signed dtype must | ||||
| not have ``zero_point``, to be consitent with tensor generated by calling | not have ``zero_point``, to be consitent with tensor generated by calling | ||||
| compiled function from `CompGraph.compile(inputs, outspec)`. | compiled function from `CompGraph.compile(inputs, outspec)`. | ||||
| @@ -13,6 +13,7 @@ import numpy as np | |||||
| import pytest | import pytest | ||||
| import megengine as mge | import megengine as mge | ||||
| import megengine.core.tensor.dtype as dtype | |||||
| import megengine.distributed as dist | import megengine.distributed as dist | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| import megengine.module as M | import megengine.module as M | ||||
| @@ -469,3 +470,18 @@ def test_2nd_grad_with_custom_gradient(): | |||||
| np.testing.assert_almost_equal( | np.testing.assert_almost_equal( | ||||
| x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 | x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 | ||||
| ) | ) | ||||
| @pytest.mark.parametrize("invalid_dtype", [np.uint8, np.int8, np.int32]) | |||||
| def test_attach_invalid_tensor_dtype(invalid_dtype): | |||||
| gm = GradManager() | |||||
| x = mge.tensor([1], dtype=invalid_dtype) | |||||
| with pytest.raises(AssertionError): | |||||
| gm.attach([x]) | |||||
| @pytest.mark.parametrize("differentible_dtype", [np.float32, np.float16]) | |||||
| def test_attach_differentible_tensor_dtype(differentible_dtype): | |||||
| gm = GradManager() | |||||
| x = mge.tensor([1], dtype=differentible_dtype) | |||||
| gm.attach([x]) | |||||