|
|
|
@@ -2635,16 +2635,20 @@ class SpaceToBatchND(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
x_rank = len(x_shape) |
|
|
|
validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name) |
|
|
|
out_shape = copy.deepcopy(x_shape) |
|
|
|
|
|
|
|
block_shape_prod = 1 |
|
|
|
for i in range(x_rank - 2): |
|
|
|
padded = out_shape[i + 2] + self.paddings[i][0] + \ |
|
|
|
offset = 2 |
|
|
|
if x_rank < 4: |
|
|
|
offset = 1 |
|
|
|
for i in range(len(self.block_shape)): |
|
|
|
padded = out_shape[i + offset] + self.paddings[i][0] + \ |
|
|
|
self.paddings[i][1] |
|
|
|
if padded % self.block_shape[i] != 0: |
|
|
|
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' |
|
|
|
f'block_shape[{i}] {self.block_shape[i]}') |
|
|
|
out_shape[i + 2] = padded // self.block_shape[i] |
|
|
|
out_shape[i + offset] = padded // self.block_shape[i] |
|
|
|
block_shape_prod = block_shape_prod * self.block_shape[i] |
|
|
|
out_shape[0] *= block_shape_prod |
|
|
|
return out_shape |
|
|
|
@@ -2715,15 +2719,19 @@ class BatchToSpaceND(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
x_rank = len(x_shape) |
|
|
|
validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name) |
|
|
|
out_shape = copy.deepcopy(x_shape) |
|
|
|
|
|
|
|
block_shape_prod = 1 |
|
|
|
for i in range(x_rank - 2): |
|
|
|
offset = 2 |
|
|
|
if x_rank < 4: |
|
|
|
offset = 1 |
|
|
|
for i in range(len(self.block_shape)): |
|
|
|
block_shape_prod = block_shape_prod * self.block_shape[i] |
|
|
|
x_block_prod = out_shape[i + 2] * self.block_shape[i] |
|
|
|
x_block_prod = out_shape[i + offset] * self.block_shape[i] |
|
|
|
crops_sum = self.crops[i][0] + self.crops[i][1] |
|
|
|
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name) |
|
|
|
out_shape[i + 2] = x_block_prod - crops_sum |
|
|
|
out_shape[i + offset] = x_block_prod - crops_sum |
|
|
|
|
|
|
|
if out_shape[0] % block_shape_prod != 0: |
|
|
|
raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by ' |
|
|
|
|