| @@ -23,7 +23,7 @@ from ....common import dtype as mstype | |||
| from ....common._register_for_tensor import tensor_operator_registry | |||
| hyper_map = base.HyperMap() | |||
| pack = P.Stack(axis=-1) | |||
| stack = P.Stack(axis=-1) | |||
| def _tensor_getitem(self, index): | |||
| @@ -36,44 +36,35 @@ def _tensor_getitem(self, index): | |||
| return tensor_index_by_tuple(self, (index,)) | |||
| raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, " | |||
| f"list and tuple ,but got {index} with type {type(index)}.") | |||
| tensor_operator_registry.register("__getitem__", _tensor_getitem) | |||
| def _tensor_setitem(self, index, value): | |||
| """Handle tensor getitem""" | |||
| """Handle tensor setitem""" | |||
| if not isinstance(value, (int, float, bool, list, tuple, Tensor)): | |||
| raise ValueError(f"only support numbers, Tensor, tuple, list as value," | |||
| f"but got {value} with type {type(value)}.") | |||
| if isinstance(index, list): | |||
| index = format_list_indices(index, self.shape[0]) | |||
| if isinstance(index, Tensor): | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_tensor_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_tensor_with_tensor(self, index, value) | |||
| if isinstance(value, tuple): | |||
| return tensor_setitem_by_tensor_with_tuple(self, index, value) | |||
| return tensor_setitem_by_tensor(self, index, value) | |||
| if isinstance(index, tuple): | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_tuple_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_tuple_with_tensor(self, index, value) | |||
| if isinstance(value, tuple): | |||
| return tensor_setitem_by_tuple_with_tuple(self, index, value) | |||
| if tuple_indices_have_false(index): | |||
| return self | |||
| index = format_tuple_indices(index) | |||
| return tensor_setitem_by_tuple(self, index, value) | |||
| if isinstance(index, bool): | |||
| return tensor_setitem_by_bool(self, index, value) | |||
| if isinstance(index, int): | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_number_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_number_with_tensor(self, index, value) | |||
| return tensor_setitem_by_number(self, index, value) | |||
| if isinstance(index, slice): | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_slice_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_slice_with_tensor(self, index, value) | |||
| if isinstance(index, bool): | |||
| return _tensor_index_by_bool(self, index) | |||
| return tensor_setitem_by_slice(self, index, value) | |||
| if index is ...: | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_ellipsis_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_ellipsis_with_tensor(self, index, value) | |||
| return tensor_setitem_by_ellipsis(self, index, value) | |||
| raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\ | |||
| and tensor with int32, got {} with type{}".format(index, type(index))) | |||
| tensor_operator_registry.register("__setitem__", _tensor_setitem) | |||
| def _broadcast(broadcast_shape, x): | |||
| """Broadcast tensor to the required shape.""" | |||
| @@ -103,7 +94,8 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name): | |||
| ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) + | |||
| len(tensor_positions) + len(sequence_positions)) | |||
| ellipsis_cnt = len(ellipsis_positions) | |||
| if (ellipsis_cnt == 0 and ellipsis_occupy_dims < 0) or (ellipsis_cnt > 0 and ellipsis_occupy_dims < 1): | |||
| # pylint: disable=chained-comparison | |||
| if ellipsis_occupy_dims < 0 and ellipsis_cnt >= 0: | |||
| const_utils.raise_index_error("For the 'getitem Operator', the data_shape should be no less than the " | |||
| "tuple index dims") | |||
| @@ -155,14 +147,6 @@ def tensor_index_by_number(data, number_index): | |||
| return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") | |||
| # TODO wait to remove after setitem by Yang Linfeng | |||
| def _tensor_index_by_bool(data, bool_index): | |||
| """Tensor getitem by a single bool value""" | |||
| if bool_index: | |||
| return F.expand_dims(data, 0) | |||
| return const_utils.make_tensor([], data.dtype, (0,) + F.shape(data)) | |||
| def _tensor_index_by_integer(data, int_index): | |||
| """Tensor getitem by a single integer number""" | |||
| data_shape = F.shape(data) | |||
| @@ -218,6 +202,31 @@ def tensor_index_by_tuple(data, tuple_index): | |||
| return _tensor_getitem_by_tuple(data, tuple_index, op_name) | |||
| def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name): | |||
| """Tensor getitem by a tuple of tensor.""" | |||
| data_shape = F.shape(data) | |||
| tuple_index_len = len(tuple_index) | |||
| indexes_types = hyper_map(F.dtype, tuple_index) | |||
| const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name) | |||
| tensor_index_shape = hyper_map(F.shape, tuple_index) | |||
| broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) | |||
| if 0 in broadcast_shape: | |||
| res_shape = broadcast_shape | |||
| if tuple_index_len < len(data_shape): | |||
| res_shape += data_shape[tuple_index_len:] | |||
| res = const_utils.make_tensor([], data.dtype, res_shape) | |||
| return res | |||
| broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) | |||
| new_broadcast_tensors = () | |||
| for tensor in broadcast_tensors: | |||
| new_broadcast_tensors += (F.cast(tensor, mstype.int64),) | |||
| indices = stack(new_broadcast_tensors) | |||
| result = F.gather_nd(data, indices) | |||
| return result | |||
| def _tensor_getitem_by_tuple_slice(data, tuple_index): | |||
| """Tensor getitem by a tuple of slice""" | |||
| data_shape = F.shape(data) | |||
| @@ -284,8 +293,9 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name): | |||
| final_index_tensors.append(slice_index_tensor) | |||
| slice_cnt += 1 | |||
| indices = pack(final_index_tensors) | |||
| return F.gather_nd(data, indices) | |||
| indices = stack(final_index_tensors) | |||
| result = F.gather_nd(data, indices) | |||
| return result | |||
| def _generate_indices_from_tuple_of_tensor(tuple_index, op_name): | |||
| @@ -299,7 +309,7 @@ def _generate_indices_from_tuple_of_tensor(tuple_index, op_name): | |||
| new_broadcast_tensors = () | |||
| for tensor in broadcast_tensors: | |||
| new_broadcast_tensors += (F.cast(tensor, mstype.int64),) | |||
| indices = pack(new_broadcast_tensors) | |||
| indices = stack(new_broadcast_tensors) | |||
| return indices | |||
| @@ -332,6 +342,11 @@ def _generate_indices_from_tuple(data, tuple_index, op_name): | |||
| tuple_index_new += (tensor_index,) | |||
| tensor_indexes.append(tensor_index) | |||
| elif i in slice_positions: | |||
| start, stop, _ = const_utils.slice_to_tuple(index) | |||
| start = const_utils.normalize_start(start, dim_size) | |||
| stop = const_utils.normalize_stop(stop, dim_size) | |||
| if start >= stop: | |||
| return None | |||
| slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size) | |||
| slice_shapes += (len(slice_ele_list_index),) | |||
| tuple_index_new += (slice_ele_list_index,) | |||
| @@ -354,7 +369,7 @@ def _generate_indices_from_tuple(data, tuple_index, op_name): | |||
| final_index_tensors.append(slice_index_tensor) | |||
| slice_cnt += 1 | |||
| indices = pack(final_index_tensors) | |||
| indices = stack(final_index_tensors) | |||
| return indices | |||
| @@ -366,44 +381,76 @@ def _generate_updates_from_scalar(data, indices, value, op_type): | |||
| return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) | |||
| def _generate_updates_from_tuple(data, index, value, op_type): | |||
| """Generate an updates tensor from a tuple.""" | |||
| def _generate_updates_from_sequence(data, index, value, op_type): | |||
| """Generate an updates tensor from a tuple, can only handle 1-D tensor/non-tensor mixtures.""" | |||
| value_types = hyper_map(F.typeof, value) | |||
| data_dtype = F.dtype(data) | |||
| value_elements_type = const_utils.check_value_elements(data_dtype, value_types) | |||
| value_elements_type = const_utils.check_value_elements(value_types) | |||
| if value_elements_type == const_utils.ALL_TENSOR: | |||
| value_shapes = hyper_map(F.shape, value) | |||
| shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM) | |||
| if shapes_same: | |||
| value = F.stack(value) | |||
| return _generate_updates_from_tensor(data, index, value, op_type) | |||
| data_shape = F.shape(data) | |||
| index_shape = F.shape(index) | |||
| return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) | |||
| value = F.stack(value).astype(data.dtype) | |||
| elif value_elements_type == const_utils.NO_TENSOR: | |||
| value = const_utils.make_tensor(value, data.dtype) | |||
| else: | |||
| new_value = () | |||
| for ele in value: | |||
| ele = ele if isinstance(ele, Tensor) else const_utils.make_tensor(ele) | |||
| new_value += (ele,) | |||
| value = F.stack(new_value).astype(data.dtype) | |||
| if op_type == const_utils.SET_ITEM_BY_NON_TENSOR: | |||
| return value | |||
| return _generate_updates_from_tensor(data, index, value, op_type) | |||
| def _generate_updates_from_tensor(data, index, value, op_type): | |||
| """Generate an updates tensor from a tensor.""" | |||
| data_shape = F.shape(data) | |||
| index_shape = F.shape(index) | |||
| value_shape = F.shape(value) | |||
| data_dtype = F.dtype(data) | |||
| value_dtype = F.dtype(value) | |||
| updates_shape = value_shape | |||
| check_dtype_same = const_utils.check_tensors_dtype_same(data_dtype, value_dtype, const_utils.TENSOR_SETITEM) | |||
| if check_dtype_same: | |||
| updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type) | |||
| need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape) | |||
| value = value.astype(data.dtype) | |||
| updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type) | |||
| need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape) | |||
| if need_broadcast: | |||
| return _broadcast(updates_shape, value) | |||
| return value | |||
| tensor_operator_registry.register("__getitem__", _tensor_getitem) | |||
| # Tensor getitem implementations are above this line, setitem implementations below. | |||
| tensor_operator_registry.register("__setitem__", _tensor_setitem) | |||
| def tensor_setitem_by_tensor(self, index, value): | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_tensor_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_tensor_with_tensor(self, index, value) | |||
| return tensor_setitem_by_tensor_with_sequence(self, index, value) | |||
| def tensor_setitem_by_tuple(self, index, value): | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_tuple_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_tuple_with_tensor(self, index, value) | |||
| return tensor_setitem_by_tuple_with_sequence(self, index, value) | |||
| def tensor_setitem_by_number(self, index, value): | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_number_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_number_with_tensor(self, index, value) | |||
| return tensor_setitem_by_number_with_sequence(self, index, value) | |||
| def tensor_setitem_by_slice(self, index, value): | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_slice_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_slice_with_tensor(self, index, value) | |||
| return tensor_setitem_by_slice_with_sequence(self, index, value) | |||
| def tensor_setitem_by_ellipsis(self, index, value): | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_ellipsis_with_number(self, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_ellipsis_with_tensor(self, value) | |||
| return tensor_setitem_by_ellipsis_with_sequence(self, value) | |||
| def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): | |||
| @@ -469,17 +516,16 @@ def tensor_setitem_by_tensor_with_number(data, index, value): | |||
| return const_utils.raise_index_error("For tensor setitem, indexing tensor dtype only supports bool/int") | |||
| def tensor_setitem_by_tensor_with_tuple(data, index, value): | |||
| def tensor_setitem_by_tensor_with_sequence(data, index, value): | |||
| """Assigns the tensor by tensor with tuple value.""" | |||
| index_dtype = F.dtype(index) | |||
| const_utils.check_type_valid(index_dtype, (mstype.int32, mstype.int64), const_utils.TENSOR_SETITEM) | |||
| result = _tensor_setitem_by_tensor_with_tuple(data, index, value) | |||
| return result | |||
| return _tensor_setitem_by_tensor_with_sequence(data, index, value) | |||
| def _tensor_indices_number(data, data_shape, index, indices, value): | |||
| """Assigns a scalar value to the tensor.""" | |||
| data_size = F.size(data) | |||
| data_size = F.shape_mul(data.shape) | |||
| data_dtype = F.dtype(data) | |||
| indices_size = F.size(indices) | |||
| indices_size = const_utils.check_indices(indices_size, index) | |||
| @@ -493,9 +539,9 @@ def _tensor_indices_number(data, data_shape, index, indices, value): | |||
| return F.select(condition, u, data) | |||
| def _tensor_setitem_by_tensor_with_tuple(data, index, value): | |||
| def _tensor_setitem_by_tensor_with_sequence(data, index, value): | |||
| """Set a tensor item by a tensor with a tuple.""" | |||
| updates = _generate_updates_from_tuple(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| updates = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| index = F.expand_dims(index, -1) | |||
| return P.TensorScatterUpdate()(data, index, updates) | |||
| @@ -507,6 +553,8 @@ def tensor_setitem_by_slice_with_number(data, input_slice, value): | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| indices = const_utils.slice2indices(input_slice, data_shape) | |||
| if indices is False: | |||
| return data | |||
| is_tuple_int = const_utils.tuple_element_is_int(input_slice) | |||
| if is_tuple_int: | |||
| indices = const_utils.integer_to_indices(input_slice, data_shape) | |||
| @@ -516,6 +564,8 @@ def tensor_setitem_by_slice_with_number(data, input_slice, value): | |||
| def tensor_setitem_by_tuple_with_number(data, tuple_index, value): | |||
| """Assigns the tensor by tuple with number value.""" | |||
| tuple_index = ignore_dim_expand(tuple_index) | |||
| if len(tuple_index) == 1: | |||
| data[tuple_index[0]] = value | |||
| return data | |||
| @@ -533,13 +583,15 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value): | |||
| if int_cnt == const_utils.ALL_INT: | |||
| tuple_index = const_utils.convert_int_to_slice(tuple_index) | |||
| indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) | |||
| if indices is None: | |||
| return data | |||
| updates = _generate_updates_from_scalar(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| return P.TensorScatterUpdate()(data, indices, updates) | |||
| def _tensor_indices_tensor(data, data_shape, index, indices, value): | |||
| """Assigns a tensor value to the tensor.""" | |||
| data_size = F.size(data) | |||
| data_size = F.shape_mul(data.shape) | |||
| data_dtype = F.dtype(data) | |||
| indices_size = F.size(indices) | |||
| indices_size = const_utils.check_indices(indices_size, index) | |||
| @@ -548,7 +600,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value): | |||
| condition = F.reshape(condition_1d, data_shape) | |||
| condition = F.cast(condition, mstype.bool_) | |||
| value_fill = None | |||
| value_size = F.size(value) | |||
| value_size = value.size | |||
| value_size = const_utils.check_indices_value_size(indices_size, value_size) | |||
| if value_size == 1: | |||
| @@ -559,7 +611,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value): | |||
| value_fill = F.reshape(value, (indices_size,)) | |||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||
| u = F.reshape(value_1d, data_shape) | |||
| return F.select(condition, u, data) | |||
| return F.select(condition, u.astype(data_dtype), data) | |||
| def tensor_setitem_by_slice_with_tensor(data, input_slice, value): | |||
| @@ -569,6 +621,8 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value): | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| indices = const_utils.slice2indices(input_slice, data_shape) | |||
| if indices is False: | |||
| return data | |||
| is_tuple_int = const_utils.tuple_element_is_int(input_slice) | |||
| if is_tuple_int: | |||
| indices = const_utils.integer_to_indices(input_slice, data_shape) | |||
| @@ -576,8 +630,18 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value): | |||
| return result | |||
| def tensor_setitem_by_slice_with_sequence(data, input_slice, value): | |||
| """Assigns a list/tuple value to the tensor by slice.""" | |||
| value = _generate_updates_from_sequence(data, input_slice, value, const_utils.SET_ITEM_BY_NON_TENSOR) | |||
| return tensor_setitem_by_slice_with_tensor(data, input_slice, value) | |||
| def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): | |||
| """Assigns the tensor by tuple with tensor value.""" | |||
| value_shape = remove_ignored_dim(tuple_index, F.shape(value), F.rank(data)) | |||
| value = F.reshape(value, value_shape) | |||
| tuple_index = ignore_dim_expand(tuple_index) | |||
| if len(tuple_index) == 1: | |||
| data[tuple_index[0]] = value | |||
| return data | |||
| @@ -600,31 +664,15 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): | |||
| new_shape += value.shape | |||
| value = F.reshape(value, new_shape) | |||
| indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) | |||
| if indices is None: | |||
| return data | |||
| updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| return P.TensorScatterUpdate()(data, indices, updates) | |||
| def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): | |||
| """Assigns the tensor by tuple with tuple of value.""" | |||
| if len(tuple_index) == 1: | |||
| data[tuple_index[0]] = value | |||
| return data | |||
| op_name = const_utils.TENSOR_GETITEM | |||
| tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) | |||
| data, tuple_index = _expand_data_dims(data, tuple_index) | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if contain_type == const_utils.ALL_TENSOR: | |||
| indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM) | |||
| else: | |||
| int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if int_cnt == const_utils.ALL_INT: | |||
| tuple_index = const_utils.convert_int_to_slice(tuple_index) | |||
| indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) | |||
| updates = _generate_updates_from_tuple(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| return P.TensorScatterUpdate()(data, indices, updates) | |||
| def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value): | |||
| value = _generate_updates_from_sequence(data, tuple_index, value, const_utils.SET_ITEM_BY_NON_TENSOR) | |||
| return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value) | |||
| def tensor_setitem_by_number_with_number(data, index, value): | |||
| @@ -634,6 +682,12 @@ def tensor_setitem_by_number_with_number(data, index, value): | |||
| return _tensor_indices_number(data, data_shape, index, indices, value) | |||
| def tensor_setitem_by_number_with_sequence(data, index, value): | |||
| """Assigns a list/tuple value to the tensor by slice.""" | |||
| value = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_NON_TENSOR) | |||
| return tensor_setitem_by_number_with_tensor(data, index, value) | |||
| def tensor_setitem_by_number_with_tensor(data, index, value): | |||
| """Assigns the tensor by number with tensor value.""" | |||
| data_shape = F.shape(data) | |||
| @@ -641,31 +695,46 @@ def tensor_setitem_by_number_with_tensor(data, index, value): | |||
| return _tensor_indices_tensor(data, data_shape, index, indices, value) | |||
| def tensor_setitem_by_ellipsis_with_number(data, index, value): | |||
| def tensor_setitem_by_ellipsis_with_number(data, value): | |||
| """Assigns the tensor by ellipsis with number value.""" | |||
| data_shape = F.shape(data) | |||
| data_dtype = F.dtype(data) | |||
| return F.fill(data_dtype, data_shape, value) | |||
| def tensor_setitem_by_ellipsis_with_tensor(data, index, value): | |||
| def tensor_setitem_by_ellipsis_with_tensor(data, value): | |||
| """Assigns the tensor by ellipsis with tensor value.""" | |||
| result = None | |||
| data_shape = F.shape(data) | |||
| data_dtype = F.dtype(data) | |||
| data_size = F.size(data) | |||
| value = value.astype(data_dtype) | |||
| value_shape = F.shape(value) | |||
| value_size = F.size(value) | |||
| check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) | |||
| if check_result: | |||
| if data_size == value_size: | |||
| result = F.reshape(value, data_shape) | |||
| result = F.cast(result, data_dtype) | |||
| elif value_size == 1: | |||
| param1 = F.fill(data_dtype, data_shape, 1) | |||
| param2 = F.cast(value, data_dtype) | |||
| result = F.tensor_mul(param1, param2) | |||
| return result | |||
| source_shape = const_utils.get_source_shape(data_shape, value_shape) | |||
| value = F.reshape(value, source_shape) | |||
| value = _broadcast(data_shape, value) | |||
| data = F.cast(value, data_dtype) | |||
| return data | |||
| def tensor_setitem_by_ellipsis_with_sequence(data, value): | |||
| """Assigns a list/tuple value to the tensor by ellipsis.""" | |||
| value = _generate_updates_from_sequence(data, None, value, const_utils.SET_ITEM_BY_NON_TENSOR) | |||
| return tensor_setitem_by_ellipsis_with_tensor(data, value) | |||
| def tensor_setitem_by_bool(data, index, value): | |||
| """Assigns a value to the tensor by boolean.""" | |||
| data_shape = F.shape(data) | |||
| if not index: | |||
| data_shape = (0,) + data_shape | |||
| if not isinstance(value, Tensor): | |||
| value = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_NON_TENSOR) | |||
| value_shape = F.shape(value) | |||
| source_shape = const_utils.get_source_shape(data_shape, value_shape) | |||
| if index: | |||
| value = F.reshape(value, source_shape) | |||
| value = _broadcast(data_shape, value) | |||
| data = value | |||
| return data | |||
| def tensor_in_sequence(x, y): | |||
| @@ -675,3 +744,79 @@ def tensor_in_sequence(x, y): | |||
| if isinstance(i, Tensor) and x.shape == i.shape and x.dtype == i.dtype: | |||
| result = F.logical_or(F.equal(x, i).all(), result) | |||
| return result | |||
| def format_list_indices(list_indices, length): | |||
| """Convert list indices to tensor or tuple indices based on its contents.""" | |||
| indices_types = hyper_map(F.typeof, list_indices) | |||
| # If eyery element in list is bool, it's treated as 1-D bool tensor. | |||
| # If every element in list is int(not all bool), it's treated as int tensor. | |||
| if const_utils.judge_indexes_types(indices_types, mstype.int_type+(mstype.bool_,)): | |||
| list_indices = const_utils.transform_sequence_index(list_indices, length, const_utils.TENSOR_SETITEM) | |||
| return const_utils.make_tensor(list_indices) | |||
| # If list contains other types(.../list/tuple/None), it's treated as a tuple | |||
| return const_utils.deep_tuple(list_indices) | |||
| def format_tuple_indices(tuple_indices): | |||
| """ | |||
| Format tuple indices by unpacking high-dimension tuple and removing expand | |||
| dimension signs(Bool and None). | |||
| """ | |||
| res = () | |||
| for i in tuple_indices: | |||
| if isinstance(i, (list, tuple)): | |||
| res += (const_utils.unpack(i),) | |||
| else: | |||
| res += (i,) | |||
| return res | |||
| def tuple_indices_have_false(tuple_indices): | |||
| """Returns True if tuple_indices contains False.""" | |||
| for i in tuple_indices: | |||
| if i is False: | |||
| return True | |||
| return False | |||
| def ignore_dim_expand(idx): | |||
| """Filters flags for dimension expansion from idx.""" | |||
| res = () | |||
| for i in idx: | |||
| if not i is True and not i is None: | |||
| res += (i,) | |||
| if not res: | |||
| res = (True,) | |||
| return res | |||
| def remove_ignored_dim(idx, value_shape, data_rank): | |||
| """Removes dimensions in value that correspond to dimension expansion flags in index.""" | |||
| has_ellipsis = False | |||
| has_true = False | |||
| cnt_trailing_expanded = 0 | |||
| cnt_not_dim_expand = 0 | |||
| for i in idx: | |||
| if not i is True and not i is None: | |||
| cnt_not_dim_expand += 1 | |||
| if const_utils.is_ellipsis(i): | |||
| has_ellipsis = True | |||
| elif has_ellipsis: | |||
| if i is None: | |||
| cnt_trailing_expanded += 1 | |||
| elif i is True and not has_true: | |||
| has_true = True | |||
| if has_true and cnt_not_dim_expand + 1 < data_rank: | |||
| cnt_trailing_expanded += 1 | |||
| if cnt_trailing_expanded == 0: | |||
| return value_shape | |||
| value_expanded_pos = len(value_shape) - cnt_trailing_expanded | |||
| value_expanded_not_unit = False | |||
| for i in value_shape[value_expanded_pos:]: | |||
| if i != 1: | |||
| value_expanded_not_unit = True | |||
| if value_expanded_pos < 0 or value_expanded_not_unit: | |||
| const_utils.raise_value_error('shape mismatch') | |||
| return value_shape[:value_expanded_pos] | |||
| @@ -43,6 +43,7 @@ TENSOR_GETITEM = "tensor getitem" | |||
| SET_ITEM_BY_ONE_TENSOR = 0 | |||
| SET_ITEM_BY_TUPLE_OF_TENSOR = 1 | |||
| SET_ITEM_BY_NON_TENSOR = 2 | |||
| @constexpr | |||
| @@ -74,10 +75,85 @@ def make_empty_slice(): | |||
| @constexpr | |||
| def make_tensor(data, data_type=mstype.int64, data_shape=None): | |||
| def _deep_list(array_like): | |||
| """convert nested tuple/list mixtures to pure nested list""" | |||
| if isinstance(array_like, (list, tuple)): | |||
| return list(map(_deep_list, array_like)) | |||
| return array_like | |||
| @constexpr | |||
| def deep_tuple(array_like): | |||
| """convert nested tuple/list mixtures to pure nested tuple""" | |||
| if isinstance(array_like, (list, tuple)): | |||
| return tuple(map(deep_tuple, array_like)) | |||
| return array_like | |||
| def _deep_tensor_to_nparray(array_like): | |||
| """ | |||
| convert a nested list of tensor to nested list of np_array. | |||
| Args: | |||
| array_like(list(tensor)): In any format of nested lists that may contain | |||
| tensors. | |||
| Returns: | |||
| array_like(list(np_array)): Formatted array that can be directly processed | |||
| by numpy.array(), with all tensor elements converted to numpy_array. | |||
| """ | |||
| # Recursively check whether each element is a tensor or not, if is tensor, | |||
| # convert it to a numpy array in place | |||
| if isinstance(array_like, Tensor): | |||
| return array_like.asnumpy() | |||
| if isinstance(array_like, list): | |||
| for idx, value in enumerate(array_like): | |||
| array_like[idx] = _deep_tensor_to_nparray(value) | |||
| return array_like | |||
| @constexpr | |||
| def make_tensor(a, dtype=mstype.int32, data_shape=None): | |||
| """ | |||
| Converts the input to tensor. | |||
| This function converts tensors from an array-like object. | |||
| Args: | |||
| a (Union[int, float, bool, list, tuple]): Input data, in any form that can | |||
| be converted to a `Tensor`. | |||
| dtype (:class:`mindspore.dtype`): Designated tensor dtype. | |||
| Returns: | |||
| Tensor, generated tensor with the specified dtype. | |||
| Raises: | |||
| TypeError: If input arguments have types not specified above. | |||
| ValueError: If input `a` has different sizes at different dimensions. | |||
| """ | |||
| if data_shape: | |||
| return Tensor(np.zeros(data_shape), data_type) | |||
| return Tensor(data, data_type) | |||
| return Tensor(np.zeros(data_shape), dtype) | |||
| if not isinstance(a, (list, tuple, int, float, bool)): | |||
| raise TypeError("input data must be `int`, `float`, `bool`, `list` or `tuple`") | |||
| if isinstance(a, (list, tuple)): | |||
| # Convert all tuple/nested tuples to lists | |||
| a = _deep_list(a) | |||
| # Convert all tensor sub-elements to numpy arrays | |||
| a = _deep_tensor_to_nparray(a) | |||
| a = np.asarray(a) | |||
| if a.dtype is np.dtype('object'): | |||
| raise ValueError('Input array must have the same size across all dimensions.') | |||
| if isinstance(a, np.ndarray): | |||
| if a.dtype is np.dtype('object'): | |||
| raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") | |||
| return Tensor(a, dtype) | |||
| @constexpr | |||
| @@ -88,12 +164,20 @@ def judge_data_rank(data_rank, min_data_rank=0, max_data_rank=8): | |||
| @constexpr | |||
| def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): | |||
| """Checks the shape and size of the sensor and value.""" | |||
| if data_shape == value_shape or data_size == value_size or value_size == 1: | |||
| return True | |||
| raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format( | |||
| value_shape, data_shape)) | |||
| def get_source_shape(data_shape, value_shape): | |||
| """Returns the shape of value that will be used to broadcast against data.""" | |||
| cannot_broadcast = False | |||
| source_shape = value_shape | |||
| for i, j in zip(reversed(data_shape), reversed(value_shape)): | |||
| if j not in (1, i): | |||
| cannot_broadcast = True | |||
| for i in range(len(value_shape) - len(data_shape)): | |||
| source_shape = data_shape | |||
| if value_shape[i] != 1: | |||
| cannot_broadcast = True | |||
| if cannot_broadcast: | |||
| raise ValueError(f'could not broadcast input array from shape {value_shape} to {data_shape}') | |||
| return source_shape | |||
| @constexpr | |||
| @@ -288,8 +372,10 @@ def slice2indices(input_slices, shape): | |||
| begin, end, strides = slice_expand(input_slices, shape) | |||
| np_r = [] | |||
| for i, element in enumerate(shape): | |||
| s = begin[i] if (begin[i] >= 0) else (element + begin[i]) | |||
| e = end[i] if (end[i] >= 0) else (element + end[i]) | |||
| s = normalize_start(begin[i], element) | |||
| e = normalize_stop(end[i], element) | |||
| if s >= e: | |||
| return False | |||
| np_r.append(np.r_[s:e:strides[i]]) | |||
| # Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape) | |||
| np_ix = np.ix_(*np_r) | |||
| @@ -364,29 +450,17 @@ def tuple_index_type_cnt(types, op_name): | |||
| @constexpr | |||
| def check_value_elements(data_dtype, types): | |||
| def check_value_elements(types): | |||
| """Judges the type of all elements of the tuple.""" | |||
| tensors_number = 0 | |||
| scalars_number = 0 | |||
| for i, ele in enumerate(types): | |||
| tensor_number = 0 | |||
| for ele in types: | |||
| if isinstance(ele, mstype.tensor_type): | |||
| ele_dtype = ele.element_type() | |||
| if data_dtype == ele_dtype: | |||
| tensors_number += 1 | |||
| else: | |||
| raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' " | |||
| f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.") | |||
| elif mstype.dtype_to_pytype(ele) == mstype.dtype_to_pytype(data_dtype): | |||
| scalars_number += 1 | |||
| else: | |||
| raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in " | |||
| f"value tuple is not consistent with assigned tensor data type '{data_dtype}'.") | |||
| if tensors_number == len(types): | |||
| tensor_number += 1 | |||
| if tensor_number == 0: | |||
| return NO_TENSOR | |||
| if tensor_number == len(types): | |||
| return ALL_TENSOR | |||
| if scalars_number == len(types): | |||
| return ALL_SCALAR | |||
| raise TypeError( | |||
| f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.") | |||
| return CONTAIN_TENSOR | |||
| @constexpr | |||
| @@ -528,10 +602,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty | |||
| updates_shape = indices_shape + data_shape[1:] | |||
| else: | |||
| updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:] | |||
| if isinstance(value, mstype.dtype_to_pytype(data_dtype)): | |||
| return Tensor(np.full(updates_shape, value), dtype=data_dtype) | |||
| raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'" | |||
| f" is not consistent with the assigned tensor data type {data_dtype}.") | |||
| return Tensor(np.full(updates_shape, value), dtype=data_dtype) | |||
| @constexpr | |||
| @@ -716,3 +787,46 @@ def mstype_eq(x, y): | |||
| def scalar_to_tensor(x): | |||
| """Convert a scalar to a tensor""" | |||
| return Tensor(x) | |||
| @constexpr | |||
| def unpack(x): | |||
| if isinstance(x, (tuple, list)) and len(x) == 1: | |||
| return unpack(x[0]) | |||
| return x | |||
| @constexpr | |||
| def slice_to_tuple(s): | |||
| return (s.start, s.stop, s.step) | |||
| @constexpr | |||
| def normalize_start(start, dim_size): | |||
| """ | |||
| Normalize `start` according to the number of dimensions (`dim_size`). | |||
| If the number of dimensions is not given, return the original input directly. | |||
| """ | |||
| if start is None: | |||
| return 0 | |||
| if start < 0: | |||
| return 0 if start < -dim_size else start % dim_size | |||
| return start if start < dim_size else dim_size | |||
| @constexpr | |||
| def normalize_stop(stop, dim_size): | |||
| """ | |||
| Normalize `stop` according to the number of dimensions (`dim_size`). | |||
| If the number of dimensions is not given, return the original input directly. | |||
| """ | |||
| if stop is None: | |||
| return dim_size | |||
| if stop < 0: | |||
| return 0 if stop < -dim_size else stop % dim_size | |||
| return stop if stop < dim_size else dim_size | |||
| @constexpr | |||
| def is_ellipsis(x): | |||
| return x is Ellipsis | |||
| @@ -18,6 +18,7 @@ | |||
| from . import _compile_utils as compile_utils | |||
| from ... import functional as F | |||
| from ...composite import base | |||
| from ....common import Tensor | |||
| setitem = base.MultitypeFuncGraph('setitem') | |||
| @@ -213,6 +214,9 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| if compile_utils.tuple_indices_have_false(tuple_index): | |||
| return data | |||
| tuple_index = compile_utils.format_tuple_indices(tuple_index) | |||
| return compile_utils.tensor_setitem_by_tuple_with_number(data, tuple_index, value) | |||
| @@ -234,6 +238,9 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| if compile_utils.tuple_indices_have_false(tuple_index): | |||
| return data | |||
| tuple_index = compile_utils.format_tuple_indices(tuple_index) | |||
| return compile_utils.tensor_setitem_by_tuple_with_tensor(data, tuple_index, value) | |||
| @@ -246,21 +253,49 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): | |||
| Syntax support: A[B, C, D] = U. | |||
| Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors. | |||
| 2) A B and C could be broadcast. | |||
| 3) U is a Tensor. | |||
| 3) U is a Tuple. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Tuple): A tuple of tensor, these tensor could be broadcast. | |||
| value (Tensor): Assignment tensor, should has the same data type as 'data'. | |||
| value (Tuple): Assignment tuple. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return compile_utils.tensor_setitem_by_tuple_with_tuple(data, tuple_index, value) | |||
| if compile_utils.tuple_indices_have_false(tuple_index): | |||
| return data | |||
| tuple_index = compile_utils.format_tuple_indices(tuple_index) | |||
| return compile_utils.tensor_setitem_by_tuple_with_sequence(data, tuple_index, value) | |||
| @setitem.register("Tensor", "Tuple", "List") | |||
| def _tensor_setitem_by_tuple_with_list(data, tuple_index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[B, C, D] = U. | |||
| Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors. | |||
| 2) A B and C could be broadcast. | |||
| 3) U is a List. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Tuple): A tuple of tensor, these tensor could be broadcast. | |||
| value (List): Assignment tuple. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| if compile_utils.tuple_indices_have_false(tuple_index): | |||
| return data | |||
| tuple_index = compile_utils.format_tuple_indices(tuple_index) | |||
| return compile_utils.tensor_setitem_by_tuple_with_sequence(data, tuple_index, value) | |||
| @setitem.register("Tensor", "Tensor", "Tuple") | |||
| def _tensor_setitem_by_tensor_v2(data, index, value): | |||
| def _tensor_setitem_by_tensor_with_tuple(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| @@ -272,11 +307,27 @@ def _tensor_setitem_by_tensor_v2(data, index, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return compile_utils.tensor_setitem_by_tensor_with_tuple(data, index, value) | |||
| return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value) | |||
| @setitem.register("Tensor", "Tensor", "List") | |||
| def _tensor_setitem_by_tensor_with_list(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Tensor): Tensor of bool type. | |||
| value (List): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value) | |||
| @setitem.register("Tensor", "Slice", "Tensor") | |||
| def _tensor_setitem_with_slice_v3(data, input_slice, value): | |||
| def _tensor_setitem_by_slice_with_tensor(data, input_slice, value): | |||
| """ | |||
| Tensor assignment. | |||
| @@ -298,7 +349,7 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value): | |||
| @setitem.register("Tensor", "Slice", "Number") | |||
| def _tensor_setitem_with_slice_v1(data, input_slice, value): | |||
| def _tensor_setitem_by_slice_with_number(data, input_slice, value): | |||
| """ | |||
| Tensor assignment. | |||
| @@ -319,21 +370,326 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value): | |||
| return compile_utils.tensor_setitem_by_slice_with_number(data, input_slice, value) | |||
| @setitem.register("Tensor", "Slice", "List") | |||
| def _tensor_setitem_by_slice_with_list(data, input_slice, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[Slice] = u | |||
| Restraint condition: A is a Tensor. | |||
| Slice like "1:3" | |||
| u is a list | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| input_slice (Slice): slice expression. | |||
| value (List): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return compile_utils.tensor_setitem_by_slice_with_sequence(data, input_slice, value) | |||
| @setitem.register("Tensor", "Slice", "Tuple") | |||
| def _tensor_setitem_by_slice_with_tuple(data, input_slice, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[Slice] = u | |||
| Restraint condition: A is a Tensor. | |||
| Slice like "1:3" | |||
| u is a tuple | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| input_slice (Slice): slice expression. | |||
| value (Tuple): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return compile_utils.tensor_setitem_by_slice_with_sequence(data, input_slice, value) | |||
| @setitem.register("Tensor", "Number", "Number") | |||
| def _tensor_setitem_with_int_v1(data, index, value): | |||
| def _tensor_setitem_by_number_with_number(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[Number] = u | |||
| Restraint condition: A is a Tensor. | |||
| u is a Number. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Number): An integer index. | |||
| value (Tuple): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| if isinstance(index, bool): | |||
| return compile_utils.tensor_setitem_by_bool(data, index, value) | |||
| return compile_utils.tensor_setitem_by_number_with_number(data, index, value) | |||
| @setitem.register("Tensor", "Number", "Tensor") | |||
| def _tensor_setitem_with_int_v2(data, index, value): | |||
| def _tensor_setitem_by_number_with_tensor(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[Number] = u | |||
| Restraint condition: A is a Tensor. | |||
| u is a Tensor. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Number): An integer index. | |||
| value (Tensor): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| if isinstance(index, bool): | |||
| return compile_utils.tensor_setitem_by_bool(data, index, value) | |||
| return compile_utils.tensor_setitem_by_number_with_tensor(data, index, value) | |||
| @setitem.register("Tensor", "Number", "Tuple") | |||
| def _tensor_setitem_by_number_with_tuple(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[Number] = u | |||
| Restraint condition: A is a Tensor. | |||
| u is a Tuple, with all elements equal in length. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Number): An integer index. | |||
| value (Tuple): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| if isinstance(index, bool): | |||
| return compile_utils.tensor_setitem_by_bool(data, index, value) | |||
| return compile_utils.tensor_setitem_by_number_with_sequence(data, index, value) | |||
| @setitem.register("Tensor", "Number", "List") | |||
| def _tensor_setitem_by_number_with_list(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[Number] = u | |||
| Restraint condition: A is a Tensor. | |||
| u is a List, with all elements equal in length. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Number): An integer index. | |||
| value (List): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| if isinstance(index, bool): | |||
| return compile_utils.tensor_setitem_by_bool(data, index, value) | |||
| return compile_utils.tensor_setitem_by_number_with_sequence(data, index, value) | |||
| @setitem.register("Tensor", "Ellipsis", "Number") | |||
| def _tensor_setitem_with_ellipsis_v1(data, index, value): | |||
| return compile_utils.tensor_setitem_by_ellipsis_with_number(data, index, value) | |||
| def _tensor_setitem_by_ellipsis_with_number(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[...] = u | |||
| Restraint condition: A is a Tensor. | |||
| u is a Number. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Ellipsis): Index is ``...``. | |||
| value (Number): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return compile_utils.tensor_setitem_by_ellipsis_with_number(data, value) | |||
| @setitem.register("Tensor", "Ellipsis", "Tensor") | |||
| def _tensor_setitem_with_ellipsis_v2(data, index, value): | |||
| return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, index, value) | |||
| def _tensor_setitem_by_ellipsis_with_tensor(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[...] = u | |||
| Restraint condition: A is a Tensor. | |||
| u is a Tensor. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Ellipsis): Index is ``...``. | |||
| value (Tensor): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, value) | |||
| @setitem.register("Tensor", "Ellipsis", "List") | |||
| def _tensor_setitem_by_ellipsis_with_list(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[...] = u | |||
| Restraint condition: A is a Tensor. | |||
| u is a List, with all elements equal in length. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Ellipsis): Index is ``...``. | |||
| value (Number): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value) | |||
| @setitem.register("Tensor", "Ellipsis", "Tuple") | |||
| def _tensor_setitem_by_ellipsis_with_tuple(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[...] = u | |||
| Restraint condition: A is a Tensor. | |||
| u is a Tuple, with all elements equal in length. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Ellipsis): Index is ``...``. | |||
| value (Number): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value) | |||
| @setitem.register("Tensor", "List", "Number") | |||
| def _tensor_setitem_by_list_with_number(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[List] = u | |||
| Restraint condition: A is a Tensor. | |||
| u is a Number. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (List). | |||
| value (Number): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| # list indices will be converted to tuple or tensor based on its contents. | |||
| index = compile_utils.format_list_indices(index, data.shape[0]) | |||
| if isinstance(index, Tensor): | |||
| return compile_utils.tensor_setitem_by_tensor_with_number(data, index, value) | |||
| if compile_utils.tuple_indices_have_false(index): | |||
| return data | |||
| index = compile_utils.format_tuple_indices(index) | |||
| return compile_utils.tensor_setitem_by_tuple_with_number(data, index, value) | |||
| @setitem.register("Tensor", "List", "Tensor") | |||
| def _tensor_setitem_by_list_with_tensor(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[List] = u | |||
| Restraint condition: A is a Tensor. | |||
| u is a Tensor. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (List). | |||
| value (Tensor): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| # list indices will be converted to tuple or tensor based on its contents. | |||
| index = compile_utils.format_list_indices(index, data.shape[0]) | |||
| if isinstance(index, Tensor): | |||
| return compile_utils.tensor_setitem_by_tensor_with_tensor(data, index, value) | |||
| if compile_utils.tuple_indices_have_false(index): | |||
| return data | |||
| index = compile_utils.format_tuple_indices(index) | |||
| return compile_utils.tensor_setitem_by_tuple_with_tensor(data, index, value) | |||
| @setitem.register("Tensor", "List", "Tuple") | |||
| def _tensor_setitem_by_list_with_tuple(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[List] = u | |||
| Restraint condition: A is a Tensor. | |||
| u is a Tuple, with all elements equal in length. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (List). | |||
| value (Tuple): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| # list indices will be converted to tuple or tensor based on its contents. | |||
| index = compile_utils.format_list_indices(index, data.shape[0]) | |||
| if isinstance(index, Tensor): | |||
| return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value) | |||
| if compile_utils.tuple_indices_have_false(index): | |||
| return data | |||
| index = compile_utils.format_tuple_indices(index) | |||
| return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value) | |||
| @setitem.register("Tensor", "List", "List") | |||
| def _tensor_setitem_by_list_with_list(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[List] = u | |||
| Restraint condition: A is a Tensor. | |||
| u is a List, with all elements equal in length. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (List). | |||
| value (List): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| # list indices will be converted to tuple or tensor based on its contents. | |||
| index = compile_utils.format_list_indices(index, data.shape[0]) | |||
| if isinstance(index, Tensor): | |||
| return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value) | |||
| if compile_utils.tuple_indices_have_false(index): | |||
| return data | |||
| index = compile_utils.format_tuple_indices(index) | |||
| return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value) | |||
| @@ -321,7 +321,7 @@ def test_setitem_by_mixed_tensors_2(): | |||
| assert np.all(out.asnumpy() == (input_np + const)) | |||
| class TensorGetItemByMixedTensorsTypeError(Cell): | |||
| class TensorGetItemByMixedTensorsIndexError(Cell): | |||
| def construct(self, x, index_0, index_1): | |||
| ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]] | |||
| return ret | |||
| @@ -331,8 +331,8 @@ def test_getitem_by_mixedtensor_exception(): | |||
| input_ms = Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32) | |||
| index_0 = Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32) | |||
| index_1 = Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32) | |||
| net1 = TensorGetItemByMixedTensorsTypeError() | |||
| with pytest.raises(TypeError): | |||
| net1 = TensorGetItemByMixedTensorsIndexError() | |||
| with pytest.raises(IndexError): | |||
| net1(input_ms, index_0, index_1) | |||
| @@ -0,0 +1,215 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test_tensor_setitem """ | |||
| import numpy as onp | |||
| import pytest | |||
| from mindspore import Tensor, context | |||
| from mindspore.nn import Cell | |||
| def setup_module(): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def setup_testcase(input_np, case_fn): | |||
| input_ms = Tensor(input_np) | |||
| class TensorSetItem(Cell): | |||
| def construct(self, x): | |||
| return case_fn(x) | |||
| class NumpySetItem(): | |||
| def __call__(self, x): | |||
| return case_fn(x) | |||
| out_ms = TensorSetItem()(input_ms) | |||
| out_np = NumpySetItem()(input_np) | |||
| assert onp.all(out_ms.asnumpy() == out_np) | |||
| class TensorSetItemByList(Cell): | |||
| def construct(self, x): | |||
| x[[0, 1], [1, 2], [1, 3]] = [3, 4] | |||
| x[([0, 1], [0, 2], [1, 1])] = [10, 5] | |||
| x[[0, 1], ..., [0, 1]] = 4 | |||
| return x | |||
| class NumpySetItemByList(): | |||
| def __call__(self, x): | |||
| x[[0, 1], [1, 2], [1, 3]] = [3, 4] | |||
| x[([0, 1], [0, 2], [1, 1])] = [10, 5] | |||
| x[[0, 1], ..., [0, 1]] = 4 | |||
| return x | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_setitem_by_list(): | |||
| x = onp.ones((2, 3, 4), dtype=onp.float32) | |||
| def cases(x): | |||
| x[[0, 1], [1, 2], [1, 3]] = [3, 4] | |||
| x[([0, 1], [0, 2], [1, 1])] = [10, 5] | |||
| x[[0, 1], ..., [0, 1]] = 4 | |||
| return x | |||
| setup_testcase(x, cases) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_setitem_with_sequence(): | |||
| x = onp.ones((2, 3, 4), dtype=onp.float32) | |||
| def cases(x): | |||
| x[...] = [3] | |||
| x[..., 1] = ([1, 2, 3], [4, 5, 6]) | |||
| x[0] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11]) | |||
| x[1:2] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11]) | |||
| return x | |||
| setup_testcase(x, cases) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_setitem_dtype(): | |||
| x = onp.ones((2, 3, 4), dtype=onp.float32) | |||
| def cases(x): | |||
| x[...] = 3 | |||
| x[..., 1] = 3.0 | |||
| x[0] = True | |||
| x[1:2] = ((0, False, 2, 3), (4.0, 5, 6, 7), [True, 9, 10, 11]) | |||
| return x | |||
| setup_testcase(x, cases) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_setitem_by_tuple_with_int(): | |||
| x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32) | |||
| def cases(x): | |||
| x[..., 2, False, 1] = -1 | |||
| x[0, True, 0, None, True] = -2 | |||
| x[0, ..., None] = -3 | |||
| x[..., 0, None, 1, True, True, None] = -4 | |||
| return x | |||
| setup_testcase(x, cases) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_setitem_by_tuple_with_list(): | |||
| x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32) | |||
| def cases(x): | |||
| x[..., 2, False, 1] = [-1] | |||
| x[0, True, 0, None, True] = [-2, -2, -2, -2] | |||
| x[0, ..., None] = [[-3], [-3], [-3], [-3]] | |||
| x[..., 0, None, 1, True, True, None] = [[[-4]], [[-4]]] | |||
| return x | |||
| setup_testcase(x, cases) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_setitem_by_nested_unit_list(): | |||
| x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32) | |||
| def cases(x): | |||
| x[[[[0]]], True] = -1 | |||
| x[[1], ..., [[[[2]]]]] = -2 | |||
| x[0, [[[2]]], [1]] = -3 | |||
| return x | |||
| setup_testcase(x, cases) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_setitem_with_broadcast(): | |||
| x = onp.arange(2*3*4*5*6).reshape(2, 3, 4, 5, 6).astype(onp.float32) | |||
| v1 = onp.full((1, 4, 5), -1).tolist() | |||
| v2 = onp.full((4, 1, 6), -2).tolist() | |||
| def cases(x): | |||
| x[..., 4] = v1 | |||
| x[0, 2] = v2 | |||
| x[1, 0, ..., 3] = [[-3], [-3], [-3], [-3]] | |||
| x[0, ..., 1, 3, 5] = -4 | |||
| return x | |||
| setup_testcase(x, cases) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_setitem_mul_by_scalar(): | |||
| x = onp.ones((4, 5), dtype=onp.float32) | |||
| def cases(x): | |||
| x[1, :] = x[1, :]*2 | |||
| x[:, 2] = x[:, 3]*3.0 | |||
| return x | |||
| setup_testcase(x, cases) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_setitem_by_slice(): | |||
| x = onp.ones((3, 4, 5), dtype=onp.float32) | |||
| def cases(x): | |||
| x[1:2] = 2 | |||
| x[-3:1] = 3 | |||
| x[-10:3:2] = 4 | |||
| x[5:0:3] = 5 | |||
| x[5:5:5] = 6 | |||
| x[-1:2] = 7 | |||
| return x | |||
| setup_testcase(x, cases) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_setitem_by_tuple_of_slices(): | |||
| x = onp.ones((3, 4, 5), dtype=onp.float32) | |||
| def cases(x): | |||
| x[1:2, 2] = 2 | |||
| x[0, -4:1] = 3 | |||
| x[1, -10:3:2] = 4 | |||
| x[5:0:3, 3] = 5 | |||
| x[1:1, 2:2] = 6 | |||
| return x | |||
| setup_testcase(x, cases) | |||