From 0d6b287955630df34c81fe94ab92dcfa81192373 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 8 Aug 2019 07:32:21 -0500 Subject: [PATCH] Allocate tensor without memory copy --- .../Tensors/Tensor.Creation.cs | 144 ++++++++++++++++-- 1 file changed, 135 insertions(+), 9 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index a104f066..896d6d4d 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -506,7 +506,7 @@ namespace Tensorflow IsMemoryOwner = true; } - private unsafe IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null) + private unsafe IntPtr AllocateWithMemoryCopy(NDArray nd, TF_DataType? tensorDType = null) { IntPtr dotHandle = IntPtr.Zero; int buffersize = 0; @@ -520,30 +520,30 @@ namespace Tensorflow var dataType = ToTFDataType(nd.dtype); // shape var dims = nd.shape.Select(x => (long)x).ToArray(); - var nd1 = nd.ravel(); + // var nd1 = nd.ravel(); switch (nd.dtype.Name) { case "Boolean": - var boolVals = Array.ConvertAll(nd1.Data(), x => Convert.ToByte(x)); + var boolVals = Array.ConvertAll(nd.Data(), x => Convert.ToByte(x)); Marshal.Copy(boolVals, 0, dotHandle, nd.size); break; case "Int16": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); break; case "Int32": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); break; case "Int64": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); break; case "Single": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); break; case "Double": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); break; case "Byte": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); break; case "String": return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data(0)), TF_DataType.TF_STRING); @@ -559,6 +559,132 @@ namespace Tensorflow ref _deallocatorArgs); return tfHandle; + } + + private unsafe IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null) + { + IntPtr dotHandle = IntPtr.Zero; + IntPtr tfHandle = IntPtr.Zero; + int buffersize = nd.size * nd.dtypesize; + + var dataType = ToTFDataType(nd.dtype); + // shape + var dims = nd.shape.Select(x => (long)x).ToArray(); + switch (nd.dtype.Name) + { + case "Boolean": + { + var boolVals = Array.ConvertAll(nd.Data(), x => Convert.ToByte(x)); + var array = nd.Data(); + fixed (byte* h = &array[0]) + { + tfHandle = c_api.TF_NewTensor(dataType, + dims, + dims.Length, + new IntPtr(h), + (UIntPtr)buffersize, + _nothingDeallocator, + ref _deallocatorArgs); + } + } + break; + case "Int16": + { + var array = nd.Data(); + fixed (short* h = &array[0]) + { + tfHandle = c_api.TF_NewTensor(dataType, + dims, + dims.Length, + new IntPtr(h), + (UIntPtr)buffersize, + _nothingDeallocator, + ref _deallocatorArgs); + } + } + break; + case "Int32": + { + var array = nd.Data(); + fixed (int* h = &array[0]) + { + tfHandle = c_api.TF_NewTensor(dataType, + dims, + dims.Length, + new IntPtr(h), + (UIntPtr)buffersize, + _nothingDeallocator, + ref _deallocatorArgs); + } + } + break; + case "Int64": + { + var array = nd.Data(); + fixed (long* h = &array[0]) + { + tfHandle = c_api.TF_NewTensor(dataType, + dims, + dims.Length, + new IntPtr(h), + (UIntPtr)buffersize, + _nothingDeallocator, + ref _deallocatorArgs); + } + } + break; + case "Single": + { + var array = nd.Data(); + fixed (float* h = &array[0]) + { + tfHandle = c_api.TF_NewTensor(dataType, + dims, + dims.Length, + new IntPtr(h), + (UIntPtr)buffersize, + _nothingDeallocator, + ref _deallocatorArgs); + } + } + break; + case "Double": + { + var array = nd.Data(); + fixed (double* h = &array[0]) + { + tfHandle = c_api.TF_NewTensor(dataType, + dims, + dims.Length, + new IntPtr(h), + (UIntPtr)buffersize, + _nothingDeallocator, + ref _deallocatorArgs); + } + } + break; + case "Byte": + { + var array = nd.Data(); + fixed (byte* h = &array[0]) + { + tfHandle = c_api.TF_NewTensor(dataType, + dims, + dims.Length, + new IntPtr(h), + (UIntPtr)buffersize, + _nothingDeallocator, + ref _deallocatorArgs); + } + } + break; + case "String": + return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data(0)), TF_DataType.TF_STRING); + default: + throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); + } + + return tfHandle; } public unsafe Tensor(byte[][] buffer, long[] shape)