Merge pull request !1407 from zhangbuxue/support_mixed_tensor_for_tensor_get_item_and_tensor_set_itemtags/v0.5.0-beta
| @@ -105,7 +105,7 @@ convert_object_map = { | |||
| T.ge: multitype_ops.greater_equal, | |||
| T.is_: F.is_, | |||
| T.is_not: F.is_not, | |||
| T.contains: F.in_dict, | |||
| T.contains: multitype_ops.in_, | |||
| T.not_contains: F.not_in_dict, | |||
| # system function | |||
| @@ -474,6 +474,8 @@ REGISTER_PYBIND_DEFINE( | |||
| (void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init()); | |||
| (void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); | |||
| (void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init()); | |||
| (void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init()); | |||
| (void)py::class_<Ellipsis, Type, std::shared_ptr<Ellipsis>>(m_sub, "Ellipsis").def(py::init()); | |||
| })); | |||
| const TypePtr kTypeExternal = std::make_shared<External>(); | |||
| @@ -95,6 +95,8 @@ string = typing.String() | |||
| type_refkey = typing.RefKeyType() | |||
| tensor_type = typing.TensorType | |||
| anything_type = typing.TypeAnything | |||
| slice_type = typing.Slice | |||
| ellipsis_type = typing.Ellipsis | |||
| number_type = (int8, | |||
| int16, | |||
| @@ -37,6 +37,7 @@ from .logical_and_impl import logical_and | |||
| from .logical_or_impl import logical_or | |||
| from .logic_not_impl import logical_not | |||
| from .uadd_impl import uadd | |||
| from .in_impl import in_ | |||
| __all__ = [ | |||
| 'add', | |||
| 'sub', | |||
| @@ -59,5 +60,6 @@ __all__ = [ | |||
| 'setitem', | |||
| 'logical_and', | |||
| 'logical_or', | |||
| 'logical_not' | |||
| 'logical_not', | |||
| 'in_' | |||
| ] | |||
| @@ -0,0 +1,154 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """constexpr util""" | |||
| from . import _constexpr_utils as const_utils | |||
| from ... import functional as F | |||
| from ... import operations as P | |||
| from ...composite import base | |||
| from ....common import dtype as mstype | |||
| hyper_map = base.HyperMap() | |||
| pack = P.Pack(axis=-1) | |||
| def broadcast(broadcast_shape, x): | |||
| """Broadcast tensor to the required shape.""" | |||
| if F.shape(x) == broadcast_shape: | |||
| return x | |||
| multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape) | |||
| if multiples: | |||
| return F.tile(x, multiples) | |||
| return x | |||
| def transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x): | |||
| """Transform indexing tensor to the required.""" | |||
| x = broadcast(broadcast_shape, x) | |||
| return broadcast(final_shape, F.reshape(x, new_shape)) | |||
| def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): | |||
| """Generate an indices tensor from a tuple of tensor.""" | |||
| indices = None | |||
| check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) | |||
| if check_index_tensor_number: | |||
| dtype_tuple = hyper_map(F.dtype, tuple_index) | |||
| check_dtypes = const_utils.check_index_tensors_dtype(dtype_tuple, op_name) | |||
| if check_dtypes: | |||
| shape_tuple = hyper_map(F.shape, tuple_index) | |||
| broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name) | |||
| broadcast_tensors = hyper_map(F.partial(broadcast, broadcast_shape), tuple_index) | |||
| indices = pack(broadcast_tensors) | |||
| return indices | |||
| def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): | |||
| """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| int_positions = const_utils.get_pos_of_int_index(indexes_types) | |||
| for i in int_positions: | |||
| tuple_index = F.tuple_setitem(tuple_index, i, F.scalar_to_tensor(tuple_index[i], mstype.int32)) | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| tensor_positions, slice_positions, ellipsis_position = \ | |||
| const_utils.separate_mixed_tensors_index(indexes_types, op_name) | |||
| tensor_indexes = [] | |||
| slice_indexes = [] | |||
| for i in tensor_positions: | |||
| tensor_indexes.append(tuple_index[i]) | |||
| for j in slice_positions: | |||
| slice_indexes.append(tuple_index[j]) | |||
| data_shape = F.shape(data) | |||
| tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) | |||
| tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) | |||
| broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \ | |||
| const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape, | |||
| indexes_types, | |||
| tensor_indexes_shapes, | |||
| tensor_indexes_dtypes, | |||
| slice_indexes, | |||
| op_name) | |||
| slice_number = 0 | |||
| final_index_tensors = [] | |||
| tuple_index_size = len(tuple_index) | |||
| index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) | |||
| for i in range(tuple_index_size): | |||
| if i in tensor_positions: | |||
| transform_tensor = transform_indexing_tensor(broadcast_shape, | |||
| final_shape, | |||
| index_tensor_new_shape, | |||
| tuple_index[i]) | |||
| final_index_tensors.append(transform_tensor) | |||
| if i in slice_positions: | |||
| slice_tensor = const_utils.convert_slice_to_tensor(slice_number, | |||
| final_shape, | |||
| indexes_shapes_info, | |||
| op_name) | |||
| final_index_tensors.append(slice_tensor) | |||
| slice_number += 1 | |||
| if i == ellipsis_position: | |||
| ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number, | |||
| ellipsis_occupied_dims, | |||
| final_shape, | |||
| indexes_shapes_info, | |||
| op_name) | |||
| for ele in ellipsis_tensors: | |||
| final_index_tensors.append(ele) | |||
| slice_number += ellipsis_occupied_dims | |||
| indices = pack(final_index_tensors) | |||
| return indices | |||
| def generate_updates_from_scalar(data, indices, value, op_type): | |||
| """Generate an updates tensor from a scalar.""" | |||
| data_shape = F.shape(data) | |||
| indices_shape = F.shape(indices) | |||
| data_dtype = F.dtype(data) | |||
| return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) | |||
| def generate_updates_from_tuple(data, index, value, op_type): | |||
| """Generate an updates tensor from a tuple.""" | |||
| value_types = hyper_map(F.typeof, value) | |||
| data_dtype = F.dtype(data) | |||
| value_elements_type = const_utils.check_value_elements(data_dtype, value_types) | |||
| if value_elements_type == const_utils.ALL_TENSOR: | |||
| value_shapes = hyper_map(F.shape, value) | |||
| shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM) | |||
| if shapes_same: | |||
| value = F.pack(value) | |||
| return generate_updates_from_tensor(data, index, value, op_type) | |||
| data_shape = F.shape(data) | |||
| index_shape = F.shape(index) | |||
| return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) | |||
| def generate_updates_from_tensor(data, index, value, op_type): | |||
| """Generate an updates tensor from a tensor.""" | |||
| data_shape = F.shape(data) | |||
| index_shape = F.shape(index) | |||
| value_shape = F.shape(value) | |||
| data_dtype = F.dtype(data) | |||
| value_dtype = F.dtype(value) | |||
| updates_shape = value_shape | |||
| check_dtype_same = const_utils.check_tensors_dtype_same(data_dtype, value_dtype, const_utils.TENSOR_SETITEM) | |||
| if check_dtype_same: | |||
| updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type) | |||
| need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape) | |||
| if need_broadcast: | |||
| return broadcast(updates_shape, value) | |||
| return value | |||
| @@ -15,19 +15,15 @@ | |||
| """constexpr util""" | |||
| from functools import reduce | |||
| import numpy as np | |||
| from ...primitive import constexpr | |||
| from ....common.tensor import Tensor | |||
| from ....common import dtype as mstype | |||
| from .... import log as logger | |||
| from ...._extends.utils import Slice, Ellipsis_ | |||
| from ....common import dtype as mstype | |||
| from ....common.tensor import Tensor | |||
| from ....ops import _utils as op_utils | |||
| from ...composite import base | |||
| from .... import log as logger | |||
| from ... import functional as F | |||
| from ... import operations as P | |||
| hyper_map = base.HyperMap() | |||
| pack = P.Pack(axis=-1) | |||
| ALL_TENSOR = 0 | |||
| NO_TENSOR = 1 | |||
| @@ -264,7 +260,7 @@ def tuple_index_elements_type(types, op_name): | |||
| return ALL_TENSOR | |||
| if tensors_number == 0: | |||
| return NO_TENSOR | |||
| raise IndexError(f"For '{op_name}', the index does not support mixed tensor.") | |||
| return CONTAIN_TENSOR | |||
| @constexpr | |||
| @@ -279,12 +275,12 @@ def check_value_elements(data_dtype, types): | |||
| tensors_number += 1 | |||
| else: | |||
| raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' " | |||
| f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.") | |||
| f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.") | |||
| elif mstype.issubclass_(ele, data_dtype): | |||
| scalars_number += 1 | |||
| else: | |||
| raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in " | |||
| f"value tuple is not consistent with origin tensor data type '{data_dtype}'.") | |||
| f"value tuple is not consistent with assigned tensor data type '{data_dtype}'.") | |||
| if tensors_number == len(types): | |||
| return ALL_TENSOR | |||
| if scalars_number == len(types): | |||
| @@ -299,51 +295,46 @@ def get_index_tensor_dtype(dtype): | |||
| return INT_ | |||
| if dtype == mstype.bool_: | |||
| return BOOL_ | |||
| raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") | |||
| raise IndexError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") | |||
| @constexpr | |||
| def check_index_tensors_dtype(dtypes, op_name): | |||
| """Check a tuple of tensor data type.""" | |||
| if op_name == TENSOR_GETITEM: | |||
| valid_dtypes = (mstype.int32, mstype.int64) | |||
| elif op_name == TENSOR_SETITEM: | |||
| valid_dtypes = (mstype.int32,) | |||
| else: | |||
| raise ValueError("Unsupported operation.") | |||
| for ele in dtypes: | |||
| if ele in valid_dtypes and ele == dtypes[0]: | |||
| continue | |||
| raise TypeError(f"For '{op_name}', the index tensors data type must be same, " | |||
| f"and should be one of the following: {valid_dtypes}, but got {dtypes}.") | |||
| if not ele == mstype.int32: | |||
| raise IndexError(f"For '{op_name}', the all index tensor " | |||
| f"data types should be mstype.int32, but got {dtypes}.") | |||
| return True | |||
| @constexpr | |||
| def check_tensor_dtype_valid(dtype, valid_dtypes): | |||
| def check_index_tensor_dtype(dtype, op_name): | |||
| """Check a tensor data type.""" | |||
| if dtype in valid_dtypes: | |||
| if dtype == mstype.int32: | |||
| return True | |||
| raise TypeError(f"The index tensor data type must be one of " | |||
| f"the following: {valid_dtypes}, but got {dtype}.") | |||
| raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.") | |||
| @constexpr | |||
| def check_tensors_dtype_same(x_dtype, y_dtype, op_name): | |||
| def check_tensors_dtype_same(data_dtype, value_dtype, op_name): | |||
| """Check tensors data type same.""" | |||
| if x_dtype == y_dtype: | |||
| if value_dtype == data_dtype: | |||
| return True | |||
| raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' " | |||
| f"is not consistent with origin tensor data type {x_dtype}.") | |||
| raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' " | |||
| f"is not consistent with assigned tensor data type {data_dtype}.") | |||
| @constexpr | |||
| def broadcast_shapes(shapes, op_name): | |||
| """Broadcasts a tuple of tensor.""" | |||
| def generate_broadcast_shape(shapes, op_name): | |||
| """Generate broadcast shape for a tuple of shape.""" | |||
| broadcast_shape = shapes[0] | |||
| for i, shape in enumerate(shapes): | |||
| logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") | |||
| broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) | |||
| try: | |||
| broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) | |||
| except ValueError as ex: | |||
| raise IndexError(ex) | |||
| return tuple(broadcast_shape) | |||
| @@ -366,14 +357,82 @@ def check_two_shapes_need_broadcast(shape_x, shape_y): | |||
| @constexpr | |||
| def compute_multiples(origin_shape, broadcast_shape): | |||
| """Compute multiples between broadcast_shape with origin_shape.""" | |||
| """Compute multiples between origin shape with broadcast shape.""" | |||
| len_gap = len(broadcast_shape) - len(origin_shape) | |||
| return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) | |||
| def tile(broadcast_shape, x): | |||
| multiples = compute_multiples(F.shape(x), broadcast_shape) | |||
| return F.tile(x, multiples) | |||
| @constexpr | |||
| def compute_new_shape(origin_shape, indexes_shapes_info): | |||
| """Compute new shape between origin shape with final shape.""" | |||
| new_shape = [] | |||
| for i in indexes_shapes_info: | |||
| if i == origin_shape: | |||
| new_shape.extend(origin_shape) | |||
| else: | |||
| new_shape.append(1) | |||
| return tuple(new_shape) | |||
| @constexpr | |||
| def convert_ellipsis_to_tensors(slice_number, | |||
| ellipsis_occupied_dims, | |||
| final_shape, | |||
| indexes_shapes_info, | |||
| op_name): | |||
| """Convert an ellipsis to a list of tensor.""" | |||
| tensor_list = [] | |||
| dims_dealt_count = 0 | |||
| while dims_dealt_count < ellipsis_occupied_dims: | |||
| shape = [] | |||
| slice_count = 0 | |||
| array = None | |||
| for ele in indexes_shapes_info: | |||
| if isinstance(ele, list): | |||
| if slice_count == slice_number: | |||
| array = np.array(ele, np.int32) | |||
| shape.append(len(ele)) | |||
| else: | |||
| shape.append(1) | |||
| slice_count += 1 | |||
| if isinstance(ele, tuple): | |||
| shape.extend([1] * len(ele)) | |||
| if array is None: | |||
| raise ValueError(f"For '{op_name}', generate tensors from ellipsis failed.") | |||
| array = np.reshape(array, shape) | |||
| reps = compute_multiples(shape, final_shape) | |||
| tensor = Tensor(np.tile(array, reps)) | |||
| tensor_list.append(tensor) | |||
| slice_number += 1 | |||
| dims_dealt_count += 1 | |||
| return tensor_list | |||
| @constexpr | |||
| def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name): | |||
| """Convert a slice to a tensor.""" | |||
| shape = [] | |||
| count = 0 | |||
| array = None | |||
| for ele in indexes_shapes_info: | |||
| if isinstance(ele, list): | |||
| if count == slice_number: | |||
| array = np.array(ele, np.int32) | |||
| shape.append(len(ele)) | |||
| else: | |||
| # When the slice is not the slice looking for, the shape is filled with 1. | |||
| shape.append(1) | |||
| count += 1 | |||
| elif isinstance(ele, tuple): | |||
| shape.extend([1] * len(ele)) | |||
| else: | |||
| shape.append(1) | |||
| if array is None: | |||
| raise ValueError(f"For '{op_name}', generate tensor from 'slice' failed.") | |||
| array = np.reshape(array, shape) | |||
| reps = compute_multiples(shape, final_shape) | |||
| tensor = Tensor(np.tile(array, reps)) | |||
| return tensor | |||
| @constexpr | |||
| @@ -381,8 +440,8 @@ def check_shapes_same(value_shapes, op_name): | |||
| """Check if the shapes in the tuple are consistent.""" | |||
| for i, shape in enumerate(value_shapes): | |||
| if shape != value_shapes[0]: | |||
| raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple " | |||
| f"is not same as the first tensor shape.") | |||
| raise ValueError(f"For '{op_name}', the {i}th tensor shape in " | |||
| f"value tuple is not same as the first tensor shape.") | |||
| return True | |||
| @@ -396,7 +455,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty | |||
| if isinstance(value, mstype.dtype_to_pytype(data_dtype)): | |||
| return Tensor(np.full(updates_shape, value), dtype=data_dtype) | |||
| raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'" | |||
| f" is not consistent with tensor data type {data_dtype}.") | |||
| f" is not consistent with the assigned tensor data type {data_dtype}.") | |||
| @constexpr | |||
| @@ -404,8 +463,8 @@ def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value | |||
| """Convert a tuple of scalar to a tensor.""" | |||
| updates_shape = generate_updates_shape(data_shape, index_shape, op_type) | |||
| if len(value) != updates_shape[-1]: | |||
| raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple " | |||
| f"does not meet the requirements: {updates_shape[-1]}.") | |||
| raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} " | |||
| f"in the updates tuple does not meet the requirements: {updates_shape[-1]}.") | |||
| array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype)) | |||
| reps = compute_multiples(updates_shape[-1:], updates_shape) | |||
| return Tensor(np.tile(array, reps)) | |||
| @@ -430,58 +489,145 @@ def check_number_of_index_tensor(data_shape, tuple_len, op_name): | |||
| f"is greater than the dimension {len(data_shape)} of the operated tensor.") | |||
| def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name): | |||
| """Generate an indices tensor from a tuple of tensor.""" | |||
| indices = None | |||
| check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) | |||
| if check_index_tensor_number: | |||
| dtype_tuple = hyper_map(F.dtype, tuple_index) | |||
| check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name) | |||
| if check_dtypes: | |||
| shape_tuple = hyper_map(F.shape, tuple_index) | |||
| broadcast_shape = broadcast_shapes(shape_tuple, op_name) | |||
| broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index) | |||
| indices = pack(broadcast_tensors) | |||
| return indices | |||
| def generate_updates_from_scalar(data, indices, value, op_type): | |||
| """Generate an updates tensor from a scalar.""" | |||
| data_shape = F.shape(data) | |||
| indices_shape = F.shape(indices) | |||
| data_dtype = F.dtype(data) | |||
| return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) | |||
| def generate_updates_from_tuple(data, index, value, op_type): | |||
| """Generate an updates tensor from a tuple.""" | |||
| value_types = hyper_map(F.typeof, value) | |||
| data_dtype = F.dtype(data) | |||
| value_elements_type = check_value_elements(data_dtype, value_types) | |||
| if value_elements_type == ALL_TENSOR: | |||
| value_shapes = hyper_map(F.shape, value) | |||
| shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM) | |||
| if shapes_same: | |||
| value = F.pack(value) | |||
| return generate_updates_from_tensor(data, index, value, op_type) | |||
| data_shape = F.shape(data) | |||
| index_shape = F.shape(index) | |||
| return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) | |||
| def generate_updates_from_tensor(data, index, value, op_type): | |||
| """Generate an updates tensor from a tensor.""" | |||
| data_shape = F.shape(data) | |||
| index_shape = F.shape(index) | |||
| value_shape = F.shape(value) | |||
| data_dtype = F.dtype(data) | |||
| value_dtype = F.dtype(value) | |||
| updates_shape = value_shape | |||
| check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM) | |||
| if check_dtype_same: | |||
| updates_shape = generate_updates_shape(data_shape, index_shape, op_type) | |||
| need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape) | |||
| if need_broadcast: | |||
| return tile(updates_shape, value) | |||
| return value | |||
| @constexpr | |||
| def generate_index_info_from_tuple_of_mixed_tensors(data_shape, | |||
| indexes_types, | |||
| tensor_indexes_shapes, | |||
| tensor_indexes_dtypes, | |||
| slice_indexes, | |||
| op_name): | |||
| """ | |||
| Generate index info which contain broadcast shape, final shape, | |||
| indexes shapes info, ellipsis size from a tuple of mixed tensors. | |||
| """ | |||
| check_index_tensors_dtype(tensor_indexes_dtypes, op_name) | |||
| data_rank = len(data_shape) | |||
| indexes_size = len(indexes_types) | |||
| if indexes_size > data_rank: | |||
| raise IndexError(f"For '{op_name}', the number {indexes_size} of index elements " | |||
| f"is greater than the dimension {len(data_shape)} of the operated tensor.") | |||
| indexes_info = {} | |||
| index_tensors_info = {} | |||
| ellipsis_num = 0 | |||
| ellipsis_occupied_dims = 0 | |||
| tensor_count = 0 | |||
| slice_count = 0 | |||
| for i, ele_type in enumerate(indexes_types): | |||
| if ellipsis_num == 0: | |||
| pos = i | |||
| else: | |||
| pos = i + ellipsis_occupied_dims - 1 | |||
| if isinstance(ele_type, mstype.tensor_type): | |||
| indexes_info[pos] = tensor_indexes_shapes[tensor_count] | |||
| index_tensors_info[pos] = tensor_indexes_shapes[tensor_count] | |||
| tensor_count += 1 | |||
| elif isinstance(ele_type, mstype.slice_type): | |||
| slice_obj = slice(slice_indexes[slice_count].start, | |||
| slice_indexes[slice_count].end, | |||
| slice_indexes[slice_count].step) | |||
| # Use list to represent slicing result. | |||
| indexes_info[pos] = list(range(data_shape[pos]))[slice_obj] | |||
| slice_count += 1 | |||
| elif isinstance(ele_type, mstype.ellipsis_type): | |||
| if ellipsis_num != 0: | |||
| raise IndexError(f"For '{op_name}', the index could only contain one ellipsis.") | |||
| ellipsis_occupied_dims = data_rank - indexes_size + 1 | |||
| for j in range(pos, pos + ellipsis_occupied_dims): | |||
| # Use list to represent slicing result. | |||
| indexes_info[j] = list(range(data_shape[j])) | |||
| ellipsis_num += 1 | |||
| else: | |||
| raise IndexError(f"For '{op_name}', the index elements only support " | |||
| f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.") | |||
| broadcast_shape, final_shape, indexes_shapes_info = \ | |||
| _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name) | |||
| return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims | |||
| def _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key: list): | |||
| """Determine whether the tensor in the index appears continuously.""" | |||
| for i in range(len(index_tensor_info_key) - 1): | |||
| if index_tensor_info_key[i + 1] != index_tensor_info_key[i] + 1: | |||
| return False | |||
| return True | |||
| def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name): | |||
| """Derive the resulting shape information from the a tuple index of mixed tensors.""" | |||
| index_tensor_info_key = list(index_tensors_info.keys()) | |||
| index_tensor_info_value = list(index_tensors_info.values()) | |||
| broadcast_shape = generate_broadcast_shape(index_tensor_info_value, op_name) | |||
| final_shape = [] | |||
| indexes_shapes_info = [] | |||
| mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key) | |||
| if mixed_tensors_continuous: | |||
| tensor_shape_dealt = False | |||
| for ele in indexes_info.values(): | |||
| if isinstance(ele, list): | |||
| final_shape.append(len(ele)) | |||
| indexes_shapes_info.append(ele) | |||
| elif isinstance(ele, tuple): | |||
| if not tensor_shape_dealt: | |||
| final_shape.extend(broadcast_shape) | |||
| indexes_shapes_info.append(broadcast_shape) | |||
| tensor_shape_dealt = True | |||
| else: | |||
| raise IndexError(f"For '{op_name}', the index elements only support " | |||
| f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.") | |||
| else: | |||
| final_shape.extend(broadcast_shape) | |||
| indexes_shapes_info.append(broadcast_shape) | |||
| for ele in indexes_info.values(): | |||
| if isinstance(ele, list): | |||
| final_shape.append(len(ele)) | |||
| indexes_shapes_info.append(ele) | |||
| elif isinstance(ele, tuple): | |||
| continue | |||
| else: | |||
| raise IndexError(f"For '{op_name}', the index elements only support " | |||
| f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.") | |||
| return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info) | |||
| @constexpr | |||
| def get_pos_of_int_index(indexes_types): | |||
| """Get int index positions from the mixed tensors index which contains int, tensor, slice, and ellipsis.""" | |||
| int_positions = [] | |||
| for i, ele_type in enumerate(indexes_types): | |||
| if ele_type == mstype.int32: | |||
| int_positions.append(i) | |||
| return int_positions | |||
| @constexpr | |||
| def separate_mixed_tensors_index(indexes_types, op_name): | |||
| """Separate the position information of tensor and slice and ellipsis from the mixed tensors index.""" | |||
| tensor_positions = [] | |||
| slice_positions = [] | |||
| ellipsis_position = None | |||
| for i, ele_type in enumerate(indexes_types): | |||
| if isinstance(ele_type, mstype.tensor_type): | |||
| tensor_positions.append(i) | |||
| elif isinstance(ele_type, mstype.slice_type): | |||
| slice_positions.append(i) | |||
| elif isinstance(ele_type, mstype.ellipsis_type): | |||
| ellipsis_position = i | |||
| else: | |||
| raise IndexError(f"For '{op_name}', the index elements only support " | |||
| f"'Tensor', 'int32', 'Slice', 'Ellipsis', but got {ele_type}.") | |||
| return tensor_positions, slice_positions, ellipsis_position | |||
| @constexpr | |||
| def scalar_in_sequence(x, y): | |||
| """Determine whether the scalar in the sequence.""" | |||
| if x is None: | |||
| raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, " | |||
| "but the scalar is not.") | |||
| if y is None: | |||
| raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, " | |||
| "but the sequence is not.") | |||
| if x in y: | |||
| return True | |||
| return False | |||
| @@ -14,11 +14,11 @@ | |||
| # ============================================================================ | |||
| """Implementation for getitem.""" | |||
| from . import _utils as multi_utils | |||
| from ..import base | |||
| from . import _compile_utils as compile_utils | |||
| from . import _constexpr_utils as const_utils | |||
| from .. import base | |||
| from ... import functional as F | |||
| from ....common import dtype as mstype | |||
| getitem = base.MultitypeFuncGraph('getitem') | |||
| """ | |||
| @@ -227,7 +227,8 @@ def _tensor_getitem_by_tensor(data, tensor_index): | |||
| Outputs: | |||
| Tensor, element type is same as the element type of data. | |||
| """ | |||
| check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64)) | |||
| check_dtypes = const_utils.check_index_tensor_dtype(F.dtype(tensor_index), | |||
| const_utils.TENSOR_GETITEM) | |||
| result = None | |||
| if check_dtypes: | |||
| result = F.gather(data, tensor_index, 0) | |||
| @@ -246,14 +247,13 @@ def _tensor_getitem_by_tuple(data, tuple_index): | |||
| Outputs: | |||
| Tensor, element type is same as the element type of data. | |||
| """ | |||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_GETITEM) | |||
| result = None | |||
| if index_elements_type == multi_utils.NO_TENSOR: | |||
| result = _tensor_slice(data, tuple_index) | |||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||
| result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index) | |||
| return result | |||
| indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) | |||
| if index_elements_type == const_utils.NO_TENSOR: | |||
| return _tensor_slice(data, tuple_index) | |||
| if index_elements_type == const_utils.ALL_TENSOR: | |||
| return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) | |||
| return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) | |||
| @getitem.register("Tensor", "Ellipsis") | |||
| @@ -273,6 +273,17 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index): | |||
| def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): | |||
| """Tensor getitem by a tuple of tensor.""" | |||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM) | |||
| indices = compile_utils.generate_indices_from_tuple_of_tensor(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_GETITEM) | |||
| result = F.gather_nd(data, indices) | |||
| return result | |||
| def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): | |||
| """Tensor getitem by a tuple of mixed tensor.""" | |||
| indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_GETITEM) | |||
| result = F.gather_nd(data, indices) | |||
| return result | |||
| @@ -0,0 +1,101 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """in_impl""" | |||
| from . import _constexpr_utils as const_utils | |||
| from ... import functional as F | |||
| from ...composite import base | |||
| in_ = base.MultitypeFuncGraph("in") | |||
| """ | |||
| in_ is a metafuncgraph object which will determine if a in b | |||
| using ".register" decorator | |||
| """ | |||
| @in_.register("Number", "Tuple") | |||
| def _number_in_tuple(x, y): | |||
| """ | |||
| Determine if a number in tuple. | |||
| Args: | |||
| x (Number): x | |||
| y (tuple): y | |||
| Returns: | |||
| bool, if x in y return true, x not in y return false. | |||
| """ | |||
| return const_utils.scalar_in_sequence(x, y) | |||
| @in_.register("Number", "List") | |||
| def _number_in_list(x, y): | |||
| """ | |||
| Determine if a number in list. | |||
| Args: | |||
| x (Number): x | |||
| y (list): y | |||
| Returns: | |||
| bool, if x in y return true, x not in y return false. | |||
| """ | |||
| return const_utils.scalar_in_sequence(x, y) | |||
| @in_.register("String", "Tuple") | |||
| def _string_in_tuple(x, y): | |||
| """ | |||
| Determine if a str in a tuple. | |||
| Args: | |||
| x (str): x | |||
| y (tuple): y | |||
| Returns: | |||
| bool, if x in y return true, x not in y return false. | |||
| """ | |||
| return const_utils.scalar_in_sequence(x, y) | |||
| @in_.register("String", "List") | |||
| def _string_in_list(x, y): | |||
| """ | |||
| Determine if a str in a list. | |||
| Args: | |||
| x (str): x | |||
| y (list): y | |||
| Returns: | |||
| bool, if x in y return true, x not in y return false. | |||
| """ | |||
| return const_utils.scalar_in_sequence(x, y) | |||
| @in_.register("String", "Dictionary") | |||
| def _str_in_dict(x, y): | |||
| """ | |||
| Determine if a str in dict. | |||
| Args: | |||
| x: str | |||
| y: dict | |||
| Returns: | |||
| bool, if x in y return true, x not in y return false. | |||
| """ | |||
| return F.in_dict(x, y) | |||
| @@ -15,10 +15,11 @@ | |||
| """Implementation for setitem.""" | |||
| from . import _compile_utils as compile_utils | |||
| from . import _constexpr_utils as const_utils | |||
| from ... import functional as F | |||
| from ...composite import base | |||
| from ....common import dtype as mstype | |||
| from ... import functional as F | |||
| from . import _utils as multi_utils | |||
| setitem = base.MultitypeFuncGraph('setitem') | |||
| @@ -139,8 +140,8 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| index_dtype = F.dtype(index) | |||
| tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) | |||
| if tensor_dtype == multi_utils.INT_: | |||
| tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) | |||
| if tensor_dtype == const_utils.INT_: | |||
| return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor) | |||
| return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor) | |||
| @@ -166,8 +167,8 @@ def _tensor_setitem_by_tensor_with_number(data, index, value): | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| index_dtype = F.dtype(index) | |||
| tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) | |||
| if tensor_dtype == multi_utils.BOOL_: | |||
| tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) | |||
| if tensor_dtype == const_utils.BOOL_: | |||
| return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value) | |||
| return _tensor_setitem_by_int_tensor_with_scalar(data, index, value) | |||
| @@ -190,17 +191,24 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) | |||
| result = None | |||
| if index_elements_type == multi_utils.NO_TENSOR: | |||
| result = _tensor_assgin_number(data, tuple_index, value) | |||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) | |||
| updates = multi_utils.generate_updates_from_scalar(data, indices, value, | |||
| multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| result = F.scatter_nd_update(data, indices, updates) | |||
| return result | |||
| indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if index_elements_type == const_utils.NO_TENSOR: | |||
| return _tensor_assgin_number(data, tuple_index, value) | |||
| if index_elements_type == const_utils.ALL_TENSOR: | |||
| indices = compile_utils.generate_indices_from_tuple_of_tensor(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| else: | |||
| indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| updates = compile_utils.generate_updates_from_scalar(data, | |||
| indices, | |||
| value, | |||
| const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| return F.scatter_nd_update(data, indices, updates) | |||
| @setitem.register("Tensor", "Tuple", "Tensor") | |||
| @@ -221,17 +229,24 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) | |||
| result = None | |||
| if index_elements_type == multi_utils.NO_TENSOR: | |||
| result = _tensor_assgin_tensor(data, tuple_index, value) | |||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) | |||
| updates = multi_utils.generate_updates_from_tensor(data, indices, value, | |||
| multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| result = F.scatter_nd_update(data, indices, updates) | |||
| return result | |||
| indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if index_elements_type == const_utils.NO_TENSOR: | |||
| return _tensor_assgin_tensor(data, tuple_index, value) | |||
| if index_elements_type == const_utils.ALL_TENSOR: | |||
| indices = compile_utils.generate_indices_from_tuple_of_tensor(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| else: | |||
| indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| updates = compile_utils.generate_updates_from_tensor(data, | |||
| indices, | |||
| value, | |||
| const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| return F.scatter_nd_update(data, indices, updates) | |||
| @setitem.register("Tensor", "Tuple", "Tuple") | |||
| @@ -253,15 +268,22 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) | |||
| result = None | |||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) | |||
| updates = multi_utils.generate_updates_from_tuple(data, indices, value, | |||
| multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| result = F.scatter_nd_update(data, indices, updates) | |||
| return result | |||
| indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if index_elements_type == const_utils.ALL_TENSOR: | |||
| indices = compile_utils.generate_indices_from_tuple_of_tensor(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| else: | |||
| indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| updates = compile_utils.generate_updates_from_tuple(data, | |||
| indices, | |||
| value, | |||
| const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| return F.scatter_nd_update(data, indices, updates) | |||
| @setitem.register("Tensor", "Tensor", "Tuple") | |||
| @@ -278,7 +300,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value): | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| index_dtype = F.dtype(index) | |||
| check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64)) | |||
| check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM) | |||
| result = None | |||
| if check_dtype: | |||
| result = _tensor_setitem_by_tensor_with_tuple(data, index, value) | |||
| @@ -331,14 +353,14 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value): | |||
| def _tensor_assgin_number(data, input_slice, value): | |||
| """Givens a scalar assign to tensor by slice""" | |||
| check_result = multi_utils.check_tensor_setitem_index(input_slice) | |||
| check_result = const_utils.check_tensor_setitem_index(input_slice) | |||
| result = None | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| indices = multi_utils.slice2indices(input_slice, data_shape) | |||
| is_tuple_int = multi_utils.tuple_element_is_int(input_slice) | |||
| indices = const_utils.slice2indices(input_slice, data_shape) | |||
| is_tuple_int = const_utils.tuple_element_is_int(input_slice) | |||
| if is_tuple_int: | |||
| indices = multi_utils.integer_to_indices(input_slice, data_shape) | |||
| indices = const_utils.integer_to_indices(input_slice, data_shape) | |||
| result = _tensor_indices_number(data, data_shape, input_slice, indices, value) | |||
| return result | |||
| @@ -347,7 +369,7 @@ def _tensor_assgin_number(data, input_slice, value): | |||
| def _tensor_setitem_with_int_v1(data, index, value): | |||
| """Syntax: A[1] = 3""" | |||
| data_shape = F.shape(data) | |||
| indices = multi_utils.integer_to_indices(index, data_shape) | |||
| indices = const_utils.integer_to_indices(index, data_shape) | |||
| return _tensor_indices_number(data, data_shape, index, indices, value) | |||
| @@ -355,7 +377,7 @@ def _tensor_setitem_with_int_v1(data, index, value): | |||
| def _tensor_setitem_with_int_v2(data, index, value): | |||
| """Syntax: A[1] = Tensor""" | |||
| data_shape = F.shape(data) | |||
| indices = multi_utils.integer_to_indices(index, data_shape) | |||
| indices = const_utils.integer_to_indices(index, data_shape) | |||
| return _tensor_indices_tensor(data, data_shape, index, indices, value) | |||
| @@ -376,7 +398,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value): | |||
| data_size = F.size(data) | |||
| value_shape = F.shape(value) | |||
| value_size = F.size(value) | |||
| check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) | |||
| check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) | |||
| if check_result: | |||
| if data_size == value_size: | |||
| result = F.reshape(value, data_shape) | |||
| @@ -391,13 +413,13 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value): | |||
| def _tensor_assgin_tensor(data, input_slice, value): | |||
| """Assigns a tensor value to the tensor by slice.""" | |||
| result = None | |||
| check_result = multi_utils.check_tensor_setitem_index(input_slice) | |||
| check_result = const_utils.check_tensor_setitem_index(input_slice) | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| indices = multi_utils.slice2indices(input_slice, data_shape) | |||
| is_tuple_int = multi_utils.tuple_element_is_int(input_slice) | |||
| indices = const_utils.slice2indices(input_slice, data_shape) | |||
| is_tuple_int = const_utils.tuple_element_is_int(input_slice) | |||
| if is_tuple_int: | |||
| indices = multi_utils.integer_to_indices(input_slice, data_shape) | |||
| indices = const_utils.integer_to_indices(input_slice, data_shape) | |||
| result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) | |||
| return result | |||
| @@ -407,7 +429,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value): | |||
| data_size = F.size(data) | |||
| data_dtype = F.dtype(data) | |||
| indices_size = F.size(indices) | |||
| indices_size = multi_utils.check_indices(indices_size, index) | |||
| indices_size = const_utils.check_indices(indices_size, index) | |||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||
| condition = F.reshape(condition_1d, data_shape) | |||
| @@ -415,7 +437,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value): | |||
| value_fill = None | |||
| value_size = F.size(value) | |||
| value_size = multi_utils.check_indices_value_size(indices_size, value_size) | |||
| value_size = const_utils.check_indices_value_size(indices_size, value_size) | |||
| if value_size == 1: | |||
| value_fill = F.fill(data_dtype, (indices_size,), 1) | |||
| value = F.cast(value, data_dtype) | |||
| @@ -432,7 +454,7 @@ def _tensor_indices_number(data, data_shape, index, indices, value): | |||
| data_size = F.size(data) | |||
| data_dtype = F.dtype(data) | |||
| indices_size = F.size(indices) | |||
| indices_size = multi_utils.check_indices(indices_size, index) | |||
| indices_size = const_utils.check_indices(indices_size, index) | |||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||
| condition = F.reshape(condition_1d, data_shape) | |||
| @@ -445,16 +467,16 @@ def _tensor_indices_number(data, data_shape, index, indices, value): | |||
| def _tensor_setitem_by_tensor_with_tuple(data, index, value): | |||
| """Set a tensor item by a tensor with a tuple.""" | |||
| updates = multi_utils.generate_updates_from_tuple(data, index, value, | |||
| multi_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| updates = compile_utils.generate_updates_from_tuple(data, index, value, | |||
| const_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| result = F.scatter_update(data, index, updates) | |||
| return result | |||
| def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): | |||
| """Set a tensor item by a int tensor with a scalar.""" | |||
| updates = multi_utils.generate_updates_from_scalar(data, index, value, | |||
| multi_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| updates = compile_utils.generate_updates_from_scalar(data, index, value, | |||
| const_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| return F.scatter_update(data, index, updates) | |||
| @@ -462,7 +484,7 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): | |||
| """Set a tensor item by a bool tensor with a scalar.""" | |||
| index_shape = F.shape(index) | |||
| shape = F.shape(data) | |||
| shape = multi_utils.check_equal( | |||
| shape = const_utils.check_equal( | |||
| shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||
| dtype = F.dtype(data) | |||
| u = F.fill(dtype, shape, value) | |||
| @@ -471,8 +493,8 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): | |||
| def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): | |||
| """Set a tensor item by a int tensor with a tensor.""" | |||
| updates = multi_utils.generate_updates_from_tensor(data, index, value, | |||
| multi_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| updates = compile_utils.generate_updates_from_tensor(data, index, value, | |||
| const_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| return F.scatter_update(data, index, updates) | |||
| @@ -480,10 +502,10 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value): | |||
| """Set a tensor item by a bool tensor with a tensor.""" | |||
| index_shape = F.shape(index) | |||
| data_shape = F.shape(data) | |||
| data_shape = multi_utils.check_equal(data_shape, index_shape, | |||
| data_shape = const_utils.check_equal(data_shape, index_shape, | |||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||
| size = F.size(value) | |||
| size = multi_utils.check_equal(1, size, | |||
| size = const_utils.check_equal(1, size, | |||
| "When assign value is a tensor, its size should be {}, but current size is {}.") | |||
| dtype = F.dtype(data) | |||
| u_cast = F.cast(value, dtype) | |||
| @@ -1419,7 +1419,6 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name): | |||
| validator.check_value_type("shape", x_shape, [tuple, list], prim_name) | |||
| validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name) | |||
| validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name) | |||
| validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT, prim_name) | |||
| rank_base = len(x_shape[0]) | |||
| N = len(x_shape) | |||
| out_shape = x_shape[0] | |||
| @@ -33,9 +33,4 @@ class IdentityEC(IExectorComponent): | |||
| keyword.desc_inputs: self.inputs[keyword.desc_inputs], | |||
| keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs]) | |||
| } | |||
| print("buxue------------------------------------------------") | |||
| print("inputs") | |||
| print(ret[keyword.desc_inputs]) | |||
| print("outputs") | |||
| print(ret[keyword.result]) | |||
| return ret | |||
| @@ -19,9 +19,9 @@ import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from ..ut_filter import non_graph_engine | |||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| from tests.ut.python.ut_filter import non_graph_engine | |||
| from tests.mindspore_test_framework.mindspore_test import mindspore_test | |||
| from tests.mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config | |||
| @@ -133,7 +133,7 @@ def test_list_append_2(): | |||
| class ListOperate(nn.Cell): | |||
| def __init__(self,): | |||
| def __init__(self, ): | |||
| super(ListOperate, self).__init__() | |||
| def construct(self, t, l): | |||
| @@ -152,6 +152,20 @@ class ListOperate(nn.Cell): | |||
| return x | |||
| class InListNet(nn.Cell): | |||
| def __init__(self, ): | |||
| super(InListNet, self).__init__() | |||
| self.list_ = [1, 2, 3, 4, 5, "ok"] | |||
| def construct(self, x): | |||
| ret = x | |||
| if 2 in self.list_: | |||
| ret = x + x | |||
| if "ok" in self.list_: | |||
| ret = x - x | |||
| return ret | |||
| class AxisListNet(nn.Cell): | |||
| def __init__(self): | |||
| super(AxisListNet, self).__init__() | |||
| @@ -204,10 +218,15 @@ test_case_ops = [ | |||
| ('AxisListDefault', { | |||
| 'block': AxisListDefaultNet(), | |||
| 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), | |||
| ('InList', { | |||
| 'block': InListNet(), | |||
| 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), | |||
| ] | |||
| test_case_lists = [test_case_ops] | |||
| test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) | |||
| # use -k to select certain testcast | |||
| # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm | |||
| @@ -19,9 +19,9 @@ import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import dtype as mstype | |||
| from ..ut_filter import non_graph_engine | |||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| from tests.ut.python.ut_filter import non_graph_engine | |||
| from tests.mindspore_test_framework.mindspore_test import mindspore_test | |||
| from tests.mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| @@ -52,6 +52,20 @@ class NestTupleGraphNet(nn.Cell): | |||
| return self.layers[0][1](x) | |||
| class InTupleNet(nn.Cell): | |||
| def __init__(self, ): | |||
| super(InTupleNet, self).__init__() | |||
| self.tuple_ = (1, 2, 3, 4, 5, "ok") | |||
| def construct(self, x): | |||
| ret = x | |||
| if 2 in self.tuple_: | |||
| ret = x + x | |||
| if "ok" in self.tuple_: | |||
| ret = x - x | |||
| return ret | |||
| test_case_ops = [ | |||
| ('TupleGraph', { | |||
| 'block': TupleGraphNet(), | |||
| @@ -59,6 +73,9 @@ test_case_ops = [ | |||
| ('NestTupleGraph', { | |||
| 'block': NestTupleGraphNet(), | |||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), | |||
| ('InTuple', { | |||
| 'block': InTupleNet(), | |||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}) | |||
| ] | |||
| test_case_lists = [test_case_ops] | |||
| @@ -176,12 +176,134 @@ class TensorGetItemByThreeTensors(Cell): | |||
| return ret | |||
| class TensorGetItemByMixedTensors(Cell): | |||
| class TensorGetItemByMixedTensors_0(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors, self).__init__() | |||
| super(TensorGetItemByMixedTensors_0, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32)) | |||
| def construct(self, tensor, index_0, index_1): | |||
| ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const | |||
| return ret | |||
| class TensorGetItemByMixedTensors_1(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors_1, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32)) | |||
| def construct(self, tensor, index_0, index_1): | |||
| ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const | |||
| return ret | |||
| class TensorGetItemByMixedTensors_2(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors_2, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32)) | |||
| def construct(self, tensor, index_0, index_1): | |||
| ret = tensor[0, index_0, index_1, ..., 3] + self.const | |||
| return ret | |||
| class TensorGetItemByMixedTensors_3(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors_3, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32)) | |||
| def construct(self, tensor, index_0, index_1): | |||
| ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const | |||
| return ret | |||
| class TensorGetItemByMixedTensors_4(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors_4, self).__init__() | |||
| self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32)) | |||
| def construct(self, tensor, index_0, index_1, index_2): | |||
| ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const | |||
| return ret | |||
| class TensorGetItemByMixedTensors_5(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors_5, self).__init__() | |||
| self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32)) | |||
| def construct(self, tensor, index_0, index_1, index_2): | |||
| ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const | |||
| return ret | |||
| class TensorGetItemByMixedTensors_6(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors_6, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32)) | |||
| def construct(self, tensor, index_0, index_1, index_2): | |||
| ret = tensor[..., index_0, index_1, index_2, 3] + self.const | |||
| return ret | |||
| class TensorSetItemByMixedTensors_0(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByMixedTensors_0, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32)) | |||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), | |||
| mstype.float32), | |||
| name="x") | |||
| self.value = value | |||
| def construct(self, index_0, index_1, index_2): | |||
| self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value | |||
| ret = self.param + self.const | |||
| return ret | |||
| class TensorSetItemByMixedTensors_1(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByMixedTensors_1, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32)) | |||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||
| name="x") | |||
| self.value = value | |||
| def construct(self, index_0, index_1, index_2): | |||
| self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value | |||
| ret = self.param + self.const | |||
| return ret | |||
| class TensorSetItemByMixedTensors_2(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByMixedTensors_2, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32)) | |||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||
| name="x") | |||
| self.value = value | |||
| def construct(self, index_0, index_1, index_2): | |||
| self.param[..., index_0, index_1, index_2, 3] = self.value | |||
| ret = self.param + self.const | |||
| return ret | |||
| class TensorGetItemByMixedTensorsTypeError(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensorsTypeError, self).__init__() | |||
| def construct(self, x, index_0, index_1): | |||
| ret = x[index_0, index_1, 0:6] | |||
| ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]] | |||
| return ret | |||
| class TensorGetItemByMixedTensorsNumberError(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensorsNumberError, self).__init__() | |||
| def construct(self, x, index_0, index_1): | |||
| ret = x[index_0, index_1, 0:3, ..., index_1, index_0] | |||
| return ret | |||
| @@ -189,7 +311,7 @@ class TensorSetItemByOneTensorWithNumber(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByOneTensorWithNumber, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.value = value | |||
| def construct(self, index): | |||
| @@ -202,7 +324,7 @@ class TensorSetItemByOneTensorWithTensor(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByOneTensorWithTensor, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| def construct(self, index, value): | |||
| self.param[index] = value | |||
| @@ -214,7 +336,7 @@ class TensorSetItemByOneTensorWithTupleOfNumber(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.value = value | |||
| def construct(self, index): | |||
| @@ -227,7 +349,7 @@ class TensorSetItemByOneTensorWithTupleOfTensor(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__() | |||
| self.const = Tensor(np.ones((6, 3, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*3*8).reshape((6, 3, 8)), mstype.float32), name="x") | |||
| self.param = Parameter(Tensor(np.arange(6 * 3 * 8).reshape((6, 3, 8)), mstype.float32), name="x") | |||
| def construct(self, index, value_0, value_1, value_2): | |||
| self.param[index] = (value_0, value_1, value_2) | |||
| @@ -239,7 +361,7 @@ class TensorSetItemByTensorsWithNumber(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByTensorsWithNumber, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.value = value | |||
| def construct(self, index_0, index_1, index_2): | |||
| @@ -252,7 +374,7 @@ class TensorSetItemByTensorsWithTensor(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByTensorsWithTensor, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| def construct(self, index_0, index_1, index_2, value): | |||
| self.param[index_0, index_1, index_2] = value | |||
| @@ -264,7 +386,7 @@ class TensorSetItemByTensorsWithTensorNumberError(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByTensorsWithTensorNumberError, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| def construct(self, index_0, index_1, index_2, index_3, value): | |||
| self.param[index_0, index_1, index_2, index_3] = value | |||
| @@ -276,7 +398,7 @@ class TensorSetItemByTensorsWithTupleOfNumber(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.value = value | |||
| def construct(self, index_0, index_1, index_2): | |||
| @@ -289,7 +411,7 @@ class TensorSetItemByTensorsWithTupleOfTensor(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| def construct(self, index_0, index_1, index_2, value_0, value_1, value_2): | |||
| self.param[index_0, index_1, index_2] = (value_0, value_1, value_2) | |||
| @@ -301,7 +423,7 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| def construct(self, index_0, index_1, index_2, value_0, value_1): | |||
| self.param[index_0, index_1, index_2] = (value_0, value_1) | |||
| @@ -313,7 +435,7 @@ class TensorSetItemByMixedTensors(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByMixedTensors, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.value = 99.0 | |||
| def construct(self, index_0, index_1): | |||
| @@ -538,11 +660,11 @@ def test_tensor_assign_bool_index(): | |||
| net1(Ta, Tb, Tc, u_tensor) | |||
| with pytest.raises(ValueError): | |||
| net1(Ta, Td, Tc, u_tensor) | |||
| with pytest.raises(TypeError): | |||
| with pytest.raises(IndexError): | |||
| net1(Ta, u_tensor, Tc, u_tensor) | |||
| with pytest.raises(ValueError): | |||
| net1(Ta, Tb, Td, u_tensor) | |||
| with pytest.raises(TypeError): | |||
| with pytest.raises(IndexError): | |||
| net1(Ta, Tb, Ta, u_tensor) | |||
| with pytest.raises(ValueError): | |||
| net1(Ta, Tb, Tc, u_tensor_error) | |||
| @@ -620,22 +742,67 @@ test_cases = [ | |||
| }), | |||
| ('TensorGetItemByOneTensor', { | |||
| 'block': TensorGetItemByOneTensor(), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByTwoTensors', { | |||
| 'block': TensorGetItemByTwoTensors(), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByThreeTensors', { | |||
| 'block': TensorGetItemByThreeTensors(), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensors_0', { | |||
| 'block': TensorGetItemByMixedTensors_0(), | |||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensors_1', { | |||
| 'block': TensorGetItemByMixedTensors_1(), | |||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensors_2', { | |||
| 'block': TensorGetItemByMixedTensors_2(), | |||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensors_3', { | |||
| 'block': TensorGetItemByMixedTensors_3(), | |||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensors_4', { | |||
| 'block': TensorGetItemByMixedTensors_4(), | |||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.float32), | |||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensors_5', { | |||
| 'block': TensorGetItemByMixedTensors_5(), | |||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensors_6', { | |||
| 'block': TensorGetItemByMixedTensors_6(), | |||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByOneTensorWithNumber', { | |||
| 'block': TensorSetItemByOneTensorWithNumber(value=0.0), | |||
| 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)], | |||
| @@ -683,46 +850,143 @@ test_cases = [ | |||
| Tensor(np.zeros((4, 5)), mstype.float32), | |||
| Tensor(np.ones((4, 5)), mstype.float32), | |||
| Tensor(np.ones((4, 5)) * 2, mstype.float32)], | |||
| }) | |||
| }), | |||
| ('TensorSetItemByMixedTensorsWithNumber_0', { | |||
| 'block': TensorSetItemByMixedTensors_0(value=88.0), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByMixedTensorsWithTensor_0', { | |||
| 'block': TensorSetItemByMixedTensors_0(value=Tensor(np.ones((4, 5, 3, 9), np.float32))), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensorsWithTupleOfNumber_0', { | |||
| 'block': TensorSetItemByMixedTensors_0(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensorsWithTupleOfTensor_0', { | |||
| 'block': TensorSetItemByMixedTensors_0(value=(Tensor(np.ones((4, 5, 3, 9), np.float32)), | |||
| Tensor(np.zeros((4, 5, 3, 9), np.float32)), | |||
| Tensor(np.ones((4, 5, 3, 9), np.float32)))), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByMixedTensorsWithNumber_1', { | |||
| 'block': TensorSetItemByMixedTensors_1(value=88.0), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByMixedTensorsWithTensor_1', { | |||
| 'block': TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensorsWithTupleOfNumber_1', { | |||
| 'block': TensorSetItemByMixedTensors_1(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensorsWithTupleOfTensor_1', { | |||
| 'block': TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)), | |||
| Tensor(np.zeros((5, 2, 6), np.float32)), | |||
| Tensor(np.ones((5, 2, 6), np.float32)), | |||
| Tensor(np.ones((5, 2, 6), np.float32)))), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByMixedTensorsWithNumber_2', { | |||
| 'block': TensorSetItemByMixedTensors_2(value=88.0), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByMixedTensorsWithTensor_2', { | |||
| 'block': TensorSetItemByMixedTensors_2(value=Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensorsWithTupleOfNumber_2', { | |||
| 'block': TensorSetItemByMixedTensors_2(value=(1.0, 2.0, 3.0, 4.0, 5.0)), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensorsWithTupleOfTensor_2', { | |||
| 'block': TensorSetItemByMixedTensors_2(value=(Tensor(np.ones((4, 5), np.float32)), | |||
| Tensor(np.zeros((4, 5), np.float32)), | |||
| Tensor(np.ones((4, 5), np.float32)))), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ] | |||
| raise_error_set = [ | |||
| ('TensorGetItemByOneTensorDtypeError', { | |||
| 'block': (TensorGetItemByOneTensor(), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| 'block': (TensorGetItemByOneTensor(), {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)], | |||
| }), | |||
| ('TensorGetItemByTwoTensorsShapeError', { | |||
| 'block': (TensorGetItemByTwoTensors(), {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| 'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByTwoTensorsDtypeError', { | |||
| 'block': (TensorGetItemByTwoTensors(), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| 'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)], | |||
| }), | |||
| ('TensorGetItemByThreeTensorsShapeError', { | |||
| 'block': (TensorGetItemByThreeTensors(), {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| 'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByThreeTensorsDtypeError', { | |||
| 'block': (TensorGetItemByThreeTensors(), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| 'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int64), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensors', { | |||
| 'block': (TensorGetItemByMixedTensors(), {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| ('TensorGetItemByMixedTensorsNumberError', { | |||
| 'block': (TensorGetItemByMixedTensorsNumberError(), {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64)], | |||
| Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensorsTypeError', { | |||
| 'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensorsDtypeError', { | |||
| 'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.float32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensorsShapeError', { | |||
| 'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(2, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByOneTensorWithNumberTypeError', { | |||
| 'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}), | |||
| @@ -760,21 +1024,21 @@ raise_error_set = [ | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTensorShapeError', { | |||
| 'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}), | |||
| 'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||
| Tensor(np.zeros((2, 5)), mstype.float32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTensorTypeError', { | |||
| 'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}), | |||
| 'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||
| Tensor(np.zeros((4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTensorNumberError', { | |||
| 'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}), | |||
| 'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||
| @@ -782,19 +1046,19 @@ raise_error_set = [ | |||
| Tensor(np.zeros((2, 5)), mstype.float32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTupleOfNumberTypeError', { | |||
| 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}), | |||
| 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1, 2, 3, 4)), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTupleOfNumberNumberError', { | |||
| 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}), | |||
| 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTupleOfTensorNumberError', { | |||
| 'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}), | |||
| 'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||
| @@ -802,7 +1066,7 @@ raise_error_set = [ | |||
| Tensor(np.ones((4, 5)), mstype.float32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTupleOfTensorTypeError', { | |||
| 'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}), | |||
| 'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||
| @@ -810,10 +1074,65 @@ raise_error_set = [ | |||
| Tensor(np.ones((4, 5)), mstype.int32), | |||
| Tensor(np.ones((4, 5)) * 2, mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByMixedTensors', { | |||
| 'block': (TensorSetItemByMixedTensors(), {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], | |||
| ('TensorSetItemByMixedTensorsWithNumberValueTypeError', { | |||
| 'block': (TensorSetItemByMixedTensors_1(value=88), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByMixedTensorsWithNumberIndexTypeError', { | |||
| 'block': (TensorSetItemByMixedTensors_1(value=88.0), {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.float32)], | |||
| }), | |||
| ('TensorSetItemByMixedTensorsWithTensorValueDtypeError', { | |||
| 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.int32))), | |||
| {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByMixedTensorsWithTensorValueShapeError', { | |||
| 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((3, 2, 6), np.float32))), | |||
| {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByMixedTensorsWithTensorIndexDtypeError', { | |||
| 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))), | |||
| {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.float32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensorsWithTupleOfNumberValueTypeError', { | |||
| 'block': (TensorSetItemByMixedTensors_1(value=(1.0, 2, 3.0, 4.0, 5.0, 6.0)), | |||
| {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensorsWithTupleOfTensorValueDtypeError', { | |||
| 'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)), | |||
| Tensor(np.zeros((5, 2, 6), np.float32)), | |||
| Tensor(np.ones((5, 2, 6), np.float32)), | |||
| Tensor(np.ones((5, 2, 6), np.int32)))), | |||
| {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorGetItemByMixedTensorsWithTupleOfTensorIndexDtypeError', { | |||
| 'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)), | |||
| Tensor(np.zeros((5, 2, 6), np.float32)), | |||
| Tensor(np.ones((5, 2, 6), np.float32)), | |||
| Tensor(np.ones((5, 2, 6), np.int32)))), | |||
| {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.float32), | |||
| Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], | |||
| }) | |||
| ] | |||