From 2ea8b9e738013a57ca8e6a11f52ebd1efb8ca72f Mon Sep 17 00:00:00 2001 From: buxue Date: Thu, 21 Jan 2021 14:12:50 +0800 Subject: [PATCH] add dict for isinstance --- mindspore/_extends/parse/standard_method.py | 8 ++++++-- mindspore/ccsrc/pybind_api/ir/dtype_py.cc | 3 +++ mindspore/common/dtype.py | 1 + .../python/pipeline/parse/test_isinstance.py | 18 +++++++++++++----- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index bcf1cc4247..1f3d3ab058 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -362,6 +362,7 @@ def check_type_same(x_type, base_type): str: mstype.String, list: mstype.List, tuple: mstype.Tuple, + dict: mstype.Dict, Tensor: mstype.tensor_type, Parameter: mstype.ref_type } @@ -371,11 +372,14 @@ def check_type_same(x_type, base_type): if isinstance(base_type, tuple): target_type = tuple(pytype_to_mstype[i] for i in base_type) else: - target_type = pytype_to_mstype[base_type] + target_type = (pytype_to_mstype[base_type],) + if (isinstance(x_type, mstype.Bool) and mstype.Int in target_type) or \ + (isinstance(x_type, mstype.ref_type) and mstype.tensor_type in target_type): + return True return isinstance(x_type, target_type) except KeyError: raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, " - f"Tensor, Parameter, or a tuple only including these types, but got {base_type}") + f"Tensor, Parameter, or a tuple containing only these types, but got {base_type}") @constexpr diff --git a/mindspore/ccsrc/pybind_api/ir/dtype_py.cc b/mindspore/ccsrc/pybind_api/ir/dtype_py.cc index 7c0bc4366b..46bc8c50e2 100644 --- a/mindspore/ccsrc/pybind_api/ir/dtype_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/dtype_py.cc @@ -115,6 +115,9 @@ REGISTER_PYBIND_DEFINE( (void)py::class_>(m_sub, "Tuple") .def(py::init()) .def(py::init>(), py::arg("elements")); + (void)py::class_>(m_sub, "Dict") + .def(py::init()) + .def(py::init>>(), py::arg("key_values")); (void)py::class_>(m_sub, "TensorType") .def(py::init()) .def(py::init(), py::arg("element")) diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index e3b8e7ae76..6c381793fb 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -104,6 +104,7 @@ Bool = typing.Bool String = typing.String List = typing.List Tuple = typing.Tuple +Dict = typing.Dict Slice = typing.Slice function_type = typing.Function Ellipsis_ = typing.TypeEllipsis diff --git a/tests/ut/python/pipeline/parse/test_isinstance.py b/tests/ut/python/pipeline/parse/test_isinstance.py index 9e54be0a24..8abe40afca 100644 --- a/tests/ut/python/pipeline/parse/test_isinstance.py +++ b/tests/ut/python/pipeline/parse/test_isinstance.py @@ -36,19 +36,26 @@ def test_isinstance(): self.list_member = list(self.tuple_member) self.weight = Parameter(1.0) self.empty_list = [] + self.dict_member = {"x": Tensor(np.arange(4)), "y": Tensor(np.arange(5))} + self.empty_dict = {} def construct(self, x, y): is_int = isinstance(self.int_member, int) is_float = isinstance(self.float_member, float) is_bool = isinstance(self.bool_member, bool) + bool_is_int = isinstance(self.bool_member, int) is_string = isinstance(self.string_member, str) is_parameter = isinstance(self.weight, Parameter) + parameter_is_tensor = isinstance(self.weight, Tensor) is_tensor_const = isinstance(self.tensor_member, Tensor) is_tensor_var = isinstance(x, Tensor) is_tuple_const = isinstance(self.tuple_member, tuple) is_tuple_var = isinstance((x, 1, 1.0, y), tuple) is_list_const = isinstance(self.list_member, list) is_list_var = isinstance([x, 1, 1.0, y], list) + is_dict_const = isinstance(self.dict_member, dict) + is_dict_var = isinstance({"x": x, "y": y}, dict) + is_empty_dic = isinstance(self.empty_dict, dict) is_list_or_tensor = isinstance([x, y], (Tensor, list)) is_int_or_float_or_tensor_or_tuple = isinstance(x, (Tensor, tuple, int, float)) float_is_int = isinstance(self.float_member, int) @@ -56,16 +63,17 @@ def test_isinstance(): tensor_is_tuple = isinstance(x, tuple) tuple_is_list = isinstance(self.tuple_member, list) is_empty_list = isinstance(self.empty_list, list) - return is_int, is_float, is_bool, is_string, \ - is_empty_list, is_parameter, is_tensor_const, is_tensor_var, \ - is_tuple_const, is_tuple_var, is_list_const, is_list_var, \ + return is_int, is_float, is_bool, bool_is_int, is_string, is_parameter, \ + parameter_is_tensor, is_tensor_const, is_tensor_var, \ + is_tuple_const, is_tuple_var, is_list_const, is_list_var, is_empty_list, \ + is_dict_const, is_dict_var, is_empty_dic, \ is_int_or_float_or_tensor_or_tuple, is_list_or_tensor, \ float_is_int, bool_is_string, tensor_is_tuple, tuple_is_list net = Net() x = Tensor(np.arange(4)) y = Tensor(np.arange(5)) - assert net(x, y) == (True,) * 14 + (False,) * 4 + assert net(x, y) == (True,) * 19 + (False,) * 4 def test_isinstance_not_supported(): @@ -81,7 +89,7 @@ def test_isinstance_not_supported(): with pytest.raises(TypeError) as err: net() assert "The second arg of 'isinstance' should be bool, int, float, str, list, tuple, Tensor, Parameter, " \ - "or a tuple only including these types, but got None" in str(err.value) + "or a tuple containing only these types, but got None" in str(err.value) def test_isinstance_second_arg_is_list():