|
|
|
@@ -37,6 +37,7 @@ from ..._c_expression import signature_kind as sig_kind |
|
|
|
from ..._c_expression import signature_dtype as sig_dtype |
|
|
|
from ..._c_expression import typing |
|
|
|
|
|
|
|
|
|
|
|
def _check_infer_attr_reduce(axis, keep_dims, prim_name): |
|
|
|
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) |
|
|
|
validator.check_value_type('axis', axis, [int, tuple], prim_name) |
|
|
|
@@ -193,7 +194,7 @@ class Cast(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) |
|
|
|
|
|
|
|
def check_elim(self, x, dtype): |
|
|
|
if isinstance(x, (Tensor, numbers.Number, Parameter)): |
|
|
|
if isinstance(x, (Tensor, numbers.Number, Parameter)): |
|
|
|
if isinstance(x, Tensor) and x.dtype == dtype: |
|
|
|
return (True, x) |
|
|
|
if isinstance(x, numbers.Number): |
|
|
|
@@ -987,10 +988,10 @@ class InvertPermutation(PrimitiveWithInfer): |
|
|
|
z.sort() |
|
|
|
|
|
|
|
for i in range(1, len(z)): |
|
|
|
if z[i-1] == z[i]: |
|
|
|
if z[i - 1] == z[i]: |
|
|
|
raise ValueError(f"For {self.name}, {z[i]} is duplicated in the input.") |
|
|
|
validator.check(f'value min', min(x_value), '', 0, Rel.EQ, self.name) |
|
|
|
validator.check(f'value max', max(x_value), '', len(x_value)-1, Rel.EQ, self.name) |
|
|
|
validator.check(f'value max', max(x_value), '', len(x_value) - 1, Rel.EQ, self.name) |
|
|
|
|
|
|
|
y = [None] * len(x_value) |
|
|
|
for i, value in enumerate(x_value): |
|
|
|
@@ -1693,6 +1694,57 @@ class Select(PrimitiveWithInfer): |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def _compute_slicing_length(begin, end, stride, x_shape, i): |
|
|
|
"""Compute the length of the slicing.""" |
|
|
|
if i >= len(x_shape): |
|
|
|
raise ValueError(f"For 'StridedSlice', When their is no new axis, the index length must be less or " |
|
|
|
f"equal than the dim of x.") |
|
|
|
x_dim = x_shape[i] |
|
|
|
if stride > 0: |
|
|
|
# When slicing forward, convert begin and end to positive numbers. |
|
|
|
if begin >= x_dim or end < -x_dim: |
|
|
|
# When slicing forward, if begin >= x_dim or end < -x_dim, the length of the slicing is 0. |
|
|
|
slicing_length = 0 |
|
|
|
else: |
|
|
|
if -x_dim <= begin < 0: |
|
|
|
begin += x_dim |
|
|
|
if begin < -x_dim: |
|
|
|
# When slicing forward, if begin < -x_dim, set begin = 0, which means start from the 0th element. |
|
|
|
begin = 0 |
|
|
|
if -x_dim <= end < 0: |
|
|
|
end += x_dim |
|
|
|
if end > x_dim: |
|
|
|
# When slicing forward, if end > x_dim, set end = x_dims, which means slice to the last element. |
|
|
|
end = x_dim |
|
|
|
if begin >= end: |
|
|
|
# When slicing forward, if begin >= end, the length of the slicing is 0. |
|
|
|
slicing_length = 0 |
|
|
|
else: |
|
|
|
slicing_length = 1 + (end - 1 - begin) // stride |
|
|
|
else: |
|
|
|
# When slicing backward, convert begin and end to negative numbers. |
|
|
|
if begin < -x_dim or end >= x_dim: |
|
|
|
# When slicing backward, if begin < -x_dim or end >= x_dim, the length of the slicing is 0. |
|
|
|
slicing_length = 0 |
|
|
|
else: |
|
|
|
if 0 <= begin < x_dim: |
|
|
|
begin += -x_dim |
|
|
|
if begin >= x_dim: |
|
|
|
# When slicing backward, if begin >= x_dim, set begin = -1, which means start from the last element. |
|
|
|
begin = -1 |
|
|
|
if 0 < end < x_dim: |
|
|
|
end += -x_dim |
|
|
|
if end < -x_dim - 1: |
|
|
|
# When slicing backward, if end < -x_dim - 1, set end = -x_dim - 1, which means |
|
|
|
# slicing to the 0th element. |
|
|
|
end = -x_dim - 1 |
|
|
|
if begin <= end: |
|
|
|
# When slicing backward, if begin <= end, the length of the slicing is 0. |
|
|
|
slicing_length = 0 |
|
|
|
else: |
|
|
|
slicing_length = 1 + (end + 1 - begin) // stride |
|
|
|
return slicing_length |
|
|
|
|
|
|
|
|
|
|
|
class StridedSlice(PrimitiveWithInfer): |
|
|
|
r""" |
|
|
|
@@ -1756,13 +1808,15 @@ class StridedSlice(PrimitiveWithInfer): |
|
|
|
ellipsis_mask=0, |
|
|
|
new_axis_mask=0, |
|
|
|
shrink_axis_mask=0): |
|
|
|
"""init StrideSlice""" |
|
|
|
"""Init StrideSlice""" |
|
|
|
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output']) |
|
|
|
validator.check_value_type('begin_mask', begin_mask, [int], self.name) |
|
|
|
validator.check_value_type('end_mask', end_mask, [int], self.name) |
|
|
|
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name) |
|
|
|
validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name) |
|
|
|
validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name) |
|
|
|
validator.check_integer('begin_mask', begin_mask, 0, Rel.GE, self.name) |
|
|
|
validator.check_integer('end_mask', end_mask, 0, Rel.GE, self.name) |
|
|
|
validator.check_integer('ellipsis_mask', ellipsis_mask, 0, Rel.GE, self.name) |
|
|
|
if len(tuple(filter(lambda x: x == '1', bin(ellipsis_mask)[-1:1:-1]))) > 1: |
|
|
|
raise ValueError(f"For '{self.name}', only support one ellipsis in the index, but got {end_mask}.") |
|
|
|
validator.check_integer('new_axis_mask', new_axis_mask, 0, Rel.GE, self.name) |
|
|
|
validator.check_integer('shrink_axis_mask', shrink_axis_mask, 0, Rel.GE, self.name) |
|
|
|
|
|
|
|
def __infer__(self, x, begin, end, strides): |
|
|
|
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value'] |
|
|
|
@@ -1770,58 +1824,103 @@ class StridedSlice(PrimitiveWithInfer): |
|
|
|
validator.check_value_type("end", end_v, [tuple], self.name) |
|
|
|
validator.check_value_type("strides", strides_v, [tuple], self.name) |
|
|
|
|
|
|
|
x_shape = x['shape'] |
|
|
|
x_shp_len = len(x_shape) |
|
|
|
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len: |
|
|
|
raise ValueError(f"For \'{self.name}\' the length of begin index{begin_v}, end index{end_v} and " |
|
|
|
f"strides{strides_v} must be equal to the dims({x_shp_len}) of input.") |
|
|
|
if tuple(filter(lambda x: not isinstance(x, int), begin_v + end_v + strides_v)): |
|
|
|
raise ValueError(f"For {self.name}, both the begins, ends, and strides must be a tuple of int, " |
|
|
|
f"but got begins: {begin_v}, ends: {end_v}, strides: {strides_v}.") |
|
|
|
|
|
|
|
ret_shape = [] |
|
|
|
append_dimensions = [] |
|
|
|
shrink_pos = bin(self.shrink_axis_mask)[::-1] |
|
|
|
new_pos = bin(self.new_axis_mask)[::-1] |
|
|
|
for i in range(x_shp_len): |
|
|
|
# After the integer is converted to binary, it is a str and the first two chars are the flag char '0b' |
|
|
|
if i < (len(new_pos) - 2) and new_pos[i] == '1': |
|
|
|
ret_shape.append(1) |
|
|
|
append_dimensions.append(x_shape[x_shp_len - 1 - len(append_dimensions)]) |
|
|
|
continue |
|
|
|
if i < (len(shrink_pos) - 2) and shrink_pos[i] == '1': |
|
|
|
validator.check_integer(f'begin[{i}]', begin_v[i], -x_shape[i], Rel.GE, self.name) |
|
|
|
validator.check_integer(f'begin[{i}]', begin_v[i], x_shape[i], Rel.LT, self.name) |
|
|
|
continue |
|
|
|
|
|
|
|
begin_idx = begin_v[i] |
|
|
|
end_idx = end_v[i] |
|
|
|
strides_idx = strides_v[i] |
|
|
|
if self.begin_mask: |
|
|
|
begin_idx = 0 |
|
|
|
if self.end_mask: |
|
|
|
end_idx = x_shape[i] |
|
|
|
validator.check_integer(f'begin[{i}]', begin_idx, x_shape[i], Rel.LE, self.name) |
|
|
|
validator.check_integer(f'end[{i}]', end_idx, x_shape[i], Rel.LE, self.name) |
|
|
|
validator.check_integer(f'strides[{i}]', strides_idx, 0, Rel.NE, self.name) |
|
|
|
if strides_idx > 0: |
|
|
|
# If sliced forward , end_idx >= begin_idx |
|
|
|
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.LE) |
|
|
|
if begin_idx < 0 < end_idx: |
|
|
|
# Turn negative begin_idx into positive values |
|
|
|
begin_idx = x_shape[i] + begin_idx |
|
|
|
num_elems = (end_idx - begin_idx + strides_idx - 1) // strides_idx |
|
|
|
else: |
|
|
|
# If sliced backwards, end_idx <= begin_idx |
|
|
|
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.GE) |
|
|
|
if end_idx < 0 < begin_idx: |
|
|
|
# Turn negative end_idx into positive values |
|
|
|
end_idx = x_shape[i] + end_idx |
|
|
|
num_elems = (end_idx - begin_idx + strides_idx + 1) // strides_idx |
|
|
|
|
|
|
|
ret_shape.append(num_elems) |
|
|
|
if append_dimensions: |
|
|
|
ret_shape += append_dimensions[::-1] |
|
|
|
if tuple(filter(lambda x: x == 0, strides_v)): |
|
|
|
raise ValueError(f"For '{self.name}', the strides cannot contain 0, but got strides: {strides_v}.") |
|
|
|
|
|
|
|
if len(end_v) != len(begin_v) or len(strides_v) != len(begin_v): |
|
|
|
raise ValueError(f"For '{self.name}' the length of begin index: {begin_v}, end index: {end_v} and " |
|
|
|
f"strides: {strides_v} must be equal.") |
|
|
|
|
|
|
|
ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v) |
|
|
|
|
|
|
|
value = None if all(ret_shape) else Tensor(np.array([]).reshape(ret_shape), x['dtype'].element_type()) |
|
|
|
return {'shape': ret_shape, |
|
|
|
'dtype': x['dtype'], |
|
|
|
'value': None} |
|
|
|
'value': value} |
|
|
|
|
|
|
|
def _compute_slicing_shape(self, x_shape, begin_v, end_v, strides_v): |
|
|
|
"""Compute the shape of the slicing.""" |
|
|
|
x_rank = len(x_shape) |
|
|
|
slice_len = len(begin_v) |
|
|
|
|
|
|
|
# After the integer is converted to binary, it is a str and the first two chars are the flag char '0b'. |
|
|
|
begin_pos = bin(self.begin_mask)[-1:1:-1] |
|
|
|
end_pos = bin(self.end_mask)[-1:1:-1] |
|
|
|
ellipsis_pos = bin(self.ellipsis_mask)[-1:1:-1] |
|
|
|
new_axis_pos = bin(self.new_axis_mask)[-1:1:-1] |
|
|
|
shrink_axis_pos = bin(self.shrink_axis_mask)[-1:1:-1] |
|
|
|
|
|
|
|
ret_shape = [] |
|
|
|
i, j = 0, 0 |
|
|
|
has_ellipsis = False |
|
|
|
while i < x_rank or j < slice_len: |
|
|
|
if j < slice_len: |
|
|
|
begin, end, stride = begin_v[j], end_v[j], strides_v[j] |
|
|
|
|
|
|
|
if j < len(ellipsis_pos) and ellipsis_pos[j] == '1': |
|
|
|
# When there is ellipsis, the latter part of the ellipsis will be processed separately. |
|
|
|
has_ellipsis = True |
|
|
|
break |
|
|
|
if j < len(begin_pos) and begin_pos[j] == '1': |
|
|
|
begin = -1 if strides_v[j] < 0 else 0 |
|
|
|
if j < len(end_pos) and end_pos[j] == '1': |
|
|
|
end = -(x_shape[i] + 1) if strides_v[j] < 0 else x_shape[i] |
|
|
|
if j < len(new_axis_pos) and new_axis_pos[j] == '1': |
|
|
|
ret_shape.append(1) |
|
|
|
j += 1 |
|
|
|
continue |
|
|
|
if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1': |
|
|
|
if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0: |
|
|
|
raise ValueError(f"For {self.name}, when shrink axis, the stride cannot be negative number, " |
|
|
|
f"and begin should be in [-{x_shape[i]}, {x_shape[i]}), " |
|
|
|
f"but got stride: {stride}, begin: {begin}.") |
|
|
|
j += 1 |
|
|
|
i += 1 |
|
|
|
continue |
|
|
|
else: |
|
|
|
begin, end, stride = 0, x_shape[i], 1 |
|
|
|
|
|
|
|
slicing_length = _compute_slicing_length(begin, end, stride, x_shape, i) |
|
|
|
ret_shape.append(slicing_length) |
|
|
|
i += 1 |
|
|
|
j += 1 |
|
|
|
if has_ellipsis: |
|
|
|
# When there is ellipsis, handle the second half of the ellipsis split. |
|
|
|
ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \ |
|
|
|
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len]))) |
|
|
|
ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims]) |
|
|
|
j += 1 |
|
|
|
i += ellipsis_occupied_dims |
|
|
|
|
|
|
|
while i < x_rank or j < slice_len: |
|
|
|
begin, end, stride = begin_v[j], end_v[j], strides_v[j] |
|
|
|
|
|
|
|
if j < len(begin_pos) and begin_pos[j] == '1': |
|
|
|
begin = -1 if strides_v[j] < 0 else 0 |
|
|
|
if j < len(end_pos) and end_pos[j] == '1': |
|
|
|
end = -(x_shape[i] + 1) if strides_v[j] < 0 else x_shape[i] |
|
|
|
if j < len(new_axis_pos) and new_axis_pos[j] == '1': |
|
|
|
ret_shape.append(1) |
|
|
|
j += 1 |
|
|
|
continue |
|
|
|
if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1': |
|
|
|
if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0: |
|
|
|
raise ValueError(f"For {self.name}, when shrink axis, the stride cannot be negative number, " |
|
|
|
f"and begin should be in [-{x_shape[i]}, {x_shape[i]}), " |
|
|
|
f"but got stride: {stride}, begin: {begin}.") |
|
|
|
j += 1 |
|
|
|
i += 1 |
|
|
|
continue |
|
|
|
|
|
|
|
slicing_length = _compute_slicing_length(begin, end, stride, x_shape, i) |
|
|
|
ret_shape.append(slicing_length) |
|
|
|
i += 1 |
|
|
|
j += 1 |
|
|
|
return ret_shape |
|
|
|
|
|
|
|
|
|
|
|
class Diag(PrimitiveWithInfer): |
|
|
|
@@ -2102,6 +2201,7 @@ class TensorScatterUpdate(PrimitiveWithInfer): |
|
|
|
>>> op = P.TensorScatterUpdate() |
|
|
|
>>> output = op(input_x, indices, update) |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self): |
|
|
|
"""Init TensorScatterUpdate""" |
|
|
|
@@ -2153,6 +2253,7 @@ class ScatterUpdate(PrimitiveWithInfer): |
|
|
|
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), |
|
|
|
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) |
|
|
|
) |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, use_locking=True): |
|
|
|
"""Init ScatterUpdate""" |
|
|
|
@@ -2201,6 +2302,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): |
|
|
|
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), |
|
|
|
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) |
|
|
|
) |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, use_locking=True): |
|
|
|
"""Init ScatterNdUpdate""" |
|
|
|
@@ -2220,6 +2322,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): |
|
|
|
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name): |
|
|
|
if updates_shape and updates_shape != indices_shape + x_shape[1:]: |
|
|
|
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or " |
|
|
|
@@ -2912,6 +3015,7 @@ class InplaceUpdate(PrimitiveWithInfer): |
|
|
|
[ 4. 5.] |
|
|
|
[ 6. 7.]]] |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, indices): |
|
|
|
"""Init InplaceUpdate""" |
|
|
|
|