diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index afd080ad95..ba2fe6f146 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -17,14 +17,14 @@ """Resources for ast tree parse.""" import ast import math + from mindspore import RowTensor, SparseTensor -from mindspore.ops.composite import multitype_ops from mindspore.ops import functional as F, composite as C +from mindspore.ops.composite import multitype_ops from . import standard_method as M from . import trope as T from .namespace import CellNamespace - # namespace define functional_ns = CellNamespace('mindspore.ops.functional') composite_ns = CellNamespace('mindspore.ops.composite') @@ -109,7 +109,7 @@ convert_object_map = { # system function T.len: M.ms_len, - T.bool: M.bool_, + T.bool_: M.bool_, T.map: C.Map(), T.partial: F.partial, T.zip: C.zip_operation, diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index 8ab5da6545..e30777b7f7 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -16,13 +16,15 @@ # ============================================================================ """standard_method""" from dataclasses import dataclass -from mindspore.common import dtype as mstype + +from mindspore import Tensor +from mindspore import dtype as mstype from ...ops import functional as F from ...ops import operations as P -from ...ops.primitive import constexpr from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \ zeros_like, ones_like from ...ops.composite.base import _append +from ...ops.primitive import constexpr __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like'] @@ -219,9 +221,23 @@ def while_cond(x): @constexpr def check_type_same(x_type, base_type): """Check x_type is same as base_type.""" - if mstype.issubclass_(x_type, base_type): - return True - return False + pytype_to_mstype = { + bool: mstype.Bool, + int: mstype.Int, + float: mstype.Float, + str: mstype.String, + list: mstype.List, + tuple: mstype.Tuple, + Tensor: mstype.tensor_type + } + try: + if isinstance(base_type, (tuple, list)): + target_type = tuple(pytype_to_mstype[i] for i in base_type) + else: + target_type = pytype_to_mstype[base_type] + return isinstance(x_type, target_type) + except KeyError: + raise TypeError(f"The type '{base_type}' is not supported for 'isinstance'") @constexpr @@ -235,7 +251,7 @@ def check_is_tensor(x): @constexpr def check_is_tuple_or_list_or_tensor(x, op_name, arg_name): """check whether x is list or tuple or tensor.""" - if isinstance(x, (mstype.list_type, mstype.tuple_type, mstype.tensor_type)): + if isinstance(x, (mstype.List, mstype.Tuple, mstype.tensor_type)): return True raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.") diff --git a/mindspore/_extends/parse/trope.py b/mindspore/_extends/parse/trope.py index 674715ef59..48e42ca232 100644 --- a/mindspore/_extends/parse/trope.py +++ b/mindspore/_extends/parse/trope.py @@ -95,3 +95,7 @@ def not_contains(x): # pragma: no cover def while_cond(x): # pragma: no cover """Not in function.""" raise RuntimeError('This operation is not meant to be called directly.') + +def bool_(x): # pragma: no cover + """judge true function.""" + raise RuntimeError('This operation is not meant to be called directly.') diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index 3873dc68bf..d38a7d0436 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -116,7 +116,7 @@ const char NAMED_PRIMITIVE_NEXT[] = "next"; const char NAMED_PRIMITIVE_GETITEM[] = "getitem"; const char NAMED_PRIMITIVE_SETITEM[] = "setitem"; const char NAMED_PRIMITIVE_HASNEXT[] = "hasnext"; -const char NAMED_PRIMITIVE_BOOL[] = "bool"; // bool: P.identity +const char NAMED_PRIMITIVE_BOOL[] = "bool_"; // bool: P.identity const char NAMED_PRIMITIVE_MAKETUPLE[] = "make_tuple"; const char NAMED_PRIMITIVE_MAKELIST[] = "make_list"; const char NAMED_PRIMITIVE_MAKESLICE[] = "make_slice"; diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.h b/mindspore/ccsrc/pipeline/jit/parse/resolve.h index db937daebf..e1effacab2 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.h +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.h @@ -109,6 +109,7 @@ class ClassType : public PyObjectWrapper { MS_DECLARE_PARENT(ClassType, PyObjectWrapper); abstract::AbstractBasePtr ToAbstract() override; }; +using ClassTypePtr = std::shared_ptr; // SymbolResolver class for resolving symbol extracted from AnfNode. class SymbolResolver { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 535ed4f3bf..830e94ee90 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -280,24 +280,20 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) { py::tuple max_shape_tuple(len); auto dic = py::dict(); bool dyn_shape = false; - bool is_build_value = true; + bool dyn_value = false; for (size_t i = 0; i < len; i++) { auto arg = arg_tuple->elements()[i]; py::dict out = ConvertAbstractToPython(arg); shape_tuple[i] = out[ATTR_SHAPE]; dtype_tuple[i] = out[ATTR_DTYPE]; + value_tuple[i] = out[ATTR_VALUE]; // Elements in tuple is tensor shape value. if (out.contains(py::str(ATTR_MIN_VALUE)) && out.contains(py::str(ATTR_MAX_VALUE))) { - value_tuple[i] = out[ATTR_VALUE]; min_value_tuple[i] = out[ATTR_MIN_VALUE]; max_value_tuple[i] = out[ATTR_MAX_VALUE]; - is_build_value = false; - } else { - value_tuple[i] = BuildValue(arg->BuildValue()); - min_value_tuple[i] = value_tuple[i]; - max_value_tuple[i] = value_tuple[i]; + dyn_value = true; } // Elements in tuple is tensor, which shape is dynamic. @@ -305,21 +301,21 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) { min_shape_tuple[i] = out[ATTR_MIN_SHAPE]; max_shape_tuple[i] = out[ATTR_MAX_SHAPE]; dyn_shape = true; - } else { - min_shape_tuple[i] = out[ATTR_SHAPE]; - max_shape_tuple[i] = out[ATTR_SHAPE]; } } + dic[ATTR_SHAPE] = shape_tuple; dic[ATTR_DTYPE] = dtype_tuple; - if (is_build_value) { - dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue()); + if (arg_tuple->BuildValue()->isa()) { + dic[ATTR_VALUE] = py::none(); } else { dic[ATTR_VALUE] = value_tuple; + } + + if (dyn_value) { dic[ATTR_MIN_VALUE] = min_value_tuple; dic[ATTR_MAX_VALUE] = max_value_tuple; } - if (dyn_shape) { dic[ATTR_MIN_SHAPE] = min_shape_tuple; dic[ATTR_MAX_SHAPE] = max_shape_tuple; @@ -333,6 +329,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) { size_t len = arg_list->size(); py::list shape_list(len); py::list dtype_list(len); + py::list value_list(len); py::list min_shape_list(len); py::list max_shape_list(len); auto dic = py::dict(); @@ -342,27 +339,29 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) { py::dict out = ConvertAbstractToPython(arg_list->elements()[i]); shape_list[i] = out[ATTR_SHAPE]; dtype_list[i] = out[ATTR_DTYPE]; + value_list[i] = out[ATTR_VALUE]; // Elements in list is tensor, which shape is dynamic. if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) { min_shape_list[i] = out[ATTR_MIN_SHAPE]; max_shape_list[i] = out[ATTR_MAX_SHAPE]; dyn_shape = true; - } else { - min_shape_list[i] = out[ATTR_SHAPE]; - max_shape_list[i] = out[ATTR_SHAPE]; } } + dic[ATTR_SHAPE] = shape_list; + dic[ATTR_DTYPE] = dtype_list; + if (arg_list->BuildValue()->isa()) { + dic[ATTR_VALUE] = py::none(); + } else { + dic[ATTR_VALUE] = value_list; + } + if (dyn_shape) { dic[ATTR_MIN_SHAPE] = min_shape_list; dic[ATTR_MAX_SHAPE] = max_shape_list; } - dic[ATTR_SHAPE] = shape_list; - dic[ATTR_DTYPE] = dtype_list; - dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue()); - return dic; } } // end anonymous namespace @@ -428,6 +427,16 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { dic[ATTR_SHAPE] = py::none(); dic[ATTR_DTYPE] = abs_base->BuildType(); dic[ATTR_VALUE] = py::none(); + if (abs_base->isa()) { + AbstractBasePtrList args = abs_base->cast()->args(); + if (!args.empty()) { + auto value = args[0]->BuildValue()->cast(); + if (value != nullptr) { + dic[ATTR_DTYPE] = std::make_shared(); + dic[ATTR_VALUE] = value->obj(); + } + } + } } else if (abs_base->isa()) { auto arg = dyn_cast(abs_base); dic[ATTR_SHAPE] = py::none(); diff --git a/mindspore/ccsrc/utils/convert_utils_py.cc b/mindspore/ccsrc/utils/convert_utils_py.cc index b78c4f9398..0dadae5cae 100644 --- a/mindspore/ccsrc/utils/convert_utils_py.cc +++ b/mindspore/ccsrc/utils/convert_utils_py.cc @@ -390,8 +390,8 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py } return PyList2DynamicShapeTensor(shape_obj, type_obj, output); } else if (py::isinstance(shape_obj) && py::isinstance(type_obj)) { - py::tuple shape_tuple = shape_obj.cast(); - py::tuple typeid_tuple = type_obj.cast(); + auto shape_tuple = shape_obj.cast(); + auto typeid_tuple = type_obj.cast(); AbstractBasePtrList ptr_list; for (size_t it = 0; it < shape_tuple.size(); ++it) { auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]); @@ -400,8 +400,8 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py auto tuple = std::make_shared(ptr_list); return tuple; } else if (py::isinstance(shape_obj) && py::isinstance(type_obj)) { - py::list shape_list = shape_obj.cast(); - py::list typeid_list = type_obj.cast(); + auto shape_list = shape_obj.cast(); + auto typeid_list = type_obj.cast(); AbstractBasePtrList ptr_list; for (size_t it = 0; it < shape_list.size(); ++it) { auto tensor_it = PyListDtype2AbstractTensor(shape_list[it], typeid_list[it]); diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index 3a71e66091..338ebb0227 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -78,36 +78,39 @@ single = float32 float64 = typing.Float(64) double = float64 +number = typing.Number() int_ = typing.Int() uint = typing.UInt() float_ = typing.Float() -number = typing.Number() - +string = typing.String() list_ = typing.List() tuple_ = typing.Tuple() +type_none = typing.TypeNone() + tensor = typing.TensorType() +index_slices = typing.RowTensorType() +sparse_tensor = typing.SparseTensorType() +undetermined = typing.UndeterminedType() + function = typing.Function() -function_type = typing.Function symbolic_key = typing.SymbolicKeyType() env_type = typing.EnvType() -env_type_type = typing.EnvType type_type = typing.TypeType() -type_none = typing.TypeNone() -type_bool = typing.Bool() -string = typing.String() type_refkey = typing.RefKeyType() -tensor_type = typing.TensorType -anything_type = typing.TypeAnything -slice_type = typing.Slice -ellipsis_type = typing.TypeEllipsis -list_type = typing.List -tuple_type = typing.Tuple -index_slices = typing.RowTensorType() -sparse_tensor = typing.SparseTensorType() -undetermined = typing.UndeterminedType() + Int = typing.Int -bool_type = typing.Bool +Float = typing.Float +Bool = typing.Bool +String = typing.String +List = typing.List +Tuple = typing.Tuple +Slice = typing.Slice +function_type = typing.Function +Ellipsis_ = typing.TypeEllipsis none_type = typing.TypeNone +env_type_type = typing.EnvType +tensor_type = typing.TensorType +anything_type = typing.TypeAnything number_type = (int8, int16, diff --git a/mindspore/core/abstract/abstract_function.h b/mindspore/core/abstract/abstract_function.h index c0eea99d40..ac18fa689f 100644 --- a/mindspore/core/abstract/abstract_function.h +++ b/mindspore/core/abstract/abstract_function.h @@ -86,6 +86,8 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom { std::string ToString() const override { return "Prim: " + prim_->name(); } + ValuePtr RealBuildValue() const override { return prim_; } + private: PrimitivePtr prim_; // store it as weak_ptr to break reference cycle. @@ -183,6 +185,7 @@ class PartialAbstractClosure : public AbstractFuncAtom { AbstractFunctionPtr fn() { return fn_; } AbstractBasePtrList args() { return args_spec_list_; } + ValuePtr RealBuildValue() const override { return fn_->BuildValue(); } AnfNodePtr node() { return node_.lock(); } void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); } AbstractFunctionPtr Copy() const override { @@ -199,6 +202,7 @@ class PartialAbstractClosure : public AbstractFuncAtom { // The CNode which this PartialAbstractClosure evaluated from. AnfNodeWeakPtr node_; }; +using PartialAbstractClosurePtr = std::shared_ptr; class JTransformedAbstractClosure : public AbstractFuncAtom { public: diff --git a/mindspore/numpy/utils.py b/mindspore/numpy/utils.py index 1c3db5ca64..76bd383b12 100644 --- a/mindspore/numpy/utils.py +++ b/mindspore/numpy/utils.py @@ -339,13 +339,13 @@ def _cpu_not_support(name): @constexpr def _check_is_tuple(obj): """Check whether obj is a tuple""" - return isinstance(obj, mstype.tuple_type) + return isinstance(obj, mstype.Tuple) @constexpr def _check_is_list(obj): """Check whether obj is a list""" - return isinstance(obj, mstype.list_type) + return isinstance(obj, mstype.List) @constexpr diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 113d71afd3..c2262af9a1 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -148,7 +148,7 @@ def _expand_data_dims_with_bool(data, tuple_index, op_name): bool_positions, tuple_index_without_bool = (), () for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)): - bool_type_tag = const_utils.judge_index_type(index_type, mstype.type_bool) + bool_type_tag = const_utils.judge_index_type(index_type, mstype.bool_) if bool_type_tag: if index: tuple_index_without_bool += (const_utils.make_tensor([0], mstype.int64),) @@ -653,6 +653,6 @@ def tensor_in_sequence(x, y): """Assigns whether a sequence contains the given tensor""" result = const_utils.scalar_to_tensor(False) for i in y: - if isinstance(i, mstype.tensor) and x.shape == i.shape and x.dtype == i.dtype: + if isinstance(i, Tensor) and x.shape == i.shape and x.dtype == i.dtype: result = F.logical_or(F.equal(x, i).all(), result) return result diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 373d6a47fe..cc918e2aff 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -171,15 +171,15 @@ def get_pos_of_indexes_types(indexes_types, op_name): slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \ sequence_positions = [], [], [], [], [], [], [] for i, index_type in enumerate(indexes_types): - if isinstance(index_type, mstype.slice_type): + if isinstance(index_type, mstype.Slice): slice_positions.append(i) - elif isinstance(index_type, mstype.ellipsis_type): + elif isinstance(index_type, mstype.Ellipsis_): ellipsis_positions.append(i) elif isinstance(index_type, mstype.none_type): none_positions.append(i) elif isinstance(index_type, mstype.Int): int_positions.append(i) - elif isinstance(index_type, mstype.bool_type): + elif isinstance(index_type, mstype.Bool): bool_positions.append(i) elif isinstance(index_type, mstype.tensor_type): tensor_positions.append(i) @@ -341,7 +341,7 @@ def tuple_index_int_cnt(types, op_name): def tuple_index_type_cnt(types, op_name): """count the tensor type of types which contains the tuple elements' type.""" tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types) - basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.ellipsis_type, mstype.slice_type)) for ele in types) + basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types) if tensor_cnt == len(types): return ALL_TENSOR if basic_cnt == len(types): @@ -614,7 +614,7 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, t indexes_info[pos] = tensor_indexes_shapes[tensor_count] index_tensors_info[pos] = tensor_indexes_shapes[tensor_count] tensor_count += 1 - elif isinstance(index_type, mstype.slice_type): + elif isinstance(index_type, mstype.Slice): slice_obj = slice(slice_indexes[slice_count].start, slice_indexes[slice_count].stop, slice_indexes[slice_count].step) @@ -680,7 +680,7 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info) -@ constexpr +@constexpr def scalar_in_sequence(x, y): """Determine whether the scalar in the sequence.""" if x is None: diff --git a/tests/ut/cpp/pipeline/static_analysis/data_test.cc b/tests/ut/cpp/pipeline/static_analysis/data_test.cc index 5264cd9e0c..5c333ed52f 100644 --- a/tests/ut/cpp/pipeline/static_analysis/data_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/data_test.cc @@ -63,7 +63,7 @@ TEST_F(TestData, test_build_value) { // BuildValue(AbstractFunction) should return kAnyValue. AbstractBasePtr abs_f1 = FromValue(prim::kPrimReturn, false); ValuePtr abs_f1_built = abs_f1->BuildValue(); - ASSERT_EQ(abs_f1_built, kAnyValue); + ASSERT_EQ(abs_f1_built, prim::kPrimReturn); FuncGraphPtr fg1 = std::make_shared(); AbstractBasePtr abs_fg1 = FromValue(fg1, false); @@ -74,17 +74,20 @@ TEST_F(TestData, test_build_value) { AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false); AbstractBasePtr abs_func_tuple = std::make_shared(AbstractBasePtrList({abs_f1, abs_f2})); ValuePtr func_tuple_built = abs_func_tuple->BuildValue(); - ASSERT_EQ(func_tuple_built, kAnyValue); + ASSERT_EQ(*func_tuple_built, + ValueTuple(std::vector{prim::kPrimReturn, prim::kPrimScalarAdd})); // BuildValue(List(AbstractFunction)) should return kAnyValue; AbstractBasePtr abs_func_list = std::make_shared(AbstractBasePtrList({abs_f1, abs_f2})); ValuePtr func_list_built = abs_func_list->BuildValue(); - ASSERT_EQ(func_list_built, kAnyValue); + ASSERT_EQ(*func_list_built, + ValueList(std::vector{prim::kPrimReturn, prim::kPrimScalarAdd})); // BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue abs_func_tuple = std::make_shared(AbstractBasePtrList({base1, abs_f2})); func_tuple_built = abs_func_tuple->BuildValue(); - ASSERT_EQ(func_tuple_built, kAnyValue); + ASSERT_EQ(*func_tuple_built, + ValueTuple(std::vector{std::make_shared(1), prim::kPrimScalarAdd})); } TEST_F(TestData, test_build_type) { diff --git a/tests/ut/python/pipeline/parse/test_isinstance.py b/tests/ut/python/pipeline/parse/test_isinstance.py new file mode 100644 index 0000000000..beaca638b3 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_isinstance.py @@ -0,0 +1,78 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test instance""" +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + + +def test_isinstance(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.int_member = 1 + self.float_member = 1.0 + self.bool_member = True + self.string_member = "abcd" + self.tensor_member = Tensor(np.arange(4)) + self.tuple_member = (1, 1.0, True, "abcd", self.tensor_member) + self.list_member = list(self.tuple_member) + + def construct(self, x, y): + is_int = isinstance(self.int_member, int) + is_float = isinstance(self.float_member, float) + is_bool = isinstance(self.bool_member, bool) + is_string = isinstance(self.string_member, str) + is_tensor_const = isinstance(self.tensor_member, Tensor) + is_tensor_var = isinstance(x, Tensor) + is_tuple_const = isinstance(self.tuple_member, tuple) + is_tuple_var = isinstance((x, 1, 1.0, y), tuple) + is_list_const = isinstance(self.list_member, list) + is_list_var = isinstance([x, 1, 1.0, y], list) + is_list_or_tensor = isinstance([x, y], (Tensor, list)) + is_int_or_float_or_tensor_or_tuple = isinstance(x, (Tensor, tuple, int, float)) + float_is_int = isinstance(self.float_member, int) + bool_is_string = isinstance(self.bool_member, str) + tensor_is_tuple = isinstance(x, tuple) + tuple_is_list = isinstance(self.tuple_member, list) + return is_int, is_float, is_bool, is_string, is_tensor_const, is_tensor_var, \ + is_tuple_const, is_tuple_var, is_list_const, is_list_var, \ + is_int_or_float_or_tensor_or_tuple, is_list_or_tensor, \ + float_is_int, bool_is_string, tensor_is_tuple, tuple_is_list + + net = Net() + x = Tensor(np.arange(4)) + y = Tensor(np.arange(5)) + assert net(x, y) == (True,) * 12 + (False,) * 4 + + +def test_isinstance_not_supported(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = (11, 22, 33, 44) + + def construct(self): + return isinstance(self.value, None) + + net = Net() + with pytest.raises(TypeError) as err: + net() + assert "The type 'None' is not supported for 'isinstance'" in str(err.value)