diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index b97ba1cd..a7821c93 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -142,18 +142,32 @@ namespace Tensorflow return new EagerTensor(val, ctx.device_name); case int[,] val: return new EagerTensor(val, ctx.device_name); + case int[,,] val: + return new EagerTensor(val, ctx.device_name); case long val: return new EagerTensor(val, ctx.device_name); + case long[] val: + return new EagerTensor(val, ctx.device_name); + case long[,] val: + return new EagerTensor(val, ctx.device_name); + case long[,,] val: + return new EagerTensor(val, ctx.device_name); case float val: return new EagerTensor(val, ctx.device_name); + case float[] val: + return new EagerTensor(val, ctx.device_name); case float[,] val: return new EagerTensor(val, ctx.device_name); - case double val: + case float[,,] val: return new EagerTensor(val, ctx.device_name); - case float[] val: + case double val: return new EagerTensor(val, ctx.device_name); case double[] val: return new EagerTensor(val, ctx.device_name); + case double[,] val: + return new EagerTensor(val, ctx.device_name); + case double[,,] val: + return new EagerTensor(val, ctx.device_name); default: throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}"); } diff --git a/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs b/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs index f30321a1..39efc8e6 100644 --- a/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs +++ b/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs @@ -26,10 +26,18 @@ namespace Tensorflow.UnitTest.TF_API [TestMethod] public void InitTensorTest() { - var a = tf.constant(new NDArray(new[, ,] { { { 1 }, { 2 }, { 3 } }, { { 4 }, { 5 }, { 6 } } })); - var b = tf.constant(new[, ,] { { { 1 }, { 2 }, { 3 } }, { { 4 }, { 5 }, { 6 } } }); - //Test Result : a is OK , and b is error . + var a = tf.constant(np.array(new[, ,] + { + { { 1 }, { 2 }, { 3 } }, + { { 4 }, { 5 }, { 6 } } + })); Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape)); + + var b = tf.constant(new[, ,] + { + { { 1 }, { 2 }, { 3 } }, + { { 4 }, { 5 }, { 6 } } + }); Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape)); } @@ -46,7 +54,7 @@ namespace Tensorflow.UnitTest.TF_API [TestMethod] public void ConcatDoubleTest() - {//double type has some error + { var a = tf.constant(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } }); var b = tf.constant(new[,] { { 5.0, 6.0 }, { 7.0, 8.0 } }); var c = tf.constant(new[,] { { 9.0, 10.0 }, { 11.0, 12.0 } });