Merge pull request !419 from candanzg/tensor_assign_bool_indextags/v0.2.0-alpha
| @@ -83,6 +83,7 @@ convert_object_map = { | |||
| T.mul: multitype_ops.mul, | |||
| T.truediv: multitype_ops.div, | |||
| T.getitem: multitype_ops.getitem, | |||
| T.setitem: multitype_ops.setitem, | |||
| T.floordiv: multitype_ops.floordiv, | |||
| T.mod: multitype_ops.mod, | |||
| T.pow: multitype_ops.pow_, | |||
| @@ -118,7 +119,6 @@ convert_object_map = { | |||
| T.iter: M.ms_iter, | |||
| T.next: M.ms_next, | |||
| T.hasnext: M.hasnext, | |||
| T.setitem: M.setitem, | |||
| T.make_tuple: F.make_tuple, | |||
| T.make_dict: F.make_dict, | |||
| @@ -23,6 +23,7 @@ from .pow_impl import pow_ | |||
| from .floordiv_impl import floordiv | |||
| from .mod_impl import mod | |||
| from .getitem_impl import getitem | |||
| from .setitem_impl import setitem | |||
| from .zeros_like_impl import zeros_like | |||
| from .ones_like_impl import ones_like | |||
| from .equal_impl import equal | |||
| @@ -55,6 +56,7 @@ __all__ = [ | |||
| 'greater_equal', | |||
| 'negative', | |||
| 'getitem', | |||
| 'setitem', | |||
| 'logical_and', | |||
| 'logical_or', | |||
| 'logical_not' | |||
| @@ -0,0 +1,45 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """constexpr util""" | |||
| from ...primitive import constexpr | |||
| @constexpr | |||
| def is_same_type(inst, type_): | |||
| """ | |||
| Check whether an object is an instance of a target type. | |||
| Inputs: | |||
| inst (mindspore.dtype): Inspected type. | |||
| type_ (mindspore.dtype): Target type. | |||
| Outputs: | |||
| bool, the check result. | |||
| """ | |||
| return inst == type_ | |||
| @constexpr | |||
| def error_msg(msg="", format_values=""): | |||
| """ | |||
| Used to throw exception information. | |||
| Inputs: | |||
| msg (str): information content. | |||
| """ | |||
| raise ValueError(msg.format(*format_values)) | |||
| @@ -0,0 +1,194 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Implementation for setitem.""" | |||
| from ...composite import base | |||
| from ....common import dtype as mstype | |||
| from ... import functional as F | |||
| from . import _multitype_ops_util as mult_util | |||
| setitem = base.MultitypeFuncGraph('setitem') | |||
| @setitem.register("List", "Number", "String") | |||
| def _list_setitem_with_string(data, number_index, value): | |||
| """ | |||
| Assign value to list. | |||
| Inputs: | |||
| data (list): Data of type lis. | |||
| number_index (Number): Index of data. | |||
| value (String): Value given. | |||
| Outputs: | |||
| List, type is same as the element type of data. | |||
| """ | |||
| return F.list_setitem(data, number_index, value) | |||
| @setitem.register("List", "Number", "Number") | |||
| def _list_setitem_with_number(data, number_index, value): | |||
| """ | |||
| Assign value to list. | |||
| Inputs: | |||
| data (list): Data of type lis. | |||
| number_index (Number): Index of data. | |||
| value (Number): Value given. | |||
| Outputs: | |||
| List, type is same as the element type of data. | |||
| """ | |||
| return F.list_setitem(data, number_index, value) | |||
| @setitem.register("List", "Number", "Tensor") | |||
| def _list_setitem_with_Tensor(data, number_index, value): | |||
| """ | |||
| Assign value to list. | |||
| Inputs: | |||
| data (list): Data of type lis. | |||
| number_index (Number): Index of data. | |||
| value (Tensor): Value given. | |||
| Outputs: | |||
| List, type is same as the element type of data. | |||
| """ | |||
| return F.list_setitem(data, number_index, value) | |||
| @setitem.register("List", "Number", "List") | |||
| def _list_setitem_with_List(data, number_index, value): | |||
| """ | |||
| Assign value to list. | |||
| Inputs: | |||
| data (list): Data of type lis. | |||
| number_index (Number): Index of data. | |||
| value (List): Value given. | |||
| Outputs: | |||
| List, type is same as the element type of data. | |||
| """ | |||
| return F.list_setitem(data, number_index, value) | |||
| @setitem.register("Dictionary", "String", "Tensor") | |||
| def _dict_setitem_with_tensor(data, key, value): | |||
| """ | |||
| Assign value to dictionary. | |||
| Inputs: | |||
| data (Dictionary): Data of type dict. | |||
| key (str): Key of the data. | |||
| value (Tensor): Value given. | |||
| Outputs: | |||
| Dict, type is as same as the element type of data. | |||
| """ | |||
| return F.dict_setitem(data, key, value) | |||
| @setitem.register("Dictionary", "String", "Number") | |||
| def _dict_setitem_with_number(data, key, value): | |||
| """ | |||
| Assign value to dictionary. | |||
| Inputs: | |||
| data (Dictionary): Data of type dict. | |||
| key (str): Key of the data. | |||
| value (Number): Value given. | |||
| Outputs: | |||
| Dict, type is as same as the element type of data. | |||
| """ | |||
| return F.dict_setitem(data, key, value) | |||
| @setitem.register("Tensor", "Tensor", "Tensor") | |||
| def _tensor_setitem_by_tensor_v1(data, index, value_tensor): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[B] = U and A[A>n] = U. | |||
| Restraint condition: 1) A, U is a Tensor, and B is a bool Tensor. | |||
| 2) A.shape == B.shape | |||
| 3) U.size == 1 | |||
| 4) n is a number | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Tensor): Tensor of bool type. | |||
| value_tensor (Tensor): Tensor with size 1. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| index_dtype = F.dtype(index) | |||
| index_shape = F.shape(index) | |||
| is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) | |||
| if not is_bool: | |||
| return mult_util.error_msg( | |||
| "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) | |||
| data_shape = F.shape(data) | |||
| if index_shape != data_shape: | |||
| return mult_util.error_msg( | |||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (data_shape, index_shape)) | |||
| size = F.size(value_tensor) | |||
| if size != 1: | |||
| return mult_util.error_msg( | |||
| "When assign value is a tensor, its size should be 1, but current size is {}.", (size,)) | |||
| dtype = F.dtype(data) | |||
| u_cast = F.cast(value_tensor, dtype) | |||
| one_data = F.ones_like(data) | |||
| u = F.tensor_mul(one_data, u_cast) | |||
| return F.select(index, u, data) | |||
| @setitem.register("Tensor", "Tensor", "Number") | |||
| def _tensor_setitem_by_tensor_v2(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[B] = u and A[A>n] = u. | |||
| Restraint condition: 1) A is a Tensor, and B is a bool Tensor. | |||
| 2) A.shape == B.shape | |||
| 3) u is a scalar | |||
| 4) n is a number | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Tensor): Tensor of bool type. | |||
| value_tensor (Number): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| index_dtype = F.dtype(index) | |||
| index_shape = F.shape(index) | |||
| is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) | |||
| if not is_bool: | |||
| return mult_util.error_msg( | |||
| "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) | |||
| shape = F.shape(data) | |||
| if index_shape != shape: | |||
| return mult_util.error_msg( | |||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (shape, index_shape)) | |||
| dtype = F.dtype(data) | |||
| u = F.fill(dtype, shape, value) | |||
| return F.select(index, u, data) | |||
| @@ -31,6 +31,9 @@ dtype = P.DType() | |||
| issubclass_ = P.IsSubClass() | |||
| isinstance_ = P.IsInstance() | |||
| fill = P.Fill() | |||
| select = P.Select() | |||
| size = P.Size() | |||
| ones_like = P.OnesLike() | |||
| shape = P.Shape() | |||
| rank = P.Rank() | |||
| reshape = P.Reshape() | |||
| @@ -68,7 +71,9 @@ scalar_cast = P.ScalarCast() | |||
| tuple_setitem = Primitive('tuple_setitem') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| list_getitem = Primitive('list_getitem') | |||
| list_setitem = Primitive('list_setitem') | |||
| dict_getitem = Primitive('dict_getitem') | |||
| dict_setitem = Primitive('dict_setitem') | |||
| tuple_div = Primitive("tuple_div") | |||
| tuple_len = Primitive("tuple_len") | |||
| tuple_reversed = Primitive("tuple_reversed") | |||
| @@ -18,6 +18,7 @@ import pytest | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore import dtype as mstype | |||
| from mindspore.nn import Cell | |||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||
| @@ -79,7 +80,102 @@ class NetWorkReduceToScalar(Cell): | |||
| return ret | |||
| class TensorAssignWithBoolTensorIndex(Cell): | |||
| def __init__(self): | |||
| super(TensorAssignWithBoolTensorIndex, self).__init__() | |||
| self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) | |||
| def construct(self, a, b, c, u_tensor, _scalar): | |||
| a[c] = u_scalar | |||
| a[b] = u_tensor | |||
| z = a + self.t | |||
| return z | |||
| class TensorAssignWithBoolTensorIndexError(Cell): | |||
| def __init__(self): | |||
| super(TensorAssignWithBoolTensorIndexError, self).__init__() | |||
| def construct(self, a, b, c, u_tensor): | |||
| a[b][c] = u_tensor | |||
| return a | |||
| class TensorAssignWithBoolTensorIndex2(Cell): | |||
| def __init__(self): | |||
| super(TensorAssignWithBoolTensorIndex2, self).__init__() | |||
| self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) | |||
| def construct(self, a, u_tensor, _scalar): | |||
| a[a>8] = u_tensor | |||
| a[a>=6] = u_scalar | |||
| a[a<3] = u_scalar | |||
| a[a<=5] = u_tensor | |||
| a[a==5] = u_scalar | |||
| z = a + self.t | |||
| return z | |||
| class TensorAssignWithBoolTensorIndex2Error(Cell): | |||
| def __init__(self): | |||
| super(TensorAssignWithBoolTensorIndex2Error, self).__init__() | |||
| def construct(self, a, u_tensor): | |||
| a[a>8][a>5] = u_tensor | |||
| return a | |||
| a = np.random.uniform(1,10,[2,3]) | |||
| b = a > 5 | |||
| c = a < 3 | |||
| Ta = Tensor(a) | |||
| Tb = Tensor(b) | |||
| Tc = Tensor(c) | |||
| Td = Tensor([True, True]) | |||
| u_tensor = Tensor([1]) | |||
| u_tensor_error = Tensor([1, 2]) | |||
| u_scalar = 5 | |||
| def test_tensor_assign_bool_index(): | |||
| net1 = TensorAssignWithBoolTensorIndex() | |||
| net2 = TensorAssignWithBoolTensorIndex2() | |||
| net1(Ta, Tb, Tc, u_tensor, u_scalar) | |||
| with pytest.raises(ValueError): | |||
| net1(Ta, Td, Tc, u_tensor, u_scalar) | |||
| with pytest.raises(ValueError): | |||
| net1(Ta, u_tensor, Tc, u_tensor, u_scalar) | |||
| with pytest.raises(ValueError): | |||
| net1(Ta, Tb, Td, u_tensor, u_scalar) | |||
| with pytest.raises(ValueError): | |||
| net1(Ta, Tb, Ta, u_tensor, u_scalar) | |||
| with pytest.raises(ValueError): | |||
| net1(Ta, Tb, Tc, u_tensor_error, u_scalar) | |||
| #net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) | |||
| with pytest.raises(ValueError): | |||
| net2(Ta, u_tensor_error, u_scalar) | |||
| net3 = TensorAssignWithBoolTensorIndexError() | |||
| with pytest.raises(AttributeError): | |||
| net3(Ta, Tb, Tc, u_tensor) | |||
| with pytest.raises(AttributeError): | |||
| net3(Ta, Tb, Tc, u_scalar) | |||
| net4 = TensorAssignWithBoolTensorIndex2Error() | |||
| with pytest.raises(AttributeError): | |||
| net4(Ta, u_tensor) | |||
| with pytest.raises(AttributeError): | |||
| net4(Ta, u_scalar) | |||
| test_cases = [ | |||
| ('TensorAssignWithBoolTensorIndex', { | |||
| 'block': TensorAssignWithBoolTensorIndex(), | |||
| 'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar], | |||
| }), | |||
| ('TensorAssignWithBoolTensorIndex2', { | |||
| 'block': TensorAssignWithBoolTensorIndex2(), | |||
| 'desc_inputs': [Ta, u_tensor, u_scalar], | |||
| }), | |||
| ('SlicePositive', { | |||
| 'block': NetWorkSlicePositive(), | |||
| 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))], | |||