| @@ -21,6 +21,7 @@ using System.Numerics; | |||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using NumSharp.Backends; | |||||
| using static Tensorflow.c_api; | using static Tensorflow.c_api; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -453,22 +454,25 @@ namespace Tensorflow | |||||
| public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | ||||
| { | { | ||||
| if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") | |||||
| if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) | |||||
| { | { | ||||
| var buffer = nd.Data<byte>(); | var buffer = nd.Data<byte>(); | ||||
| var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | |||||
| var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Count); | |||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | ||||
| IntPtr tensor = c_api.TF_TensorData(handle); | IntPtr tensor = c_api.TF_TensorData(handle); | ||||
| //TODO! handle when nd is a slice, probably by copying. | |||||
| Marshal.WriteInt64(tensor, 0); | Marshal.WriteInt64(tensor, 0); | ||||
| fixed (byte* src = &buffer[0]) | |||||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | |||||
| fixed (byte* src = nd.Unsafe) | |||||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Count, (sbyte*)(tensor + sizeof(Int64)), size, status); | |||||
| status.Check(true); | status.Check(true); | ||||
| _handle=handle; | _handle=handle; | ||||
| IsMemoryOwner = false; | IsMemoryOwner = false; | ||||
| return; | return; | ||||
| } | } | ||||
| _handle = Allocate(nd, tensorDType: tensorDType); | _handle = Allocate(nd, tensorDType: tensorDType); | ||||
| IsMemoryOwner = true; | IsMemoryOwner = true; | ||||
| } | } | ||||
| @@ -487,36 +491,10 @@ namespace Tensorflow | |||||
| var dataType = ToTFDataType(nd.dtype); | var dataType = ToTFDataType(nd.dtype); | ||||
| // shape | // shape | ||||
| var dims = nd.shape.Select(x => (long)x).ToArray(); | var dims = nd.shape.Select(x => (long)x).ToArray(); | ||||
| var nd1 = nd.ravel(); | |||||
| switch (nd.dtype.Name) | |||||
| { | |||||
| case "Boolean": | |||||
| var boolVals = Array.ConvertAll(nd1.Data<bool>(), x => Convert.ToByte(x)); | |||||
| Marshal.Copy(boolVals, 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Int16": | |||||
| Marshal.Copy(nd1.Data<short>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Int32": | |||||
| Marshal.Copy(nd1.Data<int>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Int64": | |||||
| Marshal.Copy(nd1.Data<long>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Single": | |||||
| Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Double": | |||||
| Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Byte": | |||||
| Marshal.Copy(nd1.Data<byte>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "String": | |||||
| return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>()[0]), TF_DataType.TF_STRING); | |||||
| default: | |||||
| throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); | |||||
| } | |||||
| if (nd.typecode == NPTypeCode.String || nd.typecode == NPTypeCode.Char) | |||||
| ; //TODO! handle it properly. | |||||
| nd.CopyTo(dotHandle); | |||||
| var tfHandle = c_api.TF_NewTensor(dataType, | var tfHandle = c_api.TF_NewTensor(dataType, | ||||
| dims, | dims, | ||||
| dims.Length, | dims.Length, | ||||