|
|
|
@@ -156,7 +156,6 @@ def generate_updates_from_tensor(data, index, value, op_type): |
|
|
|
return value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_getitem(self, index): |
|
|
|
"""Handle tensor getitem""" |
|
|
|
if isinstance(index, Tensor): |
|
|
|
@@ -164,16 +163,15 @@ def tensor_getitem(self, index): |
|
|
|
if isinstance(index, tuple): |
|
|
|
return tensor_index_by_tuple(self, index) |
|
|
|
if isinstance(index, int): |
|
|
|
return tensor_index_by_number(self, index) |
|
|
|
return tensor_index_by_integer(self, index) |
|
|
|
if isinstance(index, slice): |
|
|
|
return tensor_index_by_slice(self, index) |
|
|
|
if isinstance(index, bool): |
|
|
|
return tensor_index_by_bool(self, index) |
|
|
|
if index is ...: |
|
|
|
return self |
|
|
|
raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32,\ |
|
|
|
got {} with type{}".format(index, type(index))) |
|
|
|
|
|
|
|
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, " |
|
|
|
f"got {index} with type {type(index)}.") |
|
|
|
|
|
|
|
|
|
|
|
tensor_operator_registry.register("__getitem__", tensor_getitem) |
|
|
|
@@ -199,13 +197,19 @@ def tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): |
|
|
|
|
|
|
|
def tensor_index_by_slice(data, slice_index): |
|
|
|
"""Tensor getitem by a single slice""" |
|
|
|
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(F.shape(data), slice_index) |
|
|
|
shape = F.shape(data) |
|
|
|
if not shape: |
|
|
|
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor cannot be 0.") |
|
|
|
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index) |
|
|
|
return F.strided_slice(data, begin_strides, end_strides, step_strides) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_integer(data, number): |
|
|
|
"""Tensor getitem by a single integer number""" |
|
|
|
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(F.shape(data), number) |
|
|
|
shape = F.shape(data) |
|
|
|
if not shape: |
|
|
|
const_utils.raise_index_error("When tensor is indexed by an integer, the dimension of the tensor cannot be 0.") |
|
|
|
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(shape, number) |
|
|
|
shrink_axis_mask = 1 |
|
|
|
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) |
|
|
|
|
|
|
|
@@ -214,7 +218,7 @@ def tensor_index_by_bool(data, bool_value): |
|
|
|
"""Tensor getitem by a single bool value""" |
|
|
|
if bool_value: |
|
|
|
return F.expand_dims(data, 0) |
|
|
|
return const_utils.raise_index_error("bool value as indexing ,false is not supported") |
|
|
|
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.") |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_number(data, number): |
|
|
|
@@ -224,7 +228,7 @@ def tensor_index_by_number(data, number): |
|
|
|
return tensor_index_by_bool(data, number) |
|
|
|
if number_type == const_utils.INT_: |
|
|
|
return tensor_index_by_integer(data, number) |
|
|
|
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool") |
|
|
|
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_tensor(data, tensor_index): |
|
|
|
@@ -233,13 +237,18 @@ def tensor_index_by_tensor(data, tensor_index): |
|
|
|
const_utils.TENSOR_GETITEM) |
|
|
|
if dtype_valid: |
|
|
|
return F.gather(data, tensor_index, 0) |
|
|
|
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool") |
|
|
|
return const_utils.raise_index_error("For 'tensor getitem', " |
|
|
|
"the index tensor data type only support mstype.int32.") |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_tuple_slice(data, t): |
|
|
|
"""Tensor getitem by a tuple of slice""" |
|
|
|
shape = F.shape(data) |
|
|
|
if len(t) > len(shape): |
|
|
|
const_utils.raise_index_error("When tensor is indexed by a tuple, " |
|
|
|
"the length of the tuple cannot be greater than the dimension of the tensor.") |
|
|
|
begin_strides, end_strides, step_strides, shrink_axis_mask = \ |
|
|
|
const_utils.get_stride_info_from_tuple(F.shape(data), t) |
|
|
|
const_utils.get_stride_info_from_tuple(shape, t) |
|
|
|
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) |
|
|
|
|
|
|
|
|
|
|
|
|