|
|
|
@@ -178,12 +178,18 @@ def get_py_obj_dtype(obj): |
|
|
|
Type of MindSpore type. |
|
|
|
""" |
|
|
|
# Tensor |
|
|
|
if hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type): |
|
|
|
if hasattr(obj, 'shape') and hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type): |
|
|
|
return tensor_type(obj.dtype) |
|
|
|
# Primitive or Cell |
|
|
|
if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): |
|
|
|
return function |
|
|
|
if isinstance(obj, (typing.Type, type)): |
|
|
|
# mindspore type |
|
|
|
if isinstance(obj, typing.Type): |
|
|
|
return type_type |
|
|
|
# python type |
|
|
|
if isinstance(obj, type): |
|
|
|
return pytype_to_dtype(obj) |
|
|
|
# others |
|
|
|
return pytype_to_dtype(type(obj)) |
|
|
|
|
|
|
|
|
|
|
|
|