| @@ -57,10 +57,36 @@ namespace Tensorflow | |||||
| public TensorShape(params object[] dims) | public TensorShape(params object[] dims) | ||||
| { | { | ||||
| var intdims = new int[dims.Length]; | |||||
| for (int i = 0; i < dims.Length; i++) | |||||
| Array arr; | |||||
| if (dims.Length == 1) | |||||
| { | |||||
| switch (dims[0]) | |||||
| { | |||||
| case int[] intarr: | |||||
| arr = intarr; | |||||
| break; | |||||
| case long[] longarr: | |||||
| arr = longarr; | |||||
| break; | |||||
| case object[] objarr: | |||||
| arr = objarr; | |||||
| break; | |||||
| case int _: | |||||
| case long _: | |||||
| arr = dims; | |||||
| break; | |||||
| default: | |||||
| Binding.print(dims); | |||||
| throw new ArgumentException(nameof(dims)); | |||||
| } | |||||
| } else | |||||
| arr = dims; | |||||
| var intdims = new int[arr.Length]; | |||||
| for (int i = 0; i < arr.Length; i++) | |||||
| { | { | ||||
| var val = dims[i]; | |||||
| var val = arr.GetValue(i); | |||||
| if (val == Binding.None) | if (val == Binding.None) | ||||
| intdims[i] = -1; | intdims[i] = -1; | ||||
| else | else | ||||
| @@ -69,10 +95,18 @@ namespace Tensorflow | |||||
| switch (dims.Length) | switch (dims.Length) | ||||
| { | { | ||||
| case 0: shape = new Shape(new int[0]); break; | |||||
| case 1: shape = Shape.Vector((int) intdims[0]); break; | |||||
| case 2: shape = Shape.Matrix(intdims[0], intdims[1]); break; | |||||
| default: shape = new Shape(intdims); break; | |||||
| case 0: | |||||
| shape = new Shape(new int[0]); | |||||
| break; | |||||
| case 1: | |||||
| shape = Shape.Vector((int) intdims[0]); | |||||
| break; | |||||
| case 2: | |||||
| shape = Shape.Matrix(intdims[0], intdims[1]); | |||||
| break; | |||||
| default: | |||||
| shape = new Shape(intdims); | |||||
| break; | |||||
| } | } | ||||
| } | } | ||||