diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 42e76656..b1c0be1d 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -440,7 +440,7 @@ namespace Tensorflow #endif /// - /// Create a string Tensor from the given string + /// Create a string Tensor from the given string /// public unsafe Tensor(string str) { @@ -448,6 +448,7 @@ namespace Tensorflow var buffer = Encoding.UTF8.GetBytes(str); var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); @@ -459,6 +460,9 @@ namespace Tensorflow public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) { + if (tensorDType == null) + tensorDType = nd.dtype.as_dtype(); + // todo: handle nd of type "String" here too if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) { @@ -467,6 +471,7 @@ namespace Tensorflow var bytesLength = (UIntPtr) nd.size; var size = c_api.TF_StringEncodedSize(bytesLength); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); @@ -481,6 +486,7 @@ namespace Tensorflow var buffer = nd.ToArray(); var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); @@ -535,6 +541,7 @@ namespace Tensorflow int totalSize = size + buffer.Length * 8; ulong offset = 0; IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr) totalSize); + AllocationType = AllocationType.Tensorflow; // Clear offset table IntPtr pOffset = TF_TensorData(handle); @@ -626,12 +633,27 @@ namespace Tensorflow // get a handle to the pinned array which we will pass on to the tensor computation engine to use var gcHandle = GCHandle.Alloc(data, GCHandleType.Pinned); - AllocationType = AllocationType.GCHandle; - AllocationHandle = gcHandle; + var pinnedAddr = gcHandle.AddrOfPinnedObject(); + //call NewTensor + IntPtr handle; if (shape == null || shape.Length == 0) - return TF_NewTensor(dt, new long[0], 0, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr) (count * element_size)); - return TF_NewTensor(dt, shape, shape.Length, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr) (count * element_size)); + handle = TF_NewTensor(dt, new long[0], 0, pinnedAddr + start * element_size, (UIntPtr) (count * element_size)); + else + handle = TF_NewTensor(dt, shape, shape.Length, pinnedAddr + start * element_size, (UIntPtr) (count * element_size)); + + //Figure if TF decided to clone or not. + if (c_api.TF_TensorData(handle) == pinnedAddr) + { + AllocationType = AllocationType.GCHandle; + AllocationHandle = gcHandle; + } else + { + AllocationType = AllocationType.Tensorflow; + gcHandle.Free(); + } + + return handle; } } } \ No newline at end of file