|
|
@@ -23,137 +23,6 @@ from ..primitive import PrimitiveWithInfer, prim_attr_register |
|
|
from ..operations.math_ops import _infer_shape_reduce |
|
|
from ..operations.math_ops import _infer_shape_reduce |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StridedSliceAICPU(PrimitiveWithInfer): |
|
|
|
|
|
r""" |
|
|
|
|
|
|
|
|
|
|
|
Extracts a strided slice of a tensor. |
|
|
|
|
|
|
|
|
|
|
|
Given an input tensor, this operation inserts a dimension of length 1 at the dimension. |
|
|
|
|
|
This operation extracts a fragment of size (end-begin)/stride from the given |
|
|
|
|
|
'input_tensor'. Starting from the position specified by the begin, the fragment |
|
|
|
|
|
continues adding stride to the index until all dimensions are not less than end. |
|
|
|
|
|
|
|
|
|
|
|
Note: |
|
|
|
|
|
The stride may be negative value, which causes reverse slicing. |
|
|
|
|
|
The shape of `begin`, `end` and `strides` must be the same. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
begin_mask (int): Starting index of the slice. Default: 0. |
|
|
|
|
|
end_mask (int): Ending index of the slice. Default: 0. |
|
|
|
|
|
ellipsis_mask (int): An int mask. Default: 0. |
|
|
|
|
|
new_axis_mask (int): An int mask. Default: 0. |
|
|
|
|
|
shrink_axis_mask (int): An int mask. Default: 0. |
|
|
|
|
|
Currently all the masks are not in used. Use default 0 only. |
|
|
|
|
|
|
|
|
|
|
|
Inputs: |
|
|
|
|
|
- **input_x** (Tensor) - The input Tensor. |
|
|
|
|
|
- **begin** (tuple[int]) - A tuple which represents the location where to start. Only |
|
|
|
|
|
constant value is allowed. |
|
|
|
|
|
- **end** (tuple[int]) - A tuple or which represents the maximum location where to stop. |
|
|
|
|
|
Only constant value is allowed. |
|
|
|
|
|
- **strides** (tuple[int]) - A tuple which represents the stride continuously added |
|
|
|
|
|
before reach the maximum location. Only constant value is allowed. |
|
|
|
|
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
Tensor. |
|
|
|
|
|
Explain with the following example. |
|
|
|
|
|
- In the 0th dim, begin is 1, end is 2, and strides is 1, |
|
|
|
|
|
because :math:`1+1=2\geq2`, the interval is :math:`[1,2)`. |
|
|
|
|
|
Thus, return the element with :math:`index = 1` in 0th dim, i.e., [[3, 3, 3], [4, 4, 4]]. |
|
|
|
|
|
- In the 1st dim, similarly, the interval is :math:`[0,1)`. |
|
|
|
|
|
Based on the return value of the 0th dim, return the element with :math:`index = 0`, |
|
|
|
|
|
i.e., [3, 3, 3]. |
|
|
|
|
|
- In the 2nd dim, similarly, the interval is :math:`[0,3)`. |
|
|
|
|
|
Based on the return value of the 1st dim, return the element with :math:`index = 0,1,2`, |
|
|
|
|
|
i.e., [3, 3, 3]. |
|
|
|
|
|
- Finally, the output is [3, 3, 3]. |
|
|
|
|
|
|
|
|
|
|
|
Examples |
|
|
|
|
|
>>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], |
|
|
|
|
|
>>> [[5, 5, 5], [6, 6, 6]]], mindspore.float32) |
|
|
|
|
|
>>> slice = P.StridedSliceAICPU() |
|
|
|
|
|
>>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 2)) |
|
|
|
|
|
>>> output.shape |
|
|
|
|
|
(1, 1, 2) |
|
|
|
|
|
>>> output |
|
|
|
|
|
[[[3, 3]]] |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
|
|
def __init__(self, |
|
|
|
|
|
begin_mask=0, |
|
|
|
|
|
end_mask=0, |
|
|
|
|
|
ellipsis_mask=0, |
|
|
|
|
|
new_axis_mask=0, |
|
|
|
|
|
shrink_axis_mask=0): |
|
|
|
|
|
"""Initialize StrideSlice""" |
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output']) |
|
|
|
|
|
validator.check_value_type('begin_mask', begin_mask, [int], self.name) |
|
|
|
|
|
validator.check_value_type('end_mask', end_mask, [int], self.name) |
|
|
|
|
|
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name) |
|
|
|
|
|
validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name) |
|
|
|
|
|
validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name) |
|
|
|
|
|
|
|
|
|
|
|
def __infer__(self, x, begin, end, strides): |
|
|
|
|
|
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value'] |
|
|
|
|
|
validator.check_value_type("begin", begin_v, [tuple], self.name) |
|
|
|
|
|
validator.check_value_type("end", end_v, [tuple], self.name) |
|
|
|
|
|
validator.check_value_type("strides", strides_v, [tuple], self.name) |
|
|
|
|
|
|
|
|
|
|
|
x_shape = x['shape'] |
|
|
|
|
|
x_shp_len = len(x_shape) |
|
|
|
|
|
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len: |
|
|
|
|
|
raise ValueError(f"For \'{self.name}\' the length of begin index{begin_v}, end index{end_v} and " |
|
|
|
|
|
f"strides{strides_v} must be equal to the dims({x_shp_len}) of input.") |
|
|
|
|
|
|
|
|
|
|
|
ret_shape = [] |
|
|
|
|
|
append_dimensions = [] |
|
|
|
|
|
shrink_pos = bin(self.shrink_axis_mask)[::-1] |
|
|
|
|
|
new_pos = bin(self.new_axis_mask)[::-1] |
|
|
|
|
|
for i in range(x_shp_len): |
|
|
|
|
|
# After the integer is converted to binary, it is a str and the first two chars are the flag char '0b' |
|
|
|
|
|
if i < (len(new_pos) - 2) and new_pos[i] == '1': |
|
|
|
|
|
ret_shape.append(1) |
|
|
|
|
|
append_dimensions.append(x_shape[x_shp_len - 1 - len(append_dimensions)]) |
|
|
|
|
|
continue |
|
|
|
|
|
if i < (len(shrink_pos) - 2) and shrink_pos[i] == '1': |
|
|
|
|
|
validator.check_integer(f'begin[{i}]', begin_v[i], -x_shape[i], Rel.GE, self.name) |
|
|
|
|
|
validator.check_integer(f'begin[{i}]', begin_v[i], x_shape[i], Rel.LT, self.name) |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
begin_idx = begin_v[i] |
|
|
|
|
|
end_idx = end_v[i] |
|
|
|
|
|
strides_idx = strides_v[i] |
|
|
|
|
|
if self.begin_mask: |
|
|
|
|
|
begin_idx = 0 |
|
|
|
|
|
if self.end_mask: |
|
|
|
|
|
end_idx = x_shape[i] |
|
|
|
|
|
validator.check_integer(f'begin[{i}]', begin_idx, x_shape[i], Rel.LE, self.name) |
|
|
|
|
|
validator.check_integer(f'end[{i}]', end_idx, x_shape[i], Rel.LE, self.name) |
|
|
|
|
|
validator.check_integer(f'strides[{i}]', strides_idx, 0, Rel.NE, self.name) |
|
|
|
|
|
if strides_idx > 0: |
|
|
|
|
|
# If sliced forward , end_idx >= begin_idx |
|
|
|
|
|
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.LE) |
|
|
|
|
|
if begin_idx < 0 < end_idx: |
|
|
|
|
|
# Turn negative begin_idx into positive values |
|
|
|
|
|
begin_idx = x_shape[i] + begin_idx |
|
|
|
|
|
num_elems = (end_idx - begin_idx + strides_idx - 1) // strides_idx |
|
|
|
|
|
else: |
|
|
|
|
|
# If sliced backwards, end_idx <= begin_idx |
|
|
|
|
|
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.GE) |
|
|
|
|
|
if end_idx < 0 < begin_idx: |
|
|
|
|
|
# Turn negative end_idx into positive values |
|
|
|
|
|
end_idx = x_shape[i] + end_idx |
|
|
|
|
|
num_elems = (end_idx - begin_idx + strides_idx + 1) // strides_idx |
|
|
|
|
|
|
|
|
|
|
|
ret_shape.append(num_elems) |
|
|
|
|
|
if append_dimensions: |
|
|
|
|
|
ret_shape += append_dimensions[::-1] |
|
|
|
|
|
return {'shape': ret_shape, |
|
|
|
|
|
'dtype': x['dtype'], |
|
|
|
|
|
'value': None} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExtractImagePatches(PrimitiveWithInfer): |
|
|
class ExtractImagePatches(PrimitiveWithInfer): |
|
|
""" |
|
|
""" |
|
|
Extracts patches from images. |
|
|
Extracts patches from images. |
|
|
|