Browse Source

add AvgPooling layer

tags/v0.2.0-alpha
zhaojichen 6 years ago
parent
commit
422bc304df
1 changed files with 13 additions and 1 deletions
  1. +13
    -1
      mindspore/nn/layer/pooling.py

+ 13
- 1
mindspore/nn/layer/pooling.py View File

@@ -14,6 +14,7 @@
# ============================================================================
"""pooling"""
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator
from ... import context
from ..cell import Cell
@@ -272,6 +273,17 @@ class AvgPool1d(_PoolNd):
self.avg_pool = P.AvgPool(ksize=self.kernel_size,
strides=self.stride,
padding=self.pad_mode)
self.shape = F.shape
self.reduce_mean = P.ReduceMean(keep_dims=True)
self.slice = P.Slice()

def construct(self, x):
return self.avg_pool(x)
batch, channel, high, width = self.shape(x)
if width == self.kernel_size[1]:
x = self.reduce_mean(x, 3)
elif width - self.kernel_size[1] < self.stride[1]:
x = self.slice(x, (0, 0, 0, 0), (batch, channel, high, self.kernel_size[1]))
x = self.reduce_mean(x, 3)
else:
x = self.avg_pool(x)
return x

Loading…
Cancel
Save