diff --git a/src/TensorFlowNET.Core/Binding.cs b/src/TensorFlowNET.Core/Binding.cs
index e8355e32..e3136f83 100644
--- a/src/TensorFlowNET.Core/Binding.cs
+++ b/src/TensorFlowNET.Core/Binding.cs
@@ -10,7 +10,13 @@ namespace Tensorflow
///
/// Alias to null, similar to python's None.
+ /// For TensorShape, please use Unknown
///
public static readonly object None = null;
+
+ ///
+ /// Used for TensorShape None
+ ///
+ public static readonly int Unknown = -1;
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
index 1edd634f..853255cb 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
@@ -71,61 +71,32 @@ namespace Tensorflow
}
}
- public TensorShape(params object[] dims)
+ public TensorShape(params int[] dims)
{
- Array arr;
+ switch (dims.Length)
+ {
+ case 0: shape = new Shape(new int[0]); break;
+ case 1: shape = Shape.Vector((int)dims[0]); break;
+ case 2: shape = Shape.Matrix(dims[0], dims[1]); break;
+ default: shape = new Shape(dims); break;
+ }
+ }
- if (dims.Length == 1)
+ public TensorShape(int[][] dims)
+ {
+ if(dims.Length == 1)
{
- switch (dims[0])
+ switch (dims[0].Length)
{
- case int[] intarr:
- arr = intarr;
- break;
- case long[] longarr:
- arr = longarr;
- break;
- case object[] objarr:
- arr = objarr;
- break;
- case int _:
- case long _:
- arr = dims;
- break;
- case null: //==Binding.None
- arr = dims;
- break;
- default:
- Binding.print(dims);
- throw new ArgumentException(nameof(dims));
+ case 0: shape = new Shape(new int[0]); break;
+ case 1: shape = Shape.Vector((int)dims[0][0]); break;
+ case 2: shape = Shape.Matrix(dims[0][0], dims[1][2]); break;
+ default: shape = new Shape(dims[0]); break;
}
- } else
- arr = dims;
-
- var intdims = new int[arr.Length];
- for (int i = 0; i < arr.Length; i++)
- {
- var val = arr.GetValue(i);
- if (val == Binding.None)
- intdims[i] = -1;
- else
- intdims[i] = Converts.ToInt32(val);
}
-
- switch (intdims.Length)
+ else
{
- case 0:
- shape = new Shape(new int[0]);
- break;
- case 1:
- shape = Shape.Vector((int) intdims[0]);
- break;
- case 2:
- shape = Shape.Matrix(intdims[0], intdims[1]);
- break;
- default:
- shape = new Shape(intdims);
- break;
+ throw new NotImplementedException("TensorShape int[][] dims");
}
}
@@ -232,8 +203,6 @@ namespace Tensorflow
public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone());
public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone());
- public static implicit operator TensorShape(object[] dims) => new TensorShape(dims);
-
public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes
public static implicit operator TensorShape(int[] dims) => new TensorShape(dims);
@@ -260,16 +229,5 @@ namespace Tensorflow
public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
-
- public static implicit operator TensorShape(int?[] dims) => new TensorShape(dims);
- public static implicit operator TensorShape(int? dim) => new TensorShape(dim);
- public static implicit operator TensorShape((object, object) dims) => new TensorShape(dims.Item1, dims.Item2);
- public static implicit operator TensorShape((object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);
- public static implicit operator TensorShape((object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
- public static implicit operator TensorShape((object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);
- public static implicit operator TensorShape((object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);
- public static implicit operator TensorShape((object, object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7);
- public static implicit operator TensorShape((object, object, object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
-
}
}
diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs
index 494f2e89..bdb2f537 100644
--- a/src/TensorFlowNET.Core/tensorflow.cs
+++ b/src/TensorFlowNET.Core/tensorflow.cs
@@ -63,11 +63,6 @@ namespace Tensorflow
return gen_array_ops.placeholder(dtype, shape, name);
}
- public unsafe Tensor placeholder(TF_DataType dtype, object[] shape, string name = null)
- {
- return placeholder(dtype, new TensorShape(shape), name);
- }
-
public void enable_eager_execution()
{
// contex = new Context();
diff --git a/test/TensorFlowNET.UnitTest/TensorShapeTest.cs b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs
index 90048aa6..efa7def3 100644
--- a/test/TensorFlowNET.UnitTest/TensorShapeTest.cs
+++ b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs
@@ -12,48 +12,48 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void Case1()
{
- int? a = 2;
- int? b = 3;
- var dims = new object[] {(int?) None, a, b};
+ int a = 2;
+ int b = 3;
+ var dims = new [] { Unknown, a, b};
new TensorShape(dims).GetPrivate("shape").Should().BeShaped(-1, 2, 3);
}
[TestMethod]
public void Case2()
{
- int? a = 2;
- int? b = 3;
- var dims = new object[] {(int?) None, a, b};
- new TensorShape(new object[] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, 3);
+ int a = 2;
+ int b = 3;
+ var dims = new[] { Unknown, a, b};
+ new TensorShape(new [] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, 3);
}
[TestMethod]
public void Case3()
{
- int? a = 2;
- int? b = null;
- var dims = new object[] {(int?) None, a, b};
- new TensorShape(new object[] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, -1);
+ int a = 2;
+ int b = Unknown;
+ var dims = new [] { Unknown, a, b};
+ new TensorShape(new [] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, -1);
}
[TestMethod]
public void Case4()
{
- TensorShape shape = (None, None);
+ TensorShape shape = (Unknown, Unknown);
shape.GetPrivate("shape").Should().BeShaped(-1, -1);
}
[TestMethod]
public void Case5()
{
- TensorShape shape = (1, None, 3);
+ TensorShape shape = (1, Unknown, 3);
shape.GetPrivate("shape").Should().BeShaped(1, -1, 3);
}
[TestMethod]
public void Case6()
{
- TensorShape shape = (None, 1, 2, 3, None);
+ TensorShape shape = (Unknown, 1, 2, 3, Unknown);
shape.GetPrivate("shape").Should().BeShaped(-1, 1, 2, 3, -1);
}
}
diff --git a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs
index 8f97d5c2..fa8ec792 100644
--- a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs
+++ b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs
@@ -42,7 +42,7 @@ namespace TensorFlowNET.UnitTest.layers_test
{
var sess = tf.Session().as_default();
- var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, None, 1, 2));
+ var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2));
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
}
@@ -51,7 +51,7 @@ namespace TensorFlowNET.UnitTest.layers_test
{
var sess = tf.Session().as_default();
- var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(None, 4, 3, 1, 2));
+ var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2));
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
}
}