From 1ac115623cc11ca018e598720783c868c404d593 Mon Sep 17 00:00:00 2001 From: Meinrad Recheis Date: Tue, 13 Aug 2019 21:35:53 +0200 Subject: [PATCH] Tensor: correctly pass unmanaged ptr of NDArray to TF --- .../Tensors/Tensor.Creation.cs | 56 ++++++++----------- 1 file changed, 22 insertions(+), 34 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 4988836d..46c6fa0a 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -21,6 +21,7 @@ using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; +using NumSharp.Backends.Unmanaged; using static Tensorflow.c_api; namespace Tensorflow @@ -484,6 +485,7 @@ namespace Tensorflow public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) { + // todo: handle nd of type "String" here too if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") { var buffer = nd.ToArray(); @@ -502,47 +504,33 @@ namespace Tensorflow IsMemoryOwner = false; return; } - _handle = Allocate(nd, tensorDType: tensorDType); + _handle = CreateTensorFromNDArray(nd, tensorDType); IsMemoryOwner = true; } - private unsafe IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null) + private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) { - IntPtr dotHandle = IntPtr.Zero; - int buffersize = 0; - - if (nd.dtype.Name != "String") + if (nd.dtype.Name == "String") + throw new NotImplementedException("Support for NDArray of type string not implemented yet"); + IArraySlice arraySlice; + var shape = nd.Unsafe.Storage.Shape; + if (shape.IsSliced || shape.IsBroadcasted) { - buffersize = (nd.size * nd.dtypesize); - // dotHandle = Marshal.AllocHGlobal(buffersize); + // the memory is NOT contiguous, so we have to copy the view into a contiguous memory block. + arraySlice = nd.CloneData(); } - - var dataType = ToTFDataType(nd.dtype); - // shape - var dims = nd.shape.Select(x => (long)x).ToArray(); - // var nd1 = nd.ravel(); - /*switch (nd.dtype.Name) + else { - case "Boolean": - var boolVals = Array.ConvertAll(nd1.ToArray(), x => Convert.ToByte(x)); - Marshal.Copy(boolVals, 0, dotHandle, nd.size); - break; - case "String": - throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); - //return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.ToArray(0)), TF_DataType.TF_STRING); - default: - System.Buffer.MemoryCopy(nd1.Unsafe.Address, dotHandle.ToPointer(), nd.size, nd.size); - break; - }*/ - var tfHandle = c_api.TF_NewTensor(dataType, - dims, - dims.Length, - new IntPtr(nd.Unsafe.Address), - (UIntPtr)buffersize, - _hGlobalDeallocator, - ref _deallocatorArgs); - - return tfHandle; + // the memory is contiguous + arraySlice = nd.GetData(); + } + this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it + var ptr = new IntPtr(arraySlice.Address); + int num_bytes = (nd.size * nd.dtypesize); + var dtype = given_dtype ?? ToTFDataType(nd.dtype); + var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); + IsMemoryOwner = false; + return handle; } public unsafe Tensor(byte[][] buffer, long[] shape)