Browse Source

!2828 fix tensor index when there is only one element in tuple

Merge pull request !2828 from zhangbuxue/fix_tensor_index_when_there_is_only_one_element_in_tuple
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
dd666ec3c5
1 changed files with 11 additions and 0 deletions
  1. +11
    -0
      mindspore/ops/composite/multitype_ops/_compile_utils.py

+ 11
- 0
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -267,6 +267,8 @@ def _tensor_index_by_tuple_slice(data, t):

def tensor_index_by_tuple(data, tuple_index):
"""Tensor getitem by tuple of various types"""
if len(tuple_index) == 1:
return data[tuple_index[0]]
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
if index_elements_type == const_utils.NO_TENSOR:
@@ -430,6 +432,9 @@ def tensor_setitem_by_slice_with_number(data, input_slice, value):

def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
"""Assigns the tensor by tuple with number value."""
if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)

@@ -489,6 +494,9 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):

def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
"""Assigns the tensor by tuple with tensor value."""
if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)

@@ -509,6 +517,9 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):

def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
"""Assigns the tensor by tuple with tuple of value."""
if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)



Loading…
Cancel
Save