diff --git a/src/TensorFlowNET.Core/Binding.cs b/src/TensorFlowNET.Core/Binding.cs index e8355e32..e3136f83 100644 --- a/src/TensorFlowNET.Core/Binding.cs +++ b/src/TensorFlowNET.Core/Binding.cs @@ -10,7 +10,13 @@ namespace Tensorflow /// /// Alias to null, similar to python's None. + /// For TensorShape, please use Unknown /// public static readonly object None = null; + + /// + /// Used for TensorShape None + /// + public static readonly int Unknown = -1; } } diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 1edd634f..853255cb 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -71,61 +71,32 @@ namespace Tensorflow } } - public TensorShape(params object[] dims) + public TensorShape(params int[] dims) { - Array arr; + switch (dims.Length) + { + case 0: shape = new Shape(new int[0]); break; + case 1: shape = Shape.Vector((int)dims[0]); break; + case 2: shape = Shape.Matrix(dims[0], dims[1]); break; + default: shape = new Shape(dims); break; + } + } - if (dims.Length == 1) + public TensorShape(int[][] dims) + { + if(dims.Length == 1) { - switch (dims[0]) + switch (dims[0].Length) { - case int[] intarr: - arr = intarr; - break; - case long[] longarr: - arr = longarr; - break; - case object[] objarr: - arr = objarr; - break; - case int _: - case long _: - arr = dims; - break; - case null: //==Binding.None - arr = dims; - break; - default: - Binding.print(dims); - throw new ArgumentException(nameof(dims)); + case 0: shape = new Shape(new int[0]); break; + case 1: shape = Shape.Vector((int)dims[0][0]); break; + case 2: shape = Shape.Matrix(dims[0][0], dims[1][2]); break; + default: shape = new Shape(dims[0]); break; } - } else - arr = dims; - - var intdims = new int[arr.Length]; - for (int i = 0; i < arr.Length; i++) - { - var val = arr.GetValue(i); - if (val == Binding.None) - intdims[i] = -1; - else - intdims[i] = Converts.ToInt32(val); } - - switch (intdims.Length) + else { - case 0: - shape = new Shape(new int[0]); - break; - case 1: - shape = Shape.Vector((int) intdims[0]); - break; - case 2: - shape = Shape.Matrix(intdims[0], intdims[1]); - break; - default: - shape = new Shape(intdims); - break; + throw new NotImplementedException("TensorShape int[][] dims"); } } @@ -232,8 +203,6 @@ namespace Tensorflow public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); - public static implicit operator TensorShape(object[] dims) => new TensorShape(dims); - public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); @@ -260,16 +229,5 @@ namespace Tensorflow public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0); public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8); - - public static implicit operator TensorShape(int?[] dims) => new TensorShape(dims); - public static implicit operator TensorShape(int? dim) => new TensorShape(dim); - public static implicit operator TensorShape((object, object) dims) => new TensorShape(dims.Item1, dims.Item2); - public static implicit operator TensorShape((object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); - public static implicit operator TensorShape((object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); - public static implicit operator TensorShape((object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); - public static implicit operator TensorShape((object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); - public static implicit operator TensorShape((object, object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7); - public static implicit operator TensorShape((object, object, object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8); - } } diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 494f2e89..bdb2f537 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -63,11 +63,6 @@ namespace Tensorflow return gen_array_ops.placeholder(dtype, shape, name); } - public unsafe Tensor placeholder(TF_DataType dtype, object[] shape, string name = null) - { - return placeholder(dtype, new TensorShape(shape), name); - } - public void enable_eager_execution() { // contex = new Context(); diff --git a/test/TensorFlowNET.UnitTest/TensorShapeTest.cs b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs index 90048aa6..efa7def3 100644 --- a/test/TensorFlowNET.UnitTest/TensorShapeTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs @@ -12,48 +12,48 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void Case1() { - int? a = 2; - int? b = 3; - var dims = new object[] {(int?) None, a, b}; + int a = 2; + int b = 3; + var dims = new [] { Unknown, a, b}; new TensorShape(dims).GetPrivate("shape").Should().BeShaped(-1, 2, 3); } [TestMethod] public void Case2() { - int? a = 2; - int? b = 3; - var dims = new object[] {(int?) None, a, b}; - new TensorShape(new object[] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, 3); + int a = 2; + int b = 3; + var dims = new[] { Unknown, a, b}; + new TensorShape(new [] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, 3); } [TestMethod] public void Case3() { - int? a = 2; - int? b = null; - var dims = new object[] {(int?) None, a, b}; - new TensorShape(new object[] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, -1); + int a = 2; + int b = Unknown; + var dims = new [] { Unknown, a, b}; + new TensorShape(new [] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, -1); } [TestMethod] public void Case4() { - TensorShape shape = (None, None); + TensorShape shape = (Unknown, Unknown); shape.GetPrivate("shape").Should().BeShaped(-1, -1); } [TestMethod] public void Case5() { - TensorShape shape = (1, None, 3); + TensorShape shape = (1, Unknown, 3); shape.GetPrivate("shape").Should().BeShaped(1, -1, 3); } [TestMethod] public void Case6() { - TensorShape shape = (None, 1, 2, 3, None); + TensorShape shape = (Unknown, 1, 2, 3, Unknown); shape.GetPrivate("shape").Should().BeShaped(-1, 1, 2, 3, -1); } } diff --git a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs index 8f97d5c2..fa8ec792 100644 --- a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs +++ b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs @@ -42,7 +42,7 @@ namespace TensorFlowNET.UnitTest.layers_test { var sess = tf.Session().as_default(); - var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, None, 1, 2)); + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2)); sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); } @@ -51,7 +51,7 @@ namespace TensorFlowNET.UnitTest.layers_test { var sess = tf.Session().as_default(); - var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(None, 4, 3, 1, 2)); + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2)); sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); } }