diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index 95f808a8..75677e22 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -25,7 +25,13 @@ namespace Tensorflow.Eager EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status.Handle); Resolve(); } - + + public EagerTensor(byte[] value, string device_name, TF_DataType dtype) : base(value, dType: dtype) + { + EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status.Handle); + Resolve(); + } + public EagerTensor(NDArray value, string device_name) : base(value) { EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status.Handle); diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs index 2d247750..f5060126 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs @@ -1,6 +1,7 @@ using NumSharp; using System; using System.Collections.Generic; +using System.Linq; using System.Text; using static Tensorflow.Binding; @@ -49,7 +50,8 @@ namespace Tensorflow.Eager switch (dtype) { case TF_DataType.TF_STRING: - return $"b'{(string)nd}'"; + return string.Join(string.Empty, nd.ToArray() + .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())); case TF_DataType.TF_BOOL: return (nd.GetByte(0) > 0).ToString(); case TF_DataType.TF_RESOURCE: diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index d1a75338..9507f0c0 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -452,14 +452,14 @@ namespace Tensorflow public unsafe Tensor(string str) { var buffer = Encoding.UTF8.GetBytes(str); - var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); - var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + sizeof(ulong))); + var size = c_api.TF_StringEncodedSize((ulong)buffer.Length); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)(size + sizeof(ulong))); AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); fixed (byte* src = buffer) - c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(long)), size, tf.status.Handle); + c_api.TF_StringEncode(src, (ulong)buffer.Length, (sbyte*)(tensor + sizeof(long)), size, tf.status.Handle); _handle = handle; tf.status.Check(true); } @@ -474,7 +474,7 @@ namespace Tensorflow { if (nd.Unsafe.Storage.Shape.IsContiguous) { - var bytesLength = (UIntPtr) nd.size; + var bytesLength = (ulong)nd.size; var size = c_api.TF_StringEncodedSize(bytesLength); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); AllocationType = AllocationType.Tensorflow; @@ -482,13 +482,13 @@ namespace Tensorflow IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); - c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, tf.status.Handle); + c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(long)), size, tf.status.Handle); tf.status.Check(true); _handle = handle; } else { var buffer = nd.ToArray(); - var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); + var size = c_api.TF_StringEncodedSize((ulong)buffer.Length); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); AllocationType = AllocationType.Tensorflow; @@ -496,7 +496,7 @@ namespace Tensorflow Marshal.WriteInt64(tensor, 0); fixed (byte* src = buffer) - c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, tf.status.Handle); + c_api.TF_StringEncode(src, (ulong)buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, tf.status.Handle); tf.status.Check(true); _handle = handle; @@ -538,7 +538,7 @@ namespace Tensorflow int size = 0; foreach (var b in buffer) { - size += (int) TF_StringEncodedSize((UIntPtr) b.Length); + size += (int)TF_StringEncodedSize((ulong)b.Length); } int totalSize = size + buffer.Length * 8; @@ -557,7 +557,7 @@ namespace Tensorflow { fixed (byte* src = &buffer[i][0]) { - var written = TF_StringEncode(src, (UIntPtr) buffer[i].Length, (sbyte*) dst, (UIntPtr) (dstLimit.ToInt64() - dst.ToInt64()), status.Handle); + var written = TF_StringEncode(src, (ulong)buffer[i].Length, (sbyte*)dst, (ulong)(dstLimit.ToInt64() - dst.ToInt64()), status.Handle); status.Check(true); pOffset += 8; dst += (int) written; @@ -592,25 +592,27 @@ namespace Tensorflow /// [MethodImpl(MethodImplOptions.AggressiveInlining)] [SuppressMessage("ReSharper", "LocalVariableHidesMember")] - protected unsafe IntPtr CreateTensorFromArray(TF_DataType dt, long[] shape, Array data, int element_size) + protected IntPtr CreateTensorFromArray(TF_DataType dt, long[] shape, Array data, int element_size) { if (dt == TF_DataType.TF_STRING && data is byte[] buffer) - { - var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); - var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); - AllocationType = AllocationType.Tensorflow; + return CreateStringTensorFromBytes(buffer, shape); + return CreateTensorFromArray(dt, shape, data, 0, data.Length, element_size); + } - IntPtr tensor = c_api.TF_TensorData(handle); - Marshal.WriteInt64(tensor, 0); + protected unsafe IntPtr CreateStringTensorFromBytes(byte[] buffer, long[] shape) + { + var size = c_api.TF_StringEncodedSize((ulong)buffer.Length); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, 0, size + sizeof(long)); + AllocationType = AllocationType.Tensorflow; - fixed (byte* src = buffer) - c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(long)), size, tf.status.Handle); + IntPtr tensor = c_api.TF_TensorData(handle); + Marshal.WriteInt64(tensor, 0); - tf.status.Check(true); - return handle; - } + fixed (byte* src = buffer) + c_api.TF_StringEncode(src, (ulong)buffer.Length, (sbyte*)(tensor + sizeof(long)), size, tf.status.Handle); - return CreateTensorFromArray(dt, shape, data, 0, data.Length, element_size); + tf.status.Check(true); + return handle; } /// diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs index 579ff566..2a8be7a3 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -159,7 +159,7 @@ namespace Tensorflow switch (dtype) { case TF_DataType.TF_STRING: - return (NDArray)StringData()[0]; + return np.array(StringBytes()[0]); case TF_DataType.TF_INT32: storage = new UnmanagedStorage(NPTypeCode.Int32); break; diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index cf3c00c8..22b042d1 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -170,7 +170,7 @@ namespace Tensorflow /// size_t /// [DllImport(TensorFlowLibName)] - public static extern UIntPtr TF_StringEncodedSize(UIntPtr len); + public static extern ulong TF_StringEncodedSize(ulong len); /// /// Encode the string `src` (`src_len` bytes long) into `dst` in the format @@ -185,7 +185,7 @@ namespace Tensorflow /// TF_Status* /// On success returns the size in bytes of the encoded string. [DllImport(TensorFlowLibName)] - public static extern unsafe ulong TF_StringEncode(byte* src, UIntPtr src_len, sbyte* dst, UIntPtr dst_len, SafeStatusHandle status); + public static extern unsafe ulong TF_StringEncode(byte* src, ulong src_len, sbyte* dst, ulong dst_len, SafeStatusHandle status); [DllImport(TensorFlowLibName)] public static extern unsafe ulong TF_StringEncode(IntPtr src, ulong src_len, IntPtr dst, ulong dst_len, SafeStatusHandle status); diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 6fd3b882..369e0d26 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -130,6 +130,11 @@ namespace Tensorflow } } + if(dtype == TF_DataType.TF_STRING && value is byte[] bytes) + { + return new EagerTensor(bytes, ctx.device_name, TF_DataType.TF_STRING); + } + switch (value) { case EagerTensor val: diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs index cb3ea87a..bc6679a9 100644 --- a/test/TensorFlowNET.UnitTest/ConstantTest.cs +++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs @@ -165,10 +165,10 @@ namespace TensorFlowNET.UnitTest.Basics { string str = "Hello, TensorFlow.NET!"; var handle = Marshal.StringToHGlobalAnsi(str); - ulong dst_len = (ulong)c_api.TF_StringEncodedSize((UIntPtr)str.Length); + var dst_len = c_api.TF_StringEncodedSize((ulong)str.Length); Assert.AreEqual(dst_len, (ulong)23); IntPtr dst = Marshal.AllocHGlobal((int)dst_len); - ulong encoded_len = c_api.TF_StringEncode(handle, (ulong)str.Length, dst, dst_len, status.Handle); + var encoded_len = c_api.TF_StringEncode(handle, (ulong)str.Length, dst, dst_len, status.Handle); Assert.AreEqual((ulong)23, encoded_len); Assert.AreEqual(status.Code, TF_Code.TF_OK); string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); diff --git a/test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs b/test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs index 3049505b..314e57fb 100644 --- a/test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs +++ b/test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs @@ -9,6 +9,14 @@ namespace Tensorflow.UnitTest.TF_API [TestClass] public class StringsApiTest { + [TestMethod] + public void StringFromBytes() + { + var jpg = tf.constant(new byte[] { 0x41, 0xff, 0xd8, 0xff }, tf.@string); + var strings = jpg.ToString(); + Assert.AreEqual(strings, @"tf.Tensor: shape=(), dtype=string, numpy=A\xff\xd8\xff"); + } + [TestMethod] public void StringEqual() {