|
|
|
@@ -5081,12 +5081,12 @@ class SpaceToDepth(PrimitiveWithInfer): |
|
|
|
"""Initialize SpaceToDepth""" |
|
|
|
self.init_prim_io_names(inputs=['x'], outputs=['y']) |
|
|
|
validator.check_value_type('block_size', block_size, [int], self.name) |
|
|
|
validator.check('block_size', block_size, '', 2, Rel.GE) |
|
|
|
validator.check('block_size', block_size, self.name, 2, Rel.GE) |
|
|
|
self.block_size = block_size |
|
|
|
self.add_prim_attr("data_format", "NCHW") |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
validator.check('x dimension', len(x_shape), '', 4, Rel.EQ) |
|
|
|
validator.check('x dimension', len(x_shape), self.name, 4, Rel.EQ) |
|
|
|
out_shape = copy.deepcopy(x_shape) |
|
|
|
for i in range(2): |
|
|
|
if out_shape[i + 2] % self.block_size != 0: |
|
|
|
@@ -5218,9 +5218,9 @@ class SpaceToBatch(PrimitiveWithInfer): |
|
|
|
logger.warning("WARN_DEPRECATED: The usage of SpaceToBatch is deprecated." |
|
|
|
" Please use SpaceToBatchND.") |
|
|
|
validator.check_value_type('block_size', block_size, [int], self.name) |
|
|
|
validator.check('block_size', block_size, '', 2, Rel.GE, self.name) |
|
|
|
validator.check('block_size', block_size, self.name, 2, Rel.GE, self.name) |
|
|
|
self.block_size = block_size |
|
|
|
validator.check('paddings shape', np.array(paddings).shape, '', (2, 2), Rel.EQ, self.name) |
|
|
|
validator.check('paddings shape', np.array(paddings).shape, self.name, (2, 2), Rel.EQ, self.name) |
|
|
|
for elem in itertools.chain(*paddings): |
|
|
|
validator.check_non_negative_int(elem, 'paddings element', self.name) |
|
|
|
validator.check_value_type('paddings element', elem, [int], self.name) |
|
|
|
@@ -5308,7 +5308,7 @@ class BatchToSpace(PrimitiveWithInfer): |
|
|
|
validator.check('block_size', block_size, '', 2, Rel.GE, self.name) |
|
|
|
self.block_size = block_size |
|
|
|
validator.check_value_type('crops type', crops, [list, tuple], self.name) |
|
|
|
validator.check('crops shape', np.array(crops).shape, '', (2, 2)) |
|
|
|
validator.check('crops shape', np.array(crops).shape, self.name, (2, 2)) |
|
|
|
for elem in itertools.chain(*crops): |
|
|
|
validator.check_non_negative_int(elem, 'crops element', self.name) |
|
|
|
validator.check_value_type('crops element', elem, [int], self.name) |
|
|
|
@@ -5319,7 +5319,7 @@ class BatchToSpace(PrimitiveWithInfer): |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
validator.check('rank of input_x', len(x_shape), '', 4) |
|
|
|
validator.check('rank of input_x', len(x_shape), self.name, 4) |
|
|
|
out_shape = copy.deepcopy(x_shape) |
|
|
|
for i in range(2): |
|
|
|
x_block_prod = out_shape[i + 2] * self.block_size |
|
|
|
|