diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/AveragePooling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/AveragePooling2DArgs.cs
new file mode 100644
index 00000000..06903e37
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/AveragePooling2DArgs.cs
@@ -0,0 +1,7 @@
+namespace Tensorflow.Keras.ArgsDefinition
+{
+ public class AveragePooling2DArgs : Pooling2DArgs
+ {
+
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling2D.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling2DArgs.cs
similarity index 100%
rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling2D.cs
rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/MaxPooling2DArgs.cs
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/AveragePoolFunction.cs b/src/TensorFlowNET.Core/Operations/NnOps/AveragePoolFunction.cs
new file mode 100644
index 00000000..d43f8a0c
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/NnOps/AveragePoolFunction.cs
@@ -0,0 +1,47 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using static Tensorflow.Binding;
+
+namespace Tensorflow.Operations
+{
+ ///
+ /// Performs the average pooling on the input.
+ ///
+ public class AveragePoolFunction : IPoolFunction
+ {
+ public Tensor Apply(Tensor value,
+ int[] ksize,
+ int[] strides,
+ string padding,
+ string data_format = "NHWC",
+ string name = null)
+ {
+ return tf_with(ops.name_scope(name, "AveragePool", value), scope =>
+ {
+ name = scope;
+ value = ops.convert_to_tensor(value, name: "input");
+ return gen_nn_ops.average_pool(
+ value,
+ ksize: ksize,
+ strides: strides,
+ padding: padding,
+ data_format: data_format,
+ name: name);
+ });
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
index 31ac8650..0567858f 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
@@ -255,6 +255,21 @@ namespace Tensorflow.Operations
=> tf.Context.ExecuteOp("LeakyRelu", name,
new ExecuteOpArgs(features).SetAttributes(new { alpha }));
+ public static Tensor average_pool(Tensor input,
+ int[] ksize,
+ int[] strides,
+ string padding,
+ string data_format = "NHWC",
+ string name = null)
+ => tf.Context.ExecuteOp("AvgPool", name, new ExecuteOpArgs(input)
+ .SetAttributes(new
+ {
+ ksize,
+ strides,
+ padding,
+ data_format
+ }));
+
public static Tensor max_pool(Tensor input,
int[] ksize,
int[] strides,
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index 8bbc0cc8..d99d8bae 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -531,6 +531,26 @@ namespace Tensorflow.Keras.Layers
Ragged = ragged
});
+ ///
+ /// Average pooling operation for spatial data.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public AveragePooling2D AveragePooling2D(Shape pool_size = null,
+ Shape strides = null,
+ string padding = "valid",
+ string data_format = null)
+ => new AveragePooling2D(new AveragePooling2DArgs
+ {
+ PoolSize = pool_size ?? (2, 2),
+ Strides = strides,
+ Padding = padding,
+ DataFormat = data_format
+ });
+
///
/// Max pooling operation for 1D temporal data.
///
diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/AveragePooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/AveragePooling2D.cs
new file mode 100644
index 00000000..fbdb557c
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Pooling/AveragePooling2D.cs
@@ -0,0 +1,14 @@
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Operations;
+
+namespace Tensorflow.Keras.Layers
+{
+ public class AveragePooling2D : Pooling2D
+ {
+ public AveragePooling2D(AveragePooling2DArgs args)
+ : base(args)
+ {
+ args.PoolFunction = new AveragePoolFunction();
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
index 335e5da2..c9385908 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
@@ -15,7 +15,24 @@ namespace TensorFlowNET.Keras.UnitTest
[TestClass]
public class LayersTest : EagerModeTestBase
{
- // [TestMethod]
+ [TestMethod]
+ public void AveragePooling2D()
+ {
+ var x = tf.constant(new float[,]
+ {
+ { 1, 2, 3 },
+ { 4, 5, 6 },
+ { 7, 8, 9 }
+ });
+ x = tf.reshape(x, (1, 3, 3, 1));
+ var avg_pool_2d = keras.layers.AveragePooling2D(pool_size: (2, 2),
+ strides: (1, 1), padding: "valid");
+ Tensor avg = avg_pool_2d.Apply(x);
+ Assert.AreEqual((1, 2, 2, 1), avg.shape);
+ Equal(new float[] { 3, 4, 6, 7 }, avg.ToArray());
+ }
+
+ [TestMethod]
public void InputLayer()
{
var model = keras.Sequential(new List
@@ -23,8 +40,10 @@ namespace TensorFlowNET.Keras.UnitTest
keras.layers.InputLayer(input_shape: 4),
keras.layers.Dense(8)
});
- model.compile(optimizer: keras.optimizers.RMSprop(0.001f));
- model.fit(np.zeros((10, 4)), np.ones((10, 8)));
+ model.compile(optimizer: keras.optimizers.RMSprop(0.001f),
+ loss: keras.losses.MeanSquaredError(),
+ metrics: new[] { "accuracy" });
+ model.fit(np.zeros((10, 4), dtype: tf.float32), np.ones((10, 8), dtype: tf.float32));
}
[TestMethod]