Browse Source

!5669 Fix get_py_obj_dtype() for mindspore type

Merge pull request !5669 from hewei/fix_get_py_obj_dtype
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
77dd91a646
3 changed files with 10 additions and 6 deletions
  1. +8
    -2
      mindspore/common/dtype.py
  2. +1
    -2
      tests/ut/python/pynative_mode/ops/test_grad.py
  3. +1
    -2
      tests/ut/python/pynative_mode/test_staging.py

+ 8
- 2
mindspore/common/dtype.py View File

@@ -178,12 +178,18 @@ def get_py_obj_dtype(obj):
Type of MindSpore type. Type of MindSpore type.
""" """
# Tensor # Tensor
if hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type):
if hasattr(obj, 'shape') and hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type):
return tensor_type(obj.dtype) return tensor_type(obj.dtype)
# Primitive or Cell
if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'):
return function return function
if isinstance(obj, (typing.Type, type)):
# mindspore type
if isinstance(obj, typing.Type):
return type_type
# python type
if isinstance(obj, type):
return pytype_to_dtype(obj) return pytype_to_dtype(obj)
# others
return pytype_to_dtype(type(obj)) return pytype_to_dtype(type(obj))






+ 1
- 2
tests/ut/python/pynative_mode/ops/test_grad.py View File

@@ -19,7 +19,6 @@ import mindspore as ms
import mindspore.ops.operations as P import mindspore.ops.operations as P
from mindspore import Tensor, context from mindspore import Tensor, context
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore.common.dtype import get_py_obj_dtype
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from ...ut_filter import non_graph_engine from ...ut_filter import non_graph_engine
@@ -90,7 +89,7 @@ def test_cast_grad():
def test_scalar_cast_grad(): def test_scalar_cast_grad():
""" test_scalar_cast_grad """ """ test_scalar_cast_grad """
input_x = 255.5 input_x = 255.5
input_t = get_py_obj_dtype(ms.int8)
input_t = ms.int8


def fx_cast(x): def fx_cast(x):
output = F.scalar_cast(x, input_t) output = F.scalar_cast(x, input_t)


+ 1
- 2
tests/ut/python/pynative_mode/test_staging.py View File

@@ -23,7 +23,6 @@ from mindspore import context
from mindspore.common import MetaTensor from mindspore.common import MetaTensor
from mindspore.common import dtype from mindspore.common import dtype
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore.common.dtype import get_py_obj_dtype
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
@@ -185,7 +184,7 @@ def test_input_signature():
def test_scalar_cast(): def test_scalar_cast():
""" test_scalar_cast """ """ test_scalar_cast """
input_x = 8.5 input_x = 8.5
input_t = get_py_obj_dtype(ms.int64)
input_t = ms.int64


@ms_function @ms_function
def fn_cast(x, t): def fn_cast(x, t):


Loading…
Cancel
Save