From: @zhangbuxue Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -17,14 +17,14 @@ | |||
| """Resources for ast tree parse.""" | |||
| import ast | |||
| import math | |||
| from mindspore import RowTensor, SparseTensor | |||
| from mindspore.ops.composite import multitype_ops | |||
| from mindspore.ops import functional as F, composite as C | |||
| from mindspore.ops.composite import multitype_ops | |||
| from . import standard_method as M | |||
| from . import trope as T | |||
| from .namespace import CellNamespace | |||
| # namespace define | |||
| functional_ns = CellNamespace('mindspore.ops.functional') | |||
| composite_ns = CellNamespace('mindspore.ops.composite') | |||
| @@ -109,7 +109,7 @@ convert_object_map = { | |||
| # system function | |||
| T.len: M.ms_len, | |||
| T.bool: M.bool_, | |||
| T.bool_: M.bool_, | |||
| T.map: C.Map(), | |||
| T.partial: F.partial, | |||
| T.zip: C.zip_operation, | |||
| @@ -16,13 +16,15 @@ | |||
| # ============================================================================ | |||
| """standard_method""" | |||
| from dataclasses import dataclass | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore import Tensor | |||
| from mindspore import dtype as mstype | |||
| from ...ops import functional as F | |||
| from ...ops import operations as P | |||
| from ...ops.primitive import constexpr | |||
| from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \ | |||
| zeros_like, ones_like | |||
| from ...ops.composite.base import _append | |||
| from ...ops.primitive import constexpr | |||
| __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like'] | |||
| @@ -219,9 +221,23 @@ def while_cond(x): | |||
| @constexpr | |||
| def check_type_same(x_type, base_type): | |||
| """Check x_type is same as base_type.""" | |||
| if mstype.issubclass_(x_type, base_type): | |||
| return True | |||
| return False | |||
| pytype_to_mstype = { | |||
| bool: mstype.Bool, | |||
| int: mstype.Int, | |||
| float: mstype.Float, | |||
| str: mstype.String, | |||
| list: mstype.List, | |||
| tuple: mstype.Tuple, | |||
| Tensor: mstype.tensor_type | |||
| } | |||
| try: | |||
| if isinstance(base_type, (tuple, list)): | |||
| target_type = tuple(pytype_to_mstype[i] for i in base_type) | |||
| else: | |||
| target_type = pytype_to_mstype[base_type] | |||
| return isinstance(x_type, target_type) | |||
| except KeyError: | |||
| raise TypeError(f"The type '{base_type}' is not supported for 'isinstance'") | |||
| @constexpr | |||
| @@ -235,7 +251,7 @@ def check_is_tensor(x): | |||
| @constexpr | |||
| def check_is_tuple_or_list_or_tensor(x, op_name, arg_name): | |||
| """check whether x is list or tuple or tensor.""" | |||
| if isinstance(x, (mstype.list_type, mstype.tuple_type, mstype.tensor_type)): | |||
| if isinstance(x, (mstype.List, mstype.Tuple, mstype.tensor_type)): | |||
| return True | |||
| raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.") | |||
| @@ -95,3 +95,7 @@ def not_contains(x): # pragma: no cover | |||
| def while_cond(x): # pragma: no cover | |||
| """Not in function.""" | |||
| raise RuntimeError('This operation is not meant to be called directly.') | |||
| def bool_(x): # pragma: no cover | |||
| """judge true function.""" | |||
| raise RuntimeError('This operation is not meant to be called directly.') | |||
| @@ -116,7 +116,7 @@ const char NAMED_PRIMITIVE_NEXT[] = "next"; | |||
| const char NAMED_PRIMITIVE_GETITEM[] = "getitem"; | |||
| const char NAMED_PRIMITIVE_SETITEM[] = "setitem"; | |||
| const char NAMED_PRIMITIVE_HASNEXT[] = "hasnext"; | |||
| const char NAMED_PRIMITIVE_BOOL[] = "bool"; // bool: P.identity | |||
| const char NAMED_PRIMITIVE_BOOL[] = "bool_"; // bool: P.identity | |||
| const char NAMED_PRIMITIVE_MAKETUPLE[] = "make_tuple"; | |||
| const char NAMED_PRIMITIVE_MAKELIST[] = "make_list"; | |||
| const char NAMED_PRIMITIVE_MAKESLICE[] = "make_slice"; | |||
| @@ -109,6 +109,7 @@ class ClassType : public PyObjectWrapper { | |||
| MS_DECLARE_PARENT(ClassType, PyObjectWrapper); | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| }; | |||
| using ClassTypePtr = std::shared_ptr<ClassType>; | |||
| // SymbolResolver class for resolving symbol extracted from AnfNode. | |||
| class SymbolResolver { | |||
| @@ -280,24 +280,20 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) { | |||
| py::tuple max_shape_tuple(len); | |||
| auto dic = py::dict(); | |||
| bool dyn_shape = false; | |||
| bool is_build_value = true; | |||
| bool dyn_value = false; | |||
| for (size_t i = 0; i < len; i++) { | |||
| auto arg = arg_tuple->elements()[i]; | |||
| py::dict out = ConvertAbstractToPython(arg); | |||
| shape_tuple[i] = out[ATTR_SHAPE]; | |||
| dtype_tuple[i] = out[ATTR_DTYPE]; | |||
| value_tuple[i] = out[ATTR_VALUE]; | |||
| // Elements in tuple is tensor shape value. | |||
| if (out.contains(py::str(ATTR_MIN_VALUE)) && out.contains(py::str(ATTR_MAX_VALUE))) { | |||
| value_tuple[i] = out[ATTR_VALUE]; | |||
| min_value_tuple[i] = out[ATTR_MIN_VALUE]; | |||
| max_value_tuple[i] = out[ATTR_MAX_VALUE]; | |||
| is_build_value = false; | |||
| } else { | |||
| value_tuple[i] = BuildValue(arg->BuildValue()); | |||
| min_value_tuple[i] = value_tuple[i]; | |||
| max_value_tuple[i] = value_tuple[i]; | |||
| dyn_value = true; | |||
| } | |||
| // Elements in tuple is tensor, which shape is dynamic. | |||
| @@ -305,21 +301,21 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) { | |||
| min_shape_tuple[i] = out[ATTR_MIN_SHAPE]; | |||
| max_shape_tuple[i] = out[ATTR_MAX_SHAPE]; | |||
| dyn_shape = true; | |||
| } else { | |||
| min_shape_tuple[i] = out[ATTR_SHAPE]; | |||
| max_shape_tuple[i] = out[ATTR_SHAPE]; | |||
| } | |||
| } | |||
| dic[ATTR_SHAPE] = shape_tuple; | |||
| dic[ATTR_DTYPE] = dtype_tuple; | |||
| if (is_build_value) { | |||
| dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue()); | |||
| if (arg_tuple->BuildValue()->isa<AnyValue>()) { | |||
| dic[ATTR_VALUE] = py::none(); | |||
| } else { | |||
| dic[ATTR_VALUE] = value_tuple; | |||
| } | |||
| if (dyn_value) { | |||
| dic[ATTR_MIN_VALUE] = min_value_tuple; | |||
| dic[ATTR_MAX_VALUE] = max_value_tuple; | |||
| } | |||
| if (dyn_shape) { | |||
| dic[ATTR_MIN_SHAPE] = min_shape_tuple; | |||
| dic[ATTR_MAX_SHAPE] = max_shape_tuple; | |||
| @@ -333,6 +329,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) { | |||
| size_t len = arg_list->size(); | |||
| py::list shape_list(len); | |||
| py::list dtype_list(len); | |||
| py::list value_list(len); | |||
| py::list min_shape_list(len); | |||
| py::list max_shape_list(len); | |||
| auto dic = py::dict(); | |||
| @@ -342,27 +339,29 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) { | |||
| py::dict out = ConvertAbstractToPython(arg_list->elements()[i]); | |||
| shape_list[i] = out[ATTR_SHAPE]; | |||
| dtype_list[i] = out[ATTR_DTYPE]; | |||
| value_list[i] = out[ATTR_VALUE]; | |||
| // Elements in list is tensor, which shape is dynamic. | |||
| if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) { | |||
| min_shape_list[i] = out[ATTR_MIN_SHAPE]; | |||
| max_shape_list[i] = out[ATTR_MAX_SHAPE]; | |||
| dyn_shape = true; | |||
| } else { | |||
| min_shape_list[i] = out[ATTR_SHAPE]; | |||
| max_shape_list[i] = out[ATTR_SHAPE]; | |||
| } | |||
| } | |||
| dic[ATTR_SHAPE] = shape_list; | |||
| dic[ATTR_DTYPE] = dtype_list; | |||
| if (arg_list->BuildValue()->isa<AnyValue>()) { | |||
| dic[ATTR_VALUE] = py::none(); | |||
| } else { | |||
| dic[ATTR_VALUE] = value_list; | |||
| } | |||
| if (dyn_shape) { | |||
| dic[ATTR_MIN_SHAPE] = min_shape_list; | |||
| dic[ATTR_MAX_SHAPE] = max_shape_list; | |||
| } | |||
| dic[ATTR_SHAPE] = shape_list; | |||
| dic[ATTR_DTYPE] = dtype_list; | |||
| dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue()); | |||
| return dic; | |||
| } | |||
| } // end anonymous namespace | |||
| @@ -428,6 +427,16 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||
| dic[ATTR_SHAPE] = py::none(); | |||
| dic[ATTR_DTYPE] = abs_base->BuildType(); | |||
| dic[ATTR_VALUE] = py::none(); | |||
| if (abs_base->isa<PartialAbstractClosure>()) { | |||
| AbstractBasePtrList args = abs_base->cast<PartialAbstractClosurePtr>()->args(); | |||
| if (!args.empty()) { | |||
| auto value = args[0]->BuildValue()->cast<parse::ClassTypePtr>(); | |||
| if (value != nullptr) { | |||
| dic[ATTR_DTYPE] = std::make_shared<TypeType>(); | |||
| dic[ATTR_VALUE] = value->obj(); | |||
| } | |||
| } | |||
| } | |||
| } else if (abs_base->isa<AbstractUndetermined>()) { | |||
| auto arg = dyn_cast<AbstractUndetermined>(abs_base); | |||
| dic[ATTR_SHAPE] = py::none(); | |||
| @@ -390,8 +390,8 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py | |||
| } | |||
| return PyList2DynamicShapeTensor(shape_obj, type_obj, output); | |||
| } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) { | |||
| py::tuple shape_tuple = shape_obj.cast<py::tuple>(); | |||
| py::tuple typeid_tuple = type_obj.cast<py::tuple>(); | |||
| auto shape_tuple = shape_obj.cast<py::tuple>(); | |||
| auto typeid_tuple = type_obj.cast<py::tuple>(); | |||
| AbstractBasePtrList ptr_list; | |||
| for (size_t it = 0; it < shape_tuple.size(); ++it) { | |||
| auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]); | |||
| @@ -400,8 +400,8 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py | |||
| auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list); | |||
| return tuple; | |||
| } else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) { | |||
| py::list shape_list = shape_obj.cast<py::list>(); | |||
| py::list typeid_list = type_obj.cast<py::list>(); | |||
| auto shape_list = shape_obj.cast<py::list>(); | |||
| auto typeid_list = type_obj.cast<py::list>(); | |||
| AbstractBasePtrList ptr_list; | |||
| for (size_t it = 0; it < shape_list.size(); ++it) { | |||
| auto tensor_it = PyListDtype2AbstractTensor(shape_list[it], typeid_list[it]); | |||
| @@ -78,36 +78,39 @@ single = float32 | |||
| float64 = typing.Float(64) | |||
| double = float64 | |||
| number = typing.Number() | |||
| int_ = typing.Int() | |||
| uint = typing.UInt() | |||
| float_ = typing.Float() | |||
| number = typing.Number() | |||
| string = typing.String() | |||
| list_ = typing.List() | |||
| tuple_ = typing.Tuple() | |||
| type_none = typing.TypeNone() | |||
| tensor = typing.TensorType() | |||
| index_slices = typing.RowTensorType() | |||
| sparse_tensor = typing.SparseTensorType() | |||
| undetermined = typing.UndeterminedType() | |||
| function = typing.Function() | |||
| function_type = typing.Function | |||
| symbolic_key = typing.SymbolicKeyType() | |||
| env_type = typing.EnvType() | |||
| env_type_type = typing.EnvType | |||
| type_type = typing.TypeType() | |||
| type_none = typing.TypeNone() | |||
| type_bool = typing.Bool() | |||
| string = typing.String() | |||
| type_refkey = typing.RefKeyType() | |||
| tensor_type = typing.TensorType | |||
| anything_type = typing.TypeAnything | |||
| slice_type = typing.Slice | |||
| ellipsis_type = typing.TypeEllipsis | |||
| list_type = typing.List | |||
| tuple_type = typing.Tuple | |||
| index_slices = typing.RowTensorType() | |||
| sparse_tensor = typing.SparseTensorType() | |||
| undetermined = typing.UndeterminedType() | |||
| Int = typing.Int | |||
| bool_type = typing.Bool | |||
| Float = typing.Float | |||
| Bool = typing.Bool | |||
| String = typing.String | |||
| List = typing.List | |||
| Tuple = typing.Tuple | |||
| Slice = typing.Slice | |||
| function_type = typing.Function | |||
| Ellipsis_ = typing.TypeEllipsis | |||
| none_type = typing.TypeNone | |||
| env_type_type = typing.EnvType | |||
| tensor_type = typing.TensorType | |||
| anything_type = typing.TypeAnything | |||
| number_type = (int8, | |||
| int16, | |||
| @@ -86,6 +86,8 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom { | |||
| std::string ToString() const override { return "Prim: " + prim_->name(); } | |||
| ValuePtr RealBuildValue() const override { return prim_; } | |||
| private: | |||
| PrimitivePtr prim_; | |||
| // store it as weak_ptr to break reference cycle. | |||
| @@ -183,6 +185,7 @@ class PartialAbstractClosure : public AbstractFuncAtom { | |||
| AbstractFunctionPtr fn() { return fn_; } | |||
| AbstractBasePtrList args() { return args_spec_list_; } | |||
| ValuePtr RealBuildValue() const override { return fn_->BuildValue(); } | |||
| AnfNodePtr node() { return node_.lock(); } | |||
| void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); } | |||
| AbstractFunctionPtr Copy() const override { | |||
| @@ -199,6 +202,7 @@ class PartialAbstractClosure : public AbstractFuncAtom { | |||
| // The CNode which this PartialAbstractClosure evaluated from. | |||
| AnfNodeWeakPtr node_; | |||
| }; | |||
| using PartialAbstractClosurePtr = std::shared_ptr<PartialAbstractClosure>; | |||
| class JTransformedAbstractClosure : public AbstractFuncAtom { | |||
| public: | |||
| @@ -339,13 +339,13 @@ def _cpu_not_support(name): | |||
| @constexpr | |||
| def _check_is_tuple(obj): | |||
| """Check whether obj is a tuple""" | |||
| return isinstance(obj, mstype.tuple_type) | |||
| return isinstance(obj, mstype.Tuple) | |||
| @constexpr | |||
| def _check_is_list(obj): | |||
| """Check whether obj is a list""" | |||
| return isinstance(obj, mstype.list_type) | |||
| return isinstance(obj, mstype.List) | |||
| @constexpr | |||
| @@ -148,7 +148,7 @@ def _expand_data_dims_with_bool(data, tuple_index, op_name): | |||
| bool_positions, tuple_index_without_bool = (), () | |||
| for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)): | |||
| bool_type_tag = const_utils.judge_index_type(index_type, mstype.type_bool) | |||
| bool_type_tag = const_utils.judge_index_type(index_type, mstype.bool_) | |||
| if bool_type_tag: | |||
| if index: | |||
| tuple_index_without_bool += (const_utils.make_tensor([0], mstype.int64),) | |||
| @@ -653,6 +653,6 @@ def tensor_in_sequence(x, y): | |||
| """Assigns whether a sequence contains the given tensor""" | |||
| result = const_utils.scalar_to_tensor(False) | |||
| for i in y: | |||
| if isinstance(i, mstype.tensor) and x.shape == i.shape and x.dtype == i.dtype: | |||
| if isinstance(i, Tensor) and x.shape == i.shape and x.dtype == i.dtype: | |||
| result = F.logical_or(F.equal(x, i).all(), result) | |||
| return result | |||
| @@ -171,15 +171,15 @@ def get_pos_of_indexes_types(indexes_types, op_name): | |||
| slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \ | |||
| sequence_positions = [], [], [], [], [], [], [] | |||
| for i, index_type in enumerate(indexes_types): | |||
| if isinstance(index_type, mstype.slice_type): | |||
| if isinstance(index_type, mstype.Slice): | |||
| slice_positions.append(i) | |||
| elif isinstance(index_type, mstype.ellipsis_type): | |||
| elif isinstance(index_type, mstype.Ellipsis_): | |||
| ellipsis_positions.append(i) | |||
| elif isinstance(index_type, mstype.none_type): | |||
| none_positions.append(i) | |||
| elif isinstance(index_type, mstype.Int): | |||
| int_positions.append(i) | |||
| elif isinstance(index_type, mstype.bool_type): | |||
| elif isinstance(index_type, mstype.Bool): | |||
| bool_positions.append(i) | |||
| elif isinstance(index_type, mstype.tensor_type): | |||
| tensor_positions.append(i) | |||
| @@ -341,7 +341,7 @@ def tuple_index_int_cnt(types, op_name): | |||
| def tuple_index_type_cnt(types, op_name): | |||
| """count the tensor type of types which contains the tuple elements' type.""" | |||
| tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types) | |||
| basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.ellipsis_type, mstype.slice_type)) for ele in types) | |||
| basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types) | |||
| if tensor_cnt == len(types): | |||
| return ALL_TENSOR | |||
| if basic_cnt == len(types): | |||
| @@ -614,7 +614,7 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, t | |||
| indexes_info[pos] = tensor_indexes_shapes[tensor_count] | |||
| index_tensors_info[pos] = tensor_indexes_shapes[tensor_count] | |||
| tensor_count += 1 | |||
| elif isinstance(index_type, mstype.slice_type): | |||
| elif isinstance(index_type, mstype.Slice): | |||
| slice_obj = slice(slice_indexes[slice_count].start, | |||
| slice_indexes[slice_count].stop, | |||
| slice_indexes[slice_count].step) | |||
| @@ -680,7 +680,7 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te | |||
| return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info) | |||
| @ constexpr | |||
| @constexpr | |||
| def scalar_in_sequence(x, y): | |||
| """Determine whether the scalar in the sequence.""" | |||
| if x is None: | |||
| @@ -63,7 +63,7 @@ TEST_F(TestData, test_build_value) { | |||
| // BuildValue(AbstractFunction) should return kAnyValue. | |||
| AbstractBasePtr abs_f1 = FromValue(prim::kPrimReturn, false); | |||
| ValuePtr abs_f1_built = abs_f1->BuildValue(); | |||
| ASSERT_EQ(abs_f1_built, kAnyValue); | |||
| ASSERT_EQ(abs_f1_built, prim::kPrimReturn); | |||
| FuncGraphPtr fg1 = std::make_shared<FuncGraph>(); | |||
| AbstractBasePtr abs_fg1 = FromValue(fg1, false); | |||
| @@ -74,17 +74,20 @@ TEST_F(TestData, test_build_value) { | |||
| AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false); | |||
| AbstractBasePtr abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_f1, abs_f2})); | |||
| ValuePtr func_tuple_built = abs_func_tuple->BuildValue(); | |||
| ASSERT_EQ(func_tuple_built, kAnyValue); | |||
| ASSERT_EQ(*func_tuple_built, | |||
| ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd})); | |||
| // BuildValue(List(AbstractFunction)) should return kAnyValue; | |||
| AbstractBasePtr abs_func_list = std::make_shared<AbstractList>(AbstractBasePtrList({abs_f1, abs_f2})); | |||
| ValuePtr func_list_built = abs_func_list->BuildValue(); | |||
| ASSERT_EQ(func_list_built, kAnyValue); | |||
| ASSERT_EQ(*func_list_built, | |||
| ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd})); | |||
| // BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue | |||
| abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({base1, abs_f2})); | |||
| func_tuple_built = abs_func_tuple->BuildValue(); | |||
| ASSERT_EQ(func_tuple_built, kAnyValue); | |||
| ASSERT_EQ(*func_tuple_built, | |||
| ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd})); | |||
| } | |||
| TEST_F(TestData, test_build_type) { | |||
| @@ -0,0 +1,78 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test instance""" | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| def test_isinstance(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.int_member = 1 | |||
| self.float_member = 1.0 | |||
| self.bool_member = True | |||
| self.string_member = "abcd" | |||
| self.tensor_member = Tensor(np.arange(4)) | |||
| self.tuple_member = (1, 1.0, True, "abcd", self.tensor_member) | |||
| self.list_member = list(self.tuple_member) | |||
| def construct(self, x, y): | |||
| is_int = isinstance(self.int_member, int) | |||
| is_float = isinstance(self.float_member, float) | |||
| is_bool = isinstance(self.bool_member, bool) | |||
| is_string = isinstance(self.string_member, str) | |||
| is_tensor_const = isinstance(self.tensor_member, Tensor) | |||
| is_tensor_var = isinstance(x, Tensor) | |||
| is_tuple_const = isinstance(self.tuple_member, tuple) | |||
| is_tuple_var = isinstance((x, 1, 1.0, y), tuple) | |||
| is_list_const = isinstance(self.list_member, list) | |||
| is_list_var = isinstance([x, 1, 1.0, y], list) | |||
| is_list_or_tensor = isinstance([x, y], (Tensor, list)) | |||
| is_int_or_float_or_tensor_or_tuple = isinstance(x, (Tensor, tuple, int, float)) | |||
| float_is_int = isinstance(self.float_member, int) | |||
| bool_is_string = isinstance(self.bool_member, str) | |||
| tensor_is_tuple = isinstance(x, tuple) | |||
| tuple_is_list = isinstance(self.tuple_member, list) | |||
| return is_int, is_float, is_bool, is_string, is_tensor_const, is_tensor_var, \ | |||
| is_tuple_const, is_tuple_var, is_list_const, is_list_var, \ | |||
| is_int_or_float_or_tensor_or_tuple, is_list_or_tensor, \ | |||
| float_is_int, bool_is_string, tensor_is_tuple, tuple_is_list | |||
| net = Net() | |||
| x = Tensor(np.arange(4)) | |||
| y = Tensor(np.arange(5)) | |||
| assert net(x, y) == (True,) * 12 + (False,) * 4 | |||
| def test_isinstance_not_supported(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.value = (11, 22, 33, 44) | |||
| def construct(self): | |||
| return isinstance(self.value, None) | |||
| net = Net() | |||
| with pytest.raises(TypeError) as err: | |||
| net() | |||
| assert "The type 'None' is not supported for 'isinstance'" in str(err.value) | |||