| @@ -21,6 +21,7 @@ using System.Numerics; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using NumSharp.Backends; | |||
| using static Tensorflow.c_api; | |||
| namespace Tensorflow | |||
| @@ -453,22 +454,25 @@ namespace Tensorflow | |||
| public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | |||
| { | |||
| if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") | |||
| if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) | |||
| { | |||
| var buffer = nd.Data<byte>(); | |||
| var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | |||
| var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Count); | |||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | |||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||
| //TODO! handle when nd is a slice, probably by copying. | |||
| Marshal.WriteInt64(tensor, 0); | |||
| fixed (byte* src = &buffer[0]) | |||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | |||
| fixed (byte* src = nd.Unsafe) | |||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Count, (sbyte*)(tensor + sizeof(Int64)), size, status); | |||
| status.Check(true); | |||
| _handle=handle; | |||
| IsMemoryOwner = false; | |||
| return; | |||
| } | |||
| _handle = Allocate(nd, tensorDType: tensorDType); | |||
| IsMemoryOwner = true; | |||
| } | |||
| @@ -487,36 +491,10 @@ namespace Tensorflow | |||
| var dataType = ToTFDataType(nd.dtype); | |||
| // shape | |||
| var dims = nd.shape.Select(x => (long)x).ToArray(); | |||
| var nd1 = nd.ravel(); | |||
| switch (nd.dtype.Name) | |||
| { | |||
| case "Boolean": | |||
| var boolVals = Array.ConvertAll(nd1.Data<bool>(), x => Convert.ToByte(x)); | |||
| Marshal.Copy(boolVals, 0, dotHandle, nd.size); | |||
| break; | |||
| case "Int16": | |||
| Marshal.Copy(nd1.Data<short>(), 0, dotHandle, nd.size); | |||
| break; | |||
| case "Int32": | |||
| Marshal.Copy(nd1.Data<int>(), 0, dotHandle, nd.size); | |||
| break; | |||
| case "Int64": | |||
| Marshal.Copy(nd1.Data<long>(), 0, dotHandle, nd.size); | |||
| break; | |||
| case "Single": | |||
| Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); | |||
| break; | |||
| case "Double": | |||
| Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size); | |||
| break; | |||
| case "Byte": | |||
| Marshal.Copy(nd1.Data<byte>(), 0, dotHandle, nd.size); | |||
| break; | |||
| case "String": | |||
| return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>()[0]), TF_DataType.TF_STRING); | |||
| default: | |||
| throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); | |||
| } | |||
| if (nd.typecode == NPTypeCode.String || nd.typecode == NPTypeCode.Char) | |||
| ; //TODO! handle it properly. | |||
| nd.CopyTo(dotHandle); | |||
| var tfHandle = c_api.TF_NewTensor(dataType, | |||
| dims, | |||
| dims.Length, | |||