Browse Source

Improve string display from byte array.

pull/579/head
Esther2013 5 years ago
parent
commit
caf9e9264a
8 changed files with 52 additions and 29 deletions
  1. +7
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  2. +3
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.cs
  3. +24
    -22
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
  5. +2
    -2
      src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
  6. +5
    -0
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  7. +2
    -2
      test/TensorFlowNET.UnitTest/ConstantTest.cs
  8. +8
    -0
      test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs

+ 7
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -25,7 +25,13 @@ namespace Tensorflow.Eager
EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status.Handle); EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status.Handle);
Resolve(); Resolve();
} }

public EagerTensor(byte[] value, string device_name, TF_DataType dtype) : base(value, dType: dtype)
{
EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status.Handle);
Resolve();
}

public EagerTensor(NDArray value, string device_name) : base(value) public EagerTensor(NDArray value, string device_name) : base(value)
{ {
EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status.Handle); EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status.Handle);


+ 3
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.cs View File

@@ -1,6 +1,7 @@
using NumSharp; using NumSharp;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -49,7 +50,8 @@ namespace Tensorflow.Eager
switch (dtype) switch (dtype)
{ {
case TF_DataType.TF_STRING: case TF_DataType.TF_STRING:
return $"b'{(string)nd}'";
return string.Join(string.Empty, nd.ToArray<byte>()
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString()));
case TF_DataType.TF_BOOL: case TF_DataType.TF_BOOL:
return (nd.GetByte(0) > 0).ToString(); return (nd.GetByte(0) > 0).ToString();
case TF_DataType.TF_RESOURCE: case TF_DataType.TF_RESOURCE:


+ 24
- 22
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -452,14 +452,14 @@ namespace Tensorflow
public unsafe Tensor(string str) public unsafe Tensor(string str)
{ {
var buffer = Encoding.UTF8.GetBytes(str); var buffer = Encoding.UTF8.GetBytes(str);
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + sizeof(ulong)));
var size = c_api.TF_StringEncodedSize((ulong)buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)(size + sizeof(ulong)));
AllocationType = AllocationType.Tensorflow; AllocationType = AllocationType.Tensorflow;


IntPtr tensor = c_api.TF_TensorData(handle); IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0); Marshal.WriteInt64(tensor, 0);
fixed (byte* src = buffer) fixed (byte* src = buffer)
c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(long)), size, tf.status.Handle);
c_api.TF_StringEncode(src, (ulong)buffer.Length, (sbyte*)(tensor + sizeof(long)), size, tf.status.Handle);
_handle = handle; _handle = handle;
tf.status.Check(true); tf.status.Check(true);
} }
@@ -474,7 +474,7 @@ namespace Tensorflow
{ {
if (nd.Unsafe.Storage.Shape.IsContiguous) if (nd.Unsafe.Storage.Shape.IsContiguous)
{ {
var bytesLength = (UIntPtr) nd.size;
var bytesLength = (ulong)nd.size;
var size = c_api.TF_StringEncodedSize(bytesLength); var size = c_api.TF_StringEncodedSize(bytesLength);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
AllocationType = AllocationType.Tensorflow; AllocationType = AllocationType.Tensorflow;
@@ -482,13 +482,13 @@ namespace Tensorflow
IntPtr tensor = c_api.TF_TensorData(handle); IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0); Marshal.WriteInt64(tensor, 0);


c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, tf.status.Handle);
c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(long)), size, tf.status.Handle);
tf.status.Check(true); tf.status.Check(true);
_handle = handle; _handle = handle;
} else } else
{ {
var buffer = nd.ToArray<byte>(); var buffer = nd.ToArray<byte>();
var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length);
var size = c_api.TF_StringEncodedSize((ulong)buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
AllocationType = AllocationType.Tensorflow; AllocationType = AllocationType.Tensorflow;


@@ -496,7 +496,7 @@ namespace Tensorflow
Marshal.WriteInt64(tensor, 0); Marshal.WriteInt64(tensor, 0);


fixed (byte* src = buffer) fixed (byte* src = buffer)
c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, tf.status.Handle);
c_api.TF_StringEncode(src, (ulong)buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, tf.status.Handle);


tf.status.Check(true); tf.status.Check(true);
_handle = handle; _handle = handle;
@@ -538,7 +538,7 @@ namespace Tensorflow
int size = 0; int size = 0;
foreach (var b in buffer) foreach (var b in buffer)
{ {
size += (int) TF_StringEncodedSize((UIntPtr) b.Length);
size += (int)TF_StringEncodedSize((ulong)b.Length);
} }


int totalSize = size + buffer.Length * 8; int totalSize = size + buffer.Length * 8;
@@ -557,7 +557,7 @@ namespace Tensorflow
{ {
fixed (byte* src = &buffer[i][0]) fixed (byte* src = &buffer[i][0])
{ {
var written = TF_StringEncode(src, (UIntPtr) buffer[i].Length, (sbyte*) dst, (UIntPtr) (dstLimit.ToInt64() - dst.ToInt64()), status.Handle);
var written = TF_StringEncode(src, (ulong)buffer[i].Length, (sbyte*)dst, (ulong)(dstLimit.ToInt64() - dst.ToInt64()), status.Handle);
status.Check(true); status.Check(true);
pOffset += 8; pOffset += 8;
dst += (int) written; dst += (int) written;
@@ -592,25 +592,27 @@ namespace Tensorflow
/// </remarks> /// </remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
[SuppressMessage("ReSharper", "LocalVariableHidesMember")] [SuppressMessage("ReSharper", "LocalVariableHidesMember")]
protected unsafe IntPtr CreateTensorFromArray(TF_DataType dt, long[] shape, Array data, int element_size)
protected IntPtr CreateTensorFromArray(TF_DataType dt, long[] shape, Array data, int element_size)
{ {
if (dt == TF_DataType.TF_STRING && data is byte[] buffer) if (dt == TF_DataType.TF_STRING && data is byte[] buffer)
{
var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
AllocationType = AllocationType.Tensorflow;
return CreateStringTensorFromBytes(buffer, shape);
return CreateTensorFromArray(dt, shape, data, 0, data.Length, element_size);
}


IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0);
protected unsafe IntPtr CreateStringTensorFromBytes(byte[] buffer, long[] shape)
{
var size = c_api.TF_StringEncodedSize((ulong)buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, 0, size + sizeof(long));
AllocationType = AllocationType.Tensorflow;


fixed (byte* src = buffer)
c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(long)), size, tf.status.Handle);
IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0);


tf.status.Check(true);
return handle;
}
fixed (byte* src = buffer)
c_api.TF_StringEncode(src, (ulong)buffer.Length, (sbyte*)(tensor + sizeof(long)), size, tf.status.Handle);


return CreateTensorFromArray(dt, shape, data, 0, data.Length, element_size);
tf.status.Check(true);
return handle;
} }


/// <summary> /// <summary>


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Value.cs View File

@@ -159,7 +159,7 @@ namespace Tensorflow
switch (dtype) switch (dtype)
{ {
case TF_DataType.TF_STRING: case TF_DataType.TF_STRING:
return (NDArray)StringData()[0];
return np.array(StringBytes()[0]);
case TF_DataType.TF_INT32: case TF_DataType.TF_INT32:
storage = new UnmanagedStorage(NPTypeCode.Int32); storage = new UnmanagedStorage(NPTypeCode.Int32);
break; break;


+ 2
- 2
src/TensorFlowNET.Core/Tensors/c_api.tensor.cs View File

@@ -170,7 +170,7 @@ namespace Tensorflow
/// <param name="len">size_t</param> /// <param name="len">size_t</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern UIntPtr TF_StringEncodedSize(UIntPtr len);
public static extern ulong TF_StringEncodedSize(ulong len);


/// <summary> /// <summary>
/// Encode the string `src` (`src_len` bytes long) into `dst` in the format /// Encode the string `src` (`src_len` bytes long) into `dst` in the format
@@ -185,7 +185,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns>On success returns the size in bytes of the encoded string.</returns> /// <returns>On success returns the size in bytes of the encoded string.</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe ulong TF_StringEncode(byte* src, UIntPtr src_len, sbyte* dst, UIntPtr dst_len, SafeStatusHandle status);
public static extern unsafe ulong TF_StringEncode(byte* src, ulong src_len, sbyte* dst, ulong dst_len, SafeStatusHandle status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe ulong TF_StringEncode(IntPtr src, ulong src_len, IntPtr dst, ulong dst_len, SafeStatusHandle status); public static extern unsafe ulong TF_StringEncode(IntPtr src, ulong src_len, IntPtr dst, ulong dst_len, SafeStatusHandle status);


+ 5
- 0
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -130,6 +130,11 @@ namespace Tensorflow
} }
} }


if(dtype == TF_DataType.TF_STRING && value is byte[] bytes)
{
return new EagerTensor(bytes, ctx.device_name, TF_DataType.TF_STRING);
}

switch (value) switch (value)
{ {
case EagerTensor val: case EagerTensor val:


+ 2
- 2
test/TensorFlowNET.UnitTest/ConstantTest.cs View File

@@ -165,10 +165,10 @@ namespace TensorFlowNET.UnitTest.Basics
{ {
string str = "Hello, TensorFlow.NET!"; string str = "Hello, TensorFlow.NET!";
var handle = Marshal.StringToHGlobalAnsi(str); var handle = Marshal.StringToHGlobalAnsi(str);
ulong dst_len = (ulong)c_api.TF_StringEncodedSize((UIntPtr)str.Length);
var dst_len = c_api.TF_StringEncodedSize((ulong)str.Length);
Assert.AreEqual(dst_len, (ulong)23); Assert.AreEqual(dst_len, (ulong)23);
IntPtr dst = Marshal.AllocHGlobal((int)dst_len); IntPtr dst = Marshal.AllocHGlobal((int)dst_len);
ulong encoded_len = c_api.TF_StringEncode(handle, (ulong)str.Length, dst, dst_len, status.Handle);
var encoded_len = c_api.TF_StringEncode(handle, (ulong)str.Length, dst, dst_len, status.Handle);
Assert.AreEqual((ulong)23, encoded_len); Assert.AreEqual((ulong)23, encoded_len);
Assert.AreEqual(status.Code, TF_Code.TF_OK); Assert.AreEqual(status.Code, TF_Code.TF_OK);
string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte));


+ 8
- 0
test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs View File

@@ -9,6 +9,14 @@ namespace Tensorflow.UnitTest.TF_API
[TestClass] [TestClass]
public class StringsApiTest public class StringsApiTest
{ {
[TestMethod]
public void StringFromBytes()
{
var jpg = tf.constant(new byte[] { 0x41, 0xff, 0xd8, 0xff }, tf.@string);
var strings = jpg.ToString();
Assert.AreEqual(strings, @"tf.Tensor: shape=(), dtype=string, numpy=A\xff\xd8\xff");
}

[TestMethod] [TestMethod]
public void StringEqual() public void StringEqual()
{ {


Loading…
Cancel
Save