Browse Source

copy data from pointer.

tags/v0.12
Oceania2018 6 years ago
parent
commit
af9f178060
1 changed files with 8 additions and 24 deletions
  1. +8
    -24
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

+ 8
- 24
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -520,41 +520,24 @@ namespace Tensorflow
var dataType = ToTFDataType(nd.dtype); var dataType = ToTFDataType(nd.dtype);
// shape // shape
var dims = nd.shape.Select(x => (long)x).ToArray(); var dims = nd.shape.Select(x => (long)x).ToArray();
var nd1 = nd.ravel();
switch (nd.dtype.Name)
// var nd1 = nd.ravel();
/*switch (nd.dtype.Name)
{ {
case "Boolean": case "Boolean":
var boolVals = Array.ConvertAll(nd1.ToArray<bool>(), x => Convert.ToByte(x)); var boolVals = Array.ConvertAll(nd1.ToArray<bool>(), x => Convert.ToByte(x));
Marshal.Copy(boolVals, 0, dotHandle, nd.size); Marshal.Copy(boolVals, 0, dotHandle, nd.size);
break; break;
case "Int16":
Marshal.Copy(nd1.ToArray<short>(), 0, dotHandle, nd.size);
break;
case "Int32":
Marshal.Copy(nd1.ToArray<int>(), 0, dotHandle, nd.size);
break;
case "Int64":
Marshal.Copy(nd1.ToArray<long>(), 0, dotHandle, nd.size);
break;
case "Single":
Marshal.Copy(nd1.ToArray<float>(), 0, dotHandle, nd.size);
break;
case "Double":
Marshal.Copy(nd1.ToArray<double>(), 0, dotHandle, nd.size);
break;
case "Byte":
Marshal.Copy(nd1.ToArray<byte>(), 0, dotHandle, nd.size);
break;
case "String": case "String":
throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}.");
//return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.ToArray<string>(0)), TF_DataType.TF_STRING); //return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.ToArray<string>(0)), TF_DataType.TF_STRING);
default: default:
throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}.");
}
System.Buffer.MemoryCopy(nd1.Unsafe.Address, dotHandle.ToPointer(), nd.size, nd.size);
break;
}*/
var tfHandle = c_api.TF_NewTensor(dataType, var tfHandle = c_api.TF_NewTensor(dataType,
dims, dims,
dims.Length, dims.Length,
dotHandle,
new IntPtr(nd.Unsafe.Address),
(UIntPtr)buffersize, (UIntPtr)buffersize,
_hGlobalDeallocator, _hGlobalDeallocator,
ref _deallocatorArgs); ref _deallocatorArgs);
@@ -673,7 +656,8 @@ namespace Tensorflow
{ {
if (args.deallocator_called) if (args.deallocator_called)
return; return;
Marshal.FreeHGlobal(dataPtr);
// NumSharp will dispose
// Marshal.FreeHGlobal(dataPtr);
args.deallocator_called = true; args.deallocator_called = true;
} }




Loading…
Cancel
Save