diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 1fc95927..4de72c6c 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; +using NumSharp.Utilities; namespace Tensorflow { @@ -65,6 +66,30 @@ namespace Tensorflow } } + /// + /// An overload that can accept . + /// + public TensorShape(params object[] dims) + { + var intdims = new int[dims.Length]; + for (int i = 0; i < dims.Length; i++) + { + var val = dims[i]; + if (val == Binding.None) + intdims[i] = -1; + else + intdims[i] = Converts.ToInt32(val); + } + + switch (dims.Length) + { + case 0: shape = new Shape(new int[0]); break; + case 1: shape = Shape.Vector((int) intdims[0]); break; + case 2: shape = Shape.Matrix(intdims[0], intdims[1]); break; + default: shape = new Shape(intdims); break; + } + } + /// /// ///