From 0c32b73fc11b8b75213c0450df68ec9ca0d03816 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Sun, 8 Sep 2019 17:09:27 +0300 Subject: [PATCH] tf.layers.flatten: Added and fixed special unit-test case. --- src/TensorFlowNET.Core/APIs/tf.layers.cs | 3 +-- test/TensorFlowNET.UnitTest/layers_test/flatten.cs | 9 +++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index 39bacfde..786469b5 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -196,8 +196,7 @@ namespace Tensorflow inputs = array_ops.transpose(inputs, premutation.ToArray()); } - var ret = array_ops.reshape(inputs, new int[] {input_shape[0], -1}); - ret.shape = ret.shape; + var ret = array_ops.reshape(inputs, compute_output_shape(input_shape)); //ret.set_shape(compute_output_shape(ret.shape)); return ret; diff --git a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs index 981af9a4..8f97d5c2 100644 --- a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs +++ b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs @@ -45,5 +45,14 @@ namespace TensorFlowNET.UnitTest.layers_test var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, None, 1, 2)); sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); } + + [TestMethod] + public void Case5() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(None, 4, 3, 1, 2)); + sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); + } } } \ No newline at end of file