From fa94a89da6a2e5da57ddd7a998b6622806fdf143 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Mon, 29 Jul 2019 14:55:06 +0300 Subject: [PATCH] Fixed Tensor.Allocate --- .../Tensors/Tensor.Creation.cs | 48 +++++-------------- 1 file changed, 13 insertions(+), 35 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index c93b2296..0fc32f5a 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -21,6 +21,7 @@ using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; +using NumSharp.Backends; using static Tensorflow.c_api; namespace Tensorflow @@ -453,22 +454,25 @@ namespace Tensorflow public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) { - if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") + if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) { var buffer = nd.Data(); - var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); + var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Count); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); - + IntPtr tensor = c_api.TF_TensorData(handle); + + //TODO! handle when nd is a slice, probably by copying. Marshal.WriteInt64(tensor, 0); - fixed (byte* src = &buffer[0]) - c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); + fixed (byte* src = nd.Unsafe) + c_api.TF_StringEncode(src, (UIntPtr)buffer.Count, (sbyte*)(tensor + sizeof(Int64)), size, status); status.Check(true); _handle=handle; IsMemoryOwner = false; return; } + _handle = Allocate(nd, tensorDType: tensorDType); IsMemoryOwner = true; } @@ -487,36 +491,10 @@ namespace Tensorflow var dataType = ToTFDataType(nd.dtype); // shape var dims = nd.shape.Select(x => (long)x).ToArray(); - var nd1 = nd.ravel(); - switch (nd.dtype.Name) - { - case "Boolean": - var boolVals = Array.ConvertAll(nd1.Data(), x => Convert.ToByte(x)); - Marshal.Copy(boolVals, 0, dotHandle, nd.size); - break; - case "Int16": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); - break; - case "Int32": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); - break; - case "Int64": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); - break; - case "Single": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); - break; - case "Double": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); - break; - case "Byte": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); - break; - case "String": - return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data()[0]), TF_DataType.TF_STRING); - default: - throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); - } + if (nd.typecode == NPTypeCode.String || nd.typecode == NPTypeCode.Char) + ; //TODO! handle it properly. + nd.CopyTo(dotHandle); + var tfHandle = c_api.TF_NewTensor(dataType, dims, dims.Length,