diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index 089dd8a5..786469b5 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System.Collections.Generic; +using System.Linq; +using NumSharp; using Tensorflow.Keras.Layers; using Tensorflow.Operations.Activation; using static Tensorflow.Binding; @@ -182,6 +184,7 @@ namespace Tensorflow string name = null, string data_format = "channels_last") { + var input_shape = inputs.shape; if (inputs.shape.Length == 0) throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()"); @@ -193,9 +196,24 @@ namespace Tensorflow inputs = array_ops.transpose(inputs, premutation.ToArray()); } - var ret = array_ops.reshape(inputs, new int[] {inputs.shape[0], -1}); - ret.set_shape(new int[] {inputs.shape[0], -1}); + var ret = array_ops.reshape(inputs, compute_output_shape(input_shape)); + //ret.set_shape(compute_output_shape(ret.shape)); return ret; + + int[] compute_output_shape(int[] inputshape) + { + if (inputshape == null || inputshape.Length == 0) + inputshape = new int[] {1}; + + if (inputshape.Skip(1).All(d => d > 0)) + { + int[] output_shape = new int[2]; + output_shape[0] = inputshape[0]; + output_shape[1] = inputshape.Skip(1).Aggregate(1, (acc, rhs) => acc*rhs); //calculate size of all the rest dimensions + return output_shape; + } else + return new int[] {inputshape[0], -1}; //-1 == Binding.None + } } } } diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index ff383642..def78327 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -21,6 +21,7 @@ using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; using System.Linq; +using NumSharp.Utilities; namespace Tensorflow { @@ -29,9 +30,37 @@ namespace Tensorflow /// public static partial class Binding { + private static string _tostring(object obj) + { + switch (obj) + { + case NDArray nd: + return nd.ToString(false); + case Array arr: + if (arr.Rank!=1 || arr.GetType().GetElementType()?.IsArray == true) + arr = Arrays.Flatten(arr); + var objs = toObjectArray(arr); + return $"[{string.Join(", ", objs.Select(_tostring))}]"; + default: + return obj?.ToString() ?? "null"; + } + + object[] toObjectArray(Array arr) + { + var len = arr.LongLength; + var ret = new object[len]; + for (long i = 0; i < len; i++) + { + ret[i] = arr.GetValue(i); + } + + return ret; + } + } + public static void print(object obj) { - Console.WriteLine(obj.ToString()); + Console.WriteLine(_tostring(obj)); } public static int len(object a) diff --git a/src/TensorFlowNET.Core/Binding.cs b/src/TensorFlowNET.Core/Binding.cs index f443f2eb..e3136f83 100644 --- a/src/TensorFlowNET.Core/Binding.cs +++ b/src/TensorFlowNET.Core/Binding.cs @@ -7,5 +7,16 @@ namespace Tensorflow public static partial class Binding { public static tensorflow tf { get; } = New(); + + /// + /// Alias to null, similar to python's None. + /// For TensorShape, please use Unknown + /// + public static readonly object None = null; + + /// + /// Used for TensorShape None + /// + public static readonly int Unknown = -1; } } diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 1fc95927..853255cb 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; +using NumSharp.Utilities; namespace Tensorflow { @@ -32,7 +33,23 @@ namespace Tensorflow /// /// Returns the size this shape represents. /// - public int size => shape.Size; + public int size + { + get + { + var dims = shape.Dimensions; + var computed = 1; + for (int i = 0; i < dims.Length; i++) + { + var val = dims[i]; + if (val <= 0) + continue; + computed *= val; + } + + return computed; + } + } public TensorShape(TensorShapeProto proto) { @@ -59,12 +76,30 @@ namespace Tensorflow switch (dims.Length) { case 0: shape = new Shape(new int[0]); break; - case 1: shape = Shape.Vector((int) dims[0]); break; + case 1: shape = Shape.Vector((int)dims[0]); break; case 2: shape = Shape.Matrix(dims[0], dims[1]); break; default: shape = new Shape(dims); break; } } + public TensorShape(int[][] dims) + { + if(dims.Length == 1) + { + switch (dims[0].Length) + { + case 0: shape = new Shape(new int[0]); break; + case 1: shape = Shape.Vector((int)dims[0][0]); break; + case 2: shape = Shape.Matrix(dims[0][0], dims[1][2]); break; + default: shape = new Shape(dims[0]); break; + } + } + else + { + throw new NotImplementedException("TensorShape int[][] dims"); + } + } + /// /// /// @@ -188,6 +223,11 @@ namespace Tensorflow public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); - + + public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7); + + public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8); } } diff --git a/test/TensorFlowNET.UnitTest/TensorShapeTest.cs b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs new file mode 100644 index 00000000..efa7def3 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs @@ -0,0 +1,60 @@ +using System; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class TensorShapeTest + { + [TestMethod] + public void Case1() + { + int a = 2; + int b = 3; + var dims = new [] { Unknown, a, b}; + new TensorShape(dims).GetPrivate("shape").Should().BeShaped(-1, 2, 3); + } + + [TestMethod] + public void Case2() + { + int a = 2; + int b = 3; + var dims = new[] { Unknown, a, b}; + new TensorShape(new [] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, 3); + } + + [TestMethod] + public void Case3() + { + int a = 2; + int b = Unknown; + var dims = new [] { Unknown, a, b}; + new TensorShape(new [] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, -1); + } + + [TestMethod] + public void Case4() + { + TensorShape shape = (Unknown, Unknown); + shape.GetPrivate("shape").Should().BeShaped(-1, -1); + } + + [TestMethod] + public void Case5() + { + TensorShape shape = (1, Unknown, 3); + shape.GetPrivate("shape").Should().BeShaped(1, -1, 3); + } + + [TestMethod] + public void Case6() + { + TensorShape shape = (Unknown, 1, 2, 3, Unknown); + shape.GetPrivate("shape").Should().BeShaped(-1, 1, 2, 3, -1); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs index d533f128..fa8ec792 100644 --- a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs +++ b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs @@ -36,5 +36,23 @@ namespace TensorFlowNET.UnitTest.layers_test var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape()); new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw(); } + + [TestMethod] + public void Case4() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2)); + sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); + } + + [TestMethod] + public void Case5() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2)); + sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); + } } } \ No newline at end of file