| @@ -21,6 +21,7 @@ using System.Numerics; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using NumSharp.Backends.Unmanaged; | |||
| using static Tensorflow.c_api; | |||
| namespace Tensorflow | |||
| @@ -484,6 +485,7 @@ namespace Tensorflow | |||
| public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | |||
| { | |||
| // todo: handle nd of type "String" here too | |||
| if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") | |||
| { | |||
| var buffer = nd.ToArray<byte>(); | |||
| @@ -502,47 +504,33 @@ namespace Tensorflow | |||
| IsMemoryOwner = false; | |||
| return; | |||
| } | |||
| _handle = Allocate(nd, tensorDType: tensorDType); | |||
| _handle = CreateTensorFromNDArray(nd, tensorDType); | |||
| IsMemoryOwner = true; | |||
| } | |||
| private unsafe IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null) | |||
| private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) | |||
| { | |||
| IntPtr dotHandle = IntPtr.Zero; | |||
| int buffersize = 0; | |||
| if (nd.dtype.Name != "String") | |||
| if (nd.dtype.Name == "String") | |||
| throw new NotImplementedException("Support for NDArray of type string not implemented yet"); | |||
| IArraySlice arraySlice; | |||
| var shape = nd.Unsafe.Storage.Shape; | |||
| if (shape.IsSliced || shape.IsBroadcasted) | |||
| { | |||
| buffersize = (nd.size * nd.dtypesize); | |||
| // dotHandle = Marshal.AllocHGlobal(buffersize); | |||
| // the memory is NOT contiguous, so we have to copy the view into a contiguous memory block. | |||
| arraySlice = nd.CloneData(); | |||
| } | |||
| var dataType = ToTFDataType(nd.dtype); | |||
| // shape | |||
| var dims = nd.shape.Select(x => (long)x).ToArray(); | |||
| // var nd1 = nd.ravel(); | |||
| /*switch (nd.dtype.Name) | |||
| else | |||
| { | |||
| case "Boolean": | |||
| var boolVals = Array.ConvertAll(nd1.ToArray<bool>(), x => Convert.ToByte(x)); | |||
| Marshal.Copy(boolVals, 0, dotHandle, nd.size); | |||
| break; | |||
| case "String": | |||
| throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); | |||
| //return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.ToArray<string>(0)), TF_DataType.TF_STRING); | |||
| default: | |||
| System.Buffer.MemoryCopy(nd1.Unsafe.Address, dotHandle.ToPointer(), nd.size, nd.size); | |||
| break; | |||
| }*/ | |||
| var tfHandle = c_api.TF_NewTensor(dataType, | |||
| dims, | |||
| dims.Length, | |||
| new IntPtr(nd.Unsafe.Address), | |||
| (UIntPtr)buffersize, | |||
| _hGlobalDeallocator, | |||
| ref _deallocatorArgs); | |||
| return tfHandle; | |||
| // the memory is contiguous | |||
| arraySlice = nd.GetData(); | |||
| } | |||
| this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it | |||
| var ptr = new IntPtr(arraySlice.Address); | |||
| int num_bytes = (nd.size * nd.dtypesize); | |||
| var dtype = given_dtype ?? ToTFDataType(nd.dtype); | |||
| var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); | |||
| IsMemoryOwner = false; | |||
| return handle; | |||
| } | |||
| public unsafe Tensor(byte[][] buffer, long[] shape) | |||