diff --git a/src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs b/src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs
index 92bd95a5..149d2e88 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using System.Linq;
using static Tensorflow.Binding;
namespace Tensorflow.Operations
@@ -24,7 +25,7 @@ namespace Tensorflow.Operations
public class MaxPoolFunction : IPoolFunction
{
public Tensor Apply(Tensor value,
- int[] ksize,
+ int[] pool_size,
int[] strides,
string padding,
string data_format = "NHWC",
@@ -33,10 +34,9 @@ namespace Tensorflow.Operations
return tf_with(ops.name_scope(name, "MaxPool", value), scope =>
{
name = scope;
- value = ops.convert_to_tensor(value, name: "input");
return gen_nn_ops.max_pool(
value,
- ksize: ksize,
+ ksize: pool_size,
strides: strides,
padding: padding,
data_format: data_format,
diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
index 4c42cb8c..a7db6eee 100644
--- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
+++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
@@ -5,7 +5,7 @@
Tensorflow.Binding
Tensorflow
2.10.0
- 0.100.1
+ 0.100.2
10.0
enable
Haiping Chen, Meinrad Recheis, Eli Belash
@@ -20,7 +20,7 @@
Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io
- 0.100.1.0
+ 0.100.2.0
tf.net 0.100.x and above are based on tensorflow native 2.10.0
@@ -38,7 +38,7 @@ https://tensorflownet.readthedocs.io
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.
tf.net 0.10x.x aligns with TensorFlow v2.10.x native library.
- 0.100.1.0
+ 0.100.2.0
LICENSE
true
true
diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs
index 80b36c86..a2f4c51b 100644
--- a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs
+++ b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs
@@ -14,9 +14,11 @@
limitations under the License.
******************************************************************************/
+using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
+using static Tensorflow.Binding;
namespace Tensorflow.Keras.Layers
{
@@ -36,17 +38,21 @@ namespace Tensorflow.Keras.Layers
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
- int[] pool_shape;
- int[] strides;
+ int pad_axis = args.DataFormat == "channels_first" ? 2 : 3;
+ inputs = tf.expand_dims(inputs, pad_axis);
+ int[] pool_shape = new int[] { args.PoolSize, 1 };
+ int[] strides = new int[] { args.Strides, 1 };
+ var ndim = inputs[0].ndim;
+
if (args.DataFormat == "channels_last")
{
- pool_shape = new int[] { 1, args.PoolSize, 1 };
- strides = new int[] { 1, args.Strides, 1 };
+ pool_shape = new int[] { 1 }.Concat(pool_shape).Concat(new int[] { 1 }).ToArray();
+ strides = new int[] { 1 }.Concat(strides).Concat(new int[] { 1 }).ToArray();
}
else
{
- pool_shape = new int[] { 1, 1, args.PoolSize };
- strides = new int[] { 1, 1, args.Strides };
+ pool_shape = new int[] { 1, 1 }.Concat(pool_shape).ToArray();
+ strides = new int[] { 1, 1 }.Concat(strides).ToArray();
}
var outputs = args.PoolFunction.Apply(
@@ -54,9 +60,9 @@ namespace Tensorflow.Keras.Layers
ksize: pool_shape,
strides: strides,
padding: args.Padding.ToUpper(),
- data_format: conv_utils.convert_data_format(args.DataFormat, 3));
+ data_format: conv_utils.convert_data_format(args.DataFormat, ndim));
- return outputs;
+ return tf.squeeze(outputs, pad_axis);
}
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs
index e65bf038..27032255 100644
--- a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs
+++ b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs
@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Layers
int[] strides;
if (args.DataFormat == "channels_last")
{
- pool_shape = new int[] { 1, (int)args.PoolSize.dims[0], (int)args.PoolSize.dims[1], 1 };
+ pool_shape = new int[] { 1, (int)args.PoolSize.dims[0], (int)args.PoolSize.dims[1], 1 };
strides = new int[] { 1, (int)args.Strides.dims[0], (int)args.Strides.dims[1], 1 };
}
else
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index d45c7de2..f7d18635 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -7,7 +7,7 @@
enable
Tensorflow.Keras
AnyCPU;x64
- 0.10.1
+ 0.10.2
Haiping Chen
Keras for .NET
Apache 2.0, Haiping Chen 2021
@@ -37,8 +37,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
Git
true
Open.snk
- 0.10.1.0
- 0.10.1.0
+ 0.10.2.0
+ 0.10.2.0
LICENSE
Debug;Release;GPU
@@ -70,7 +70,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
-
+
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/PoolingTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/PoolingTest.cs
index 8af40855..0eab0a98 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/PoolingTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/PoolingTest.cs
@@ -4,6 +4,7 @@ using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
+using Microsoft.VisualBasic;
namespace TensorFlowNET.Keras.UnitTest
{
@@ -226,7 +227,7 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual(expected, y[0].numpy());
}
- [TestMethod, Ignore("There's an error generated from TF complaining about the shape of the pool. Needs further investigation.")]
+ [TestMethod]
public void Max1DPoolingChannelsLast()
{
var x = input_array_1D;
@@ -239,7 +240,7 @@ namespace TensorFlowNET.Keras.UnitTest
var expected = np.array(new float[,,]
{
- {{2.0f, 2.0f, 3.0f, 3.0f, 3.0f},
+ {{1.0f, 2.0f, 3.0f, 3.0f, 3.0f},
{ 1.0f, 2.0f, 3.0f, 3.0f, 3.0f}},
{{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},