| @@ -21,6 +21,7 @@ using System.Numerics; | |||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using NumSharp.Backends; | |||||
| using NumSharp.Backends.Unmanaged; | using NumSharp.Backends.Unmanaged; | ||||
| using static Tensorflow.c_api; | using static Tensorflow.c_api; | ||||
| @@ -477,7 +478,7 @@ namespace Tensorflow | |||||
| IntPtr tensor = c_api.TF_TensorData(handle); | IntPtr tensor = c_api.TF_TensorData(handle); | ||||
| Marshal.WriteInt64(tensor, 0); | Marshal.WriteInt64(tensor, 0); | ||||
| fixed (byte* src = &buffer[0]) | |||||
| fixed (byte* src = buffer) | |||||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | ||||
| _handle = handle; | _handle = handle; | ||||
| status.Check(true); | status.Check(true); | ||||
| @@ -486,24 +487,45 @@ namespace Tensorflow | |||||
| public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | ||||
| { | { | ||||
| // todo: handle nd of type "String" here too | // todo: handle nd of type "String" here too | ||||
| if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") | |||||
| if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) | |||||
| { | { | ||||
| var buffer = nd.ToArray<byte>(); | |||||
| var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | |||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | |||||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||||
| Marshal.WriteInt64(tensor, 0); | |||||
| if (nd.Unsafe.Storage.Shape.IsContiguous) | |||||
| { | |||||
| var bytesLength = (UIntPtr)nd.size; | |||||
| var size = c_api.TF_StringEncodedSize(bytesLength); | |||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); | |||||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||||
| Marshal.WriteInt64(tensor, 0); | |||||
| var status = new Status(); | |||||
| c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||||
| status.Check(true); | |||||
| _handle = handle; | |||||
| IsMemoryOwner = false; | |||||
| } | |||||
| else | |||||
| { | |||||
| var buffer = nd.ToArray<byte>(); | |||||
| var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); | |||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); | |||||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||||
| Marshal.WriteInt64(tensor, 0); | |||||
| var status = new Status(); | |||||
| fixed (byte* src = buffer) | |||||
| c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||||
| status.Check(true); | |||||
| _handle = handle; | |||||
| IsMemoryOwner = false; | |||||
| } | |||||
| var status = new Status(); | |||||
| fixed (byte* src = &buffer[0]) | |||||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | |||||
| status.Check(true); | |||||
| _handle=handle; | |||||
| IsMemoryOwner = false; | |||||
| return; | return; | ||||
| } | } | ||||
| _handle = CreateTensorFromNDArray(nd, tensorDType); | _handle = CreateTensorFromNDArray(nd, tensorDType); | ||||
| IsMemoryOwner = true; | IsMemoryOwner = true; | ||||
| } | } | ||||