Browse Source

!15208 tensor setitem fix value with expanded dimensions before ellipsis & fix broadcast with value ndim < slice ndim

From: @jachua
Reviewed-by: @guoqi1024,@liangchenghui
Signed-off-by: @liangchenghui
pull/15208/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
121f8a7289
3 changed files with 41 additions and 15 deletions
  1. +29
    -15
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  2. +11
    -0
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py
  3. +1
    -0
      tests/st/pynative/test_tensor_setitem.py

+ 29
- 15
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -72,6 +72,7 @@ def _broadcast(broadcast_shape, x):
return x
multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape)
if multiples:
x = F.reshape(x, const_utils.expanded_shape(F.shape(x), len(multiples) - F.rank(x)))
return F.tile(x, multiples)
return x

@@ -794,29 +795,42 @@ def ignore_dim_expand(idx):
def remove_ignored_dim(idx, value_shape, data_rank):
"""Removes dimensions in value that correspond to dimension expansion flags in index."""
has_ellipsis = False
has_true = False
has_leading_true = False
has_trailing_true = False
cnt_leading_expanded = 0
cnt_trailing_expanded = 0
cnt_not_dim_expand = 0
for i in idx:
if not i is True and not i is None:
cnt_not_dim_expand += 1
if const_utils.is_ellipsis(i):
has_ellipsis = True
elif has_ellipsis:
if i is None:
if i is True:
if has_ellipsis:
has_trailing_true = True
else:
has_leading_true = True
elif i is None:
if has_ellipsis:
cnt_trailing_expanded += 1
elif i is True and not has_true:
has_true = True
if has_true and cnt_not_dim_expand + 1 < data_rank:
cnt_trailing_expanded += 1
else:
cnt_leading_expanded += 1
else:
if const_utils.is_ellipsis(i):
has_ellipsis = True
cnt_not_dim_expand += 1
if cnt_not_dim_expand + 1 < data_rank:
if has_leading_true:
cnt_leading_expanded += 1
elif has_trailing_true:
cnt_trailing_expanded += 1

value_starting_pos = 0
while cnt_leading_expanded > 0 and value_shape[value_starting_pos] == 1:
value_starting_pos += 1
cnt_leading_expanded -= 1

if cnt_trailing_expanded == 0:
return value_shape
value_expanded_pos = len(value_shape) - cnt_trailing_expanded
value_expanded_not_unit = False
for i in value_shape[value_expanded_pos:]:
for i in const_utils.tuple_slice(value_shape, value_expanded_pos, None):
if i != 1:
value_expanded_not_unit = True
if value_expanded_pos < 0 or value_expanded_not_unit:
const_utils.raise_value_error('shape mismatch')
return value_shape[:value_expanded_pos]
return const_utils.tuple_slice(value_shape, value_starting_pos, value_expanded_pos)

+ 11
- 0
mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -830,3 +830,14 @@ def normalize_stop(stop, dim_size):
@constexpr
def is_ellipsis(x):
return x is Ellipsis


@constexpr
def tuple_slice(tup, start, end):
"""get sliced tuple from start and end."""
return tup[start:end]


@constexpr
def expanded_shape(shape, expand_size):
return (1,)*expand_size + shape

+ 1
- 0
tests/st/pynative/test_tensor_setitem.py View File

@@ -129,6 +129,7 @@ def test_setitem_by_tuple_with_list():
x[0, True, 0, None, True] = [-2, -2, -2, -2]
x[0, ..., None] = [[-3], [-3], [-3], [-3]]
x[..., 0, None, 1, True, True, None] = [[[-4]], [[-4]]]
x[None, True, [1, 0], (False, True, True), [2]] = [[2, 3]]
return x
setup_testcase(x, cases)



Loading…
Cancel
Save