Browse Source

Tensor.Creation: Revamp of CreateTensorFromArray to properly handle TF_NewTensor

- Added other missing AllocationType setting in different cases.
tags/v0.12
Eli Belash 6 years ago
parent
commit
be996facb8
1 changed files with 27 additions and 5 deletions
  1. +27
    -5
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

+ 27
- 5
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -440,7 +440,7 @@ namespace Tensorflow
#endif #endif


/// <summary> /// <summary>
/// Create a string Tensor from the given string
/// Create a string Tensor from the given string
/// </summary> /// </summary>
public unsafe Tensor(string str) public unsafe Tensor(string str)
{ {
@@ -448,6 +448,7 @@ namespace Tensorflow
var buffer = Encoding.UTF8.GetBytes(str); var buffer = Encoding.UTF8.GetBytes(str);
var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
AllocationType = AllocationType.Tensorflow;


IntPtr tensor = c_api.TF_TensorData(handle); IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0); Marshal.WriteInt64(tensor, 0);
@@ -459,6 +460,9 @@ namespace Tensorflow


public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null)
{ {
if (tensorDType == null)
tensorDType = nd.dtype.as_dtype();

// todo: handle nd of type "String" here too // todo: handle nd of type "String" here too
if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte)
{ {
@@ -467,6 +471,7 @@ namespace Tensorflow
var bytesLength = (UIntPtr) nd.size; var bytesLength = (UIntPtr) nd.size;
var size = c_api.TF_StringEncodedSize(bytesLength); var size = c_api.TF_StringEncodedSize(bytesLength);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
AllocationType = AllocationType.Tensorflow;


IntPtr tensor = c_api.TF_TensorData(handle); IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0); Marshal.WriteInt64(tensor, 0);
@@ -481,6 +486,7 @@ namespace Tensorflow
var buffer = nd.ToArray<byte>(); var buffer = nd.ToArray<byte>();
var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
AllocationType = AllocationType.Tensorflow;


IntPtr tensor = c_api.TF_TensorData(handle); IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0); Marshal.WriteInt64(tensor, 0);
@@ -535,6 +541,7 @@ namespace Tensorflow
int totalSize = size + buffer.Length * 8; int totalSize = size + buffer.Length * 8;
ulong offset = 0; ulong offset = 0;
IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr) totalSize); IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr) totalSize);
AllocationType = AllocationType.Tensorflow;


// Clear offset table // Clear offset table
IntPtr pOffset = TF_TensorData(handle); IntPtr pOffset = TF_TensorData(handle);
@@ -626,12 +633,27 @@ namespace Tensorflow


// get a handle to the pinned array which we will pass on to the tensor computation engine to use // get a handle to the pinned array which we will pass on to the tensor computation engine to use
var gcHandle = GCHandle.Alloc(data, GCHandleType.Pinned); var gcHandle = GCHandle.Alloc(data, GCHandleType.Pinned);
AllocationType = AllocationType.GCHandle;
AllocationHandle = gcHandle;
var pinnedAddr = gcHandle.AddrOfPinnedObject();


//call NewTensor
IntPtr handle;
if (shape == null || shape.Length == 0) if (shape == null || shape.Length == 0)
return TF_NewTensor(dt, new long[0], 0, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr) (count * element_size));
return TF_NewTensor(dt, shape, shape.Length, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr) (count * element_size));
handle = TF_NewTensor(dt, new long[0], 0, pinnedAddr + start * element_size, (UIntPtr) (count * element_size));
else
handle = TF_NewTensor(dt, shape, shape.Length, pinnedAddr + start * element_size, (UIntPtr) (count * element_size));

//Figure if TF decided to clone or not.
if (c_api.TF_TensorData(handle) == pinnedAddr)
{
AllocationType = AllocationType.GCHandle;
AllocationHandle = gcHandle;
} else
{
AllocationType = AllocationType.Tensorflow;
gcHandle.Free();
}

return handle;
} }
} }
} }

Loading…
Cancel
Save