From dded90a56486fe2676ac18ed7ac0779002607ce0 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 11 Jul 2021 10:50:28 -0500 Subject: [PATCH] fix string tensor for non ascii char. --- src/TensorFlowNET.Core/APIs/tf.linalg.cs | 2 +- .../Eager/EagerTensor.Creation.cs | 5 +- src/TensorFlowNET.Core/Numpy/NDArray.cs | 4 +- src/TensorFlowNET.Core/Numpy/Shape.cs | 8 +- .../Operations/linalg_ops.cs | 2 +- .../Tensors/Tensor.Creation.cs | 131 +++++++++++------- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 3 + src/TensorFlowNET.Core/Tensors/constant_op.cs | 3 +- .../ManagedAPI/LinalgTest.cs | 6 +- 9 files changed, 102 insertions(+), 62 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs index beb3122c..6687839c 100644 --- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs +++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs @@ -28,7 +28,7 @@ namespace Tensorflow public Tensor eye(int num_rows, int num_columns = -1, TensorShape batch_shape = null, - TF_DataType dtype = TF_DataType.TF_FLOAT, + TF_DataType dtype = TF_DataType.TF_DOUBLE, string name = null) => ops.eye(num_rows, num_columns: num_columns, batch_shape: batch_shape, dtype: dtype, name: name); diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index d1789aae..a512fba9 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -50,7 +50,10 @@ namespace Tensorflow.Eager public EagerTensor(Shape shape, TF_DataType dtype) : base(shape, dtype) => NewEagerTensorHandle(_handle); - internal unsafe EagerTensor(Array array, Shape shape) : base(array, shape) + public EagerTensor(Array array, Shape shape) : base(array, shape) + => NewEagerTensorHandle(_handle); + + public EagerTensor(byte[] bytes, TF_DataType dtype) : base(bytes, dtype) => NewEagerTensorHandle(_handle); void NewEagerTensorHandle(IntPtr h) diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs index 05cc420b..1cfc9b4e 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs @@ -39,8 +39,8 @@ namespace Tensorflow.NumPy public bool HasNext() => throw new NotImplementedException(""); public T MoveNext() => throw new NotImplementedException(""); public NDArray reshape(Shape newshape) => new NDArray(_tensor, newshape); - public NDArray astype(Type type) => throw new NotImplementedException(""); - public NDArray astype(TF_DataType type) => throw new NotImplementedException(""); + public NDArray astype(Type type) => new NDArray(math_ops.cast(_tensor, type.as_tf_dtype())); + public NDArray astype(TF_DataType dtype) => new NDArray(math_ops.cast(_tensor, dtype)); public NDArray ravel() => throw new NotImplementedException(""); public void shuffle(NDArray nd) => throw new NotImplementedException(""); public Array ToMuliDimArray() => throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs index ee8d981d..961955dd 100644 --- a/src/TensorFlowNET.Core/Numpy/Shape.cs +++ b/src/TensorFlowNET.Core/Numpy/Shape.cs @@ -62,11 +62,17 @@ namespace Tensorflow { get { + // scalar + if (ndim == 0) + return 1; + var computed = 1L; for (int i = 0; i < _dims.Length; i++) { var val = _dims[i]; - if (val <= 0) + if (val == 0) + return 0; + else if (val < 0) continue; computed *= val; } diff --git a/src/TensorFlowNET.Core/Operations/linalg_ops.cs b/src/TensorFlowNET.Core/Operations/linalg_ops.cs index 89ff28ef..d383830c 100644 --- a/src/TensorFlowNET.Core/Operations/linalg_ops.cs +++ b/src/TensorFlowNET.Core/Operations/linalg_ops.cs @@ -8,7 +8,7 @@ namespace Tensorflow public Tensor eye(int num_rows, int num_columns = -1, TensorShape batch_shape = null, - TF_DataType dtype = TF_DataType.TF_FLOAT, + TF_DataType dtype = TF_DataType.TF_DOUBLE, string name = null) { return tf_with(ops.name_scope(name, default_name: "eye", new { num_rows, num_columns, batch_shape }), scope => diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 991c6a51..49847856 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -67,16 +67,17 @@ namespace Tensorflow } #region scala - public Tensor(bool value) => InitTensor(value); - public Tensor(byte value) => InitTensor(value); - public Tensor(sbyte value) => InitTensor(value); - public Tensor(short value) => InitTensor(value); - public Tensor(int value) => InitTensor(value); - public Tensor(uint value) => InitTensor(value); - public Tensor(long value) => InitTensor(value); - public Tensor(ulong value) => InitTensor(value); - public Tensor(float value) => InitTensor(value); - public Tensor(double value) => InitTensor(value); + public Tensor(bool value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(byte value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(sbyte value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(short value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(int value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(uint value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(long value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(ulong value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(float value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(double value) => InitTensor(new[] { value }, Shape.Scalar); + public Tensor(string value) => InitTensor(new[] { value }, TensorShape.Scalar); #endregion #region 1d array @@ -94,6 +95,10 @@ namespace Tensorflow public Tensor(Complex[] data, Shape? shape = null) => InitTensor(data, shape); #endregion + public Tensor(Shape shape, TF_DataType dtype) => InitTensor(shape, dtype); + public Tensor(Array array, Shape? shape = null) => InitTensor(array, shape); + public Tensor(byte[] bytes, TF_DataType dtype) => InitTensor(bytes, dtype); + public Tensor(Operation op, int value_index, TF_DataType dtype) { _op = op; @@ -103,65 +108,87 @@ namespace Tensorflow isCreatedInGraphMode = !tf.executing_eagerly(); } - internal Tensor(Shape shape, TF_DataType dtype) => InitTensor(shape, dtype); - internal Tensor(Array array, Shape? shape = null) => InitTensor(array, shape); - internal Tensor(string value) => InitTensor(value); - - protected unsafe void InitTensor(T data) where T : unmanaged - { - _handle = TF_NewTensor(data); - isCreatedInGraphMode = !tf.executing_eagerly(); - } - protected unsafe void InitTensor(Shape shape, TF_DataType dtype) { _handle = TF_NewTensor(shape, dtype, null); isCreatedInGraphMode = !tf.executing_eagerly(); } - protected void InitTensor(string value) + protected unsafe void InitTensor(byte[] bytes, TF_DataType dtype) { - _handle = StringTensor(new[] { value }, TensorShape.Scalar); + if (dtype == TF_DataType.TF_STRING) + _handle = StringTensor(new byte[][] { bytes }, TensorShape.Scalar); + else + throw new NotImplementedException(""); isCreatedInGraphMode = !tf.executing_eagerly(); } protected unsafe void InitTensor(Array array, Shape? shape = null) { + isCreatedInGraphMode = !tf.executing_eagerly(); + shape = shape ?? array.GetShape(); - var dtype = array.GetType().GetElementType().as_tf_dtype(); + var dtype = array.GetDataType(); - switch (array) + if (shape.size == 0) { - case bool[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case bool[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case bool[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case bool[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case byte[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case byte[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case byte[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case byte[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case int[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case int[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case int[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case int[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case long[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case long[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case long[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case long[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case float[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case float[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case float[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case float[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case double[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case double[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case double[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case double[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; - case string[] val: _handle = StringTensor(val, shape); break; - default: - throw new NotImplementedException(""); + _handle = TF_NewTensor(shape, dtype, null); + return; } - isCreatedInGraphMode = !tf.executing_eagerly(); + _handle = array switch + { + bool[] val => InitTensor(val, shape, dtype), + bool[,] val => InitTensor(val, shape, dtype), + bool[,,] val => InitTensor(val, shape, dtype), + bool[,,,] val => InitTensor(val, shape, dtype), + byte[] val => InitTensor(val, shape, dtype), + byte[,] val => InitTensor(val, shape, dtype), + byte[,,] val => InitTensor(val, shape, dtype), + byte[,,,] val => InitTensor(val, shape, dtype), + int[] val => InitTensor(val, shape, dtype), + int[,] val => InitTensor(val, shape, dtype), + int[,,] val => InitTensor(val, shape, dtype), + int[,,,] val => InitTensor(val, shape, dtype), + long[] val => InitTensor(val, shape, dtype), + long[,] val => InitTensor(val, shape, dtype), + long[,,] val => InitTensor(val, shape, dtype), + long[,,,] val => InitTensor(val, shape, dtype), + float[] val => InitTensor(val, shape, dtype), + float[,] val => InitTensor(val, shape, dtype), + float[,,] val => InitTensor(val, shape, dtype), + float[,,,] val => InitTensor(val, shape, dtype), + double[] val => InitTensor(val, shape, dtype), + double[,] val => InitTensor(val, shape, dtype), + double[,,] val => InitTensor(val, shape, dtype), + double[,,,] val => InitTensor(val, shape, dtype), + string[] val => StringTensor(val, shape), + _ => throw new NotImplementedException("") + }; + } + + unsafe IntPtr InitTensor(T[] array, Shape shape, TF_DataType dtype) where T : unmanaged + { + fixed (T* addr = &array[0]) + return TF_NewTensor(shape, dtype, addr); + } + + unsafe IntPtr InitTensor(T[,] array, Shape shape, TF_DataType dtype) where T : unmanaged + { + fixed (T* addr = &array[0, 0]) + return TF_NewTensor(shape, dtype, addr); + } + + unsafe IntPtr InitTensor(T[,,] array, Shape shape, TF_DataType dtype) where T : unmanaged + { + fixed (T* addr = &array[0, 0, 0]) + return TF_NewTensor(shape, dtype, addr); + } + + unsafe IntPtr InitTensor(T[,,,] array, Shape shape, TF_DataType dtype) where T : unmanaged + { + fixed (T* addr = &array[0, 0, 0, 0]) + return TF_NewTensor(shape, dtype, addr); } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 0f86065d..4fe422c3 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -160,6 +160,9 @@ namespace Tensorflow { if (dims != null && shape2.dims != null) { + if (dims.Contains(-1) || shape2.dims.Contains(-1)) + return true; + if (shape.size != (ulong)shape2.size) return false; } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 185fd8a5..2d38f9b7 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -152,8 +152,9 @@ namespace Tensorflow value = nd.astype(dtype.as_system_dtype()); } + // non ascii char if (dtype == TF_DataType.TF_STRING && value is byte[] bytes) - return new EagerTensor(bytes, ctx.DeviceName, TF_DataType.TF_STRING); + return new EagerTensor(bytes, TF_DataType.TF_STRING); switch (value) { diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs index 4d1ea26f..73c6415b 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs @@ -13,9 +13,9 @@ namespace TensorFlowNET.UnitTest.ManagedAPI Assert.AreEqual((3, 3), tensor.TensorShape); - Assert.AreEqual(0.0f, (float)tensor[2, 0]); - Assert.AreEqual(0.0f, (float)tensor[2, 1]); - Assert.AreEqual(1.0f, (float)tensor[2, 2]); + Assert.AreEqual(0.0f, (double)tensor[2, 0]); + Assert.AreEqual(0.0f, (double)tensor[2, 1]); + Assert.AreEqual(1.0f, (double)tensor[2, 2]); } } }