|
|
|
@@ -3259,8 +3259,8 @@ class SpaceToBatchND(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Args: |
|
|
|
block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value greater than 1. |
|
|
|
The length of `block_shape` is M correspoding to the number of spatial dimensions. |
|
|
|
paddings (list): The padding values for H and W dimension, containing M subtraction list. |
|
|
|
The length of `block_shape` is M correspoding to the number of spatial dimensions. M must be 2. |
|
|
|
paddings (list): The padding values for H and W dimension, containing 2 subtraction list. |
|
|
|
Each contains 2 integer value. All values must be greater than 0. |
|
|
|
`paddings[i]` specifies the paddings for the spatial dimension i, |
|
|
|
which corresponds to the input dimension i+2. |
|
|
|
@@ -3294,21 +3294,28 @@ class SpaceToBatchND(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, block_shape, paddings): |
|
|
|
"""Init SpaceToBatchND""" |
|
|
|
self.ori_block_shape = block_shape |
|
|
|
self.ori_paddings = paddings |
|
|
|
validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name) |
|
|
|
validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name) |
|
|
|
block_rank = len(block_shape) |
|
|
|
|
|
|
|
validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name) |
|
|
|
for elem in block_shape: |
|
|
|
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name) |
|
|
|
validator.check_value_type('block_shape element', elem, [int], self.name) |
|
|
|
|
|
|
|
self.block_shape = block_shape |
|
|
|
|
|
|
|
validator.check_value_type('paddings type', paddings, [list, tuple], self.name) |
|
|
|
validator.check('paddings length', len(paddings), '', 2, Rel.EQ, self.name) |
|
|
|
validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name) |
|
|
|
for elem in itertools.chain(*paddings): |
|
|
|
validator.check_integer('paddings element', elem, 0, Rel.GE, self.name) |
|
|
|
validator.check_value_type('paddings element', elem, [int], self.name) |
|
|
|
self.paddings = paddings |
|
|
|
block_shape_append = [1] + list(self.block_shape) |
|
|
|
self.add_prim_attr("block_shape", block_shape_append) |
|
|
|
paddings_append = [[0, 0]] + list(self.paddings) |
|
|
|
self.add_prim_attr("paddings", paddings_append) |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) |
|
|
|
@@ -3321,7 +3328,7 @@ class SpaceToBatchND(PrimitiveWithInfer): |
|
|
|
|
|
|
|
block_shape_prod = 1 |
|
|
|
offset = 2 |
|
|
|
if x_rank < 4: |
|
|
|
if x_rank <= 4: |
|
|
|
offset = 1 |
|
|
|
for i in range(len(self.block_shape)): |
|
|
|
padded = out_shape[i + offset] + self.paddings[i][0] + \ |
|
|
|
@@ -3345,7 +3352,7 @@ class BatchToSpaceND(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Args: |
|
|
|
block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value >= 1. |
|
|
|
The length of block_shape is M correspoding to the number of spatial dimensions. |
|
|
|
The length of block_shape is M correspoding to the number of spatial dimensions. M must be 2. |
|
|
|
crops (Union[list(int), tuple(int)]): The crop value for H and W dimension, containing 2 subtraction list, |
|
|
|
each containing 2 int value. |
|
|
|
All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to |
|
|
|
@@ -3380,22 +3387,28 @@ class BatchToSpaceND(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, block_shape, crops): |
|
|
|
"""Init BatchToSpaceND""" |
|
|
|
self.ori_block_shape = block_shape |
|
|
|
self.ori_crops = crops |
|
|
|
validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name) |
|
|
|
validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name) |
|
|
|
block_rank = len(block_shape) |
|
|
|
|
|
|
|
validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name) |
|
|
|
for elem in block_shape: |
|
|
|
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name) |
|
|
|
validator.check_value_type('block_shape element', elem, [int], self.name) |
|
|
|
|
|
|
|
self.block_shape = block_shape |
|
|
|
|
|
|
|
validator.check_value_type('crops type', crops, [list, tuple], self.name) |
|
|
|
validator.check('crops length', len(crops), '', 2, Rel.EQ, self.name) |
|
|
|
validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name) |
|
|
|
for elem in itertools.chain(*crops): |
|
|
|
validator.check_integer('crops element', elem, 0, Rel.GE, self.name) |
|
|
|
validator.check_value_type('crops element', elem, [int], self.name) |
|
|
|
self.crops = crops |
|
|
|
block_shape_append = [1] + list(self.block_shape) |
|
|
|
self.add_prim_attr("block_shape", block_shape_append) |
|
|
|
crops_append = [[0, 0]] + list(self.crops) |
|
|
|
self.add_prim_attr("crops", crops_append) |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) |
|
|
|
@@ -3408,7 +3421,7 @@ class BatchToSpaceND(PrimitiveWithInfer): |
|
|
|
|
|
|
|
block_shape_prod = 1 |
|
|
|
offset = 2 |
|
|
|
if x_rank < 4: |
|
|
|
if x_rank <= 4: |
|
|
|
offset = 1 |
|
|
|
for i in range(len(self.block_shape)): |
|
|
|
block_shape_prod = block_shape_prod * self.block_shape[i] |
|
|
|
@@ -3591,12 +3604,12 @@ class EditDistance(PrimitiveWithInfer): |
|
|
|
The shape of tensor is :math:`(N, R)`. |
|
|
|
- **hypothesis_values** (Tensor) - The values of the hypothesis list SparseTensor. |
|
|
|
Must be 1-D vector with length of N. |
|
|
|
- **hypothesis_shape** (Tensor) - The values of the hypothesis list SparseTensor. |
|
|
|
- **hypothesis_shape** (Tensor) - The shape of the hypothesis list SparseTensor. |
|
|
|
Must be R-length vector with int64 data type. Only constant value is allowed. |
|
|
|
- **truth_indices** (Tensor) - The indices of the truth list SparseTensor. With int64 data type. |
|
|
|
The shape of tensor is :math:`(M, R)`. |
|
|
|
- **truth_values** (Tensor) - The values of the truth list SparseTensor. Must be 1-D vector with length of M. |
|
|
|
- **truth_shape** (Tensor) - The values of the truth list SparseTensor. |
|
|
|
- **truth_shape** (Tensor) - The shape of the truth list SparseTensor. |
|
|
|
Must be R-length vector with int64 data type. Only constant value is allowed. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
|