| @@ -66,8 +66,8 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): | |||||
| tuple_len = len(tuple_index) | tuple_len = len(tuple_index) | ||||
| for i in range(tuple_len): | for i in range(tuple_len): | ||||
| if i in int_positions: | if i in int_positions: | ||||
| tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] + \ | |||||
| data_shape[i], mstype.int32),) | |||||
| tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] + | |||||
| data_shape[i], mstype.int32),) | |||||
| else: | else: | ||||
| tuple_index_new += (tuple_index[i],) | tuple_index_new += (tuple_index[i],) | ||||
| indexes_types = hyper_map(F.typeof, tuple_index_new) | indexes_types = hyper_map(F.typeof, tuple_index_new) | ||||
| @@ -95,24 +95,16 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): | |||||
| index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) | index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) | ||||
| for i in range(tuple_index_size): | for i in range(tuple_index_size): | ||||
| if i in tensor_positions: | if i in tensor_positions: | ||||
| transform_tensor = _transform_indexing_tensor(broadcast_shape, | |||||
| final_shape, | |||||
| index_tensor_new_shape, | |||||
| tuple_index_new[i]) | |||||
| transform_tensor = _transform_indexing_tensor( | |||||
| broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i]) | |||||
| final_index_tensors.append(transform_tensor) | final_index_tensors.append(transform_tensor) | ||||
| if i in slice_positions: | if i in slice_positions: | ||||
| slice_tensor = const_utils.convert_slice_to_tensor(slice_number, | |||||
| final_shape, | |||||
| indexes_shapes_info, | |||||
| op_name) | |||||
| slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name) | |||||
| final_index_tensors.append(slice_tensor) | final_index_tensors.append(slice_tensor) | ||||
| slice_number += 1 | slice_number += 1 | ||||
| if i == ellipsis_position: | if i == ellipsis_position: | ||||
| ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number, | |||||
| ellipsis_occupied_dims, | |||||
| final_shape, | |||||
| indexes_shapes_info, | |||||
| op_name) | |||||
| ellipsis_tensors = const_utils.convert_ellipsis_to_tensors( | |||||
| slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name) | |||||
| for ele in ellipsis_tensors: | for ele in ellipsis_tensors: | ||||
| final_index_tensors.append(ele) | final_index_tensors.append(ele) | ||||
| slice_number += ellipsis_occupied_dims | slice_number += ellipsis_occupied_dims | ||||
| @@ -266,12 +258,13 @@ def _tensor_index_by_tuple_slice(data, tuple_index): | |||||
| return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) | return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) | ||||
| def tensor_expand_dims(data, tuple_index): | |||||
| """Expand tensor dims by tuple contains None and replace the None by slice in tuple_index """ | |||||
| none_positions, tuple_index_without_none = const_utils.split_tuple_index_for_none(tuple_index) | |||||
| for position in none_positions: | |||||
| data = F.expand_dims(data, position) | |||||
| return data, tuple_index_without_none | |||||
| def tensor_index_by_list(data, list_index): | |||||
| """Tensor getitem by list of int and bool""" | |||||
| data_shape = F.shape(data) | |||||
| const_utils.check_list_index_type(list_index) | |||||
| list_index = const_utils.transform_list(list_index, data_shape[0]) | |||||
| tensor_index = const_utils.convert_list_to_tensor(list_index) | |||||
| return F.gather(data, tensor_index, 0) | |||||
| def tensor_index_by_tuple(data, tuple_index): | def tensor_index_by_tuple(data, tuple_index): | ||||
| @@ -128,12 +128,14 @@ def is_same_type(inst, type_): | |||||
| """ | """ | ||||
| return inst == type_ | return inst == type_ | ||||
| @constexpr | @constexpr | ||||
| def check_valid_dim(dim, name): | def check_valid_dim(dim, name): | ||||
| if dim not in (1, 2): | if dim not in (1, 2): | ||||
| raise ValueError( | raise ValueError( | ||||
| f"For {name}, inputs dim must be 1d or 2d") | f"For {name}, inputs dim must be 1d or 2d") | ||||
| @constexpr | @constexpr | ||||
| def check_valid_type(data_type, value_type, name): | def check_valid_type(data_type, value_type, name): | ||||
| if not data_type in value_type: | if not data_type in value_type: | ||||
| @@ -422,6 +424,42 @@ def compute_new_shape(origin_shape, indexes_shapes_info): | |||||
| return tuple(new_shape) | return tuple(new_shape) | ||||
| @constexpr | |||||
| def check_list_index_type(list_index): | |||||
| """check if the item's type of list_index is bool or int""" | |||||
| if not all([isinstance(index, (int, bool)) for index in list_index]): | |||||
| raise IndexError( | |||||
| f"Tensor only support 'integer' or 'boolean' array(list/tuple), but got {type(index)} in array") | |||||
| @constexpr | |||||
| def transform_list(list_index, shape): | |||||
| """transfor list_index from int or bool to int""" | |||||
| bool_count = len(list(filter(lambda index: isinstance(index, bool), list_index))) | |||||
| int_count = len(list(filter(lambda index: isinstance(index, int), list_index)))-bool_count | |||||
| if int_count == 0: | |||||
| if bool_count == shape: | |||||
| list_index = list(filter(lambda i: list_index[i], range(bool_count))) | |||||
| else: | |||||
| raise IndexError("The boolean array should have the same length with the corresponding dimensiton") | |||||
| else: | |||||
| list_index = [int(index) for index in list_index] | |||||
| for i, index in enumerate(list_index): | |||||
| if index < -shape or index >= shape: | |||||
| raise IndexError(f"The index should in the range [-{shape}, {shape-1}] to fit the corresponding dim " | |||||
| f"length, but get {index}.") | |||||
| if index < 0: | |||||
| index += shape | |||||
| list_index[i] = index | |||||
| return list_index | |||||
| @constexpr | |||||
| def convert_list_to_tensor(list_index): | |||||
| """convert the list_index to tensor_index with mstype.int64 dtype""" | |||||
| return Tensor(list_index, mstype.int64) | |||||
| @constexpr | @constexpr | ||||
| def convert_int_to_slice(tuple_indexes): | def convert_int_to_slice(tuple_indexes): | ||||
| tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes) | tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes) | ||||
| @@ -234,3 +234,18 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index): | |||||
| Tensor, same as data. | Tensor, same as data. | ||||
| """ | """ | ||||
| return data | return data | ||||
| @getitem.register("Tensor", "List") | |||||
| def _tensor_getitem_by_list(data, list_index): | |||||
| """ | |||||
| Getting item of tensor by list. | |||||
| Inputs: | |||||
| data (Tensor): A tensor | |||||
| list_index (List): A list object. | |||||
| Outputs: | |||||
| Tensor ,same as data. | |||||
| """ | |||||
| return compile_utils.tensor_index_by_list(data, list_index) | |||||
| @@ -0,0 +1,81 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test_tensor_slice """ | |||||
| import numpy as np | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore import dtype as mstype | |||||
| from mindspore.nn import Cell | |||||
| class NetWorkFancyIndexBoolean(Cell): | |||||
| def __init__(self, index): | |||||
| super(NetWorkFancyIndexBoolean, self).__init__() | |||||
| self.index = index | |||||
| def construct(self, tensor): | |||||
| return tensor[self.index] | |||||
| class NetWorkFancyIndexInterger(Cell): | |||||
| def __init__(self, index): | |||||
| super(NetWorkFancyIndexInterger, self).__init__() | |||||
| self.index = index | |||||
| def construct(self, tensor): | |||||
| return tensor[self.index] | |||||
| class NetWorkFancyIndexIntergerBooleanMixed(Cell): | |||||
| def __init__(self, index): | |||||
| super(NetWorkFancyIndexIntergerBooleanMixed, self).__init__() | |||||
| self.index = index | |||||
| def construct(self, tensor): | |||||
| return tensor[self.index] | |||||
| def test_tensor_fancy_index_integer_list(): | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| index = [0, 2, 1] | |||||
| net = NetWorkFancyIndexBoolean(index) | |||||
| input_np = np.arange(60).reshape(3, 4, 5) | |||||
| input_me = Tensor(input_np, dtype=mstype.float32) | |||||
| output_me = net(input_me).asnumpy() | |||||
| output_np = input_np[index] | |||||
| assert np.allclose(output_np, output_me, 0, 0) | |||||
| def test_tensor_fancy_boolean_list(): | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| index = [True, True, False] | |||||
| net = NetWorkFancyIndexInterger(index) | |||||
| input_np = np.arange(60).reshape(3, 4, 5) | |||||
| input_me = Tensor(input_np, dtype=mstype.float32) | |||||
| output_me = net(input_me).asnumpy() | |||||
| output_np = input_np[index] | |||||
| assert np.allclose(output_np, output_me, 0, 0) | |||||
| def test_tensor_fancy_integer_boolean_list(): | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| index = [1, 2, True, False] | |||||
| net = NetWorkFancyIndexIntergerBooleanMixed(index) | |||||
| input_np = np.arange(60).reshape(3, 4, 5) | |||||
| input_me = Tensor(input_np, dtype=mstype.float32) | |||||
| output_me = net(input_me).asnumpy() | |||||
| output_np = input_np[index] | |||||
| assert np.allclose(output_np, output_me, 0, 0) | |||||