|
|
|
@@ -271,7 +271,7 @@ def tensor_index_by_tuple(data, tuple_index): |
|
|
|
if len(tuple_index) == 1: |
|
|
|
return data[tuple_index[0]] |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) |
|
|
|
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM) |
|
|
|
if index_elements_type == const_utils.NO_TENSOR: |
|
|
|
return _tensor_index_by_tuple_slice(data, tuple_index) |
|
|
|
if index_elements_type == const_utils.ALL_TENSOR: |
|
|
|
@@ -437,13 +437,16 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value): |
|
|
|
data[tuple_index[0]] = value |
|
|
|
return data |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
|
|
if index_elements_type == const_utils.ALL_TENSOR: |
|
|
|
indices = _generate_indices_from_tuple_of_tensor(data, |
|
|
|
tuple_index, |
|
|
|
const_utils.TENSOR_SETITEM) |
|
|
|
else: |
|
|
|
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
if int_cnt == const_utils.ALL_INT: |
|
|
|
tuple_index = const_utils.convert_int_to_slice(tuple_index) |
|
|
|
indices = _generate_indices_from_tuple_of_mixed_tensors(data, |
|
|
|
tuple_index, |
|
|
|
const_utils.TENSOR_SETITEM) |
|
|
|
@@ -498,16 +501,37 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): |
|
|
|
if len(tuple_index) == 1: |
|
|
|
data[tuple_index[0]] = value |
|
|
|
return data |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
data_shape = data.shape |
|
|
|
tuple_index_new = () |
|
|
|
for i, index in enumerate(tuple_index): |
|
|
|
if isinstance(index, mstype.Int): |
|
|
|
if index < -data_shape[i] or index >= data_shape[i]: |
|
|
|
const_utils.raise_index_error("The index is out of the data's special dimension range.") |
|
|
|
elif index < 0: |
|
|
|
tuple_index_new += (tuple_index[i]+data_shape[i],) |
|
|
|
else: |
|
|
|
tuple_index_new += (tuple_index[i],) |
|
|
|
else: |
|
|
|
tuple_index_new += (tuple_index[i],) |
|
|
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index_new) |
|
|
|
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
|
|
if index_elements_type == const_utils.ALL_TENSOR: |
|
|
|
indices = _generate_indices_from_tuple_of_tensor(data, |
|
|
|
tuple_index, |
|
|
|
tuple_index_new, |
|
|
|
const_utils.TENSOR_SETITEM) |
|
|
|
else: |
|
|
|
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
if int_cnt == const_utils.ALL_INT: |
|
|
|
tuple_index_new = const_utils.convert_int_to_slice(tuple_index_new) |
|
|
|
new_shape = () |
|
|
|
for _ in tuple_index_new: |
|
|
|
new_shape += (1,) |
|
|
|
new_shape += value.shape |
|
|
|
value = F.reshape(value, new_shape) |
|
|
|
indices = _generate_indices_from_tuple_of_mixed_tensors(data, |
|
|
|
tuple_index, |
|
|
|
tuple_index_new, |
|
|
|
const_utils.TENSOR_SETITEM) |
|
|
|
updates = _generate_updates_from_tensor(data, |
|
|
|
indices, |
|
|
|
@@ -522,13 +546,16 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): |
|
|
|
data[tuple_index[0]] = value |
|
|
|
return data |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
|
|
if index_elements_type == const_utils.ALL_TENSOR: |
|
|
|
indices = _generate_indices_from_tuple_of_tensor(data, |
|
|
|
tuple_index, |
|
|
|
const_utils.TENSOR_SETITEM) |
|
|
|
else: |
|
|
|
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
if int_cnt == const_utils.ALL_INT: |
|
|
|
tuple_index = const_utils.convert_int_to_slice(tuple_index) |
|
|
|
indices = _generate_indices_from_tuple_of_mixed_tensors(data, |
|
|
|
tuple_index, |
|
|
|
const_utils.TENSOR_SETITEM) |
|
|
|
|