Browse Source

Allocate tensor without memory copy

tags/v0.12
Oceania2018 6 years ago
parent
commit
0d6b287955
1 changed files with 135 additions and 9 deletions
  1. +135
    -9
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

+ 135
- 9
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -506,7 +506,7 @@ namespace Tensorflow
IsMemoryOwner = true;
}

private unsafe IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null)
private unsafe IntPtr AllocateWithMemoryCopy(NDArray nd, TF_DataType? tensorDType = null)
{
IntPtr dotHandle = IntPtr.Zero;
int buffersize = 0;
@@ -520,30 +520,30 @@ namespace Tensorflow
var dataType = ToTFDataType(nd.dtype);
// shape
var dims = nd.shape.Select(x => (long)x).ToArray();
var nd1 = nd.ravel();
// var nd1 = nd.ravel();
switch (nd.dtype.Name)
{
case "Boolean":
var boolVals = Array.ConvertAll(nd1.Data<bool>(), x => Convert.ToByte(x));
var boolVals = Array.ConvertAll(nd.Data<bool>(), x => Convert.ToByte(x));
Marshal.Copy(boolVals, 0, dotHandle, nd.size);
break;
case "Int16":
Marshal.Copy(nd1.Data<short>(), 0, dotHandle, nd.size);
Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size);
break;
case "Int32":
Marshal.Copy(nd1.Data<int>(), 0, dotHandle, nd.size);
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size);
break;
case "Int64":
Marshal.Copy(nd1.Data<long>(), 0, dotHandle, nd.size);
Marshal.Copy(nd.Data<long>(), 0, dotHandle, nd.size);
break;
case "Single":
Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size);
Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size);
break;
case "Double":
Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size);
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size);
break;
case "Byte":
Marshal.Copy(nd1.Data<byte>(), 0, dotHandle, nd.size);
Marshal.Copy(nd.Data<byte>(), 0, dotHandle, nd.size);
break;
case "String":
return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>(0)), TF_DataType.TF_STRING);
@@ -559,6 +559,132 @@ namespace Tensorflow
ref _deallocatorArgs);

return tfHandle;
}
private unsafe IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null)
{
IntPtr dotHandle = IntPtr.Zero;
IntPtr tfHandle = IntPtr.Zero;
int buffersize = nd.size * nd.dtypesize;
var dataType = ToTFDataType(nd.dtype);
// shape
var dims = nd.shape.Select(x => (long)x).ToArray();
switch (nd.dtype.Name)
{
case "Boolean":
{
var boolVals = Array.ConvertAll(nd.Data<bool>(), x => Convert.ToByte(x));
var array = nd.Data<byte>();
fixed (byte* h = &array[0])
{
tfHandle = c_api.TF_NewTensor(dataType,
dims,
dims.Length,
new IntPtr(h),
(UIntPtr)buffersize,
_nothingDeallocator,
ref _deallocatorArgs);
}
}
break;
case "Int16":
{
var array = nd.Data<short>();
fixed (short* h = &array[0])
{
tfHandle = c_api.TF_NewTensor(dataType,
dims,
dims.Length,
new IntPtr(h),
(UIntPtr)buffersize,
_nothingDeallocator,
ref _deallocatorArgs);
}
}
break;
case "Int32":
{
var array = nd.Data<int>();
fixed (int* h = &array[0])
{
tfHandle = c_api.TF_NewTensor(dataType,
dims,
dims.Length,
new IntPtr(h),
(UIntPtr)buffersize,
_nothingDeallocator,
ref _deallocatorArgs);
}
}
break;
case "Int64":
{
var array = nd.Data<long>();
fixed (long* h = &array[0])
{
tfHandle = c_api.TF_NewTensor(dataType,
dims,
dims.Length,
new IntPtr(h),
(UIntPtr)buffersize,
_nothingDeallocator,
ref _deallocatorArgs);
}
}
break;
case "Single":
{
var array = nd.Data<float>();
fixed (float* h = &array[0])
{
tfHandle = c_api.TF_NewTensor(dataType,
dims,
dims.Length,
new IntPtr(h),
(UIntPtr)buffersize,
_nothingDeallocator,
ref _deallocatorArgs);
}
}
break;
case "Double":
{
var array = nd.Data<double>();
fixed (double* h = &array[0])
{
tfHandle = c_api.TF_NewTensor(dataType,
dims,
dims.Length,
new IntPtr(h),
(UIntPtr)buffersize,
_nothingDeallocator,
ref _deallocatorArgs);
}
}
break;
case "Byte":
{
var array = nd.Data<byte>();
fixed (byte* h = &array[0])
{
tfHandle = c_api.TF_NewTensor(dataType,
dims,
dims.Length,
new IntPtr(h),
(UIntPtr)buffersize,
_nothingDeallocator,
ref _deallocatorArgs);
}
}
break;
case "String":
return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>(0)), TF_DataType.TF_STRING);
default:
throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}.");
}
return tfHandle;
}

public unsafe Tensor(byte[][] buffer, long[] shape)


Loading…
Cancel
Save