Browse Source

improve isinstance function in graph mode

tags/v1.2.0-rc1
buxue 4 years ago
parent
commit
15593dc98f
4 changed files with 28 additions and 7 deletions
  1. +6
    -3
      mindspore/_extends/parse/standard_method.py
  2. +1
    -0
      mindspore/common/dtype.py
  3. +1
    -1
      mindspore/ops/composite/multitype_ops/zeros_like_impl.py
  4. +20
    -3
      tests/ut/python/pipeline/parse/test_isinstance.py

+ 6
- 3
mindspore/_extends/parse/standard_method.py View File

@@ -18,7 +18,7 @@


from dataclasses import dataclass from dataclasses import dataclass


from mindspore import Tensor
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype from mindspore import dtype as mstype


from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
@@ -361,10 +361,13 @@ def check_type_same(x_type, base_type):
str: mstype.String, str: mstype.String,
list: mstype.List, list: mstype.List,
tuple: mstype.Tuple, tuple: mstype.Tuple,
Tensor: mstype.tensor_type
Tensor: mstype.tensor_type,
Parameter: mstype.ref_type
} }
try: try:
if isinstance(base_type, (tuple, list)):
if isinstance(base_type, list):
raise TypeError("The second arg of 'isinstance' must be a type or a tuple of types, but got a list")
if isinstance(base_type, tuple):
target_type = tuple(pytype_to_mstype[i] for i in base_type) target_type = tuple(pytype_to_mstype[i] for i in base_type)
else: else:
target_type = pytype_to_mstype[base_type] target_type = pytype_to_mstype[base_type]


+ 1
- 0
mindspore/common/dtype.py View File

@@ -111,6 +111,7 @@ none_type = typing.TypeNone
env_type_type = typing.EnvType env_type_type = typing.EnvType
tensor_type = typing.TensorType tensor_type = typing.TensorType
anything_type = typing.TypeAnything anything_type = typing.TypeAnything
ref_type = typing.RefType


number_type = (int8, number_type = (int8,
int16, int16,


+ 1
- 1
mindspore/ops/composite/multitype_ops/zeros_like_impl.py View File

@@ -26,7 +26,7 @@ using ".register" decorator.




@zeros_like_leaf.register("Number") @zeros_like_leaf.register("Number")
def _zeros_like_scala(x):
def _zeros_like_scalar(x):
"""Returns 0 which has the same dtype as x where x is a scalar.""" """Returns 0 which has the same dtype as x where x is a scalar."""
return 0 return 0




+ 20
- 3
tests/ut/python/pipeline/parse/test_isinstance.py View File

@@ -17,7 +17,7 @@ import numpy as np
import pytest import pytest


import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Tensor, Parameter
from mindspore import context from mindspore import context


context.set_context(mode=context.GRAPH_MODE, save_graphs=True) context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
@@ -34,12 +34,14 @@ def test_isinstance():
self.tensor_member = Tensor(np.arange(4)) self.tensor_member = Tensor(np.arange(4))
self.tuple_member = (1, 1.0, True, "abcd", self.tensor_member) self.tuple_member = (1, 1.0, True, "abcd", self.tensor_member)
self.list_member = list(self.tuple_member) self.list_member = list(self.tuple_member)
self.weight = Parameter(1.0)


def construct(self, x, y): def construct(self, x, y):
is_int = isinstance(self.int_member, int) is_int = isinstance(self.int_member, int)
is_float = isinstance(self.float_member, float) is_float = isinstance(self.float_member, float)
is_bool = isinstance(self.bool_member, bool) is_bool = isinstance(self.bool_member, bool)
is_string = isinstance(self.string_member, str) is_string = isinstance(self.string_member, str)
is_parameter = isinstance(self.weight, Parameter)
is_tensor_const = isinstance(self.tensor_member, Tensor) is_tensor_const = isinstance(self.tensor_member, Tensor)
is_tensor_var = isinstance(x, Tensor) is_tensor_var = isinstance(x, Tensor)
is_tuple_const = isinstance(self.tuple_member, tuple) is_tuple_const = isinstance(self.tuple_member, tuple)
@@ -52,7 +54,7 @@ def test_isinstance():
bool_is_string = isinstance(self.bool_member, str) bool_is_string = isinstance(self.bool_member, str)
tensor_is_tuple = isinstance(x, tuple) tensor_is_tuple = isinstance(x, tuple)
tuple_is_list = isinstance(self.tuple_member, list) tuple_is_list = isinstance(self.tuple_member, list)
return is_int, is_float, is_bool, is_string, is_tensor_const, is_tensor_var, \
return is_int, is_float, is_bool, is_string, is_parameter, is_tensor_const, is_tensor_var, \
is_tuple_const, is_tuple_var, is_list_const, is_list_var, \ is_tuple_const, is_tuple_var, is_list_const, is_list_var, \
is_int_or_float_or_tensor_or_tuple, is_list_or_tensor, \ is_int_or_float_or_tensor_or_tuple, is_list_or_tensor, \
float_is_int, bool_is_string, tensor_is_tuple, tuple_is_list float_is_int, bool_is_string, tensor_is_tuple, tuple_is_list
@@ -60,7 +62,7 @@ def test_isinstance():
net = Net() net = Net()
x = Tensor(np.arange(4)) x = Tensor(np.arange(4))
y = Tensor(np.arange(5)) y = Tensor(np.arange(5))
assert net(x, y) == (True,) * 12 + (False,) * 4
assert net(x, y) == (True,) * 13 + (False,) * 4




def test_isinstance_not_supported(): def test_isinstance_not_supported():
@@ -76,3 +78,18 @@ def test_isinstance_not_supported():
with pytest.raises(TypeError) as err: with pytest.raises(TypeError) as err:
net() net()
assert "The type 'None' is not supported for 'isinstance'" in str(err.value) assert "The type 'None' is not supported for 'isinstance'" in str(err.value)


def test_isinstance_second_arg_is_list():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = (11, 22, 33, 44)

def construct(self):
return isinstance(self.value, [tuple, int, float])

net = Net()
with pytest.raises(TypeError) as err:
net()
assert "The second arg of 'isinstance' must be a type or a tuple of types, but got a list" in str(err.value)

Loading…
Cancel
Save