Browse Source

!5669 Fix get_py_obj_dtype() for mindspore type

Merge pull request !5669 from hewei/fix_get_py_obj_dtype
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
77dd91a646
3 changed files with 10 additions and 6 deletions
  1. +8
    -2
      mindspore/common/dtype.py
  2. +1
    -2
      tests/ut/python/pynative_mode/ops/test_grad.py
  3. +1
    -2
      tests/ut/python/pynative_mode/test_staging.py

+ 8
- 2
mindspore/common/dtype.py View File

@@ -178,12 +178,18 @@ def get_py_obj_dtype(obj):
Type of MindSpore type.
"""
# Tensor
if hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type):
if hasattr(obj, 'shape') and hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type):
return tensor_type(obj.dtype)
# Primitive or Cell
if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'):
return function
if isinstance(obj, (typing.Type, type)):
# mindspore type
if isinstance(obj, typing.Type):
return type_type
# python type
if isinstance(obj, type):
return pytype_to_dtype(obj)
# others
return pytype_to_dtype(type(obj))




+ 1
- 2
tests/ut/python/pynative_mode/ops/test_grad.py View File

@@ -19,7 +19,6 @@ import mindspore as ms
import mindspore.ops.operations as P
from mindspore import Tensor, context
from mindspore.common.api import ms_function
from mindspore.common.dtype import get_py_obj_dtype
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from ...ut_filter import non_graph_engine
@@ -90,7 +89,7 @@ def test_cast_grad():
def test_scalar_cast_grad():
""" test_scalar_cast_grad """
input_x = 255.5
input_t = get_py_obj_dtype(ms.int8)
input_t = ms.int8

def fx_cast(x):
output = F.scalar_cast(x, input_t)


+ 1
- 2
tests/ut/python/pynative_mode/test_staging.py View File

@@ -23,7 +23,6 @@ from mindspore import context
from mindspore.common import MetaTensor
from mindspore.common import dtype
from mindspore.common.api import ms_function
from mindspore.common.dtype import get_py_obj_dtype
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from ..ut_filter import non_graph_engine
@@ -185,7 +184,7 @@ def test_input_signature():
def test_scalar_cast():
""" test_scalar_cast """
input_x = 8.5
input_t = get_py_obj_dtype(ms.int64)
input_t = ms.int64

@ms_function
def fn_cast(x, t):


Loading…
Cancel
Save