|
|
|
@@ -161,6 +161,13 @@ def tensor_index_by_slice(data, slice_index): |
|
|
|
|
|
|
|
def tensor_index_by_number(data, number): |
|
|
|
"""Tensor getitem by a Number which may be integer/float/bool value""" |
|
|
|
data_type = F.typeof(data) |
|
|
|
if const_utils.judge_index_type(data_type, mstype.tensor_type): |
|
|
|
data_shape = F.shape(data) |
|
|
|
data_rank = len(data_shape) |
|
|
|
min_data_rank, max_data_rank = 0, 8 |
|
|
|
const_utils.judge_data_rank(data_rank, min_data_rank, max_data_rank) |
|
|
|
|
|
|
|
number_type = const_utils.check_number_index_type(number) |
|
|
|
if number_type == const_utils.BOOL_: |
|
|
|
return _tensor_index_by_bool(data, number) |
|
|
|
|