1) A[B]=U
2) A[A>n]=U
A.shape == B.shape
U is a scalar or Tensor(size==1)
B is Tensor(dtype=bool)
n is a Number
Signed-off-by: candanzg <zhangshucheng@huawei.com>
tags/v0.2.0-alpha
| @@ -83,6 +83,7 @@ convert_object_map = { | |||||
| T.mul: multitype_ops.mul, | T.mul: multitype_ops.mul, | ||||
| T.truediv: multitype_ops.div, | T.truediv: multitype_ops.div, | ||||
| T.getitem: multitype_ops.getitem, | T.getitem: multitype_ops.getitem, | ||||
| T.setitem: multitype_ops.setitem, | |||||
| T.floordiv: multitype_ops.floordiv, | T.floordiv: multitype_ops.floordiv, | ||||
| T.mod: multitype_ops.mod, | T.mod: multitype_ops.mod, | ||||
| T.pow: multitype_ops.pow_, | T.pow: multitype_ops.pow_, | ||||
| @@ -118,7 +119,6 @@ convert_object_map = { | |||||
| T.iter: M.ms_iter, | T.iter: M.ms_iter, | ||||
| T.next: M.ms_next, | T.next: M.ms_next, | ||||
| T.hasnext: M.hasnext, | T.hasnext: M.hasnext, | ||||
| T.setitem: M.setitem, | |||||
| T.make_tuple: F.make_tuple, | T.make_tuple: F.make_tuple, | ||||
| T.make_dict: F.make_dict, | T.make_dict: F.make_dict, | ||||
| @@ -23,6 +23,7 @@ from .pow_impl import pow_ | |||||
| from .floordiv_impl import floordiv | from .floordiv_impl import floordiv | ||||
| from .mod_impl import mod | from .mod_impl import mod | ||||
| from .getitem_impl import getitem | from .getitem_impl import getitem | ||||
| from .setitem_impl import setitem | |||||
| from .zeros_like_impl import zeros_like | from .zeros_like_impl import zeros_like | ||||
| from .ones_like_impl import ones_like | from .ones_like_impl import ones_like | ||||
| from .equal_impl import equal | from .equal_impl import equal | ||||
| @@ -55,6 +56,7 @@ __all__ = [ | |||||
| 'greater_equal', | 'greater_equal', | ||||
| 'negative', | 'negative', | ||||
| 'getitem', | 'getitem', | ||||
| 'setitem', | |||||
| 'logical_and', | 'logical_and', | ||||
| 'logical_or', | 'logical_or', | ||||
| 'logical_not' | 'logical_not' | ||||
| @@ -0,0 +1,45 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """constexpr util""" | |||||
| from ...primitive import constexpr | |||||
| @constexpr | |||||
| def is_same_type(inst, type_): | |||||
| """ | |||||
| Check whether an object is an instance of a target type. | |||||
| Inputs: | |||||
| inst (mindspore.dtype): Inspected type. | |||||
| type_ (mindspore.dtype): Target type. | |||||
| Outputs: | |||||
| bool, the check result. | |||||
| """ | |||||
| return inst == type_ | |||||
| @constexpr | |||||
| def error_msg(msg="", format_values=""): | |||||
| """ | |||||
| Used to throw exception information. | |||||
| Inputs: | |||||
| msg (str): information content. | |||||
| """ | |||||
| raise ValueError(msg.format(*format_values)) | |||||
| @@ -0,0 +1,194 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Implementation for setitem.""" | |||||
| from ...composite import base | |||||
| from ....common import dtype as mstype | |||||
| from ... import functional as F | |||||
| from . import _multitype_ops_util as mult_util | |||||
| setitem = base.MultitypeFuncGraph('setitem') | |||||
| @setitem.register("List", "Number", "String") | |||||
| def _list_setitem_with_string(data, number_index, value): | |||||
| """ | |||||
| Assign value to list. | |||||
| Inputs: | |||||
| data (list): Data of type lis. | |||||
| number_index (Number): Index of data. | |||||
| value (String): Value given. | |||||
| Outputs: | |||||
| List, type is same as the element type of data. | |||||
| """ | |||||
| return F.list_setitem(data, number_index, value) | |||||
| @setitem.register("List", "Number", "Number") | |||||
| def _list_setitem_with_number(data, number_index, value): | |||||
| """ | |||||
| Assign value to list. | |||||
| Inputs: | |||||
| data (list): Data of type lis. | |||||
| number_index (Number): Index of data. | |||||
| value (Number): Value given. | |||||
| Outputs: | |||||
| List, type is same as the element type of data. | |||||
| """ | |||||
| return F.list_setitem(data, number_index, value) | |||||
| @setitem.register("List", "Number", "Tensor") | |||||
| def _list_setitem_with_Tensor(data, number_index, value): | |||||
| """ | |||||
| Assign value to list. | |||||
| Inputs: | |||||
| data (list): Data of type lis. | |||||
| number_index (Number): Index of data. | |||||
| value (Tensor): Value given. | |||||
| Outputs: | |||||
| List, type is same as the element type of data. | |||||
| """ | |||||
| return F.list_setitem(data, number_index, value) | |||||
| @setitem.register("List", "Number", "List") | |||||
| def _list_setitem_with_List(data, number_index, value): | |||||
| """ | |||||
| Assign value to list. | |||||
| Inputs: | |||||
| data (list): Data of type lis. | |||||
| number_index (Number): Index of data. | |||||
| value (List): Value given. | |||||
| Outputs: | |||||
| List, type is same as the element type of data. | |||||
| """ | |||||
| return F.list_setitem(data, number_index, value) | |||||
| @setitem.register("Dictionary", "String", "Tensor") | |||||
| def _dict_setitem_with_tensor(data, key, value): | |||||
| """ | |||||
| Assign value to dictionary. | |||||
| Inputs: | |||||
| data (Dictionary): Data of type dict. | |||||
| key (str): Key of the data. | |||||
| value (Tensor): Value given. | |||||
| Outputs: | |||||
| Dict, type is as same as the element type of data. | |||||
| """ | |||||
| return F.dict_setitem(data, key, value) | |||||
| @setitem.register("Dictionary", "String", "Number") | |||||
| def _dict_setitem_with_number(data, key, value): | |||||
| """ | |||||
| Assign value to dictionary. | |||||
| Inputs: | |||||
| data (Dictionary): Data of type dict. | |||||
| key (str): Key of the data. | |||||
| value (Number): Value given. | |||||
| Outputs: | |||||
| Dict, type is as same as the element type of data. | |||||
| """ | |||||
| return F.dict_setitem(data, key, value) | |||||
| @setitem.register("Tensor", "Tensor", "Tensor") | |||||
| def _tensor_setitem_by_tensor_v1(data, index, value_tensor): | |||||
| """ | |||||
| Tensor assignment. | |||||
| Note: | |||||
| Syntax support: A[B] = U and A[A>n] = U. | |||||
| Restraint condition: 1) A, U is a Tensor, and B is a bool Tensor. | |||||
| 2) A.shape == B.shape | |||||
| 3) U.size == 1 | |||||
| 4) n is a number | |||||
| Inputs: | |||||
| data (Tensor): Assigned tensor. | |||||
| index (Tensor): Tensor of bool type. | |||||
| value_tensor (Tensor): Tensor with size 1. | |||||
| Outputs: | |||||
| Tensor, element type and shape is same as data. | |||||
| """ | |||||
| index_dtype = F.dtype(index) | |||||
| index_shape = F.shape(index) | |||||
| is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) | |||||
| if not is_bool: | |||||
| return mult_util.error_msg( | |||||
| "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) | |||||
| data_shape = F.shape(data) | |||||
| if index_shape != data_shape: | |||||
| return mult_util.error_msg( | |||||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (data_shape, index_shape)) | |||||
| size = F.size(value_tensor) | |||||
| if size != 1: | |||||
| return mult_util.error_msg( | |||||
| "When assign value is a tensor, its size should be 1, but current size is {}.", (size,)) | |||||
| dtype = F.dtype(data) | |||||
| u_cast = F.cast(value_tensor, dtype) | |||||
| one_data = F.ones_like(data) | |||||
| u = F.tensor_mul(one_data, u_cast) | |||||
| return F.select(index, u, data) | |||||
| @setitem.register("Tensor", "Tensor", "Number") | |||||
| def _tensor_setitem_by_tensor_v2(data, index, value): | |||||
| """ | |||||
| Tensor assignment. | |||||
| Note: | |||||
| Syntax support: A[B] = u and A[A>n] = u. | |||||
| Restraint condition: 1) A is a Tensor, and B is a bool Tensor. | |||||
| 2) A.shape == B.shape | |||||
| 3) u is a scalar | |||||
| 4) n is a number | |||||
| Inputs: | |||||
| data (Tensor): Assigned tensor. | |||||
| index (Tensor): Tensor of bool type. | |||||
| value_tensor (Number): Assignment value. | |||||
| Outputs: | |||||
| Tensor, element type and shape is same as data. | |||||
| """ | |||||
| index_dtype = F.dtype(index) | |||||
| index_shape = F.shape(index) | |||||
| is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) | |||||
| if not is_bool: | |||||
| return mult_util.error_msg( | |||||
| "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) | |||||
| shape = F.shape(data) | |||||
| if index_shape != shape: | |||||
| return mult_util.error_msg( | |||||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (shape, index_shape)) | |||||
| dtype = F.dtype(data) | |||||
| u = F.fill(dtype, shape, value) | |||||
| return F.select(index, u, data) | |||||
| @@ -31,6 +31,9 @@ dtype = P.DType() | |||||
| issubclass_ = P.IsSubClass() | issubclass_ = P.IsSubClass() | ||||
| isinstance_ = P.IsInstance() | isinstance_ = P.IsInstance() | ||||
| fill = P.Fill() | fill = P.Fill() | ||||
| select = P.Select() | |||||
| size = P.Size() | |||||
| ones_like = P.OnesLike() | |||||
| shape = P.Shape() | shape = P.Shape() | ||||
| rank = P.Rank() | rank = P.Rank() | ||||
| reshape = P.Reshape() | reshape = P.Reshape() | ||||
| @@ -68,7 +71,9 @@ scalar_cast = P.ScalarCast() | |||||
| tuple_setitem = Primitive('tuple_setitem') | tuple_setitem = Primitive('tuple_setitem') | ||||
| tuple_getitem = Primitive('tuple_getitem') | tuple_getitem = Primitive('tuple_getitem') | ||||
| list_getitem = Primitive('list_getitem') | list_getitem = Primitive('list_getitem') | ||||
| list_setitem = Primitive('list_setitem') | |||||
| dict_getitem = Primitive('dict_getitem') | dict_getitem = Primitive('dict_getitem') | ||||
| dict_setitem = Primitive('dict_setitem') | |||||
| tuple_div = Primitive("tuple_div") | tuple_div = Primitive("tuple_div") | ||||
| tuple_len = Primitive("tuple_len") | tuple_len = Primitive("tuple_len") | ||||
| tuple_reversed = Primitive("tuple_reversed") | tuple_reversed = Primitive("tuple_reversed") | ||||
| @@ -18,6 +18,7 @@ import pytest | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore import dtype as mstype | |||||
| from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | from ....mindspore_test_framework.mindspore_test import mindspore_test | ||||
| @@ -79,7 +80,102 @@ class NetWorkReduceToScalar(Cell): | |||||
| return ret | return ret | ||||
| class TensorAssignWithBoolTensorIndex(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithBoolTensorIndex, self).__init__() | |||||
| self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) | |||||
| def construct(self, a, b, c, u_tensor, _scalar): | |||||
| a[c] = u_scalar | |||||
| a[b] = u_tensor | |||||
| z = a + self.t | |||||
| return z | |||||
| class TensorAssignWithBoolTensorIndexError(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithBoolTensorIndexError, self).__init__() | |||||
| def construct(self, a, b, c, u_tensor): | |||||
| a[b][c] = u_tensor | |||||
| return a | |||||
| class TensorAssignWithBoolTensorIndex2(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithBoolTensorIndex2, self).__init__() | |||||
| self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) | |||||
| def construct(self, a, u_tensor, _scalar): | |||||
| a[a>8] = u_tensor | |||||
| a[a>=6] = u_scalar | |||||
| a[a<3] = u_scalar | |||||
| a[a<=5] = u_tensor | |||||
| a[a==5] = u_scalar | |||||
| z = a + self.t | |||||
| return z | |||||
| class TensorAssignWithBoolTensorIndex2Error(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithBoolTensorIndex2Error, self).__init__() | |||||
| def construct(self, a, u_tensor): | |||||
| a[a>8][a>5] = u_tensor | |||||
| return a | |||||
| a = np.random.uniform(1,10,[2,3]) | |||||
| b = a > 5 | |||||
| c = a < 3 | |||||
| Ta = Tensor(a) | |||||
| Tb = Tensor(b) | |||||
| Tc = Tensor(c) | |||||
| Td = Tensor([True, True]) | |||||
| u_tensor = Tensor([1]) | |||||
| u_tensor_error = Tensor([1, 2]) | |||||
| u_scalar = 5 | |||||
| def test_tensor_assign_bool_index(): | |||||
| net1 = TensorAssignWithBoolTensorIndex() | |||||
| net2 = TensorAssignWithBoolTensorIndex2() | |||||
| net1(Ta, Tb, Tc, u_tensor, u_scalar) | |||||
| with pytest.raises(ValueError): | |||||
| net1(Ta, Td, Tc, u_tensor, u_scalar) | |||||
| with pytest.raises(ValueError): | |||||
| net1(Ta, u_tensor, Tc, u_tensor, u_scalar) | |||||
| with pytest.raises(ValueError): | |||||
| net1(Ta, Tb, Td, u_tensor, u_scalar) | |||||
| with pytest.raises(ValueError): | |||||
| net1(Ta, Tb, Ta, u_tensor, u_scalar) | |||||
| with pytest.raises(ValueError): | |||||
| net1(Ta, Tb, Tc, u_tensor_error, u_scalar) | |||||
| #net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) | |||||
| with pytest.raises(ValueError): | |||||
| net2(Ta, u_tensor_error, u_scalar) | |||||
| net3 = TensorAssignWithBoolTensorIndexError() | |||||
| with pytest.raises(AttributeError): | |||||
| net3(Ta, Tb, Tc, u_tensor) | |||||
| with pytest.raises(AttributeError): | |||||
| net3(Ta, Tb, Tc, u_scalar) | |||||
| net4 = TensorAssignWithBoolTensorIndex2Error() | |||||
| with pytest.raises(AttributeError): | |||||
| net4(Ta, u_tensor) | |||||
| with pytest.raises(AttributeError): | |||||
| net4(Ta, u_scalar) | |||||
| test_cases = [ | test_cases = [ | ||||
| ('TensorAssignWithBoolTensorIndex', { | |||||
| 'block': TensorAssignWithBoolTensorIndex(), | |||||
| 'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar], | |||||
| }), | |||||
| ('TensorAssignWithBoolTensorIndex2', { | |||||
| 'block': TensorAssignWithBoolTensorIndex2(), | |||||
| 'desc_inputs': [Ta, u_tensor, u_scalar], | |||||
| }), | |||||
| ('SlicePositive', { | ('SlicePositive', { | ||||
| 'block': NetWorkSlicePositive(), | 'block': NetWorkSlicePositive(), | ||||
| 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))], | 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))], | ||||