|
|
|
@@ -72,6 +72,7 @@ def _broadcast(broadcast_shape, x): |
|
|
|
return x |
|
|
|
multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape) |
|
|
|
if multiples: |
|
|
|
x = F.reshape(x, const_utils.expanded_shape(F.shape(x), len(multiples) - F.rank(x))) |
|
|
|
return F.tile(x, multiples) |
|
|
|
return x |
|
|
|
|
|
|
|
@@ -794,29 +795,42 @@ def ignore_dim_expand(idx): |
|
|
|
def remove_ignored_dim(idx, value_shape, data_rank): |
|
|
|
"""Removes dimensions in value that correspond to dimension expansion flags in index.""" |
|
|
|
has_ellipsis = False |
|
|
|
has_true = False |
|
|
|
has_leading_true = False |
|
|
|
has_trailing_true = False |
|
|
|
cnt_leading_expanded = 0 |
|
|
|
cnt_trailing_expanded = 0 |
|
|
|
cnt_not_dim_expand = 0 |
|
|
|
for i in idx: |
|
|
|
if not i is True and not i is None: |
|
|
|
cnt_not_dim_expand += 1 |
|
|
|
if const_utils.is_ellipsis(i): |
|
|
|
has_ellipsis = True |
|
|
|
elif has_ellipsis: |
|
|
|
if i is None: |
|
|
|
if i is True: |
|
|
|
if has_ellipsis: |
|
|
|
has_trailing_true = True |
|
|
|
else: |
|
|
|
has_leading_true = True |
|
|
|
elif i is None: |
|
|
|
if has_ellipsis: |
|
|
|
cnt_trailing_expanded += 1 |
|
|
|
elif i is True and not has_true: |
|
|
|
has_true = True |
|
|
|
if has_true and cnt_not_dim_expand + 1 < data_rank: |
|
|
|
cnt_trailing_expanded += 1 |
|
|
|
else: |
|
|
|
cnt_leading_expanded += 1 |
|
|
|
else: |
|
|
|
if const_utils.is_ellipsis(i): |
|
|
|
has_ellipsis = True |
|
|
|
cnt_not_dim_expand += 1 |
|
|
|
if cnt_not_dim_expand + 1 < data_rank: |
|
|
|
if has_leading_true: |
|
|
|
cnt_leading_expanded += 1 |
|
|
|
elif has_trailing_true: |
|
|
|
cnt_trailing_expanded += 1 |
|
|
|
|
|
|
|
value_starting_pos = 0 |
|
|
|
while cnt_leading_expanded > 0 and value_shape[value_starting_pos] == 1: |
|
|
|
value_starting_pos += 1 |
|
|
|
cnt_leading_expanded -= 1 |
|
|
|
|
|
|
|
if cnt_trailing_expanded == 0: |
|
|
|
return value_shape |
|
|
|
value_expanded_pos = len(value_shape) - cnt_trailing_expanded |
|
|
|
value_expanded_not_unit = False |
|
|
|
for i in value_shape[value_expanded_pos:]: |
|
|
|
for i in const_utils.tuple_slice(value_shape, value_expanded_pos, None): |
|
|
|
if i != 1: |
|
|
|
value_expanded_not_unit = True |
|
|
|
if value_expanded_pos < 0 or value_expanded_not_unit: |
|
|
|
const_utils.raise_value_error('shape mismatch') |
|
|
|
return value_shape[:value_expanded_pos] |
|
|
|
return const_utils.tuple_slice(value_shape, value_starting_pos, value_expanded_pos) |