|
|
|
@@ -268,13 +268,8 @@ class AvgPool1d(_PoolNd): |
|
|
|
ParamValidator.check_type('kernel_size', kernel_size, [int,]) |
|
|
|
ParamValidator.check_type('stride', stride, [int,]) |
|
|
|
self.pad_mode = ParamValidator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) |
|
|
|
if not isinstance(kernel_size, int): |
|
|
|
ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE) |
|
|
|
raise ValueError("kernel_size should be 1 int number but got {}". |
|
|
|
format(kernel_size)) |
|
|
|
if not isinstance(stride, int): |
|
|
|
ParamValidator.check_integer("stride", stride, 1, Rel.GE) |
|
|
|
raise ValueError("stride should be 1 int number but got {}".format(stride)) |
|
|
|
ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE) |
|
|
|
ParamValidator.check_integer("stride", stride, 1, Rel.GE) |
|
|
|
self.kernel_size = (1, kernel_size) |
|
|
|
self.stride = (1, stride) |
|
|
|
self.avg_pool = P.AvgPool(ksize=self.kernel_size, |
|
|
|
|