Browse Source

Fixed Tensor.Allocate

tags/v0.12
Eli Belash 6 years ago
parent
commit
fa94a89da6
1 changed files with 13 additions and 35 deletions
  1. +13
    -35
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

+ 13
- 35
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -21,6 +21,7 @@ using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using NumSharp.Backends;
using static Tensorflow.c_api;
namespace Tensorflow
@@ -453,22 +454,25 @@ namespace Tensorflow
public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null)
{
if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte")
if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte)
{
var buffer = nd.Data<byte>();
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length);
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Count);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));
IntPtr tensor = c_api.TF_TensorData(handle);
//TODO! handle when nd is a slice, probably by copying.
Marshal.WriteInt64(tensor, 0);
fixed (byte* src = &buffer[0])
c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status);
fixed (byte* src = nd.Unsafe)
c_api.TF_StringEncode(src, (UIntPtr)buffer.Count, (sbyte*)(tensor + sizeof(Int64)), size, status);
status.Check(true);
_handle=handle;
IsMemoryOwner = false;
return;
}
_handle = Allocate(nd, tensorDType: tensorDType);
IsMemoryOwner = true;
}
@@ -487,36 +491,10 @@ namespace Tensorflow
var dataType = ToTFDataType(nd.dtype);
// shape
var dims = nd.shape.Select(x => (long)x).ToArray();
var nd1 = nd.ravel();
switch (nd.dtype.Name)
{
case "Boolean":
var boolVals = Array.ConvertAll(nd1.Data<bool>(), x => Convert.ToByte(x));
Marshal.Copy(boolVals, 0, dotHandle, nd.size);
break;
case "Int16":
Marshal.Copy(nd1.Data<short>(), 0, dotHandle, nd.size);
break;
case "Int32":
Marshal.Copy(nd1.Data<int>(), 0, dotHandle, nd.size);
break;
case "Int64":
Marshal.Copy(nd1.Data<long>(), 0, dotHandle, nd.size);
break;
case "Single":
Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size);
break;
case "Double":
Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size);
break;
case "Byte":
Marshal.Copy(nd1.Data<byte>(), 0, dotHandle, nd.size);
break;
case "String":
return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>()[0]), TF_DataType.TF_STRING);
default:
throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}.");
}
if (nd.typecode == NPTypeCode.String || nd.typecode == NPTypeCode.Char)
; //TODO! handle it properly.
nd.CopyTo(dotHandle);
var tfHandle = c_api.TF_NewTensor(dataType,
dims,
dims.Length,


Loading…
Cancel
Save