|
|
|
@@ -14,6 +14,7 @@ |
|
|
|
# ============================================================================ |
|
|
|
"""pooling""" |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.ops import functional as F |
|
|
|
from mindspore._checkparam import Validator as validator |
|
|
|
from ... import context |
|
|
|
from ..cell import Cell |
|
|
|
@@ -272,6 +273,17 @@ class AvgPool1d(_PoolNd): |
|
|
|
self.avg_pool = P.AvgPool(ksize=self.kernel_size, |
|
|
|
strides=self.stride, |
|
|
|
padding=self.pad_mode) |
|
|
|
self.shape = F.shape |
|
|
|
self.reduce_mean = P.ReduceMean(keep_dims=True) |
|
|
|
self.slice = P.Slice() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
return self.avg_pool(x) |
|
|
|
batch, channel, high, width = self.shape(x) |
|
|
|
if width == self.kernel_size[1]: |
|
|
|
x = self.reduce_mean(x, 3) |
|
|
|
elif width - self.kernel_size[1] < self.stride[1]: |
|
|
|
x = self.slice(x, (0, 0, 0, 0), (batch, channel, high, self.kernel_size[1])) |
|
|
|
x = self.reduce_mean(x, 3) |
|
|
|
else: |
|
|
|
x = self.avg_pool(x) |
|
|
|
return x |