diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
index 1fc95927..4de72c6c 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
@@ -3,6 +3,7 @@ using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
+using NumSharp.Utilities;
namespace Tensorflow
{
@@ -65,6 +66,30 @@ namespace Tensorflow
}
}
+ ///
+ /// An overload that can accept .
+ ///
+ public TensorShape(params object[] dims)
+ {
+ var intdims = new int[dims.Length];
+ for (int i = 0; i < dims.Length; i++)
+ {
+ var val = dims[i];
+ if (val == Binding.None)
+ intdims[i] = -1;
+ else
+ intdims[i] = Converts.ToInt32(val);
+ }
+
+ switch (dims.Length)
+ {
+ case 0: shape = new Shape(new int[0]); break;
+ case 1: shape = Shape.Vector((int) intdims[0]); break;
+ case 2: shape = Shape.Matrix(intdims[0], intdims[1]); break;
+ default: shape = new Shape(intdims); break;
+ }
+ }
+
///
///
///