Browse Source

Adapt MaxPool3DGrad

pull/15169/head
liuxiao93 4 years ago
parent
commit
2b3d7845da
2 changed files with 2 additions and 5 deletions
  1. +2
    -1
      mindspore/ops/_op_impl/tbe/max_pool3d_grad.py
  2. +0
    -4
      mindspore/ops/operations/_grad_ops.py

+ 2
- 1
mindspore/ops/_op_impl/tbe/max_pool3d_grad.py View File

@@ -26,7 +26,8 @@ max_pool3d_grad_op_info = TBERegOp("MaxPool3DGrad") \
.partial_flag(True) \
.attr("kernel_size", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("pad_list", "required", "listInt", "all") \
.attr("pad_mode", "optional", "str", "all") \
.attr("pad_list", "required", "listInt", "all", "0,0,0") \
.attr("format", "optional", "str", "all") \
.input(0, "orig_x", False, "required", "all") \
.input(1, "orig_y", False, "required", "all") \


+ 0
- 4
mindspore/ops/operations/_grad_ops.py View File

@@ -977,10 +977,6 @@ class MaxPool3DGrad(PrimitiveWithInfer):

def infer_shape(self, x_shape, y_shape, grad_shape):
validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
pad_list = _get_max_pool3d_grad_pads_by_pad_mode(x_shape, self.kernel_size, self.strides, self.pad_mode)
for pad in pad_list:
validator.check_non_negative_int(pad, 'element of pad_list', self.name)
self.add_prim_attr("pad_list", pad_list)
return x_shape

def infer_dtype(self, x_dtype, y_dtype, grad_dtype):


Loading…
Cancel
Save