diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index 089dd8a5..39bacfde 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System.Collections.Generic; +using System.Linq; +using NumSharp; using Tensorflow.Keras.Layers; using Tensorflow.Operations.Activation; using static Tensorflow.Binding; @@ -182,6 +184,7 @@ namespace Tensorflow string name = null, string data_format = "channels_last") { + var input_shape = inputs.shape; if (inputs.shape.Length == 0) throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()"); @@ -193,9 +196,25 @@ namespace Tensorflow inputs = array_ops.transpose(inputs, premutation.ToArray()); } - var ret = array_ops.reshape(inputs, new int[] {inputs.shape[0], -1}); - ret.set_shape(new int[] {inputs.shape[0], -1}); + var ret = array_ops.reshape(inputs, new int[] {input_shape[0], -1}); + ret.shape = ret.shape; + //ret.set_shape(compute_output_shape(ret.shape)); return ret; + + int[] compute_output_shape(int[] inputshape) + { + if (inputshape == null || inputshape.Length == 0) + inputshape = new int[] {1}; + + if (inputshape.Skip(1).All(d => d > 0)) + { + int[] output_shape = new int[2]; + output_shape[0] = inputshape[0]; + output_shape[1] = inputshape.Skip(1).Aggregate(1, (acc, rhs) => acc*rhs); //calculate size of all the rest dimensions + return output_shape; + } else + return new int[] {inputshape[0], -1}; //-1 == Binding.None + } } } } diff --git a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs index d533f128..981af9a4 100644 --- a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs +++ b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs @@ -36,5 +36,14 @@ namespace TensorFlowNET.UnitTest.layers_test var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape()); new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw(); } + + [TestMethod] + public void Case4() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, None, 1, 2)); + sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); + } } } \ No newline at end of file