Browse Source

add AvgPooling layer

tags/v0.2.0-alpha
zhaojichen 5 years ago
parent
commit
e1b6addefd
1 changed files with 6 additions and 0 deletions
  1. +6
    -0
      mindspore/nn/layer/pooling.py

+ 6
- 0
mindspore/nn/layer/pooling.py View File

@@ -18,6 +18,7 @@ from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ... import context from ... import context
from ..cell import Cell from ..cell import Cell
from ..._checkparam import Rel




class _PoolNd(Cell): class _PoolNd(Cell):
@@ -263,10 +264,15 @@ class AvgPool1d(_PoolNd):
stride=1, stride=1,
pad_mode="valid"): pad_mode="valid"):
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode) super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
validator.check_type('kernel_size', kernel_size, [int,])
validator.check_type('stride', stride, [int,])
self.padding = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'])
if not isinstance(kernel_size, int): if not isinstance(kernel_size, int):
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE)
raise ValueError("kernel_size should be 1 int number but got {}". raise ValueError("kernel_size should be 1 int number but got {}".
format(kernel_size)) format(kernel_size))
if not isinstance(stride, int): if not isinstance(stride, int):
validator.check_integer("stride", stride, 1, Rel.GE)
raise ValueError("stride should be 1 int number but got {}".format(stride)) raise ValueError("stride should be 1 int number but got {}".format(stride))
self.kernel_size = (1, kernel_size) self.kernel_size = (1, kernel_size)
self.stride = (1, stride) self.stride = (1, stride)


Loading…
Cancel
Save