| @@ -25,8 +25,8 @@ batch_to_space_nd_op_info = TBERegOp("BatchToSpaceND") \ | |||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .attr("block_shape", "required", "listInt", "all") \ | .attr("block_shape", "required", "listInt", "all") \ | ||||
| .attr("crops", "required", "listListInt", "all") \ | .attr("crops", "required", "listListInt", "all") \ | ||||
| .input(0, "x", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .input(0, "x", False, "required", "all", reshape_type="NH") \ | |||||
| .output(0, "y", False, "required", "all", reshape_type="NH") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @@ -27,6 +27,8 @@ conv2d_backprop_filter_op_info = TBERegOp("Conv2DBackpropFilter") \ | |||||
| .attr("stride", "required", "listInt", "all") \ | .attr("stride", "required", "listInt", "all") \ | ||||
| .attr("pad_list", "required", "listInt", "all") \ | .attr("pad_list", "required", "listInt", "all") \ | ||||
| .attr("dilation", "required", "listInt", "all") \ | .attr("dilation", "required", "listInt", "all") \ | ||||
| .attr("groups", "optional", "int", "all") \ | |||||
| .attr("data_format", "optional", "str", "all") \ | |||||
| .input(0, "out_backprop", False, "required", "all") \ | .input(0, "out_backprop", False, "required", "all") \ | ||||
| .input(1, "x", False, "required", "all") \ | .input(1, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| @@ -25,8 +25,8 @@ space_to_batch_nd_op_info = TBERegOp("SpaceToBatchND") \ | |||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .attr("block_shape", "required", "listInt", "all") \ | .attr("block_shape", "required", "listInt", "all") \ | ||||
| .attr("paddings", "required", "listListInt", "all") \ | .attr("paddings", "required", "listListInt", "all") \ | ||||
| .input(0, "x", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .input(0, "x", False, "required", "all", reshape_type="NH") \ | |||||
| .output(0, "y", False, "required", "all", reshape_type="NH") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @@ -237,6 +237,7 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): | |||||
| self.add_prim_attr('stride', self.stride) | self.add_prim_attr('stride', self.stride) | ||||
| self.dilation = dilation | self.dilation = dilation | ||||
| self.group = group | self.group = group | ||||
| self.add_prim_attr('groups', group) | |||||
| self.add_prim_attr('data_format', "NCHW") | self.add_prim_attr('data_format', "NCHW") | ||||
| def __infer__(self, doutput, x, w_size): | def __infer__(self, doutput, x, w_size): | ||||
| @@ -2635,16 +2635,20 @@ class SpaceToBatchND(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| x_rank = len(x_shape) | x_rank = len(x_shape) | ||||
| validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name) | |||||
| out_shape = copy.deepcopy(x_shape) | out_shape = copy.deepcopy(x_shape) | ||||
| block_shape_prod = 1 | block_shape_prod = 1 | ||||
| for i in range(x_rank - 2): | |||||
| padded = out_shape[i + 2] + self.paddings[i][0] + \ | |||||
| offset = 2 | |||||
| if x_rank < 4: | |||||
| offset = 1 | |||||
| for i in range(len(self.block_shape)): | |||||
| padded = out_shape[i + offset] + self.paddings[i][0] + \ | |||||
| self.paddings[i][1] | self.paddings[i][1] | ||||
| if padded % self.block_shape[i] != 0: | if padded % self.block_shape[i] != 0: | ||||
| raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' | raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' | ||||
| f'block_shape[{i}] {self.block_shape[i]}') | f'block_shape[{i}] {self.block_shape[i]}') | ||||
| out_shape[i + 2] = padded // self.block_shape[i] | |||||
| out_shape[i + offset] = padded // self.block_shape[i] | |||||
| block_shape_prod = block_shape_prod * self.block_shape[i] | block_shape_prod = block_shape_prod * self.block_shape[i] | ||||
| out_shape[0] *= block_shape_prod | out_shape[0] *= block_shape_prod | ||||
| return out_shape | return out_shape | ||||
| @@ -2715,15 +2719,19 @@ class BatchToSpaceND(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| x_rank = len(x_shape) | x_rank = len(x_shape) | ||||
| validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name) | |||||
| out_shape = copy.deepcopy(x_shape) | out_shape = copy.deepcopy(x_shape) | ||||
| block_shape_prod = 1 | block_shape_prod = 1 | ||||
| for i in range(x_rank - 2): | |||||
| offset = 2 | |||||
| if x_rank < 4: | |||||
| offset = 1 | |||||
| for i in range(len(self.block_shape)): | |||||
| block_shape_prod = block_shape_prod * self.block_shape[i] | block_shape_prod = block_shape_prod * self.block_shape[i] | ||||
| x_block_prod = out_shape[i + 2] * self.block_shape[i] | |||||
| x_block_prod = out_shape[i + offset] * self.block_shape[i] | |||||
| crops_sum = self.crops[i][0] + self.crops[i][1] | crops_sum = self.crops[i][0] + self.crops[i][1] | ||||
| validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name) | validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name) | ||||
| out_shape[i + 2] = x_block_prod - crops_sum | |||||
| out_shape[i + offset] = x_block_prod - crops_sum | |||||
| if out_shape[0] % block_shape_prod != 0: | if out_shape[0] % block_shape_prod != 0: | ||||
| raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by ' | raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by ' | ||||