Merge pull request !1407 from zhangbuxue/support_mixed_tensor_for_tensor_get_item_and_tensor_set_itemtags/v0.5.0-beta
| @@ -105,7 +105,7 @@ convert_object_map = { | |||||
| T.ge: multitype_ops.greater_equal, | T.ge: multitype_ops.greater_equal, | ||||
| T.is_: F.is_, | T.is_: F.is_, | ||||
| T.is_not: F.is_not, | T.is_not: F.is_not, | ||||
| T.contains: F.in_dict, | |||||
| T.contains: multitype_ops.in_, | |||||
| T.not_contains: F.not_in_dict, | T.not_contains: F.not_in_dict, | ||||
| # system function | # system function | ||||
| @@ -474,6 +474,8 @@ REGISTER_PYBIND_DEFINE( | |||||
| (void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init()); | (void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init()); | ||||
| (void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); | (void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); | ||||
| (void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init()); | (void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init()); | ||||
| (void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init()); | |||||
| (void)py::class_<Ellipsis, Type, std::shared_ptr<Ellipsis>>(m_sub, "Ellipsis").def(py::init()); | |||||
| })); | })); | ||||
| const TypePtr kTypeExternal = std::make_shared<External>(); | const TypePtr kTypeExternal = std::make_shared<External>(); | ||||
| @@ -95,6 +95,8 @@ string = typing.String() | |||||
| type_refkey = typing.RefKeyType() | type_refkey = typing.RefKeyType() | ||||
| tensor_type = typing.TensorType | tensor_type = typing.TensorType | ||||
| anything_type = typing.TypeAnything | anything_type = typing.TypeAnything | ||||
| slice_type = typing.Slice | |||||
| ellipsis_type = typing.Ellipsis | |||||
| number_type = (int8, | number_type = (int8, | ||||
| int16, | int16, | ||||
| @@ -37,6 +37,7 @@ from .logical_and_impl import logical_and | |||||
| from .logical_or_impl import logical_or | from .logical_or_impl import logical_or | ||||
| from .logic_not_impl import logical_not | from .logic_not_impl import logical_not | ||||
| from .uadd_impl import uadd | from .uadd_impl import uadd | ||||
| from .in_impl import in_ | |||||
| __all__ = [ | __all__ = [ | ||||
| 'add', | 'add', | ||||
| 'sub', | 'sub', | ||||
| @@ -59,5 +60,6 @@ __all__ = [ | |||||
| 'setitem', | 'setitem', | ||||
| 'logical_and', | 'logical_and', | ||||
| 'logical_or', | 'logical_or', | ||||
| 'logical_not' | |||||
| 'logical_not', | |||||
| 'in_' | |||||
| ] | ] | ||||
| @@ -0,0 +1,154 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """constexpr util""" | |||||
| from . import _constexpr_utils as const_utils | |||||
| from ... import functional as F | |||||
| from ... import operations as P | |||||
| from ...composite import base | |||||
| from ....common import dtype as mstype | |||||
| hyper_map = base.HyperMap() | |||||
| pack = P.Pack(axis=-1) | |||||
| def broadcast(broadcast_shape, x): | |||||
| """Broadcast tensor to the required shape.""" | |||||
| if F.shape(x) == broadcast_shape: | |||||
| return x | |||||
| multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape) | |||||
| if multiples: | |||||
| return F.tile(x, multiples) | |||||
| return x | |||||
| def transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x): | |||||
| """Transform indexing tensor to the required.""" | |||||
| x = broadcast(broadcast_shape, x) | |||||
| return broadcast(final_shape, F.reshape(x, new_shape)) | |||||
| def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): | |||||
| """Generate an indices tensor from a tuple of tensor.""" | |||||
| indices = None | |||||
| check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) | |||||
| if check_index_tensor_number: | |||||
| dtype_tuple = hyper_map(F.dtype, tuple_index) | |||||
| check_dtypes = const_utils.check_index_tensors_dtype(dtype_tuple, op_name) | |||||
| if check_dtypes: | |||||
| shape_tuple = hyper_map(F.shape, tuple_index) | |||||
| broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name) | |||||
| broadcast_tensors = hyper_map(F.partial(broadcast, broadcast_shape), tuple_index) | |||||
| indices = pack(broadcast_tensors) | |||||
| return indices | |||||
| def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): | |||||
| """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" | |||||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||||
| int_positions = const_utils.get_pos_of_int_index(indexes_types) | |||||
| for i in int_positions: | |||||
| tuple_index = F.tuple_setitem(tuple_index, i, F.scalar_to_tensor(tuple_index[i], mstype.int32)) | |||||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||||
| tensor_positions, slice_positions, ellipsis_position = \ | |||||
| const_utils.separate_mixed_tensors_index(indexes_types, op_name) | |||||
| tensor_indexes = [] | |||||
| slice_indexes = [] | |||||
| for i in tensor_positions: | |||||
| tensor_indexes.append(tuple_index[i]) | |||||
| for j in slice_positions: | |||||
| slice_indexes.append(tuple_index[j]) | |||||
| data_shape = F.shape(data) | |||||
| tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) | |||||
| tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) | |||||
| broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \ | |||||
| const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape, | |||||
| indexes_types, | |||||
| tensor_indexes_shapes, | |||||
| tensor_indexes_dtypes, | |||||
| slice_indexes, | |||||
| op_name) | |||||
| slice_number = 0 | |||||
| final_index_tensors = [] | |||||
| tuple_index_size = len(tuple_index) | |||||
| index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) | |||||
| for i in range(tuple_index_size): | |||||
| if i in tensor_positions: | |||||
| transform_tensor = transform_indexing_tensor(broadcast_shape, | |||||
| final_shape, | |||||
| index_tensor_new_shape, | |||||
| tuple_index[i]) | |||||
| final_index_tensors.append(transform_tensor) | |||||
| if i in slice_positions: | |||||
| slice_tensor = const_utils.convert_slice_to_tensor(slice_number, | |||||
| final_shape, | |||||
| indexes_shapes_info, | |||||
| op_name) | |||||
| final_index_tensors.append(slice_tensor) | |||||
| slice_number += 1 | |||||
| if i == ellipsis_position: | |||||
| ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number, | |||||
| ellipsis_occupied_dims, | |||||
| final_shape, | |||||
| indexes_shapes_info, | |||||
| op_name) | |||||
| for ele in ellipsis_tensors: | |||||
| final_index_tensors.append(ele) | |||||
| slice_number += ellipsis_occupied_dims | |||||
| indices = pack(final_index_tensors) | |||||
| return indices | |||||
| def generate_updates_from_scalar(data, indices, value, op_type): | |||||
| """Generate an updates tensor from a scalar.""" | |||||
| data_shape = F.shape(data) | |||||
| indices_shape = F.shape(indices) | |||||
| data_dtype = F.dtype(data) | |||||
| return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) | |||||
| def generate_updates_from_tuple(data, index, value, op_type): | |||||
| """Generate an updates tensor from a tuple.""" | |||||
| value_types = hyper_map(F.typeof, value) | |||||
| data_dtype = F.dtype(data) | |||||
| value_elements_type = const_utils.check_value_elements(data_dtype, value_types) | |||||
| if value_elements_type == const_utils.ALL_TENSOR: | |||||
| value_shapes = hyper_map(F.shape, value) | |||||
| shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM) | |||||
| if shapes_same: | |||||
| value = F.pack(value) | |||||
| return generate_updates_from_tensor(data, index, value, op_type) | |||||
| data_shape = F.shape(data) | |||||
| index_shape = F.shape(index) | |||||
| return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) | |||||
| def generate_updates_from_tensor(data, index, value, op_type): | |||||
| """Generate an updates tensor from a tensor.""" | |||||
| data_shape = F.shape(data) | |||||
| index_shape = F.shape(index) | |||||
| value_shape = F.shape(value) | |||||
| data_dtype = F.dtype(data) | |||||
| value_dtype = F.dtype(value) | |||||
| updates_shape = value_shape | |||||
| check_dtype_same = const_utils.check_tensors_dtype_same(data_dtype, value_dtype, const_utils.TENSOR_SETITEM) | |||||
| if check_dtype_same: | |||||
| updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type) | |||||
| need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape) | |||||
| if need_broadcast: | |||||
| return broadcast(updates_shape, value) | |||||
| return value | |||||
| @@ -15,19 +15,15 @@ | |||||
| """constexpr util""" | """constexpr util""" | ||||
| from functools import reduce | from functools import reduce | ||||
| import numpy as np | import numpy as np | ||||
| from ...primitive import constexpr | from ...primitive import constexpr | ||||
| from ....common.tensor import Tensor | |||||
| from ....common import dtype as mstype | |||||
| from .... import log as logger | |||||
| from ...._extends.utils import Slice, Ellipsis_ | from ...._extends.utils import Slice, Ellipsis_ | ||||
| from ....common import dtype as mstype | |||||
| from ....common.tensor import Tensor | |||||
| from ....ops import _utils as op_utils | from ....ops import _utils as op_utils | ||||
| from ...composite import base | |||||
| from .... import log as logger | |||||
| from ... import functional as F | |||||
| from ... import operations as P | |||||
| hyper_map = base.HyperMap() | |||||
| pack = P.Pack(axis=-1) | |||||
| ALL_TENSOR = 0 | ALL_TENSOR = 0 | ||||
| NO_TENSOR = 1 | NO_TENSOR = 1 | ||||
| @@ -264,7 +260,7 @@ def tuple_index_elements_type(types, op_name): | |||||
| return ALL_TENSOR | return ALL_TENSOR | ||||
| if tensors_number == 0: | if tensors_number == 0: | ||||
| return NO_TENSOR | return NO_TENSOR | ||||
| raise IndexError(f"For '{op_name}', the index does not support mixed tensor.") | |||||
| return CONTAIN_TENSOR | |||||
| @constexpr | @constexpr | ||||
| @@ -279,12 +275,12 @@ def check_value_elements(data_dtype, types): | |||||
| tensors_number += 1 | tensors_number += 1 | ||||
| else: | else: | ||||
| raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' " | raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' " | ||||
| f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.") | |||||
| f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.") | |||||
| elif mstype.issubclass_(ele, data_dtype): | elif mstype.issubclass_(ele, data_dtype): | ||||
| scalars_number += 1 | scalars_number += 1 | ||||
| else: | else: | ||||
| raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in " | raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in " | ||||
| f"value tuple is not consistent with origin tensor data type '{data_dtype}'.") | |||||
| f"value tuple is not consistent with assigned tensor data type '{data_dtype}'.") | |||||
| if tensors_number == len(types): | if tensors_number == len(types): | ||||
| return ALL_TENSOR | return ALL_TENSOR | ||||
| if scalars_number == len(types): | if scalars_number == len(types): | ||||
| @@ -299,51 +295,46 @@ def get_index_tensor_dtype(dtype): | |||||
| return INT_ | return INT_ | ||||
| if dtype == mstype.bool_: | if dtype == mstype.bool_: | ||||
| return BOOL_ | return BOOL_ | ||||
| raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") | |||||
| raise IndexError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") | |||||
| @constexpr | @constexpr | ||||
| def check_index_tensors_dtype(dtypes, op_name): | def check_index_tensors_dtype(dtypes, op_name): | ||||
| """Check a tuple of tensor data type.""" | """Check a tuple of tensor data type.""" | ||||
| if op_name == TENSOR_GETITEM: | |||||
| valid_dtypes = (mstype.int32, mstype.int64) | |||||
| elif op_name == TENSOR_SETITEM: | |||||
| valid_dtypes = (mstype.int32,) | |||||
| else: | |||||
| raise ValueError("Unsupported operation.") | |||||
| for ele in dtypes: | for ele in dtypes: | ||||
| if ele in valid_dtypes and ele == dtypes[0]: | |||||
| continue | |||||
| raise TypeError(f"For '{op_name}', the index tensors data type must be same, " | |||||
| f"and should be one of the following: {valid_dtypes}, but got {dtypes}.") | |||||
| if not ele == mstype.int32: | |||||
| raise IndexError(f"For '{op_name}', the all index tensor " | |||||
| f"data types should be mstype.int32, but got {dtypes}.") | |||||
| return True | return True | ||||
| @constexpr | @constexpr | ||||
| def check_tensor_dtype_valid(dtype, valid_dtypes): | |||||
| def check_index_tensor_dtype(dtype, op_name): | |||||
| """Check a tensor data type.""" | """Check a tensor data type.""" | ||||
| if dtype in valid_dtypes: | |||||
| if dtype == mstype.int32: | |||||
| return True | return True | ||||
| raise TypeError(f"The index tensor data type must be one of " | |||||
| f"the following: {valid_dtypes}, but got {dtype}.") | |||||
| raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.") | |||||
| @constexpr | @constexpr | ||||
| def check_tensors_dtype_same(x_dtype, y_dtype, op_name): | |||||
| def check_tensors_dtype_same(data_dtype, value_dtype, op_name): | |||||
| """Check tensors data type same.""" | """Check tensors data type same.""" | ||||
| if x_dtype == y_dtype: | |||||
| if value_dtype == data_dtype: | |||||
| return True | return True | ||||
| raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' " | |||||
| f"is not consistent with origin tensor data type {x_dtype}.") | |||||
| raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' " | |||||
| f"is not consistent with assigned tensor data type {data_dtype}.") | |||||
| @constexpr | @constexpr | ||||
| def broadcast_shapes(shapes, op_name): | |||||
| """Broadcasts a tuple of tensor.""" | |||||
| def generate_broadcast_shape(shapes, op_name): | |||||
| """Generate broadcast shape for a tuple of shape.""" | |||||
| broadcast_shape = shapes[0] | broadcast_shape = shapes[0] | ||||
| for i, shape in enumerate(shapes): | for i, shape in enumerate(shapes): | ||||
| logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") | logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") | ||||
| broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) | |||||
| try: | |||||
| broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) | |||||
| except ValueError as ex: | |||||
| raise IndexError(ex) | |||||
| return tuple(broadcast_shape) | return tuple(broadcast_shape) | ||||
| @@ -366,14 +357,82 @@ def check_two_shapes_need_broadcast(shape_x, shape_y): | |||||
| @constexpr | @constexpr | ||||
| def compute_multiples(origin_shape, broadcast_shape): | def compute_multiples(origin_shape, broadcast_shape): | ||||
| """Compute multiples between broadcast_shape with origin_shape.""" | |||||
| """Compute multiples between origin shape with broadcast shape.""" | |||||
| len_gap = len(broadcast_shape) - len(origin_shape) | len_gap = len(broadcast_shape) - len(origin_shape) | ||||
| return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) | return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) | ||||
| def tile(broadcast_shape, x): | |||||
| multiples = compute_multiples(F.shape(x), broadcast_shape) | |||||
| return F.tile(x, multiples) | |||||
| @constexpr | |||||
| def compute_new_shape(origin_shape, indexes_shapes_info): | |||||
| """Compute new shape between origin shape with final shape.""" | |||||
| new_shape = [] | |||||
| for i in indexes_shapes_info: | |||||
| if i == origin_shape: | |||||
| new_shape.extend(origin_shape) | |||||
| else: | |||||
| new_shape.append(1) | |||||
| return tuple(new_shape) | |||||
| @constexpr | |||||
| def convert_ellipsis_to_tensors(slice_number, | |||||
| ellipsis_occupied_dims, | |||||
| final_shape, | |||||
| indexes_shapes_info, | |||||
| op_name): | |||||
| """Convert an ellipsis to a list of tensor.""" | |||||
| tensor_list = [] | |||||
| dims_dealt_count = 0 | |||||
| while dims_dealt_count < ellipsis_occupied_dims: | |||||
| shape = [] | |||||
| slice_count = 0 | |||||
| array = None | |||||
| for ele in indexes_shapes_info: | |||||
| if isinstance(ele, list): | |||||
| if slice_count == slice_number: | |||||
| array = np.array(ele, np.int32) | |||||
| shape.append(len(ele)) | |||||
| else: | |||||
| shape.append(1) | |||||
| slice_count += 1 | |||||
| if isinstance(ele, tuple): | |||||
| shape.extend([1] * len(ele)) | |||||
| if array is None: | |||||
| raise ValueError(f"For '{op_name}', generate tensors from ellipsis failed.") | |||||
| array = np.reshape(array, shape) | |||||
| reps = compute_multiples(shape, final_shape) | |||||
| tensor = Tensor(np.tile(array, reps)) | |||||
| tensor_list.append(tensor) | |||||
| slice_number += 1 | |||||
| dims_dealt_count += 1 | |||||
| return tensor_list | |||||
| @constexpr | |||||
| def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name): | |||||
| """Convert a slice to a tensor.""" | |||||
| shape = [] | |||||
| count = 0 | |||||
| array = None | |||||
| for ele in indexes_shapes_info: | |||||
| if isinstance(ele, list): | |||||
| if count == slice_number: | |||||
| array = np.array(ele, np.int32) | |||||
| shape.append(len(ele)) | |||||
| else: | |||||
| # When the slice is not the slice looking for, the shape is filled with 1. | |||||
| shape.append(1) | |||||
| count += 1 | |||||
| elif isinstance(ele, tuple): | |||||
| shape.extend([1] * len(ele)) | |||||
| else: | |||||
| shape.append(1) | |||||
| if array is None: | |||||
| raise ValueError(f"For '{op_name}', generate tensor from 'slice' failed.") | |||||
| array = np.reshape(array, shape) | |||||
| reps = compute_multiples(shape, final_shape) | |||||
| tensor = Tensor(np.tile(array, reps)) | |||||
| return tensor | |||||
| @constexpr | @constexpr | ||||
| @@ -381,8 +440,8 @@ def check_shapes_same(value_shapes, op_name): | |||||
| """Check if the shapes in the tuple are consistent.""" | """Check if the shapes in the tuple are consistent.""" | ||||
| for i, shape in enumerate(value_shapes): | for i, shape in enumerate(value_shapes): | ||||
| if shape != value_shapes[0]: | if shape != value_shapes[0]: | ||||
| raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple " | |||||
| f"is not same as the first tensor shape.") | |||||
| raise ValueError(f"For '{op_name}', the {i}th tensor shape in " | |||||
| f"value tuple is not same as the first tensor shape.") | |||||
| return True | return True | ||||
| @@ -396,7 +455,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty | |||||
| if isinstance(value, mstype.dtype_to_pytype(data_dtype)): | if isinstance(value, mstype.dtype_to_pytype(data_dtype)): | ||||
| return Tensor(np.full(updates_shape, value), dtype=data_dtype) | return Tensor(np.full(updates_shape, value), dtype=data_dtype) | ||||
| raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'" | raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'" | ||||
| f" is not consistent with tensor data type {data_dtype}.") | |||||
| f" is not consistent with the assigned tensor data type {data_dtype}.") | |||||
| @constexpr | @constexpr | ||||
| @@ -404,8 +463,8 @@ def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value | |||||
| """Convert a tuple of scalar to a tensor.""" | """Convert a tuple of scalar to a tensor.""" | ||||
| updates_shape = generate_updates_shape(data_shape, index_shape, op_type) | updates_shape = generate_updates_shape(data_shape, index_shape, op_type) | ||||
| if len(value) != updates_shape[-1]: | if len(value) != updates_shape[-1]: | ||||
| raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple " | |||||
| f"does not meet the requirements: {updates_shape[-1]}.") | |||||
| raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} " | |||||
| f"in the updates tuple does not meet the requirements: {updates_shape[-1]}.") | |||||
| array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype)) | array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype)) | ||||
| reps = compute_multiples(updates_shape[-1:], updates_shape) | reps = compute_multiples(updates_shape[-1:], updates_shape) | ||||
| return Tensor(np.tile(array, reps)) | return Tensor(np.tile(array, reps)) | ||||
| @@ -430,58 +489,145 @@ def check_number_of_index_tensor(data_shape, tuple_len, op_name): | |||||
| f"is greater than the dimension {len(data_shape)} of the operated tensor.") | f"is greater than the dimension {len(data_shape)} of the operated tensor.") | ||||
| def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name): | |||||
| """Generate an indices tensor from a tuple of tensor.""" | |||||
| indices = None | |||||
| check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) | |||||
| if check_index_tensor_number: | |||||
| dtype_tuple = hyper_map(F.dtype, tuple_index) | |||||
| check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name) | |||||
| if check_dtypes: | |||||
| shape_tuple = hyper_map(F.shape, tuple_index) | |||||
| broadcast_shape = broadcast_shapes(shape_tuple, op_name) | |||||
| broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index) | |||||
| indices = pack(broadcast_tensors) | |||||
| return indices | |||||
| def generate_updates_from_scalar(data, indices, value, op_type): | |||||
| """Generate an updates tensor from a scalar.""" | |||||
| data_shape = F.shape(data) | |||||
| indices_shape = F.shape(indices) | |||||
| data_dtype = F.dtype(data) | |||||
| return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) | |||||
| def generate_updates_from_tuple(data, index, value, op_type): | |||||
| """Generate an updates tensor from a tuple.""" | |||||
| value_types = hyper_map(F.typeof, value) | |||||
| data_dtype = F.dtype(data) | |||||
| value_elements_type = check_value_elements(data_dtype, value_types) | |||||
| if value_elements_type == ALL_TENSOR: | |||||
| value_shapes = hyper_map(F.shape, value) | |||||
| shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM) | |||||
| if shapes_same: | |||||
| value = F.pack(value) | |||||
| return generate_updates_from_tensor(data, index, value, op_type) | |||||
| data_shape = F.shape(data) | |||||
| index_shape = F.shape(index) | |||||
| return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) | |||||
| def generate_updates_from_tensor(data, index, value, op_type): | |||||
| """Generate an updates tensor from a tensor.""" | |||||
| data_shape = F.shape(data) | |||||
| index_shape = F.shape(index) | |||||
| value_shape = F.shape(value) | |||||
| data_dtype = F.dtype(data) | |||||
| value_dtype = F.dtype(value) | |||||
| updates_shape = value_shape | |||||
| check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM) | |||||
| if check_dtype_same: | |||||
| updates_shape = generate_updates_shape(data_shape, index_shape, op_type) | |||||
| need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape) | |||||
| if need_broadcast: | |||||
| return tile(updates_shape, value) | |||||
| return value | |||||
| @constexpr | |||||
| def generate_index_info_from_tuple_of_mixed_tensors(data_shape, | |||||
| indexes_types, | |||||
| tensor_indexes_shapes, | |||||
| tensor_indexes_dtypes, | |||||
| slice_indexes, | |||||
| op_name): | |||||
| """ | |||||
| Generate index info which contain broadcast shape, final shape, | |||||
| indexes shapes info, ellipsis size from a tuple of mixed tensors. | |||||
| """ | |||||
| check_index_tensors_dtype(tensor_indexes_dtypes, op_name) | |||||
| data_rank = len(data_shape) | |||||
| indexes_size = len(indexes_types) | |||||
| if indexes_size > data_rank: | |||||
| raise IndexError(f"For '{op_name}', the number {indexes_size} of index elements " | |||||
| f"is greater than the dimension {len(data_shape)} of the operated tensor.") | |||||
| indexes_info = {} | |||||
| index_tensors_info = {} | |||||
| ellipsis_num = 0 | |||||
| ellipsis_occupied_dims = 0 | |||||
| tensor_count = 0 | |||||
| slice_count = 0 | |||||
| for i, ele_type in enumerate(indexes_types): | |||||
| if ellipsis_num == 0: | |||||
| pos = i | |||||
| else: | |||||
| pos = i + ellipsis_occupied_dims - 1 | |||||
| if isinstance(ele_type, mstype.tensor_type): | |||||
| indexes_info[pos] = tensor_indexes_shapes[tensor_count] | |||||
| index_tensors_info[pos] = tensor_indexes_shapes[tensor_count] | |||||
| tensor_count += 1 | |||||
| elif isinstance(ele_type, mstype.slice_type): | |||||
| slice_obj = slice(slice_indexes[slice_count].start, | |||||
| slice_indexes[slice_count].end, | |||||
| slice_indexes[slice_count].step) | |||||
| # Use list to represent slicing result. | |||||
| indexes_info[pos] = list(range(data_shape[pos]))[slice_obj] | |||||
| slice_count += 1 | |||||
| elif isinstance(ele_type, mstype.ellipsis_type): | |||||
| if ellipsis_num != 0: | |||||
| raise IndexError(f"For '{op_name}', the index could only contain one ellipsis.") | |||||
| ellipsis_occupied_dims = data_rank - indexes_size + 1 | |||||
| for j in range(pos, pos + ellipsis_occupied_dims): | |||||
| # Use list to represent slicing result. | |||||
| indexes_info[j] = list(range(data_shape[j])) | |||||
| ellipsis_num += 1 | |||||
| else: | |||||
| raise IndexError(f"For '{op_name}', the index elements only support " | |||||
| f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.") | |||||
| broadcast_shape, final_shape, indexes_shapes_info = \ | |||||
| _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name) | |||||
| return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims | |||||
| def _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key: list): | |||||
| """Determine whether the tensor in the index appears continuously.""" | |||||
| for i in range(len(index_tensor_info_key) - 1): | |||||
| if index_tensor_info_key[i + 1] != index_tensor_info_key[i] + 1: | |||||
| return False | |||||
| return True | |||||
| def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name): | |||||
| """Derive the resulting shape information from the a tuple index of mixed tensors.""" | |||||
| index_tensor_info_key = list(index_tensors_info.keys()) | |||||
| index_tensor_info_value = list(index_tensors_info.values()) | |||||
| broadcast_shape = generate_broadcast_shape(index_tensor_info_value, op_name) | |||||
| final_shape = [] | |||||
| indexes_shapes_info = [] | |||||
| mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key) | |||||
| if mixed_tensors_continuous: | |||||
| tensor_shape_dealt = False | |||||
| for ele in indexes_info.values(): | |||||
| if isinstance(ele, list): | |||||
| final_shape.append(len(ele)) | |||||
| indexes_shapes_info.append(ele) | |||||
| elif isinstance(ele, tuple): | |||||
| if not tensor_shape_dealt: | |||||
| final_shape.extend(broadcast_shape) | |||||
| indexes_shapes_info.append(broadcast_shape) | |||||
| tensor_shape_dealt = True | |||||
| else: | |||||
| raise IndexError(f"For '{op_name}', the index elements only support " | |||||
| f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.") | |||||
| else: | |||||
| final_shape.extend(broadcast_shape) | |||||
| indexes_shapes_info.append(broadcast_shape) | |||||
| for ele in indexes_info.values(): | |||||
| if isinstance(ele, list): | |||||
| final_shape.append(len(ele)) | |||||
| indexes_shapes_info.append(ele) | |||||
| elif isinstance(ele, tuple): | |||||
| continue | |||||
| else: | |||||
| raise IndexError(f"For '{op_name}', the index elements only support " | |||||
| f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.") | |||||
| return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info) | |||||
| @constexpr | |||||
| def get_pos_of_int_index(indexes_types): | |||||
| """Get int index positions from the mixed tensors index which contains int, tensor, slice, and ellipsis.""" | |||||
| int_positions = [] | |||||
| for i, ele_type in enumerate(indexes_types): | |||||
| if ele_type == mstype.int32: | |||||
| int_positions.append(i) | |||||
| return int_positions | |||||
| @constexpr | |||||
| def separate_mixed_tensors_index(indexes_types, op_name): | |||||
| """Separate the position information of tensor and slice and ellipsis from the mixed tensors index.""" | |||||
| tensor_positions = [] | |||||
| slice_positions = [] | |||||
| ellipsis_position = None | |||||
| for i, ele_type in enumerate(indexes_types): | |||||
| if isinstance(ele_type, mstype.tensor_type): | |||||
| tensor_positions.append(i) | |||||
| elif isinstance(ele_type, mstype.slice_type): | |||||
| slice_positions.append(i) | |||||
| elif isinstance(ele_type, mstype.ellipsis_type): | |||||
| ellipsis_position = i | |||||
| else: | |||||
| raise IndexError(f"For '{op_name}', the index elements only support " | |||||
| f"'Tensor', 'int32', 'Slice', 'Ellipsis', but got {ele_type}.") | |||||
| return tensor_positions, slice_positions, ellipsis_position | |||||
| @constexpr | |||||
| def scalar_in_sequence(x, y): | |||||
| """Determine whether the scalar in the sequence.""" | |||||
| if x is None: | |||||
| raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, " | |||||
| "but the scalar is not.") | |||||
| if y is None: | |||||
| raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, " | |||||
| "but the sequence is not.") | |||||
| if x in y: | |||||
| return True | |||||
| return False | |||||
| @@ -14,11 +14,11 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Implementation for getitem.""" | """Implementation for getitem.""" | ||||
| from . import _utils as multi_utils | |||||
| from ..import base | |||||
| from . import _compile_utils as compile_utils | |||||
| from . import _constexpr_utils as const_utils | |||||
| from .. import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| from ....common import dtype as mstype | |||||
| getitem = base.MultitypeFuncGraph('getitem') | getitem = base.MultitypeFuncGraph('getitem') | ||||
| """ | """ | ||||
| @@ -227,7 +227,8 @@ def _tensor_getitem_by_tensor(data, tensor_index): | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type is same as the element type of data. | Tensor, element type is same as the element type of data. | ||||
| """ | """ | ||||
| check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64)) | |||||
| check_dtypes = const_utils.check_index_tensor_dtype(F.dtype(tensor_index), | |||||
| const_utils.TENSOR_GETITEM) | |||||
| result = None | result = None | ||||
| if check_dtypes: | if check_dtypes: | ||||
| result = F.gather(data, tensor_index, 0) | result = F.gather(data, tensor_index, 0) | ||||
| @@ -246,14 +247,13 @@ def _tensor_getitem_by_tuple(data, tuple_index): | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type is same as the element type of data. | Tensor, element type is same as the element type of data. | ||||
| """ | """ | ||||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_GETITEM) | |||||
| result = None | |||||
| if index_elements_type == multi_utils.NO_TENSOR: | |||||
| result = _tensor_slice(data, tuple_index) | |||||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||||
| result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index) | |||||
| return result | |||||
| indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) | |||||
| if index_elements_type == const_utils.NO_TENSOR: | |||||
| return _tensor_slice(data, tuple_index) | |||||
| if index_elements_type == const_utils.ALL_TENSOR: | |||||
| return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) | |||||
| return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) | |||||
| @getitem.register("Tensor", "Ellipsis") | @getitem.register("Tensor", "Ellipsis") | ||||
| @@ -273,6 +273,17 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index): | |||||
| def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): | def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): | ||||
| """Tensor getitem by a tuple of tensor.""" | """Tensor getitem by a tuple of tensor.""" | ||||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM) | |||||
| indices = compile_utils.generate_indices_from_tuple_of_tensor(data, | |||||
| tuple_index, | |||||
| const_utils.TENSOR_GETITEM) | |||||
| result = F.gather_nd(data, indices) | |||||
| return result | |||||
| def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): | |||||
| """Tensor getitem by a tuple of mixed tensor.""" | |||||
| indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, | |||||
| tuple_index, | |||||
| const_utils.TENSOR_GETITEM) | |||||
| result = F.gather_nd(data, indices) | result = F.gather_nd(data, indices) | ||||
| return result | return result | ||||
| @@ -0,0 +1,101 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """in_impl""" | |||||
| from . import _constexpr_utils as const_utils | |||||
| from ... import functional as F | |||||
| from ...composite import base | |||||
| in_ = base.MultitypeFuncGraph("in") | |||||
| """ | |||||
| in_ is a metafuncgraph object which will determine if a in b | |||||
| using ".register" decorator | |||||
| """ | |||||
| @in_.register("Number", "Tuple") | |||||
| def _number_in_tuple(x, y): | |||||
| """ | |||||
| Determine if a number in tuple. | |||||
| Args: | |||||
| x (Number): x | |||||
| y (tuple): y | |||||
| Returns: | |||||
| bool, if x in y return true, x not in y return false. | |||||
| """ | |||||
| return const_utils.scalar_in_sequence(x, y) | |||||
| @in_.register("Number", "List") | |||||
| def _number_in_list(x, y): | |||||
| """ | |||||
| Determine if a number in list. | |||||
| Args: | |||||
| x (Number): x | |||||
| y (list): y | |||||
| Returns: | |||||
| bool, if x in y return true, x not in y return false. | |||||
| """ | |||||
| return const_utils.scalar_in_sequence(x, y) | |||||
| @in_.register("String", "Tuple") | |||||
| def _string_in_tuple(x, y): | |||||
| """ | |||||
| Determine if a str in a tuple. | |||||
| Args: | |||||
| x (str): x | |||||
| y (tuple): y | |||||
| Returns: | |||||
| bool, if x in y return true, x not in y return false. | |||||
| """ | |||||
| return const_utils.scalar_in_sequence(x, y) | |||||
| @in_.register("String", "List") | |||||
| def _string_in_list(x, y): | |||||
| """ | |||||
| Determine if a str in a list. | |||||
| Args: | |||||
| x (str): x | |||||
| y (list): y | |||||
| Returns: | |||||
| bool, if x in y return true, x not in y return false. | |||||
| """ | |||||
| return const_utils.scalar_in_sequence(x, y) | |||||
| @in_.register("String", "Dictionary") | |||||
| def _str_in_dict(x, y): | |||||
| """ | |||||
| Determine if a str in dict. | |||||
| Args: | |||||
| x: str | |||||
| y: dict | |||||
| Returns: | |||||
| bool, if x in y return true, x not in y return false. | |||||
| """ | |||||
| return F.in_dict(x, y) | |||||
| @@ -15,10 +15,11 @@ | |||||
| """Implementation for setitem.""" | """Implementation for setitem.""" | ||||
| from . import _compile_utils as compile_utils | |||||
| from . import _constexpr_utils as const_utils | |||||
| from ... import functional as F | |||||
| from ...composite import base | from ...composite import base | ||||
| from ....common import dtype as mstype | from ....common import dtype as mstype | ||||
| from ... import functional as F | |||||
| from . import _utils as multi_utils | |||||
| setitem = base.MultitypeFuncGraph('setitem') | setitem = base.MultitypeFuncGraph('setitem') | ||||
| @@ -139,8 +140,8 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): | |||||
| Tensor, element type and shape is same as data. | Tensor, element type and shape is same as data. | ||||
| """ | """ | ||||
| index_dtype = F.dtype(index) | index_dtype = F.dtype(index) | ||||
| tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) | |||||
| if tensor_dtype == multi_utils.INT_: | |||||
| tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) | |||||
| if tensor_dtype == const_utils.INT_: | |||||
| return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor) | return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor) | ||||
| return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor) | return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor) | ||||
| @@ -166,8 +167,8 @@ def _tensor_setitem_by_tensor_with_number(data, index, value): | |||||
| Tensor, element type and shape is same as data. | Tensor, element type and shape is same as data. | ||||
| """ | """ | ||||
| index_dtype = F.dtype(index) | index_dtype = F.dtype(index) | ||||
| tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) | |||||
| if tensor_dtype == multi_utils.BOOL_: | |||||
| tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) | |||||
| if tensor_dtype == const_utils.BOOL_: | |||||
| return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value) | return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value) | ||||
| return _tensor_setitem_by_int_tensor_with_scalar(data, index, value) | return _tensor_setitem_by_int_tensor_with_scalar(data, index, value) | ||||
| @@ -190,17 +191,24 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value): | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type and shape is same as data. | Tensor, element type and shape is same as data. | ||||
| """ | """ | ||||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) | |||||
| result = None | |||||
| if index_elements_type == multi_utils.NO_TENSOR: | |||||
| result = _tensor_assgin_number(data, tuple_index, value) | |||||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) | |||||
| updates = multi_utils.generate_updates_from_scalar(data, indices, value, | |||||
| multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||||
| result = F.scatter_nd_update(data, indices, updates) | |||||
| return result | |||||
| indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) | |||||
| if index_elements_type == const_utils.NO_TENSOR: | |||||
| return _tensor_assgin_number(data, tuple_index, value) | |||||
| if index_elements_type == const_utils.ALL_TENSOR: | |||||
| indices = compile_utils.generate_indices_from_tuple_of_tensor(data, | |||||
| tuple_index, | |||||
| const_utils.TENSOR_SETITEM) | |||||
| else: | |||||
| indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, | |||||
| tuple_index, | |||||
| const_utils.TENSOR_SETITEM) | |||||
| updates = compile_utils.generate_updates_from_scalar(data, | |||||
| indices, | |||||
| value, | |||||
| const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||||
| return F.scatter_nd_update(data, indices, updates) | |||||
| @setitem.register("Tensor", "Tuple", "Tensor") | @setitem.register("Tensor", "Tuple", "Tensor") | ||||
| @@ -221,17 +229,24 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type and shape is same as data. | Tensor, element type and shape is same as data. | ||||
| """ | """ | ||||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) | |||||
| result = None | |||||
| if index_elements_type == multi_utils.NO_TENSOR: | |||||
| result = _tensor_assgin_tensor(data, tuple_index, value) | |||||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) | |||||
| updates = multi_utils.generate_updates_from_tensor(data, indices, value, | |||||
| multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||||
| result = F.scatter_nd_update(data, indices, updates) | |||||
| return result | |||||
| indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) | |||||
| if index_elements_type == const_utils.NO_TENSOR: | |||||
| return _tensor_assgin_tensor(data, tuple_index, value) | |||||
| if index_elements_type == const_utils.ALL_TENSOR: | |||||
| indices = compile_utils.generate_indices_from_tuple_of_tensor(data, | |||||
| tuple_index, | |||||
| const_utils.TENSOR_SETITEM) | |||||
| else: | |||||
| indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, | |||||
| tuple_index, | |||||
| const_utils.TENSOR_SETITEM) | |||||
| updates = compile_utils.generate_updates_from_tensor(data, | |||||
| indices, | |||||
| value, | |||||
| const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||||
| return F.scatter_nd_update(data, indices, updates) | |||||
| @setitem.register("Tensor", "Tuple", "Tuple") | @setitem.register("Tensor", "Tuple", "Tuple") | ||||
| @@ -253,15 +268,22 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type and shape is same as data. | Tensor, element type and shape is same as data. | ||||
| """ | """ | ||||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) | |||||
| result = None | |||||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) | |||||
| updates = multi_utils.generate_updates_from_tuple(data, indices, value, | |||||
| multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||||
| result = F.scatter_nd_update(data, indices, updates) | |||||
| return result | |||||
| indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) | |||||
| if index_elements_type == const_utils.ALL_TENSOR: | |||||
| indices = compile_utils.generate_indices_from_tuple_of_tensor(data, | |||||
| tuple_index, | |||||
| const_utils.TENSOR_SETITEM) | |||||
| else: | |||||
| indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, | |||||
| tuple_index, | |||||
| const_utils.TENSOR_SETITEM) | |||||
| updates = compile_utils.generate_updates_from_tuple(data, | |||||
| indices, | |||||
| value, | |||||
| const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||||
| return F.scatter_nd_update(data, indices, updates) | |||||
| @setitem.register("Tensor", "Tensor", "Tuple") | @setitem.register("Tensor", "Tensor", "Tuple") | ||||
| @@ -278,7 +300,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value): | |||||
| Tensor, element type and shape is same as data. | Tensor, element type and shape is same as data. | ||||
| """ | """ | ||||
| index_dtype = F.dtype(index) | index_dtype = F.dtype(index) | ||||
| check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64)) | |||||
| check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM) | |||||
| result = None | result = None | ||||
| if check_dtype: | if check_dtype: | ||||
| result = _tensor_setitem_by_tensor_with_tuple(data, index, value) | result = _tensor_setitem_by_tensor_with_tuple(data, index, value) | ||||
| @@ -331,14 +353,14 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value): | |||||
| def _tensor_assgin_number(data, input_slice, value): | def _tensor_assgin_number(data, input_slice, value): | ||||
| """Givens a scalar assign to tensor by slice""" | """Givens a scalar assign to tensor by slice""" | ||||
| check_result = multi_utils.check_tensor_setitem_index(input_slice) | |||||
| check_result = const_utils.check_tensor_setitem_index(input_slice) | |||||
| result = None | result = None | ||||
| if check_result: | if check_result: | ||||
| data_shape = F.shape(data) | data_shape = F.shape(data) | ||||
| indices = multi_utils.slice2indices(input_slice, data_shape) | |||||
| is_tuple_int = multi_utils.tuple_element_is_int(input_slice) | |||||
| indices = const_utils.slice2indices(input_slice, data_shape) | |||||
| is_tuple_int = const_utils.tuple_element_is_int(input_slice) | |||||
| if is_tuple_int: | if is_tuple_int: | ||||
| indices = multi_utils.integer_to_indices(input_slice, data_shape) | |||||
| indices = const_utils.integer_to_indices(input_slice, data_shape) | |||||
| result = _tensor_indices_number(data, data_shape, input_slice, indices, value) | result = _tensor_indices_number(data, data_shape, input_slice, indices, value) | ||||
| return result | return result | ||||
| @@ -347,7 +369,7 @@ def _tensor_assgin_number(data, input_slice, value): | |||||
| def _tensor_setitem_with_int_v1(data, index, value): | def _tensor_setitem_with_int_v1(data, index, value): | ||||
| """Syntax: A[1] = 3""" | """Syntax: A[1] = 3""" | ||||
| data_shape = F.shape(data) | data_shape = F.shape(data) | ||||
| indices = multi_utils.integer_to_indices(index, data_shape) | |||||
| indices = const_utils.integer_to_indices(index, data_shape) | |||||
| return _tensor_indices_number(data, data_shape, index, indices, value) | return _tensor_indices_number(data, data_shape, index, indices, value) | ||||
| @@ -355,7 +377,7 @@ def _tensor_setitem_with_int_v1(data, index, value): | |||||
| def _tensor_setitem_with_int_v2(data, index, value): | def _tensor_setitem_with_int_v2(data, index, value): | ||||
| """Syntax: A[1] = Tensor""" | """Syntax: A[1] = Tensor""" | ||||
| data_shape = F.shape(data) | data_shape = F.shape(data) | ||||
| indices = multi_utils.integer_to_indices(index, data_shape) | |||||
| indices = const_utils.integer_to_indices(index, data_shape) | |||||
| return _tensor_indices_tensor(data, data_shape, index, indices, value) | return _tensor_indices_tensor(data, data_shape, index, indices, value) | ||||
| @@ -376,7 +398,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value): | |||||
| data_size = F.size(data) | data_size = F.size(data) | ||||
| value_shape = F.shape(value) | value_shape = F.shape(value) | ||||
| value_size = F.size(value) | value_size = F.size(value) | ||||
| check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) | |||||
| check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) | |||||
| if check_result: | if check_result: | ||||
| if data_size == value_size: | if data_size == value_size: | ||||
| result = F.reshape(value, data_shape) | result = F.reshape(value, data_shape) | ||||
| @@ -391,13 +413,13 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value): | |||||
| def _tensor_assgin_tensor(data, input_slice, value): | def _tensor_assgin_tensor(data, input_slice, value): | ||||
| """Assigns a tensor value to the tensor by slice.""" | """Assigns a tensor value to the tensor by slice.""" | ||||
| result = None | result = None | ||||
| check_result = multi_utils.check_tensor_setitem_index(input_slice) | |||||
| check_result = const_utils.check_tensor_setitem_index(input_slice) | |||||
| if check_result: | if check_result: | ||||
| data_shape = F.shape(data) | data_shape = F.shape(data) | ||||
| indices = multi_utils.slice2indices(input_slice, data_shape) | |||||
| is_tuple_int = multi_utils.tuple_element_is_int(input_slice) | |||||
| indices = const_utils.slice2indices(input_slice, data_shape) | |||||
| is_tuple_int = const_utils.tuple_element_is_int(input_slice) | |||||
| if is_tuple_int: | if is_tuple_int: | ||||
| indices = multi_utils.integer_to_indices(input_slice, data_shape) | |||||
| indices = const_utils.integer_to_indices(input_slice, data_shape) | |||||
| result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) | result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) | ||||
| return result | return result | ||||
| @@ -407,7 +429,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value): | |||||
| data_size = F.size(data) | data_size = F.size(data) | ||||
| data_dtype = F.dtype(data) | data_dtype = F.dtype(data) | ||||
| indices_size = F.size(indices) | indices_size = F.size(indices) | ||||
| indices_size = multi_utils.check_indices(indices_size, index) | |||||
| indices_size = const_utils.check_indices(indices_size, index) | |||||
| update = F.fill(mstype.int32, (indices_size,), 1) | update = F.fill(mstype.int32, (indices_size,), 1) | ||||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | condition_1d = F.scatter_nd(indices, update, (data_size,)) | ||||
| condition = F.reshape(condition_1d, data_shape) | condition = F.reshape(condition_1d, data_shape) | ||||
| @@ -415,7 +437,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value): | |||||
| value_fill = None | value_fill = None | ||||
| value_size = F.size(value) | value_size = F.size(value) | ||||
| value_size = multi_utils.check_indices_value_size(indices_size, value_size) | |||||
| value_size = const_utils.check_indices_value_size(indices_size, value_size) | |||||
| if value_size == 1: | if value_size == 1: | ||||
| value_fill = F.fill(data_dtype, (indices_size,), 1) | value_fill = F.fill(data_dtype, (indices_size,), 1) | ||||
| value = F.cast(value, data_dtype) | value = F.cast(value, data_dtype) | ||||
| @@ -432,7 +454,7 @@ def _tensor_indices_number(data, data_shape, index, indices, value): | |||||
| data_size = F.size(data) | data_size = F.size(data) | ||||
| data_dtype = F.dtype(data) | data_dtype = F.dtype(data) | ||||
| indices_size = F.size(indices) | indices_size = F.size(indices) | ||||
| indices_size = multi_utils.check_indices(indices_size, index) | |||||
| indices_size = const_utils.check_indices(indices_size, index) | |||||
| update = F.fill(mstype.int32, (indices_size,), 1) | update = F.fill(mstype.int32, (indices_size,), 1) | ||||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | condition_1d = F.scatter_nd(indices, update, (data_size,)) | ||||
| condition = F.reshape(condition_1d, data_shape) | condition = F.reshape(condition_1d, data_shape) | ||||
| @@ -445,16 +467,16 @@ def _tensor_indices_number(data, data_shape, index, indices, value): | |||||
| def _tensor_setitem_by_tensor_with_tuple(data, index, value): | def _tensor_setitem_by_tensor_with_tuple(data, index, value): | ||||
| """Set a tensor item by a tensor with a tuple.""" | """Set a tensor item by a tensor with a tuple.""" | ||||
| updates = multi_utils.generate_updates_from_tuple(data, index, value, | |||||
| multi_utils.SET_ITEM_BY_ONE_TENSOR) | |||||
| updates = compile_utils.generate_updates_from_tuple(data, index, value, | |||||
| const_utils.SET_ITEM_BY_ONE_TENSOR) | |||||
| result = F.scatter_update(data, index, updates) | result = F.scatter_update(data, index, updates) | ||||
| return result | return result | ||||
| def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): | def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): | ||||
| """Set a tensor item by a int tensor with a scalar.""" | """Set a tensor item by a int tensor with a scalar.""" | ||||
| updates = multi_utils.generate_updates_from_scalar(data, index, value, | |||||
| multi_utils.SET_ITEM_BY_ONE_TENSOR) | |||||
| updates = compile_utils.generate_updates_from_scalar(data, index, value, | |||||
| const_utils.SET_ITEM_BY_ONE_TENSOR) | |||||
| return F.scatter_update(data, index, updates) | return F.scatter_update(data, index, updates) | ||||
| @@ -462,7 +484,7 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): | |||||
| """Set a tensor item by a bool tensor with a scalar.""" | """Set a tensor item by a bool tensor with a scalar.""" | ||||
| index_shape = F.shape(index) | index_shape = F.shape(index) | ||||
| shape = F.shape(data) | shape = F.shape(data) | ||||
| shape = multi_utils.check_equal( | |||||
| shape = const_utils.check_equal( | |||||
| shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | ||||
| dtype = F.dtype(data) | dtype = F.dtype(data) | ||||
| u = F.fill(dtype, shape, value) | u = F.fill(dtype, shape, value) | ||||
| @@ -471,8 +493,8 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): | |||||
| def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): | def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): | ||||
| """Set a tensor item by a int tensor with a tensor.""" | """Set a tensor item by a int tensor with a tensor.""" | ||||
| updates = multi_utils.generate_updates_from_tensor(data, index, value, | |||||
| multi_utils.SET_ITEM_BY_ONE_TENSOR) | |||||
| updates = compile_utils.generate_updates_from_tensor(data, index, value, | |||||
| const_utils.SET_ITEM_BY_ONE_TENSOR) | |||||
| return F.scatter_update(data, index, updates) | return F.scatter_update(data, index, updates) | ||||
| @@ -480,10 +502,10 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value): | |||||
| """Set a tensor item by a bool tensor with a tensor.""" | """Set a tensor item by a bool tensor with a tensor.""" | ||||
| index_shape = F.shape(index) | index_shape = F.shape(index) | ||||
| data_shape = F.shape(data) | data_shape = F.shape(data) | ||||
| data_shape = multi_utils.check_equal(data_shape, index_shape, | |||||
| data_shape = const_utils.check_equal(data_shape, index_shape, | |||||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | ||||
| size = F.size(value) | size = F.size(value) | ||||
| size = multi_utils.check_equal(1, size, | |||||
| size = const_utils.check_equal(1, size, | |||||
| "When assign value is a tensor, its size should be {}, but current size is {}.") | "When assign value is a tensor, its size should be {}, but current size is {}.") | ||||
| dtype = F.dtype(data) | dtype = F.dtype(data) | ||||
| u_cast = F.cast(value, dtype) | u_cast = F.cast(value, dtype) | ||||
| @@ -1419,7 +1419,6 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name): | |||||
| validator.check_value_type("shape", x_shape, [tuple, list], prim_name) | validator.check_value_type("shape", x_shape, [tuple, list], prim_name) | ||||
| validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name) | validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name) | ||||
| validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name) | validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name) | ||||
| validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT, prim_name) | |||||
| rank_base = len(x_shape[0]) | rank_base = len(x_shape[0]) | ||||
| N = len(x_shape) | N = len(x_shape) | ||||
| out_shape = x_shape[0] | out_shape = x_shape[0] | ||||
| @@ -33,9 +33,4 @@ class IdentityEC(IExectorComponent): | |||||
| keyword.desc_inputs: self.inputs[keyword.desc_inputs], | keyword.desc_inputs: self.inputs[keyword.desc_inputs], | ||||
| keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs]) | keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs]) | ||||
| } | } | ||||
| print("buxue------------------------------------------------") | |||||
| print("inputs") | |||||
| print(ret[keyword.desc_inputs]) | |||||
| print("outputs") | |||||
| print(ret[keyword.result]) | |||||
| return ret | return ret | ||||
| @@ -19,9 +19,9 @@ import mindspore.nn as nn | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from ..ut_filter import non_graph_engine | |||||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||||
| from tests.ut.python.ut_filter import non_graph_engine | |||||
| from tests.mindspore_test_framework.mindspore_test import mindspore_test | |||||
| from tests.mindspore_test_framework.pipeline.forward.compile_forward \ | |||||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config | import pipeline_for_compile_forward_ge_graph_for_case_by_case_config | ||||
| @@ -133,7 +133,7 @@ def test_list_append_2(): | |||||
| class ListOperate(nn.Cell): | class ListOperate(nn.Cell): | ||||
| def __init__(self,): | |||||
| def __init__(self, ): | |||||
| super(ListOperate, self).__init__() | super(ListOperate, self).__init__() | ||||
| def construct(self, t, l): | def construct(self, t, l): | ||||
| @@ -152,6 +152,20 @@ class ListOperate(nn.Cell): | |||||
| return x | return x | ||||
| class InListNet(nn.Cell): | |||||
| def __init__(self, ): | |||||
| super(InListNet, self).__init__() | |||||
| self.list_ = [1, 2, 3, 4, 5, "ok"] | |||||
| def construct(self, x): | |||||
| ret = x | |||||
| if 2 in self.list_: | |||||
| ret = x + x | |||||
| if "ok" in self.list_: | |||||
| ret = x - x | |||||
| return ret | |||||
| class AxisListNet(nn.Cell): | class AxisListNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(AxisListNet, self).__init__() | super(AxisListNet, self).__init__() | ||||
| @@ -204,10 +218,15 @@ test_case_ops = [ | |||||
| ('AxisListDefault', { | ('AxisListDefault', { | ||||
| 'block': AxisListDefaultNet(), | 'block': AxisListDefaultNet(), | ||||
| 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), | 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), | ||||
| ('InList', { | |||||
| 'block': InListNet(), | |||||
| 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), | |||||
| ] | ] | ||||
| test_case_lists = [test_case_ops] | test_case_lists = [test_case_ops] | ||||
| test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) | test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) | ||||
| # use -k to select certain testcast | # use -k to select certain testcast | ||||
| # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm | # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm | ||||
| @@ -19,9 +19,9 @@ import mindspore.context as context | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import dtype as mstype | from mindspore import dtype as mstype | ||||
| from ..ut_filter import non_graph_engine | |||||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||||
| from tests.ut.python.ut_filter import non_graph_engine | |||||
| from tests.mindspore_test_framework.mindspore_test import mindspore_test | |||||
| from tests.mindspore_test_framework.pipeline.forward.compile_forward \ | |||||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config | import pipeline_for_compile_forward_ge_graph_for_case_by_case_config | ||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | ||||
| @@ -52,6 +52,20 @@ class NestTupleGraphNet(nn.Cell): | |||||
| return self.layers[0][1](x) | return self.layers[0][1](x) | ||||
| class InTupleNet(nn.Cell): | |||||
| def __init__(self, ): | |||||
| super(InTupleNet, self).__init__() | |||||
| self.tuple_ = (1, 2, 3, 4, 5, "ok") | |||||
| def construct(self, x): | |||||
| ret = x | |||||
| if 2 in self.tuple_: | |||||
| ret = x + x | |||||
| if "ok" in self.tuple_: | |||||
| ret = x - x | |||||
| return ret | |||||
| test_case_ops = [ | test_case_ops = [ | ||||
| ('TupleGraph', { | ('TupleGraph', { | ||||
| 'block': TupleGraphNet(), | 'block': TupleGraphNet(), | ||||
| @@ -59,6 +73,9 @@ test_case_ops = [ | |||||
| ('NestTupleGraph', { | ('NestTupleGraph', { | ||||
| 'block': NestTupleGraphNet(), | 'block': NestTupleGraphNet(), | ||||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), | 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), | ||||
| ('InTuple', { | |||||
| 'block': InTupleNet(), | |||||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}) | |||||
| ] | ] | ||||
| test_case_lists = [test_case_ops] | test_case_lists = [test_case_ops] | ||||
| @@ -176,12 +176,134 @@ class TensorGetItemByThreeTensors(Cell): | |||||
| return ret | return ret | ||||
| class TensorGetItemByMixedTensors(Cell): | |||||
| class TensorGetItemByMixedTensors_0(Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(TensorGetItemByMixedTensors, self).__init__() | |||||
| super(TensorGetItemByMixedTensors_0, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1): | |||||
| ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensors_1(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensors_1, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1): | |||||
| ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensors_2(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensors_2, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1): | |||||
| ret = tensor[0, index_0, index_1, ..., 3] + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensors_3(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensors_3, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1): | |||||
| ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensors_4(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensors_4, self).__init__() | |||||
| self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1, index_2): | |||||
| ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensors_5(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensors_5, self).__init__() | |||||
| self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1, index_2): | |||||
| ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensors_6(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensors_6, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1, index_2): | |||||
| ret = tensor[..., index_0, index_1, index_2, 3] + self.const | |||||
| return ret | |||||
| class TensorSetItemByMixedTensors_0(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByMixedTensors_0, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32)) | |||||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), | |||||
| mstype.float32), | |||||
| name="x") | |||||
| self.value = value | |||||
| def construct(self, index_0, index_1, index_2): | |||||
| self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByMixedTensors_1(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByMixedTensors_1, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32)) | |||||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||||
| name="x") | |||||
| self.value = value | |||||
| def construct(self, index_0, index_1, index_2): | |||||
| self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByMixedTensors_2(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByMixedTensors_2, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32)) | |||||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||||
| name="x") | |||||
| self.value = value | |||||
| def construct(self, index_0, index_1, index_2): | |||||
| self.param[..., index_0, index_1, index_2, 3] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensorsTypeError(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensorsTypeError, self).__init__() | |||||
| def construct(self, x, index_0, index_1): | def construct(self, x, index_0, index_1): | ||||
| ret = x[index_0, index_1, 0:6] | |||||
| ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]] | |||||
| return ret | |||||
| class TensorGetItemByMixedTensorsNumberError(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensorsNumberError, self).__init__() | |||||
| def construct(self, x, index_0, index_1): | |||||
| ret = x[index_0, index_1, 0:3, ..., index_1, index_0] | |||||
| return ret | return ret | ||||
| @@ -189,7 +311,7 @@ class TensorSetItemByOneTensorWithNumber(Cell): | |||||
| def __init__(self, value): | def __init__(self, value): | ||||
| super(TensorSetItemByOneTensorWithNumber, self).__init__() | super(TensorSetItemByOneTensorWithNumber, self).__init__() | ||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | ||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = value | self.value = value | ||||
| def construct(self, index): | def construct(self, index): | ||||
| @@ -202,7 +324,7 @@ class TensorSetItemByOneTensorWithTensor(Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(TensorSetItemByOneTensorWithTensor, self).__init__() | super(TensorSetItemByOneTensorWithTensor, self).__init__() | ||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | ||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index, value): | def construct(self, index, value): | ||||
| self.param[index] = value | self.param[index] = value | ||||
| @@ -214,7 +336,7 @@ class TensorSetItemByOneTensorWithTupleOfNumber(Cell): | |||||
| def __init__(self, value): | def __init__(self, value): | ||||
| super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__() | super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__() | ||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | ||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = value | self.value = value | ||||
| def construct(self, index): | def construct(self, index): | ||||
| @@ -227,7 +349,7 @@ class TensorSetItemByOneTensorWithTupleOfTensor(Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__() | super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__() | ||||
| self.const = Tensor(np.ones((6, 3, 8)), mstype.float32) | self.const = Tensor(np.ones((6, 3, 8)), mstype.float32) | ||||
| self.param = Parameter(Tensor(np.arange(6*3*8).reshape((6, 3, 8)), mstype.float32), name="x") | |||||
| self.param = Parameter(Tensor(np.arange(6 * 3 * 8).reshape((6, 3, 8)), mstype.float32), name="x") | |||||
| def construct(self, index, value_0, value_1, value_2): | def construct(self, index, value_0, value_1, value_2): | ||||
| self.param[index] = (value_0, value_1, value_2) | self.param[index] = (value_0, value_1, value_2) | ||||
| @@ -239,7 +361,7 @@ class TensorSetItemByTensorsWithNumber(Cell): | |||||
| def __init__(self, value): | def __init__(self, value): | ||||
| super(TensorSetItemByTensorsWithNumber, self).__init__() | super(TensorSetItemByTensorsWithNumber, self).__init__() | ||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | ||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = value | self.value = value | ||||
| def construct(self, index_0, index_1, index_2): | def construct(self, index_0, index_1, index_2): | ||||
| @@ -252,7 +374,7 @@ class TensorSetItemByTensorsWithTensor(Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(TensorSetItemByTensorsWithTensor, self).__init__() | super(TensorSetItemByTensorsWithTensor, self).__init__() | ||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | ||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index_0, index_1, index_2, value): | def construct(self, index_0, index_1, index_2, value): | ||||
| self.param[index_0, index_1, index_2] = value | self.param[index_0, index_1, index_2] = value | ||||
| @@ -264,7 +386,7 @@ class TensorSetItemByTensorsWithTensorNumberError(Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(TensorSetItemByTensorsWithTensorNumberError, self).__init__() | super(TensorSetItemByTensorsWithTensorNumberError, self).__init__() | ||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | ||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index_0, index_1, index_2, index_3, value): | def construct(self, index_0, index_1, index_2, index_3, value): | ||||
| self.param[index_0, index_1, index_2, index_3] = value | self.param[index_0, index_1, index_2, index_3] = value | ||||
| @@ -276,7 +398,7 @@ class TensorSetItemByTensorsWithTupleOfNumber(Cell): | |||||
| def __init__(self, value): | def __init__(self, value): | ||||
| super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__() | super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__() | ||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | ||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = value | self.value = value | ||||
| def construct(self, index_0, index_1, index_2): | def construct(self, index_0, index_1, index_2): | ||||
| @@ -289,7 +411,7 @@ class TensorSetItemByTensorsWithTupleOfTensor(Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__() | super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__() | ||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | ||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index_0, index_1, index_2, value_0, value_1, value_2): | def construct(self, index_0, index_1, index_2, value_0, value_1, value_2): | ||||
| self.param[index_0, index_1, index_2] = (value_0, value_1, value_2) | self.param[index_0, index_1, index_2] = (value_0, value_1, value_2) | ||||
| @@ -301,7 +423,7 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__() | super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__() | ||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | ||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index_0, index_1, index_2, value_0, value_1): | def construct(self, index_0, index_1, index_2, value_0, value_1): | ||||
| self.param[index_0, index_1, index_2] = (value_0, value_1) | self.param[index_0, index_1, index_2] = (value_0, value_1) | ||||
| @@ -313,7 +435,7 @@ class TensorSetItemByMixedTensors(Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(TensorSetItemByMixedTensors, self).__init__() | super(TensorSetItemByMixedTensors, self).__init__() | ||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | ||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = 99.0 | self.value = 99.0 | ||||
| def construct(self, index_0, index_1): | def construct(self, index_0, index_1): | ||||
| @@ -538,11 +660,11 @@ def test_tensor_assign_bool_index(): | |||||
| net1(Ta, Tb, Tc, u_tensor) | net1(Ta, Tb, Tc, u_tensor) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| net1(Ta, Td, Tc, u_tensor) | net1(Ta, Td, Tc, u_tensor) | ||||
| with pytest.raises(TypeError): | |||||
| with pytest.raises(IndexError): | |||||
| net1(Ta, u_tensor, Tc, u_tensor) | net1(Ta, u_tensor, Tc, u_tensor) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| net1(Ta, Tb, Td, u_tensor) | net1(Ta, Tb, Td, u_tensor) | ||||
| with pytest.raises(TypeError): | |||||
| with pytest.raises(IndexError): | |||||
| net1(Ta, Tb, Ta, u_tensor) | net1(Ta, Tb, Ta, u_tensor) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| net1(Ta, Tb, Tc, u_tensor_error) | net1(Ta, Tb, Tc, u_tensor_error) | ||||
| @@ -620,22 +742,67 @@ test_cases = [ | |||||
| }), | }), | ||||
| ('TensorGetItemByOneTensor', { | ('TensorGetItemByOneTensor', { | ||||
| 'block': TensorGetItemByOneTensor(), | 'block': TensorGetItemByOneTensor(), | ||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)], | Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)], | ||||
| }), | }), | ||||
| ('TensorGetItemByTwoTensors', { | ('TensorGetItemByTwoTensors', { | ||||
| 'block': TensorGetItemByTwoTensors(), | 'block': TensorGetItemByTwoTensors(), | ||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], | Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], | ||||
| }), | }), | ||||
| ('TensorGetItemByThreeTensors', { | ('TensorGetItemByThreeTensors', { | ||||
| 'block': TensorGetItemByThreeTensors(), | 'block': TensorGetItemByThreeTensors(), | ||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | ||||
| }), | }), | ||||
| ('TensorGetItemByMixedTensors_0', { | |||||
| 'block': TensorGetItemByMixedTensors_0(), | |||||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensors_1', { | |||||
| 'block': TensorGetItemByMixedTensors_1(), | |||||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensors_2', { | |||||
| 'block': TensorGetItemByMixedTensors_2(), | |||||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensors_3', { | |||||
| 'block': TensorGetItemByMixedTensors_3(), | |||||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensors_4', { | |||||
| 'block': TensorGetItemByMixedTensors_4(), | |||||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.float32), | |||||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensors_5', { | |||||
| 'block': TensorGetItemByMixedTensors_5(), | |||||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensors_6', { | |||||
| 'block': TensorGetItemByMixedTensors_6(), | |||||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByOneTensorWithNumber', { | ('TensorSetItemByOneTensorWithNumber', { | ||||
| 'block': TensorSetItemByOneTensorWithNumber(value=0.0), | 'block': TensorSetItemByOneTensorWithNumber(value=0.0), | ||||
| 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)], | 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)], | ||||
| @@ -683,46 +850,143 @@ test_cases = [ | |||||
| Tensor(np.zeros((4, 5)), mstype.float32), | Tensor(np.zeros((4, 5)), mstype.float32), | ||||
| Tensor(np.ones((4, 5)), mstype.float32), | Tensor(np.ones((4, 5)), mstype.float32), | ||||
| Tensor(np.ones((4, 5)) * 2, mstype.float32)], | Tensor(np.ones((4, 5)) * 2, mstype.float32)], | ||||
| }) | |||||
| }), | |||||
| ('TensorSetItemByMixedTensorsWithNumber_0', { | |||||
| 'block': TensorSetItemByMixedTensors_0(value=88.0), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByMixedTensorsWithTensor_0', { | |||||
| 'block': TensorSetItemByMixedTensors_0(value=Tensor(np.ones((4, 5, 3, 9), np.float32))), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensorsWithTupleOfNumber_0', { | |||||
| 'block': TensorSetItemByMixedTensors_0(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensorsWithTupleOfTensor_0', { | |||||
| 'block': TensorSetItemByMixedTensors_0(value=(Tensor(np.ones((4, 5, 3, 9), np.float32)), | |||||
| Tensor(np.zeros((4, 5, 3, 9), np.float32)), | |||||
| Tensor(np.ones((4, 5, 3, 9), np.float32)))), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByMixedTensorsWithNumber_1', { | |||||
| 'block': TensorSetItemByMixedTensors_1(value=88.0), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByMixedTensorsWithTensor_1', { | |||||
| 'block': TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensorsWithTupleOfNumber_1', { | |||||
| 'block': TensorSetItemByMixedTensors_1(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensorsWithTupleOfTensor_1', { | |||||
| 'block': TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)), | |||||
| Tensor(np.zeros((5, 2, 6), np.float32)), | |||||
| Tensor(np.ones((5, 2, 6), np.float32)), | |||||
| Tensor(np.ones((5, 2, 6), np.float32)))), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByMixedTensorsWithNumber_2', { | |||||
| 'block': TensorSetItemByMixedTensors_2(value=88.0), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByMixedTensorsWithTensor_2', { | |||||
| 'block': TensorSetItemByMixedTensors_2(value=Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensorsWithTupleOfNumber_2', { | |||||
| 'block': TensorSetItemByMixedTensors_2(value=(1.0, 2.0, 3.0, 4.0, 5.0)), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensorsWithTupleOfTensor_2', { | |||||
| 'block': TensorSetItemByMixedTensors_2(value=(Tensor(np.ones((4, 5), np.float32)), | |||||
| Tensor(np.zeros((4, 5), np.float32)), | |||||
| Tensor(np.ones((4, 5), np.float32)))), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ] | ] | ||||
| raise_error_set = [ | raise_error_set = [ | ||||
| ('TensorGetItemByOneTensorDtypeError', { | ('TensorGetItemByOneTensorDtypeError', { | ||||
| 'block': (TensorGetItemByOneTensor(), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| 'block': (TensorGetItemByOneTensor(), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)], | Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)], | ||||
| }), | }), | ||||
| ('TensorGetItemByTwoTensorsShapeError', { | ('TensorGetItemByTwoTensorsShapeError', { | ||||
| 'block': (TensorGetItemByTwoTensors(), {'exception': ValueError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| 'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)], | Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)], | ||||
| }), | }), | ||||
| ('TensorGetItemByTwoTensorsDtypeError', { | ('TensorGetItemByTwoTensorsDtypeError', { | ||||
| 'block': (TensorGetItemByTwoTensors(), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| 'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)], | Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)], | ||||
| }), | }), | ||||
| ('TensorGetItemByThreeTensorsShapeError', { | ('TensorGetItemByThreeTensorsShapeError', { | ||||
| 'block': (TensorGetItemByThreeTensors(), {'exception': ValueError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| 'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32), | Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)], | Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)], | ||||
| }), | }), | ||||
| ('TensorGetItemByThreeTensorsDtypeError', { | ('TensorGetItemByThreeTensorsDtypeError', { | ||||
| 'block': (TensorGetItemByThreeTensors(), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| 'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int64), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | ||||
| }), | }), | ||||
| ('TensorGetItemByMixedTensors', { | |||||
| 'block': (TensorGetItemByMixedTensors(), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| ('TensorGetItemByMixedTensorsNumberError', { | |||||
| 'block': (TensorGetItemByMixedTensorsNumberError(), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64)], | |||||
| Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensorsTypeError', { | |||||
| 'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensorsDtypeError', { | |||||
| 'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.float32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensorsShapeError', { | |||||
| 'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(2, 4, 5)), mstype.int32)], | |||||
| }), | }), | ||||
| ('TensorSetItemByOneTensorWithNumberTypeError', { | ('TensorSetItemByOneTensorWithNumberTypeError', { | ||||
| 'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}), | 'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}), | ||||
| @@ -760,21 +1024,21 @@ raise_error_set = [ | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | ||||
| }), | }), | ||||
| ('TensorSetItemByTensorsWithTensorShapeError', { | ('TensorSetItemByTensorsWithTensorShapeError', { | ||||
| 'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}), | |||||
| 'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | ||||
| Tensor(np.zeros((2, 5)), mstype.float32)], | Tensor(np.zeros((2, 5)), mstype.float32)], | ||||
| }), | }), | ||||
| ('TensorSetItemByTensorsWithTensorTypeError', { | ('TensorSetItemByTensorsWithTensorTypeError', { | ||||
| 'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}), | |||||
| 'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | ||||
| Tensor(np.zeros((4, 5)), mstype.int32)], | Tensor(np.zeros((4, 5)), mstype.int32)], | ||||
| }), | }), | ||||
| ('TensorSetItemByTensorsWithTensorNumberError', { | ('TensorSetItemByTensorsWithTensorNumberError', { | ||||
| 'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}), | |||||
| 'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | ||||
| @@ -782,19 +1046,19 @@ raise_error_set = [ | |||||
| Tensor(np.zeros((2, 5)), mstype.float32)], | Tensor(np.zeros((2, 5)), mstype.float32)], | ||||
| }), | }), | ||||
| ('TensorSetItemByTensorsWithTupleOfNumberTypeError', { | ('TensorSetItemByTensorsWithTupleOfNumberTypeError', { | ||||
| 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}), | |||||
| 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1, 2, 3, 4)), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | ||||
| }), | }), | ||||
| ('TensorSetItemByTensorsWithTupleOfNumberNumberError', { | ('TensorSetItemByTensorsWithTupleOfNumberNumberError', { | ||||
| 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}), | |||||
| 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | ||||
| }), | }), | ||||
| ('TensorSetItemByTensorsWithTupleOfTensorNumberError', { | ('TensorSetItemByTensorsWithTupleOfTensorNumberError', { | ||||
| 'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}), | |||||
| 'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | ||||
| @@ -802,7 +1066,7 @@ raise_error_set = [ | |||||
| Tensor(np.ones((4, 5)), mstype.float32)], | Tensor(np.ones((4, 5)), mstype.float32)], | ||||
| }), | }), | ||||
| ('TensorSetItemByTensorsWithTupleOfTensorTypeError', { | ('TensorSetItemByTensorsWithTupleOfTensorTypeError', { | ||||
| 'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}), | |||||
| 'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | ||||
| @@ -810,10 +1074,65 @@ raise_error_set = [ | |||||
| Tensor(np.ones((4, 5)), mstype.int32), | Tensor(np.ones((4, 5)), mstype.int32), | ||||
| Tensor(np.ones((4, 5)) * 2, mstype.int32)], | Tensor(np.ones((4, 5)) * 2, mstype.int32)], | ||||
| }), | }), | ||||
| ('TensorSetItemByMixedTensors', { | |||||
| 'block': (TensorSetItemByMixedTensors(), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], | |||||
| ('TensorSetItemByMixedTensorsWithNumberValueTypeError', { | |||||
| 'block': (TensorSetItemByMixedTensors_1(value=88), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByMixedTensorsWithNumberIndexTypeError', { | |||||
| 'block': (TensorSetItemByMixedTensors_1(value=88.0), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.float32)], | |||||
| }), | |||||
| ('TensorSetItemByMixedTensorsWithTensorValueDtypeError', { | |||||
| 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.int32))), | |||||
| {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByMixedTensorsWithTensorValueShapeError', { | |||||
| 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((3, 2, 6), np.float32))), | |||||
| {'exception': ValueError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByMixedTensorsWithTensorIndexDtypeError', { | |||||
| 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))), | |||||
| {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.float32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensorsWithTupleOfNumberValueTypeError', { | |||||
| 'block': (TensorSetItemByMixedTensors_1(value=(1.0, 2, 3.0, 4.0, 5.0, 6.0)), | |||||
| {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensorsWithTupleOfTensorValueDtypeError', { | |||||
| 'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)), | |||||
| Tensor(np.zeros((5, 2, 6), np.float32)), | |||||
| Tensor(np.ones((5, 2, 6), np.float32)), | |||||
| Tensor(np.ones((5, 2, 6), np.int32)))), | |||||
| {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorGetItemByMixedTensorsWithTupleOfTensorIndexDtypeError', { | |||||
| 'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)), | |||||
| Tensor(np.zeros((5, 2, 6), np.float32)), | |||||
| Tensor(np.ones((5, 2, 6), np.float32)), | |||||
| Tensor(np.ones((5, 2, 6), np.int32)))), | |||||
| {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.float32), | |||||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||||
| }) | }) | ||||
| ] | ] | ||||