Merge pull request !5669 from hewei/fix_get_py_obj_dtypetags/v1.0.0
| @@ -178,12 +178,18 @@ def get_py_obj_dtype(obj): | |||||
| Type of MindSpore type. | Type of MindSpore type. | ||||
| """ | """ | ||||
| # Tensor | # Tensor | ||||
| if hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type): | |||||
| if hasattr(obj, 'shape') and hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type): | |||||
| return tensor_type(obj.dtype) | return tensor_type(obj.dtype) | ||||
| # Primitive or Cell | |||||
| if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): | if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): | ||||
| return function | return function | ||||
| if isinstance(obj, (typing.Type, type)): | |||||
| # mindspore type | |||||
| if isinstance(obj, typing.Type): | |||||
| return type_type | |||||
| # python type | |||||
| if isinstance(obj, type): | |||||
| return pytype_to_dtype(obj) | return pytype_to_dtype(obj) | ||||
| # others | |||||
| return pytype_to_dtype(type(obj)) | return pytype_to_dtype(type(obj)) | ||||
| @@ -19,7 +19,6 @@ import mindspore as ms | |||||
| import mindspore.ops.operations as P | import mindspore.ops.operations as P | ||||
| from mindspore import Tensor, context | from mindspore import Tensor, context | ||||
| from mindspore.common.api import ms_function | from mindspore.common.api import ms_function | ||||
| from mindspore.common.dtype import get_py_obj_dtype | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from ...ut_filter import non_graph_engine | from ...ut_filter import non_graph_engine | ||||
| @@ -90,7 +89,7 @@ def test_cast_grad(): | |||||
| def test_scalar_cast_grad(): | def test_scalar_cast_grad(): | ||||
| """ test_scalar_cast_grad """ | """ test_scalar_cast_grad """ | ||||
| input_x = 255.5 | input_x = 255.5 | ||||
| input_t = get_py_obj_dtype(ms.int8) | |||||
| input_t = ms.int8 | |||||
| def fx_cast(x): | def fx_cast(x): | ||||
| output = F.scalar_cast(x, input_t) | output = F.scalar_cast(x, input_t) | ||||
| @@ -23,7 +23,6 @@ from mindspore import context | |||||
| from mindspore.common import MetaTensor | from mindspore.common import MetaTensor | ||||
| from mindspore.common import dtype | from mindspore.common import dtype | ||||
| from mindspore.common.api import ms_function | from mindspore.common.api import ms_function | ||||
| from mindspore.common.dtype import get_py_obj_dtype | |||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from ..ut_filter import non_graph_engine | from ..ut_filter import non_graph_engine | ||||
| @@ -185,7 +184,7 @@ def test_input_signature(): | |||||
| def test_scalar_cast(): | def test_scalar_cast(): | ||||
| """ test_scalar_cast """ | """ test_scalar_cast """ | ||||
| input_x = 8.5 | input_x = 8.5 | ||||
| input_t = get_py_obj_dtype(ms.int64) | |||||
| input_t = ms.int64 | |||||
| @ms_function | @ms_function | ||||
| def fn_cast(x, t): | def fn_cast(x, t): | ||||