| @@ -3,6 +3,7 @@ using System; | |||||
| using System.Diagnostics.CodeAnalysis; | using System.Diagnostics.CodeAnalysis; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using NumSharp.Utilities; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -65,6 +66,30 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// An overload that can accept <see cref="Binding.None"/>. | |||||
| /// </summary> | |||||
| public TensorShape(params object[] dims) | |||||
| { | |||||
| var intdims = new int[dims.Length]; | |||||
| for (int i = 0; i < dims.Length; i++) | |||||
| { | |||||
| var val = dims[i]; | |||||
| if (val == Binding.None) | |||||
| intdims[i] = -1; | |||||
| else | |||||
| intdims[i] = Converts.ToInt32(val); | |||||
| } | |||||
| switch (dims.Length) | |||||
| { | |||||
| case 0: shape = new Shape(new int[0]); break; | |||||
| case 1: shape = Shape.Vector((int) intdims[0]); break; | |||||
| case 2: shape = Shape.Matrix(intdims[0], intdims[1]); break; | |||||
| default: shape = new Shape(intdims); break; | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||