Browse Source

complete fancy index getitem

tags/v1.1.0
Payne 5 years ago
parent
commit
4c8f0914d0
4 changed files with 148 additions and 21 deletions
  1. +14
    -21
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  2. +38
    -0
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py
  3. +15
    -0
      mindspore/ops/composite/multitype_ops/getitem_impl.py
  4. +81
    -0
      tests/ut/python/ops/ test_tensor_fancy_index.py

+ 14
- 21
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -66,8 +66,8 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
tuple_len = len(tuple_index) tuple_len = len(tuple_index)
for i in range(tuple_len): for i in range(tuple_len):
if i in int_positions: if i in int_positions:
tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] + \
data_shape[i], mstype.int32),)
tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] +
data_shape[i], mstype.int32),)
else: else:
tuple_index_new += (tuple_index[i],) tuple_index_new += (tuple_index[i],)
indexes_types = hyper_map(F.typeof, tuple_index_new) indexes_types = hyper_map(F.typeof, tuple_index_new)
@@ -95,24 +95,16 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_size): for i in range(tuple_index_size):
if i in tensor_positions: if i in tensor_positions:
transform_tensor = _transform_indexing_tensor(broadcast_shape,
final_shape,
index_tensor_new_shape,
tuple_index_new[i])
transform_tensor = _transform_indexing_tensor(
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i])
final_index_tensors.append(transform_tensor) final_index_tensors.append(transform_tensor)
if i in slice_positions: if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number,
final_shape,
indexes_shapes_info,
op_name)
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
final_index_tensors.append(slice_tensor) final_index_tensors.append(slice_tensor)
slice_number += 1 slice_number += 1
if i == ellipsis_position: if i == ellipsis_position:
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number,
ellipsis_occupied_dims,
final_shape,
indexes_shapes_info,
op_name)
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(
slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name)
for ele in ellipsis_tensors: for ele in ellipsis_tensors:
final_index_tensors.append(ele) final_index_tensors.append(ele)
slice_number += ellipsis_occupied_dims slice_number += ellipsis_occupied_dims
@@ -266,12 +258,13 @@ def _tensor_index_by_tuple_slice(data, tuple_index):
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)




def tensor_expand_dims(data, tuple_index):
"""Expand tensor dims by tuple contains None and replace the None by slice in tuple_index """
none_positions, tuple_index_without_none = const_utils.split_tuple_index_for_none(tuple_index)
for position in none_positions:
data = F.expand_dims(data, position)
return data, tuple_index_without_none
def tensor_index_by_list(data, list_index):
"""Tensor getitem by list of int and bool"""
data_shape = F.shape(data)
const_utils.check_list_index_type(list_index)
list_index = const_utils.transform_list(list_index, data_shape[0])
tensor_index = const_utils.convert_list_to_tensor(list_index)
return F.gather(data, tensor_index, 0)




def tensor_index_by_tuple(data, tuple_index): def tensor_index_by_tuple(data, tuple_index):


+ 38
- 0
mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -128,12 +128,14 @@ def is_same_type(inst, type_):
""" """
return inst == type_ return inst == type_



@constexpr @constexpr
def check_valid_dim(dim, name): def check_valid_dim(dim, name):
if dim not in (1, 2): if dim not in (1, 2):
raise ValueError( raise ValueError(
f"For {name}, inputs dim must be 1d or 2d") f"For {name}, inputs dim must be 1d or 2d")



@constexpr @constexpr
def check_valid_type(data_type, value_type, name): def check_valid_type(data_type, value_type, name):
if not data_type in value_type: if not data_type in value_type:
@@ -422,6 +424,42 @@ def compute_new_shape(origin_shape, indexes_shapes_info):
return tuple(new_shape) return tuple(new_shape)




@constexpr
def check_list_index_type(list_index):
"""check if the item's type of list_index is bool or int"""
if not all([isinstance(index, (int, bool)) for index in list_index]):
raise IndexError(
f"Tensor only support 'integer' or 'boolean' array(list/tuple), but got {type(index)} in array")


@constexpr
def transform_list(list_index, shape):
"""transfor list_index from int or bool to int"""
bool_count = len(list(filter(lambda index: isinstance(index, bool), list_index)))
int_count = len(list(filter(lambda index: isinstance(index, int), list_index)))-bool_count
if int_count == 0:
if bool_count == shape:
list_index = list(filter(lambda i: list_index[i], range(bool_count)))
else:
raise IndexError("The boolean array should have the same length with the corresponding dimensiton")
else:
list_index = [int(index) for index in list_index]
for i, index in enumerate(list_index):
if index < -shape or index >= shape:
raise IndexError(f"The index should in the range [-{shape}, {shape-1}] to fit the corresponding dim "
f"length, but get {index}.")
if index < 0:
index += shape
list_index[i] = index
return list_index


@constexpr
def convert_list_to_tensor(list_index):
"""convert the list_index to tensor_index with mstype.int64 dtype"""
return Tensor(list_index, mstype.int64)


@constexpr @constexpr
def convert_int_to_slice(tuple_indexes): def convert_int_to_slice(tuple_indexes):
tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes) tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes)


+ 15
- 0
mindspore/ops/composite/multitype_ops/getitem_impl.py View File

@@ -234,3 +234,18 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
Tensor, same as data. Tensor, same as data.
""" """
return data return data


@getitem.register("Tensor", "List")
def _tensor_getitem_by_list(data, list_index):
"""
Getting item of tensor by list.

Inputs:
data (Tensor): A tensor
list_index (List): A list object.

Outputs:
Tensor ,same as data.
"""
return compile_utils.tensor_index_by_list(data, list_index)

+ 81
- 0
tests/ut/python/ops/ test_tensor_fancy_index.py View File

@@ -0,0 +1,81 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_tensor_slice """
import numpy as np

from mindspore import Tensor
from mindspore import context
from mindspore import dtype as mstype
from mindspore.nn import Cell


class NetWorkFancyIndexBoolean(Cell):
def __init__(self, index):
super(NetWorkFancyIndexBoolean, self).__init__()
self.index = index

def construct(self, tensor):
return tensor[self.index]


class NetWorkFancyIndexInterger(Cell):
def __init__(self, index):
super(NetWorkFancyIndexInterger, self).__init__()
self.index = index

def construct(self, tensor):
return tensor[self.index]


class NetWorkFancyIndexIntergerBooleanMixed(Cell):
def __init__(self, index):
super(NetWorkFancyIndexIntergerBooleanMixed, self).__init__()
self.index = index

def construct(self, tensor):
return tensor[self.index]


def test_tensor_fancy_index_integer_list():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = [0, 2, 1]
net = NetWorkFancyIndexBoolean(index)
input_np = np.arange(60).reshape(3, 4, 5)
input_me = Tensor(input_np, dtype=mstype.float32)
output_me = net(input_me).asnumpy()
output_np = input_np[index]
assert np.allclose(output_np, output_me, 0, 0)


def test_tensor_fancy_boolean_list():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = [True, True, False]
net = NetWorkFancyIndexInterger(index)
input_np = np.arange(60).reshape(3, 4, 5)
input_me = Tensor(input_np, dtype=mstype.float32)
output_me = net(input_me).asnumpy()
output_np = input_np[index]
assert np.allclose(output_np, output_me, 0, 0)


def test_tensor_fancy_integer_boolean_list():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = [1, 2, True, False]
net = NetWorkFancyIndexIntergerBooleanMixed(index)
input_np = np.arange(60).reshape(3, 4, 5)
input_me = Tensor(input_np, dtype=mstype.float32)
output_me = net(input_me).asnumpy()
output_np = input_np[index]
assert np.allclose(output_np, output_me, 0, 0)

Loading…
Cancel
Save