Merge pull request !2060 from amongo/SupportPynativeSetItemtags/v0.5.0-beta
| @@ -92,6 +92,10 @@ Tensor &Tensor::operator=(const Tensor &tensor) { | |||
| } | |||
| return *this; | |||
| } | |||
| Tensor &Tensor::AssignValue(const Tensor &tensor) { | |||
| *this = tensor; | |||
| return *this; | |||
| } | |||
| bool Tensor::operator==(const Tensor &tensor) const { | |||
| return (MetaTensor::operator==(tensor) && data_ == tensor.data_); | |||
| @@ -470,6 +474,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||
| >>> data.set_dtype(mindspore.int32) | |||
| mindspore.int32 | |||
| )mydelimiter") | |||
| .def("assign_value", &Tensor::AssignValue, R"mydelimiter( | |||
| Assign another tensor value to this. | |||
| Arg: | |||
| value (:class:`mindspore.tensor`): The value tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) | |||
| >>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32)) | |||
| >>> data.assign_value(data2) | |||
| >>> data.shape | |||
| (2, 2) | |||
| )mydelimiter") | |||
| .def("__str__", &Tensor::ToString) | |||
| .def("__repr__", &Tensor::ToStringRepr) | |||
| .def(py::pickle( | |||
| @@ -173,6 +173,9 @@ class Tensor : public MetaTensor { | |||
| // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. | |||
| bool ValueEqual(const Tensor &other) const; | |||
| // assgin value to this tensor | |||
| Tensor &AssignValue(const Tensor &tensor); | |||
| bool operator==(const Value &other) const override { | |||
| if (other.isa<Tensor>()) { | |||
| auto other_ = static_cast<const Tensor &>(other); | |||
| @@ -203,6 +203,8 @@ class Parameter: | |||
| return self.default_input / other | |||
| def __setitem__(self, index, value): | |||
| default_input = self.default_input | |||
| default_input[index] = value | |||
| return self | |||
| def set_parameter_data(self, data): | |||
| @@ -150,6 +150,8 @@ class Tensor(Tensor_): | |||
| return out | |||
| def __setitem__(self, index, value): | |||
| out = tensor_operator_registry.get('__setitem__')(self, index, value) | |||
| self.assign_value(out) | |||
| return self | |||
| def __gt__(self, other): | |||
| @@ -26,7 +26,7 @@ hyper_map = base.HyperMap() | |||
| pack = P.Pack(axis=-1) | |||
| def broadcast(broadcast_shape, x): | |||
| def _broadcast(broadcast_shape, x): | |||
| """Broadcast tensor to the required shape.""" | |||
| if F.shape(x) == broadcast_shape: | |||
| return x | |||
| @@ -36,13 +36,13 @@ def broadcast(broadcast_shape, x): | |||
| return x | |||
| def transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x): | |||
| def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x): | |||
| """Transform indexing tensor to the required.""" | |||
| x = broadcast(broadcast_shape, x) | |||
| return broadcast(final_shape, F.reshape(x, new_shape)) | |||
| x = _broadcast(broadcast_shape, x) | |||
| return _broadcast(final_shape, F.reshape(x, new_shape)) | |||
| def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): | |||
| def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): | |||
| """Generate an indices tensor from a tuple of tensor.""" | |||
| indices = None | |||
| check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) | |||
| @@ -52,26 +52,31 @@ def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): | |||
| if check_dtypes: | |||
| shape_tuple = hyper_map(F.shape, tuple_index) | |||
| broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name) | |||
| broadcast_tensors = hyper_map(F.partial(broadcast, broadcast_shape), tuple_index) | |||
| broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) | |||
| indices = pack(broadcast_tensors) | |||
| return indices | |||
| def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): | |||
| def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): | |||
| """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| int_positions = const_utils.get_pos_of_int_index(indexes_types) | |||
| for i in int_positions: | |||
| tuple_index = F.tuple_setitem(tuple_index, i, F.scalar_to_tensor(tuple_index[i], mstype.int32)) | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| tuple_index_new = () | |||
| tuple_len = len(tuple_index) | |||
| for i in range(tuple_len): | |||
| if i in int_positions: | |||
| tuple_index_new = tuple_index_new + (F.scalar_to_tensor(tuple_index[i], mstype.int32),) | |||
| else: | |||
| tuple_index_new = tuple_index_new + (tuple_index[i],) | |||
| indexes_types = hyper_map(F.typeof, tuple_index_new) | |||
| tensor_positions, slice_positions, ellipsis_position = \ | |||
| const_utils.separate_mixed_tensors_index(indexes_types, op_name) | |||
| tensor_indexes = [] | |||
| slice_indexes = [] | |||
| for i in tensor_positions: | |||
| tensor_indexes.append(tuple_index[i]) | |||
| tensor_indexes.append(tuple_index_new[i]) | |||
| for j in slice_positions: | |||
| slice_indexes.append(tuple_index[j]) | |||
| slice_indexes.append(tuple_index_new[j]) | |||
| data_shape = F.shape(data) | |||
| tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) | |||
| tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) | |||
| @@ -85,14 +90,14 @@ def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): | |||
| slice_number = 0 | |||
| final_index_tensors = [] | |||
| tuple_index_size = len(tuple_index) | |||
| tuple_index_size = len(tuple_index_new) | |||
| index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) | |||
| for i in range(tuple_index_size): | |||
| if i in tensor_positions: | |||
| transform_tensor = transform_indexing_tensor(broadcast_shape, | |||
| final_shape, | |||
| index_tensor_new_shape, | |||
| tuple_index[i]) | |||
| transform_tensor = _transform_indexing_tensor(broadcast_shape, | |||
| final_shape, | |||
| index_tensor_new_shape, | |||
| tuple_index_new[i]) | |||
| final_index_tensors.append(transform_tensor) | |||
| if i in slice_positions: | |||
| slice_tensor = const_utils.convert_slice_to_tensor(slice_number, | |||
| @@ -114,7 +119,7 @@ def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): | |||
| return indices | |||
| def generate_updates_from_scalar(data, indices, value, op_type): | |||
| def _generate_updates_from_scalar(data, indices, value, op_type): | |||
| """Generate an updates tensor from a scalar.""" | |||
| data_shape = F.shape(data) | |||
| indices_shape = F.shape(indices) | |||
| @@ -122,7 +127,7 @@ def generate_updates_from_scalar(data, indices, value, op_type): | |||
| return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) | |||
| def generate_updates_from_tuple(data, index, value, op_type): | |||
| def _generate_updates_from_tuple(data, index, value, op_type): | |||
| """Generate an updates tensor from a tuple.""" | |||
| value_types = hyper_map(F.typeof, value) | |||
| data_dtype = F.dtype(data) | |||
| @@ -132,14 +137,14 @@ def generate_updates_from_tuple(data, index, value, op_type): | |||
| shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM) | |||
| if shapes_same: | |||
| value = F.pack(value) | |||
| return generate_updates_from_tensor(data, index, value, op_type) | |||
| return _generate_updates_from_tensor(data, index, value, op_type) | |||
| data_shape = F.shape(data) | |||
| index_shape = F.shape(index) | |||
| return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) | |||
| def generate_updates_from_tensor(data, index, value, op_type): | |||
| def _generate_updates_from_tensor(data, index, value, op_type): | |||
| """Generate an updates tensor from a tensor.""" | |||
| data_shape = F.shape(data) | |||
| index_shape = F.shape(index) | |||
| @@ -152,45 +157,47 @@ def generate_updates_from_tensor(data, index, value, op_type): | |||
| updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type) | |||
| need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape) | |||
| if need_broadcast: | |||
| return broadcast(updates_shape, value) | |||
| return _broadcast(updates_shape, value) | |||
| return value | |||
| def tensor_getitem(self, index): | |||
| def _tensor_getitem(self, index): | |||
| """Handle tensor getitem""" | |||
| if isinstance(index, Tensor): | |||
| return tensor_index_by_tensor(self, index) | |||
| if isinstance(index, tuple): | |||
| return tensor_index_by_tuple(self, index) | |||
| if isinstance(index, int): | |||
| return tensor_index_by_integer(self, index) | |||
| return _tensor_index_by_integer(self, index) | |||
| if isinstance(index, slice): | |||
| return tensor_index_by_slice(self, index) | |||
| if isinstance(index, bool): | |||
| return tensor_index_by_bool(self, index) | |||
| return _tensor_index_by_bool(self, index) | |||
| if index is None: | |||
| return F.expand_dims(self, 0) | |||
| if index is ...: | |||
| return self | |||
| raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, " | |||
| f"got {index} with type {type(index)}.") | |||
| tensor_operator_registry.register("__getitem__", tensor_getitem) | |||
| tensor_operator_registry.register("__getitem__", _tensor_getitem) | |||
| def tensor_getitem_by_tuple_of_tensor(data, tuple_index): | |||
| def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): | |||
| """Tensor getitem by a tuple of tensor.""" | |||
| indices = generate_indices_from_tuple_of_tensor(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_GETITEM) | |||
| indices = _generate_indices_from_tuple_of_tensor(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_GETITEM) | |||
| result = F.gather_nd(data, indices) | |||
| return result | |||
| def tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): | |||
| def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): | |||
| """Tensor getitem by a tuple of mixed tensor.""" | |||
| indices = generate_indices_from_tuple_of_mixed_tensors(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_GETITEM) | |||
| indices = _generate_indices_from_tuple_of_mixed_tensors(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_GETITEM) | |||
| result = F.gather_nd(data, indices) | |||
| return result | |||
| @@ -204,7 +211,7 @@ def tensor_index_by_slice(data, slice_index): | |||
| return F.strided_slice(data, begin_strides, end_strides, step_strides) | |||
| def tensor_index_by_integer(data, number): | |||
| def _tensor_index_by_integer(data, number): | |||
| """Tensor getitem by a single integer number""" | |||
| shape = F.shape(data) | |||
| if not shape: | |||
| @@ -214,7 +221,7 @@ def tensor_index_by_integer(data, number): | |||
| return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) | |||
| def tensor_index_by_bool(data, bool_value): | |||
| def _tensor_index_by_bool(data, bool_value): | |||
| """Tensor getitem by a single bool value""" | |||
| if bool_value: | |||
| return F.expand_dims(data, 0) | |||
| @@ -225,9 +232,9 @@ def tensor_index_by_number(data, number): | |||
| """Tensor getitem by a Number which may be integer/float/bool value""" | |||
| number_type = const_utils.check_number_index_type(number) | |||
| if number_type == const_utils.BOOL_: | |||
| return tensor_index_by_bool(data, number) | |||
| return _tensor_index_by_bool(data, number) | |||
| if number_type == const_utils.INT_: | |||
| return tensor_index_by_integer(data, number) | |||
| return _tensor_index_by_integer(data, number) | |||
| return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") | |||
| @@ -241,7 +248,7 @@ def tensor_index_by_tensor(data, tensor_index): | |||
| "the index tensor data type only support mstype.int32.") | |||
| def tensor_index_by_tuple_slice(data, t): | |||
| def _tensor_index_by_tuple_slice(data, t): | |||
| """Tensor getitem by a tuple of slice""" | |||
| shape = F.shape(data) | |||
| if len(t) > len(shape): | |||
| @@ -257,7 +264,303 @@ def tensor_index_by_tuple(data, tuple_index): | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) | |||
| if index_elements_type == const_utils.NO_TENSOR: | |||
| return tensor_index_by_tuple_slice(data, tuple_index) | |||
| return _tensor_index_by_tuple_slice(data, tuple_index) | |||
| if index_elements_type == const_utils.ALL_TENSOR: | |||
| return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) | |||
| return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) | |||
| def _tensor_setitem(self, index, value): | |||
| """Handle tensor getitem""" | |||
| if isinstance(index, Tensor): | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_tensor_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_tensor_with_tensor(self, index, value) | |||
| if isinstance(value, tuple): | |||
| return tensor_setitem_by_tensor_with_tuple(self, index, value) | |||
| if isinstance(index, tuple): | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_tuple_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_tuple_with_tensor(self, index, value) | |||
| if isinstance(value, tuple): | |||
| return tensor_setitem_by_tuple_with_tuple(self, index, value) | |||
| if isinstance(index, int): | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_number_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_number_with_tensor(self, index, value) | |||
| if isinstance(index, slice): | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_slice_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_slice_with_tensor(self, index, value) | |||
| if isinstance(index, bool): | |||
| return _tensor_index_by_bool(self, index) | |||
| if index is ...: | |||
| if isinstance(value, (int, float, bool)): | |||
| return tensor_setitem_by_ellipsis_with_number(self, index, value) | |||
| if isinstance(value, Tensor): | |||
| return tensor_setitem_by_ellipsis_with_tensor(self, index, value) | |||
| raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\ | |||
| and tensor with int32, got {} with type{}".format(index, type(index))) | |||
| tensor_operator_registry.register("__setitem__", _tensor_setitem) | |||
| def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): | |||
| """Set a tensor item by a int tensor with a tensor.""" | |||
| updates = _generate_updates_from_tensor(data, index, value, | |||
| const_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| index = F.expand_dims(index, -1) | |||
| return P.TensorScatterUpdate()(data, index, updates) | |||
| def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value): | |||
| """Set a tensor item by a bool tensor with a tensor.""" | |||
| index_shape = F.shape(index) | |||
| data_shape = F.shape(data) | |||
| data_shape = const_utils.check_equal(data_shape, index_shape, | |||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||
| size = F.size(value) | |||
| size = const_utils.check_equal(1, size, | |||
| "When assign value is a tensor, its size should be {}, but current size is {}.") | |||
| dtype = F.dtype(data) | |||
| u_cast = F.cast(value, dtype) | |||
| one_data = F.ones_like(data) | |||
| u = F.tensor_mul(one_data, u_cast) | |||
| result = F.select(index, u, data) | |||
| return result | |||
| def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): | |||
| """setitem by tensor index(dtype is int or bool) with tensor as value""" | |||
| index_dtype = F.dtype(index) | |||
| tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) | |||
| if tensor_dtype == const_utils.INT_: | |||
| return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor) | |||
| return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor) | |||
| def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): | |||
| """Set a tensor item by a bool tensor with a scalar.""" | |||
| index_shape = F.shape(index) | |||
| shape = F.shape(data) | |||
| shape = const_utils.check_equal( | |||
| shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||
| dtype = F.dtype(data) | |||
| u = F.fill(dtype, shape, value) | |||
| return F.select(index, u, data) | |||
| def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): | |||
| """Set a tensor item by a int tensor with a scalar.""" | |||
| updates = _generate_updates_from_scalar(data, index, value, | |||
| const_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| index = F.expand_dims(index, -1) | |||
| return P.TensorScatterUpdate()(data, index, updates) | |||
| def tensor_setitem_by_tensor_with_number(data, index, value): | |||
| index_dtype = F.dtype(index) | |||
| tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) | |||
| if tensor_dtype == const_utils.BOOL_: | |||
| return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value) | |||
| if tensor_dtype == const_utils.INT_: | |||
| return _tensor_setitem_by_int_tensor_with_scalar(data, index, value) | |||
| return const_utils.raise_index_error("For tensor setitem, indexing tensor dtype only supports bool/int") | |||
| def tensor_setitem_by_tensor_with_tuple(data, index, value): | |||
| """Assigns the tensor by tensor with tuple value.""" | |||
| index_dtype = F.dtype(index) | |||
| check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM) | |||
| result = None | |||
| if check_dtype: | |||
| result = _tensor_setitem_by_tensor_with_tuple(data, index, value) | |||
| return result | |||
| def _tensor_indices_number(data, data_shape, index, indices, value): | |||
| """Assigns a scalar value to the tensor.""" | |||
| data_size = F.size(data) | |||
| data_dtype = F.dtype(data) | |||
| indices_size = F.size(indices) | |||
| indices_size = const_utils.check_indices(indices_size, index) | |||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||
| condition = F.reshape(condition_1d, data_shape) | |||
| condition = F.cast(condition, mstype.bool_) | |||
| value_fill = F.fill(data_dtype, (indices_size,), value) | |||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||
| u = F.reshape(value_1d, data_shape) | |||
| return F.select(condition, u, data) | |||
| def _tensor_setitem_by_tensor_with_tuple(data, index, value): | |||
| """Set a tensor item by a tensor with a tuple.""" | |||
| updates = _generate_updates_from_tuple(data, index, value, | |||
| const_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| index = F.expand_dims(index, -1) | |||
| result = P.TensorScatterUpdate()(data, index, updates) | |||
| return result | |||
| def tensor_setitem_by_slice_with_number(data, input_slice, value): | |||
| """Givens a scalar assign to tensor by slice""" | |||
| check_result = const_utils.check_tensor_setitem_index(input_slice) | |||
| result = None | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| indices = const_utils.slice2indices(input_slice, data_shape) | |||
| is_tuple_int = const_utils.tuple_element_is_int(input_slice) | |||
| if is_tuple_int: | |||
| indices = const_utils.integer_to_indices(input_slice, data_shape) | |||
| result = _tensor_indices_number(data, data_shape, input_slice, indices, value) | |||
| return result | |||
| def tensor_setitem_by_tuple_with_number(data, tuple_index, value): | |||
| """Assigns the tensor by tuple with number value.""" | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if index_elements_type == const_utils.NO_TENSOR: | |||
| return tensor_setitem_by_slice_with_number(data, tuple_index, value) | |||
| if index_elements_type == const_utils.ALL_TENSOR: | |||
| indices = _generate_indices_from_tuple_of_tensor(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| else: | |||
| indices = _generate_indices_from_tuple_of_mixed_tensors(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| updates = _generate_updates_from_scalar(data, | |||
| indices, | |||
| value, | |||
| const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| return P.TensorScatterUpdate()(data, indices, updates) | |||
| def _tensor_indices_tensor(data, data_shape, index, indices, value): | |||
| """Assigns a tensor value to the tensor.""" | |||
| data_size = F.size(data) | |||
| data_dtype = F.dtype(data) | |||
| indices_size = F.size(indices) | |||
| indices_size = const_utils.check_indices(indices_size, index) | |||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||
| condition = F.reshape(condition_1d, data_shape) | |||
| condition = F.cast(condition, mstype.bool_) | |||
| value_fill = None | |||
| value_size = F.size(value) | |||
| value_size = const_utils.check_indices_value_size(indices_size, value_size) | |||
| if value_size == 1: | |||
| value_fill = F.fill(data_dtype, (indices_size,), 1) | |||
| value = F.cast(value, data_dtype) | |||
| value_fill = F.tensor_mul(value_fill, value) | |||
| elif value_size > 1: | |||
| value_fill = F.reshape(value, (indices_size,)) | |||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||
| u = F.reshape(value_1d, data_shape) | |||
| return F.select(condition, u, data) | |||
| def tensor_setitem_by_slice_with_tensor(data, input_slice, value): | |||
| """Assigns a tensor value to the tensor by slice.""" | |||
| result = None | |||
| check_result = const_utils.check_tensor_setitem_index(input_slice) | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| indices = const_utils.slice2indices(input_slice, data_shape) | |||
| is_tuple_int = const_utils.tuple_element_is_int(input_slice) | |||
| if is_tuple_int: | |||
| indices = const_utils.integer_to_indices(input_slice, data_shape) | |||
| result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) | |||
| return result | |||
| def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): | |||
| """Assigns the tensor by tuple with tensor value.""" | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if index_elements_type == const_utils.NO_TENSOR: | |||
| return tensor_setitem_by_slice_with_tensor(data, tuple_index, value) | |||
| if index_elements_type == const_utils.ALL_TENSOR: | |||
| indices = _generate_indices_from_tuple_of_tensor(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| else: | |||
| indices = _generate_indices_from_tuple_of_mixed_tensors(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| updates = _generate_updates_from_tensor(data, | |||
| indices, | |||
| value, | |||
| const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| return P.TensorScatterUpdate()(data, indices, updates) | |||
| def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): | |||
| """Assigns the tensor by tuple with tuple of value.""" | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if index_elements_type == const_utils.ALL_TENSOR: | |||
| return tensor_getitem_by_tuple_of_tensor(data, tuple_index) | |||
| return tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) | |||
| indices = _generate_indices_from_tuple_of_tensor(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| else: | |||
| indices = _generate_indices_from_tuple_of_mixed_tensors(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| updates = _generate_updates_from_tuple(data, | |||
| indices, | |||
| value, | |||
| const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| return P.TensorScatterUpdate()(data, indices, updates) | |||
| def tensor_setitem_by_number_with_number(data, index, value): | |||
| """Assigns the tensor by number with number value.""" | |||
| data_shape = F.shape(data) | |||
| indices = const_utils.integer_to_indices(index, data_shape) | |||
| return _tensor_indices_number(data, data_shape, index, indices, value) | |||
| def tensor_setitem_by_number_with_tensor(data, index, value): | |||
| """Assigns the tensor by number with tensor value.""" | |||
| data_shape = F.shape(data) | |||
| indices = const_utils.integer_to_indices(index, data_shape) | |||
| return _tensor_indices_tensor(data, data_shape, index, indices, value) | |||
| def tensor_setitem_by_ellipsis_with_number(data, index, value): | |||
| """Assigns the tensor by ellipsis with number value.""" | |||
| data_shape = F.shape(data) | |||
| data_dtype = F.dtype(data) | |||
| return F.fill(data_dtype, data_shape, value) | |||
| def tensor_setitem_by_ellipsis_with_tensor(data, index, value): | |||
| """Assigns the tensor by ellipsis with tensor value.""" | |||
| result = None | |||
| data_shape = F.shape(data) | |||
| data_dtype = F.dtype(data) | |||
| data_size = F.size(data) | |||
| value_shape = F.shape(value) | |||
| value_size = F.size(value) | |||
| check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) | |||
| if check_result: | |||
| if data_size == value_size: | |||
| result = F.reshape(value, data_shape) | |||
| result = F.cast(result, data_dtype) | |||
| elif value_size == 1: | |||
| param1 = F.fill(data_dtype, data_shape, 1) | |||
| param2 = F.cast(value, data_dtype) | |||
| result = F.tensor_mul(param1, param2) | |||
| return result | |||
| @@ -16,10 +16,8 @@ | |||
| """Implementation for setitem.""" | |||
| from . import _compile_utils as compile_utils | |||
| from . import _constexpr_utils as const_utils | |||
| from ... import functional as F | |||
| from ...composite import base | |||
| from ....common import dtype as mstype | |||
| setitem = base.MultitypeFuncGraph('setitem') | |||
| @@ -139,11 +137,7 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| index_dtype = F.dtype(index) | |||
| tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) | |||
| if tensor_dtype == const_utils.INT_: | |||
| return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor) | |||
| return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor) | |||
| return compile_utils.tensor_setitem_by_tensor_with_tensor(data, index, value_tensor) | |||
| @setitem.register("Tensor", "Tensor", "Number") | |||
| @@ -166,11 +160,7 @@ def _tensor_setitem_by_tensor_with_number(data, index, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| index_dtype = F.dtype(index) | |||
| tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) | |||
| if tensor_dtype == const_utils.BOOL_: | |||
| return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value) | |||
| return _tensor_setitem_by_int_tensor_with_scalar(data, index, value) | |||
| return compile_utils.tensor_setitem_by_tensor_with_number(data, index, value) | |||
| @setitem.register("Tensor", "Tuple", "Number") | |||
| @@ -191,24 +181,7 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if index_elements_type == const_utils.NO_TENSOR: | |||
| return _tensor_assgin_number(data, tuple_index, value) | |||
| if index_elements_type == const_utils.ALL_TENSOR: | |||
| indices = compile_utils.generate_indices_from_tuple_of_tensor(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| else: | |||
| indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| updates = compile_utils.generate_updates_from_scalar(data, | |||
| indices, | |||
| value, | |||
| const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| return F.scatter_nd_update(data, indices, updates) | |||
| return compile_utils.tensor_setitem_by_tuple_with_number(data, tuple_index, value) | |||
| @setitem.register("Tensor", "Tuple", "Tensor") | |||
| @@ -229,24 +202,7 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if index_elements_type == const_utils.NO_TENSOR: | |||
| return _tensor_assgin_tensor(data, tuple_index, value) | |||
| if index_elements_type == const_utils.ALL_TENSOR: | |||
| indices = compile_utils.generate_indices_from_tuple_of_tensor(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| else: | |||
| indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| updates = compile_utils.generate_updates_from_tensor(data, | |||
| indices, | |||
| value, | |||
| const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| return F.scatter_nd_update(data, indices, updates) | |||
| return compile_utils.tensor_setitem_by_tuple_with_tensor(data, tuple_index, value) | |||
| @setitem.register("Tensor", "Tuple", "Tuple") | |||
| @@ -268,22 +224,7 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if index_elements_type == const_utils.ALL_TENSOR: | |||
| indices = compile_utils.generate_indices_from_tuple_of_tensor(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| else: | |||
| indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, | |||
| tuple_index, | |||
| const_utils.TENSOR_SETITEM) | |||
| updates = compile_utils.generate_updates_from_tuple(data, | |||
| indices, | |||
| value, | |||
| const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| return F.scatter_nd_update(data, indices, updates) | |||
| return compile_utils.tensor_setitem_by_tuple_with_tuple(data, tuple_index, value) | |||
| @setitem.register("Tensor", "Tensor", "Tuple") | |||
| @@ -299,12 +240,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| index_dtype = F.dtype(index) | |||
| check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM) | |||
| result = None | |||
| if check_dtype: | |||
| result = _tensor_setitem_by_tensor_with_tuple(data, index, value) | |||
| return result | |||
| return compile_utils.tensor_setitem_by_tensor_with_tuple(data, index, value) | |||
| @setitem.register("Tensor", "Slice", "Tensor") | |||
| @@ -326,7 +262,7 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return _tensor_assgin_tensor(data, input_slice, value) | |||
| return compile_utils.tensor_setitem_by_slice_with_tensor(data, input_slice, value) | |||
| @setitem.register("Tensor", "Slice", "Number") | |||
| @@ -348,168 +284,28 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return _tensor_assgin_number(data, input_slice, value) | |||
| def _tensor_assgin_number(data, input_slice, value): | |||
| """Givens a scalar assign to tensor by slice""" | |||
| check_result = const_utils.check_tensor_setitem_index(input_slice) | |||
| result = None | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| indices = const_utils.slice2indices(input_slice, data_shape) | |||
| is_tuple_int = const_utils.tuple_element_is_int(input_slice) | |||
| if is_tuple_int: | |||
| indices = const_utils.integer_to_indices(input_slice, data_shape) | |||
| result = _tensor_indices_number(data, data_shape, input_slice, indices, value) | |||
| return result | |||
| return compile_utils.tensor_setitem_by_slice_with_number(data, input_slice, value) | |||
| @setitem.register("Tensor", "Number", "Number") | |||
| def _tensor_setitem_with_int_v1(data, index, value): | |||
| """Syntax: A[1] = 3""" | |||
| data_shape = F.shape(data) | |||
| indices = const_utils.integer_to_indices(index, data_shape) | |||
| return _tensor_indices_number(data, data_shape, index, indices, value) | |||
| return compile_utils.tensor_setitem_by_number_with_number(data, index, value) | |||
| @setitem.register("Tensor", "Number", "Tensor") | |||
| def _tensor_setitem_with_int_v2(data, index, value): | |||
| """Syntax: A[1] = Tensor""" | |||
| data_shape = F.shape(data) | |||
| indices = const_utils.integer_to_indices(index, data_shape) | |||
| return _tensor_indices_tensor(data, data_shape, index, indices, value) | |||
| return compile_utils.tensor_setitem_by_number_with_tensor(data, index, value) | |||
| @setitem.register("Tensor", "Ellipsis", "Number") | |||
| def _tensor_setitem_with_ellipsis_v1(data, index, value): | |||
| """Syntax: A[...] = number.""" | |||
| data_shape = F.shape(data) | |||
| data_dtype = F.dtype(data) | |||
| return F.fill(data_dtype, data_shape, value) | |||
| return compile_utils.tensor_setitem_by_ellipsis_with_number(data, index, value) | |||
| @setitem.register("Tensor", "Ellipsis", "Tensor") | |||
| def _tensor_setitem_with_ellipsis_v2(data, index, value): | |||
| """Syntax: A[...] = Tensor.""" | |||
| result = None | |||
| data_shape = F.shape(data) | |||
| data_dtype = F.dtype(data) | |||
| data_size = F.size(data) | |||
| value_shape = F.shape(value) | |||
| value_size = F.size(value) | |||
| check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) | |||
| if check_result: | |||
| if data_size == value_size: | |||
| result = F.reshape(value, data_shape) | |||
| result = F.cast(result, data_dtype) | |||
| elif value_size == 1: | |||
| param1 = F.fill(data_dtype, data_shape, 1) | |||
| param2 = F.cast(value, data_dtype) | |||
| result = F.tensor_mul(param1, param2) | |||
| return result | |||
| def _tensor_assgin_tensor(data, input_slice, value): | |||
| """Assigns a tensor value to the tensor by slice.""" | |||
| result = None | |||
| check_result = const_utils.check_tensor_setitem_index(input_slice) | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| indices = const_utils.slice2indices(input_slice, data_shape) | |||
| is_tuple_int = const_utils.tuple_element_is_int(input_slice) | |||
| if is_tuple_int: | |||
| indices = const_utils.integer_to_indices(input_slice, data_shape) | |||
| result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) | |||
| return result | |||
| def _tensor_indices_tensor(data, data_shape, index, indices, value): | |||
| """Assigns a tensor value to the tensor.""" | |||
| data_size = F.size(data) | |||
| data_dtype = F.dtype(data) | |||
| indices_size = F.size(indices) | |||
| indices_size = const_utils.check_indices(indices_size, index) | |||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||
| condition = F.reshape(condition_1d, data_shape) | |||
| condition = F.cast(condition, mstype.bool_) | |||
| value_fill = None | |||
| value_size = F.size(value) | |||
| value_size = const_utils.check_indices_value_size(indices_size, value_size) | |||
| if value_size == 1: | |||
| value_fill = F.fill(data_dtype, (indices_size,), 1) | |||
| value = F.cast(value, data_dtype) | |||
| value_fill = F.tensor_mul(value_fill, value) | |||
| elif value_size > 1: | |||
| value_fill = F.reshape(value, (indices_size,)) | |||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||
| u = F.reshape(value_1d, data_shape) | |||
| return F.select(condition, u, data) | |||
| def _tensor_indices_number(data, data_shape, index, indices, value): | |||
| """Assigns a scalar value to the tensor.""" | |||
| data_size = F.size(data) | |||
| data_dtype = F.dtype(data) | |||
| indices_size = F.size(indices) | |||
| indices_size = const_utils.check_indices(indices_size, index) | |||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||
| condition = F.reshape(condition_1d, data_shape) | |||
| condition = F.cast(condition, mstype.bool_) | |||
| value_fill = F.fill(data_dtype, (indices_size,), value) | |||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||
| u = F.reshape(value_1d, data_shape) | |||
| return F.select(condition, u, data) | |||
| def _tensor_setitem_by_tensor_with_tuple(data, index, value): | |||
| """Set a tensor item by a tensor with a tuple.""" | |||
| updates = compile_utils.generate_updates_from_tuple(data, index, value, | |||
| const_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| result = F.scatter_update(data, index, updates) | |||
| return result | |||
| def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): | |||
| """Set a tensor item by a int tensor with a scalar.""" | |||
| updates = compile_utils.generate_updates_from_scalar(data, index, value, | |||
| const_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| return F.scatter_update(data, index, updates) | |||
| def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): | |||
| """Set a tensor item by a bool tensor with a scalar.""" | |||
| index_shape = F.shape(index) | |||
| shape = F.shape(data) | |||
| shape = const_utils.check_equal( | |||
| shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||
| dtype = F.dtype(data) | |||
| u = F.fill(dtype, shape, value) | |||
| return F.select(index, u, data) | |||
| def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): | |||
| """Set a tensor item by a int tensor with a tensor.""" | |||
| updates = compile_utils.generate_updates_from_tensor(data, index, value, | |||
| const_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| return F.scatter_update(data, index, updates) | |||
| def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value): | |||
| """Set a tensor item by a bool tensor with a tensor.""" | |||
| index_shape = F.shape(index) | |||
| data_shape = F.shape(data) | |||
| data_shape = const_utils.check_equal(data_shape, index_shape, | |||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||
| size = F.size(value) | |||
| size = const_utils.check_equal(1, size, | |||
| "When assign value is a tensor, its size should be {}, but current size is {}.") | |||
| dtype = F.dtype(data) | |||
| u_cast = F.cast(value, dtype) | |||
| one_data = F.ones_like(data) | |||
| u = F.tensor_mul(one_data, u_cast) | |||
| result = F.select(index, u, data) | |||
| return result | |||
| return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, index, value) | |||
| @@ -20,10 +20,14 @@ from mindspore import Tensor, Parameter | |||
| from mindspore import context | |||
| from mindspore import dtype as mstype | |||
| from mindspore.nn import Cell | |||
| from mindspore.common.parameter import ParameterTuple | |||
| from mindspore.ops import composite as C | |||
| def setup_module(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
| class NetWorkSlicePositive(Cell): | |||
| def __init__(self): | |||
| super(NetWorkSlicePositive, self).__init__() | |||
| @@ -139,7 +143,7 @@ class TensorGetItemByThreeTensors(Cell): | |||
| return ret0, ret1, ret2 | |||
| def Xtest_getitem_by_tensors(): | |||
| def test_getitem_by_tensors(): | |||
| net = TensorGetItemByThreeTensors() | |||
| input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) | |||
| index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32) | |||
| @@ -155,119 +159,140 @@ def Xtest_getitem_by_tensors(): | |||
| assert np.all(output2.asnumpy() == input_x[index_0, index_1, index_2] + np.ones([5, 3, 4, 5])) | |||
| class TensorGetItemByMixedTensors_0(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors_0, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32)) | |||
| def construct(self, tensor, index_0, index_1): | |||
| ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const | |||
| return ret | |||
| class TensorGetItemByMixedTensors_1(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors_1, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32)) | |||
| class TensorGetItemByMixedTensorsBasicCase(Cell): | |||
| def __init__(self, c0, c1, c2, c3, c4, c5): | |||
| super(TensorGetItemByMixedTensorsBasicCase, self).__init__() | |||
| self.const0 = Tensor(c0) | |||
| self.const1 = Tensor(c1) | |||
| self.const2 = Tensor(c2) | |||
| self.const3 = Tensor(c3) | |||
| self.const4 = Tensor(c4) | |||
| self.const5 = Tensor(c5) | |||
| def construct(self, tensor, index_0, index_1): | |||
| ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const | |||
| return ret | |||
| class TensorGetItemByMixedTensors_2(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors_2, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32)) | |||
| def construct(self, tensor, index_0, index_1): | |||
| ret = tensor[0, index_0, index_1, ..., 3] + self.const | |||
| return ret | |||
| class TensorGetItemByMixedTensors_3(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors_3, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32)) | |||
| def construct(self, tensor, index_0, index_1): | |||
| ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const | |||
| return ret | |||
| class TensorGetItemByMixedTensors_4(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors_4, self).__init__() | |||
| self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32)) | |||
| def construct(self, tensor, index_0, index_1, index_2): | |||
| ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const | |||
| return ret | |||
| class TensorGetItemByMixedTensors_5(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors_5, self).__init__() | |||
| self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32)) | |||
| def construct(self, tensor, index_0, index_1, index_2): | |||
| ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const | |||
| return ret | |||
| class TensorGetItemByMixedTensors_6(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensors_6, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32)) | |||
| def construct(self, tensor, index_0, index_1, index_2): | |||
| ret = tensor[..., index_0, index_1, index_2, 3] + self.const | |||
| return ret | |||
| ret0 = tensor[index_0, index_1, 0:3] + self.const0 | |||
| ret1 = tensor[0:3, index_0, ...] + self.const1 | |||
| ret2 = tensor[0, index_0, index_1] + self.const2 | |||
| ret3 = tensor[..., index_0, 0:3] + self.const3 | |||
| ret4 = tensor[0:2, index_0, index_1] + self.const4 | |||
| ret5 = tensor[..., index_0, index_1] + self.const5 | |||
| return ret0, ret1, ret2, ret3, ret4, ret5 | |||
| def test_getitem_by_mixed_tensors(): | |||
| const0 = np.ones((3, 4, 5, 3), np.float32) | |||
| const1 = np.ones((3, 3, 4, 5, 5), np.float32) | |||
| const2 = np.ones((3, 4, 5), np.float32) | |||
| const3 = np.ones((3, 3, 4, 5, 3), np.float32) | |||
| const4 = np.ones((2, 3, 4, 5), np.float32) | |||
| const5 = np.ones((3, 3, 4, 5), np.float32) | |||
| net = TensorGetItemByMixedTensorsBasicCase(const0, const1, const2, const3, const4, const5) | |||
| input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32) | |||
| input_ms = Tensor(input_np, mstype.float32) | |||
| index_np_0 = np.random.randint(3, size=(3, 4, 5)).astype(np.int32) | |||
| index_np_1 = np.random.randint(4, size=(4, 5)).astype(np.int32) | |||
| index_0 = Tensor(index_np_0, mstype.int32) | |||
| index_1 = Tensor(index_np_1, mstype.int32) | |||
| out0, out1, out2, out3, out4, out5 = net(input_ms, index_0, index_1) | |||
| assert np.all(out0.asnumpy() == (input_np[index_np_0, index_np_1, 0:3] + const0)) | |||
| assert np.all(out1.asnumpy() == (input_np[0:3, index_np_0, ...] + const1)) | |||
| assert np.all(out2.asnumpy() == (input_np[0, index_np_0, index_np_1] + const2)) | |||
| assert np.all(out3.asnumpy() == (input_np[..., index_np_0, 0:3] + const3)) | |||
| assert np.all(out4.asnumpy() == (input_np[0:2, index_np_0, index_np_1] + const4)) | |||
| assert np.all(out5.asnumpy() == (input_np[..., index_np_0, index_np_1] + const5)) | |||
| class TensorSetItemByMixedTensors_0(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByMixedTensors_0, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32)) | |||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), | |||
| self.const = Tensor(np.ones((3, 4, 5), np.float32)) | |||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), | |||
| mstype.float32), | |||
| name="x") | |||
| self.value = value | |||
| def construct(self, index_0, index_1, index_2): | |||
| self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value | |||
| self.param[0:2, index_0, index_1] = self.value | |||
| ret = self.param + self.const | |||
| return ret | |||
| def test_setitem_by_mixed_tensors_0(): | |||
| value = 88.0 | |||
| net = TensorSetItemByMixedTensors_0(value) | |||
| index_0 = np.random.randint(3, size=(3, 4, 5)) | |||
| index_1 = np.random.randint(4, size=(4, 5)) | |||
| index_2 = np.random.randint(3, size=(2, 1, 4, 5)) | |||
| index_0_ms = Tensor(index_0, mstype.int32) | |||
| index_1_ms = Tensor(index_1, mstype.int32) | |||
| index_2_ms = Tensor(index_2, mstype.int32) | |||
| input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32) | |||
| const = np.ones((3, 4, 5), np.float32) | |||
| out = net(index_0_ms, index_1_ms, index_2_ms) | |||
| input_np[0:2, index_0, index_1] = value | |||
| assert np.all(out.asnumpy() == (input_np + const)) | |||
| class TensorSetItemByMixedTensors_1(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByMixedTensors_1, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32)) | |||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||
| self.const = Tensor(np.ones((3, 4, 5), np.float32)) | |||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32), | |||
| name="x") | |||
| self.value = value | |||
| def construct(self, index_0, index_1, index_2): | |||
| self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value | |||
| self.param[0:2, index_0, ...] = self.value | |||
| ret = self.param + self.const | |||
| return ret | |||
| def test_setitem_by_mixed_tensors_1(): | |||
| value = 88.0 | |||
| net = TensorSetItemByMixedTensors_1(value) | |||
| index_0 = np.random.randint(3, size=(3, 4, 5)) | |||
| index_1 = np.random.randint(4, size=(4, 5)) | |||
| index_2 = np.random.randint(3, size=(2, 1, 4, 5)) | |||
| index_0_ms = Tensor(index_0, mstype.int32) | |||
| index_1_ms = Tensor(index_1, mstype.int32) | |||
| index_2_ms = Tensor(index_2, mstype.int32) | |||
| input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32) | |||
| const = np.ones((3, 4, 5), np.float32) | |||
| out = net(index_0_ms, index_1_ms, index_2_ms) | |||
| input_np[0:2, index_0, ...] = value | |||
| assert np.all(out.asnumpy() == (input_np + const)) | |||
| class TensorSetItemByMixedTensors_2(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByMixedTensors_2, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float16)) | |||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float16), | |||
| self.const = Tensor(np.ones((3, 4, 5), np.float16)) | |||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float16), | |||
| name="x") | |||
| self.value = value | |||
| def construct(self, index_0, index_1, index_2): | |||
| self.param[..., index_0, index_1, index_2, 3] = self.value | |||
| self.param[..., index_0, 1] = self.value | |||
| ret = self.param + self.const | |||
| return ret | |||
| def test_setitem_by_mixed_tensors_2(): | |||
| value = 88.0 | |||
| net = TensorSetItemByMixedTensors_2(value) | |||
| index_0 = np.random.randint(3, size=(3, 4, 5)) | |||
| index_1 = np.random.randint(4, size=(4, 5)) | |||
| index_2 = np.random.randint(3, size=(2, 1, 4, 5)) | |||
| index_0_ms = Tensor(index_0, mstype.int32) | |||
| index_1_ms = Tensor(index_1, mstype.int32) | |||
| index_2_ms = Tensor(index_2, mstype.int32) | |||
| input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32) | |||
| const = np.ones((3, 4, 5), np.float32) | |||
| out = net(index_0_ms, index_1_ms, index_2_ms) | |||
| input_np[..., index_0, 1] = value | |||
| assert np.all(out.asnumpy() == (input_np + const)) | |||
| class TensorGetItemByMixedTensorsTypeError(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensorsTypeError, self).__init__() | |||
| @@ -277,13 +302,13 @@ class TensorGetItemByMixedTensorsTypeError(Cell): | |||
| return ret | |||
| class TensorGetItemByMixedTensorsNumberError(Cell): | |||
| def __init__(self): | |||
| super(TensorGetItemByMixedTensorsNumberError, self).__init__() | |||
| def construct(self, x, index_0, index_1): | |||
| ret = x[index_0, index_1, 0:3, ..., index_1, index_0] | |||
| return ret | |||
| def test_getitem_by_mixedtensor_exception(): | |||
| input_ms = Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32) | |||
| index_0 = Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32) | |||
| index_1 = Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32) | |||
| net1 = TensorGetItemByMixedTensorsTypeError() | |||
| with pytest.raises(TypeError): | |||
| net1(input_ms, index_0, index_1) | |||
| class TensorSetItemByOneTensorWithNumber(Cell): | |||
| @@ -299,6 +324,18 @@ class TensorSetItemByOneTensorWithNumber(Cell): | |||
| return ret | |||
| def test_setitem_one_tensor_with_number(): | |||
| value = 0.0 | |||
| net = TensorSetItemByOneTensorWithNumber(value) | |||
| index_np = np.random.randint(4, size=(5, 4)) | |||
| index = Tensor(index_np, mstype.int32) | |||
| input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)) | |||
| const = np.ones((6, 7, 8)).astype(np.float32) | |||
| out = net(index) | |||
| input_data[index_np] = value | |||
| assert np.all(out.asnumpy() == (input_data + const)) | |||
| class TensorSetItemByOneTensorWithTensor(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByOneTensorWithTensor, self).__init__() | |||
| @@ -311,6 +348,19 @@ class TensorSetItemByOneTensorWithTensor(Cell): | |||
| return ret | |||
| def test_setitem_by_one_tensor_with_tensor(): | |||
| net = TensorSetItemByOneTensorWithTensor() | |||
| index_np = np.random.randint(4, size=(5, 4)) | |||
| index = Tensor(index_np, mstype.int32) | |||
| input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)) | |||
| const = np.ones((6, 7, 8)).astype(np.float32) | |||
| value = np.zeros((4, 7, 8)).astype(np.float32) | |||
| value_ms = Tensor(value, mstype.float32) | |||
| out = net(index, value_ms) | |||
| input_data[index_np] = value | |||
| assert np.all(out.asnumpy() == (input_data + const)) | |||
| class TensorSetItemByOneTensorWithTupleOfNumber(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__() | |||
| @@ -324,6 +374,18 @@ class TensorSetItemByOneTensorWithTupleOfNumber(Cell): | |||
| return ret | |||
| def test_setitem_by_one_tensor_with_tuple_number(): | |||
| value = (0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7) | |||
| net = TensorSetItemByOneTensorWithTupleOfNumber(value) | |||
| input_np = np.random.randint(5, size=(5, 4)) | |||
| input_ms = Tensor(input_np, mstype.int32) | |||
| input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32) | |||
| const = np.ones((6, 7, 8)).astype(np.float32) | |||
| out = net(input_ms) | |||
| input_data[input_np] = value | |||
| assert np.all(out.asnumpy() == (input_data + const)) | |||
| class TensorSetItemByOneTensorWithTupleOfTensor(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__() | |||
| @@ -336,6 +398,23 @@ class TensorSetItemByOneTensorWithTupleOfTensor(Cell): | |||
| return ret | |||
| def test_setitem_by_one_tensor_with_tuple_tensors(): | |||
| net = TensorSetItemByOneTensorWithTupleOfTensor() | |||
| input_np = np.random.randint(6, size=(5, 4)).astype(np.int32) | |||
| input_ms = Tensor(input_np, mstype.int32) | |||
| input_data = np.arange(6 * 3 * 8).reshape((6, 3, 8)).astype(np.float32) | |||
| value_0_np = np.zeros((8,), np.float32) | |||
| value_1_np = np.ones((8,), np.float32) | |||
| value_2_np = np.ones((8,), np.float32)*2 | |||
| value_0 = Tensor(value_0_np) | |||
| value_1 = Tensor(value_1_np) | |||
| value_2 = Tensor(value_2_np) | |||
| const = np.ones((6, 3, 8)).astype(np.float32) | |||
| out = net(input_ms, value_0, value_1, value_2) | |||
| input_data[input_np] = (value_0_np, value_1_np, value_2_np) | |||
| assert np.all(out.asnumpy() == (input_data + const)) | |||
| class TensorSetItemByTensorsWithNumber(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByTensorsWithNumber, self).__init__() | |||
| @@ -349,6 +428,22 @@ class TensorSetItemByTensorsWithNumber(Cell): | |||
| return ret | |||
| def test_setitem_by_tensors_with_number(): | |||
| value = 0.0 | |||
| net = TensorSetItemByTensorsWithNumber(value) | |||
| index_0 = np.random.randint(6, size=(3, 4, 5)) | |||
| index_1 = np.random.randint(7, size=(4, 5)) | |||
| index_2 = np.random.randint(8, size=(5, 3, 4, 5)) | |||
| index_0_ms = Tensor(index_0, mstype.int32) | |||
| index_1_ms = Tensor(index_1, mstype.int32) | |||
| index_2_ms = Tensor(index_2, mstype.int32) | |||
| out = net(index_0_ms, index_1_ms, index_2_ms) | |||
| const = np.ones((6, 7, 8)).astype(np.float32) | |||
| input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32) | |||
| input_data[index_0, index_1, index_2] = value | |||
| assert np.all(out.asnumpy() == (input_data + const)) | |||
| class TensorSetItemByTensorsWithTensor(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByTensorsWithTensor, self).__init__() | |||
| @@ -361,6 +456,23 @@ class TensorSetItemByTensorsWithTensor(Cell): | |||
| return ret | |||
| def test_setitem_by_tensors_with_tensor(): | |||
| net = TensorSetItemByTensorsWithTensor() | |||
| index_0 = np.random.randint(6, size=(3, 4, 5)) | |||
| index_1 = np.random.randint(7, size=(4, 5)) | |||
| index_2 = np.random.randint(8, size=(5, 3, 4, 5)) | |||
| value = np.zeros((4, 5)).astype(np.float32) | |||
| index_0_ms = Tensor(index_0, mstype.int32) | |||
| index_1_ms = Tensor(index_1, mstype.int32) | |||
| index_2_ms = Tensor(index_2, mstype.int32) | |||
| value_ms = Tensor(value, mstype.float32) | |||
| out = net(index_0_ms, index_1_ms, index_2_ms, value_ms) | |||
| const = np.ones((6, 7, 8)).astype(np.float32) | |||
| input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32) | |||
| input_data[index_0, index_1, index_2] = value | |||
| assert np.all(out.asnumpy() == (input_data + const)) | |||
| class TensorSetItemByTensorsWithTensorNumberError(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByTensorsWithTensorNumberError, self).__init__() | |||
| @@ -373,6 +485,17 @@ class TensorSetItemByTensorsWithTensorNumberError(Cell): | |||
| return ret | |||
| def test_setitem_by_tensors_with_tensor_error(): | |||
| index_0 = Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32) | |||
| index_1 = Tensor(np.random.randint(7, size=(4, 5)), mstype.int32) | |||
| index_2 = Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32) | |||
| index_3 = Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32) | |||
| value = Tensor(np.zeros((2, 5)), mstype.float32) | |||
| net = TensorSetItemByTensorsWithTensorNumberError() | |||
| with pytest.raises(IndexError): | |||
| net(index_0, index_1, index_2, index_3, value) | |||
| class TensorSetItemByTensorsWithTupleOfNumber(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__() | |||
| @@ -386,6 +509,22 @@ class TensorSetItemByTensorsWithTupleOfNumber(Cell): | |||
| return ret | |||
| def test_setitem_by_tensors_with_tuple_of_number(): | |||
| value = (0.0, 1.1, 2.2, 3.3, 4.4) | |||
| net = TensorSetItemByTensorsWithTupleOfNumber(value) | |||
| index_0 = np.random.randint(6, size=(3, 4, 5)) | |||
| index_1 = np.random.randint(7, size=(4, 5)) | |||
| index_2 = np.random.randint(8, size=(5, 3, 4, 5)) | |||
| index_0_ms = Tensor(index_0, mstype.int32) | |||
| index_1_ms = Tensor(index_1, mstype.int32) | |||
| index_2_ms = Tensor(index_2, mstype.int32) | |||
| input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32) | |||
| input_data[index_0, index_1, index_2] = value | |||
| const = np.ones((6, 7, 8)).astype(np.float32) | |||
| out = net(index_0_ms, index_1_ms, index_2_ms) | |||
| assert np.all(out.asnumpy() == (input_data + const)) | |||
| class TensorSetItemByTensorsWithTupleOfTensor(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__() | |||
| @@ -398,6 +537,27 @@ class TensorSetItemByTensorsWithTupleOfTensor(Cell): | |||
| return ret | |||
| def test_setitem_by_tensors_with_tuple_of_tensor(): | |||
| value_0 = np.zeros((4, 5)) | |||
| value_1 = np.ones((4, 5)) | |||
| value_2 = np.ones((4, 5)) * 2 | |||
| value_0_ms = Tensor(value_0, mstype.float32) | |||
| value_1_ms = Tensor(value_1, mstype.float32) | |||
| value_2_ms = Tensor(value_2, mstype.float32) | |||
| net = TensorSetItemByTensorsWithTupleOfTensor() | |||
| index_0 = np.random.randint(6, size=(3, 4, 5)) | |||
| index_1 = np.random.randint(7, size=(4, 5)) | |||
| index_2 = np.random.randint(8, size=(5, 3, 4, 5)) | |||
| index_0_ms = Tensor(index_0, mstype.int32) | |||
| index_1_ms = Tensor(index_1, mstype.int32) | |||
| index_2_ms = Tensor(index_2, mstype.int32) | |||
| input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32) | |||
| input_data[index_0, index_1, index_2] = (value_0, value_1, value_2) | |||
| const = np.ones((6, 7, 8)).astype(np.float32) | |||
| out = net(index_0_ms, index_1_ms, index_2_ms, value_0_ms, value_1_ms, value_2_ms) | |||
| assert np.all(out.asnumpy() == (input_data + const)) | |||
| class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__() | |||
| @@ -410,17 +570,44 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): | |||
| return ret | |||
| class TensorSetItemByMixedTensors(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByMixedTensors, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.value = 99.0 | |||
| def test_setitem_by_tensor_with_tuple_of_tensor_error(): | |||
| net = TensorSetItemByTensorsWithTupleOfTensorNumberError() | |||
| index_0_ms = Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32) | |||
| index_1_ms = Tensor(np.random.randint(7, size=(4, 5)), mstype.int32) | |||
| index_2_ms = Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32) | |||
| value_0 = np.zeros((4, 5)) | |||
| value_1 = np.ones((4, 5)) | |||
| value_0_ms = Tensor(value_0, mstype.float32) | |||
| value_1_ms = Tensor(value_1, mstype.float32) | |||
| with pytest.raises(ValueError): | |||
| net(index_0_ms, index_1_ms, index_2_ms, value_0_ms, value_1_ms) | |||
| def construct(self, index_0, index_1): | |||
| self.param[index_0, index_1, 0:6] = self.value | |||
| ret = self.param + self.const | |||
| return ret | |||
| def test_setitem_grad(): | |||
| class Net(Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.weight = Parameter( | |||
| Tensor(np.ones([4, 4, 5]), dtype=mstype.float32), "b1", requires_grad=True) | |||
| def construct(self, a, b): | |||
| a[1:3:1, ::] = b | |||
| c = a + self.weight | |||
| return c | |||
| class GradNet(Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, x, y, sens): | |||
| return C.grad_by_list_with_sens(self.net, self.weights)(x, y, sens) | |||
| net = GradNet(Net()) | |||
| x = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32) | |||
| y = Tensor(np.array([3]).astype(np.float32), mstype.float32) | |||
| sens = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32) | |||
| net(x, y, sens) | |||
| class TensorAssignWithSliceError1(Cell): | |||
| @@ -475,7 +662,6 @@ class TensorAssignWithSlice(Cell): | |||
| def test_tensor_assign(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| net = TensorAssignWithSlice() | |||
| net2 = TensorAssignWithSlice2() | |||
| net_e1 = TensorAssignWithSliceError1() | |||
| @@ -621,7 +807,7 @@ class TensorAssignWithTupleInteger(Cell): | |||
| class TensorAssignWithBoolTensorIndex(Cell): | |||
| def __init__(self): | |||
| super(TensorAssignWithBoolTensorIndex, self).__init__() | |||
| self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32) | |||
| self.t = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| self.u_scalar = 5 | |||
| def construct(self, a, b, c, u_tensor): | |||
| @@ -643,8 +829,7 @@ class TensorAssignWithBoolTensorIndexError(Cell): | |||
| class TensorAssignWithBoolTensorIndex2(Cell): | |||
| def __init__(self): | |||
| super(TensorAssignWithBoolTensorIndex2, self).__init__() | |||
| self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32) | |||
| self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32) | |||
| self.t = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| self.u_scalar = 5 | |||
| def construct(self, a, u_tensor): | |||
| @@ -666,7 +851,40 @@ class TensorAssignWithBoolTensorIndex2Error(Cell): | |||
| return a | |||
| def Xtest_tensor_assign_bool_index(): | |||
| def test_tensor_assign_bool_index_0(): | |||
| a = np.arange(60).reshape(3, 4, 5) | |||
| b = a > 5 | |||
| c = a < 3 | |||
| Ta = Tensor(a, dtype=mstype.float32) | |||
| Tb = Tensor(b) | |||
| Tc = Tensor(c) | |||
| u_tensor = Tensor([1], dtype=mstype.float32) | |||
| net1 = TensorAssignWithBoolTensorIndex() | |||
| out = net1(Ta, Tb, Tc, u_tensor) | |||
| res = np.arange(60).reshape(3, 4, 5) | |||
| res[c] = 5 | |||
| res[b] = 1 | |||
| res = res + np.ones([3, 4, 5]) | |||
| assert np.all(out.asnumpy() == res) | |||
| def test_tensor_assign_bool_index_1(): | |||
| a = np.arange(60).reshape(3, 4, 5) | |||
| Ta = Tensor(a, dtype=mstype.float32) | |||
| u_tensor = Tensor([1], dtype=mstype.float32) | |||
| net2 = TensorAssignWithBoolTensorIndex2() | |||
| out = net2(Ta, u_tensor) | |||
| res = np.arange(60).reshape(3, 4, 5) | |||
| res[res > 8] = 1 | |||
| res[res >= 6] = 5 | |||
| res[res < 3] = 5 | |||
| res[res <= 5] = 1 | |||
| res[res == 5] = 5 | |||
| res = res + np.ones([3, 4, 5]) | |||
| assert np.all(out.asnumpy() == res) | |||
| def test_tensor_assign_bool_index_exception(): | |||
| a = np.arange(60).reshape(3, 4, 5) | |||
| b = a > 5 | |||
| c = a < 3 | |||
| @@ -679,8 +897,6 @@ def Xtest_tensor_assign_bool_index(): | |||
| u_scalar = 5 | |||
| net1 = TensorAssignWithBoolTensorIndex() | |||
| net2 = TensorAssignWithBoolTensorIndex2() | |||
| net1(Ta, Tb, Tc, u_tensor) | |||
| net1(Ta, Tb, Tc, u_tensor) | |||
| with pytest.raises(ValueError): | |||
| net1(Ta, Td, Tc, u_tensor) | |||
| with pytest.raises(IndexError): | |||
| @@ -695,14 +911,14 @@ def Xtest_tensor_assign_bool_index(): | |||
| with pytest.raises(ValueError): | |||
| net2(Ta, u_tensor_error) | |||
| net3 = TensorAssignWithBoolTensorIndexError() | |||
| with pytest.raises(AttributeError): | |||
| with pytest.raises(IndexError): | |||
| net3(Ta, Tb, Tc, u_tensor) | |||
| with pytest.raises(AttributeError): | |||
| with pytest.raises(IndexError): | |||
| net3(Ta, Tb, Tc, u_scalar) | |||
| net4 = TensorAssignWithBoolTensorIndex2Error() | |||
| with pytest.raises(AttributeError): | |||
| with pytest.raises(IndexError): | |||
| net4(Ta, u_tensor) | |||
| with pytest.raises(AttributeError): | |||
| with pytest.raises(IndexError): | |||
| net4(Ta, u_scalar) | |||