|
|
|
@@ -1904,7 +1904,7 @@ class MaxPoolWithArgmax(_Pool): |
|
|
|
|
|
|
|
class MaxPool3D(PrimitiveWithInfer): |
|
|
|
r""" |
|
|
|
Max pooling operation. |
|
|
|
3D max pooling operation. |
|
|
|
|
|
|
|
Applies a 3D max pooling over an input Tensor which can be regarded as a composition of 3D planes. |
|
|
|
|
|
|
|
@@ -1947,7 +1947,7 @@ class MaxPool3D(PrimitiveWithInfer): |
|
|
|
TypeError: If `pad_mode` or `data_format` is not a string. |
|
|
|
ValueError: If numbers in `kernel_size` or `strides` are not positive. |
|
|
|
ValueError: If `pad_mode` is not one of 'same', 'valid'. |
|
|
|
ValueError: If `kernel_size` or `strides` is a tuple whose length is not equal to 3 or 5. |
|
|
|
ValueError: If `kernel_size` or `strides` is a tuple whose length is not equal to 3. |
|
|
|
ValueError: If `data_format` is not 'NCDHW'. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
@@ -1971,9 +1971,10 @@ class MaxPool3D(PrimitiveWithInfer): |
|
|
|
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name) |
|
|
|
self.add_prim_attr("pad_mode", self.pad_mode) |
|
|
|
self.data_format = validator.check_string(data_format, ['NCDHW'], 'data_format', self.name) |
|
|
|
self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name, allow_five=True, ret_five=True) |
|
|
|
self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name, |
|
|
|
allow_five=False, ret_five=True) |
|
|
|
self.add_prim_attr("kernel_size", self.kernel_size) |
|
|
|
self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True) |
|
|
|
self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=False, ret_five=True) |
|
|
|
self.add_prim_attr("strides", self.strides) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
@@ -2274,7 +2275,7 @@ class BiasAdd(PrimitiveWithCheck): |
|
|
|
self.add_prim_attr('data_format', self.format) |
|
|
|
|
|
|
|
def check_shape(self, x_shape, b_shape): |
|
|
|
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) |
|
|
|
validator.check_int_range(len(x_shape), 2, 5, Rel.INC_BOTH, "x rank", self.name) |
|
|
|
if self.format == "NCDHW" and (len(x_shape) != 5 or context.get_context("device_target") != "Ascend"): |
|
|
|
raise ValueError("NCDHW format only support 5-dims input in Ascend target.") |
|
|
|
validator.check_equal_int(len(b_shape), 1, "bias rank", self.name) |
|
|
|
|