From cc54fe19bf9cb109c66e95c80bef240b3c3ea782 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Sun, 8 Sep 2019 17:01:37 +0300 Subject: [PATCH] TensorShape: Fixed construction when passing int[] or long[] --- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 48 ++++++++++++++++--- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 4f2accc3..734e9bc5 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -57,10 +57,36 @@ namespace Tensorflow public TensorShape(params object[] dims) { - var intdims = new int[dims.Length]; - for (int i = 0; i < dims.Length; i++) + Array arr; + + if (dims.Length == 1) + { + switch (dims[0]) + { + case int[] intarr: + arr = intarr; + break; + case long[] longarr: + arr = longarr; + break; + case object[] objarr: + arr = objarr; + break; + case int _: + case long _: + arr = dims; + break; + default: + Binding.print(dims); + throw new ArgumentException(nameof(dims)); + } + } else + arr = dims; + + var intdims = new int[arr.Length]; + for (int i = 0; i < arr.Length; i++) { - var val = dims[i]; + var val = arr.GetValue(i); if (val == Binding.None) intdims[i] = -1; else @@ -69,10 +95,18 @@ namespace Tensorflow switch (dims.Length) { - case 0: shape = new Shape(new int[0]); break; - case 1: shape = Shape.Vector((int) intdims[0]); break; - case 2: shape = Shape.Matrix(intdims[0], intdims[1]); break; - default: shape = new Shape(intdims); break; + case 0: + shape = new Shape(new int[0]); + break; + case 1: + shape = Shape.Vector((int) intdims[0]); + break; + case 2: + shape = Shape.Matrix(intdims[0], intdims[1]); + break; + default: + shape = new Shape(intdims); + break; } }