diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs new file mode 100644 index 00000000..b9b71154 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs @@ -0,0 +1,411 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using NumSharp; +using System; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Linq; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using NumSharp.Backends; +using NumSharp.Backends.Unmanaged; +using NumSharp.Utilities; +using static Tensorflow.c_api; + +#if SERIALIZABLE +using Newtonsoft.Json; +#endif + +namespace Tensorflow +{ + [SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] + public partial class Tensor + { + public T ToScalar() + { + unsafe + { + if (typeof(T).as_dtype() == this.dtype && this.dtype != TF_DataType.TF_STRING) + return Unsafe.Read(this.buffer.ToPointer()); + + switch (this.dtype) + { +#if _REGEN + %foreach supported_numericals_TF_DataType,supported_numericals,supported_numericals_lowercase% + case TF_DataType.#1: + return Converts.ChangeType(*(#3*) this.buffer); + % +#else + + case TF_DataType.TF_UINT8: + return Converts.ChangeType(*(byte*) this.buffer); + case TF_DataType.TF_INT16: + return Converts.ChangeType(*(short*) this.buffer); + case TF_DataType.TF_UINT16: + return Converts.ChangeType(*(ushort*) this.buffer); + case TF_DataType.TF_INT32: + return Converts.ChangeType(*(int*) this.buffer); + case TF_DataType.TF_UINT32: + return Converts.ChangeType(*(uint*) this.buffer); + case TF_DataType.TF_INT64: + return Converts.ChangeType(*(long*) this.buffer); + case TF_DataType.TF_UINT64: + return Converts.ChangeType(*(ulong*) this.buffer); + case TF_DataType.TF_DOUBLE: + return Converts.ChangeType(*(double*) this.buffer); + case TF_DataType.TF_FLOAT: + return Converts.ChangeType(*(float*) this.buffer); +#endif + case TF_DataType.TF_STRING: + if (this.NDims != 0) + throw new ArgumentException($"{nameof(Tensor)} can only be scalar."); + + IntPtr stringStartAddress = IntPtr.Zero; + UIntPtr dstLen = UIntPtr.Zero; + + using (var status = new Status()) + { + c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, status); + status.Check(true); + } + + var dstLenInt = checked((int) dstLen); + var value = Encoding.UTF8.GetString((byte*) stringStartAddress, dstLenInt); + if (typeof(T) == typeof(string)) + return (T) (object) value; + else + return Converts.ChangeType(value); + + case TF_DataType.TF_COMPLEX64: + case TF_DataType.TF_COMPLEX128: + default: + throw new NotSupportedException(); + } + } + } + + public unsafe void CopyTo(NDArray nd) + { + if (!nd.Shape.IsContiguous) + throw new ArgumentException("NDArray has to be contiguous (ndarray.Shape.IsContiguous)."); + +#if _REGEN + #region Compute + switch (nd.typecode) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: + { + CopyTo<#2>(new Span<#2>(nd.Unsafe.Address, nd.size*nd.dtypesize)); + break; + } + % + default: + throw new NotSupportedException(); + } + #endregion +#else + + #region Compute + + switch (nd.typecode) + { + case NPTypeCode.Boolean: + { + CopyTo(new Span(nd.Unsafe.Address, nd.size * nd.dtypesize)); + break; + } + case NPTypeCode.Byte: + { + CopyTo(new Span(nd.Unsafe.Address, nd.size * nd.dtypesize)); + break; + } + case NPTypeCode.Int16: + { + CopyTo(new Span(nd.Unsafe.Address, nd.size * nd.dtypesize)); + break; + } + case NPTypeCode.UInt16: + { + CopyTo(new Span(nd.Unsafe.Address, nd.size * nd.dtypesize)); + break; + } + case NPTypeCode.Int32: + { + CopyTo(new Span(nd.Unsafe.Address, nd.size * nd.dtypesize)); + break; + } + case NPTypeCode.UInt32: + { + CopyTo(new Span(nd.Unsafe.Address, nd.size * nd.dtypesize)); + break; + } + case NPTypeCode.Int64: + { + CopyTo(new Span(nd.Unsafe.Address, nd.size * nd.dtypesize)); + break; + } + case NPTypeCode.UInt64: + { + CopyTo(new Span(nd.Unsafe.Address, nd.size * nd.dtypesize)); + break; + } + case NPTypeCode.Char: + { + CopyTo(new Span(nd.Unsafe.Address, nd.size * nd.dtypesize)); + break; + } + case NPTypeCode.Double: + { + CopyTo(new Span(nd.Unsafe.Address, nd.size * nd.dtypesize)); + break; + } + case NPTypeCode.Single: + { + CopyTo(new Span(nd.Unsafe.Address, nd.size * nd.dtypesize)); + break; + } + default: + throw new NotSupportedException(); + } + + #endregion +#endif + } + + public void CopyTo(Span destination) where T : unmanaged + { + unsafe + { + var len = checked((int) this.size); + //perform regular CopyTo using Span.CopyTo. + if (typeof(T).as_dtype() == this.dtype && this.dtype != TF_DataType.TF_STRING) //T can't be a string but tensor can. + { + var src = (T*) this.buffer; + var srcSpan = new Span(src, len); + srcSpan.CopyTo(destination); + + return; + } + + if (len > destination.Length) + throw new ArgumentException("Destinion was too short to perform CopyTo."); + + //Perform cast to type . + fixed (T* dst = destination) + { + switch (this.dtype) + { +#if _REGEN + %foreach supported_numericals_TF_DataType,supported_numericals,supported_numericals_lowercase% + case TF_DataType.#1: + { + var converter = Converts.FindConverter<#3, T>(); + var src = (#3*) this.buffer; + for (var i = 0; i < len; i++) + *(dst + i) = converter(unchecked(*(src + i))); + return; + } + % +#else + case TF_DataType.TF_BOOL: + { + var converter = Converts.FindConverter(); + var src = (bool*) this.buffer; + for (var i = 0; i < len; i++) + *(dst + i) = converter(unchecked(*(src + i))); + return; + } + case TF_DataType.TF_UINT8: + { + var converter = Converts.FindConverter(); + var src = (byte*) this.buffer; + for (var i = 0; i < len; i++) + *(dst + i) = converter(unchecked(*(src + i))); + return; + } + case TF_DataType.TF_INT16: + { + var converter = Converts.FindConverter(); + var src = (short*) this.buffer; + for (var i = 0; i < len; i++) + *(dst + i) = converter(unchecked(*(src + i))); + return; + } + case TF_DataType.TF_UINT16: + { + var converter = Converts.FindConverter(); + var src = (ushort*) this.buffer; + for (var i = 0; i < len; i++) + *(dst + i) = converter(unchecked(*(src + i))); + return; + } + case TF_DataType.TF_INT32: + { + var converter = Converts.FindConverter(); + var src = (int*) this.buffer; + for (var i = 0; i < len; i++) + *(dst + i) = converter(unchecked(*(src + i))); + return; + } + case TF_DataType.TF_UINT32: + { + var converter = Converts.FindConverter(); + var src = (uint*) this.buffer; + for (var i = 0; i < len; i++) + *(dst + i) = converter(unchecked(*(src + i))); + return; + } + case TF_DataType.TF_INT64: + { + var converter = Converts.FindConverter(); + var src = (long*) this.buffer; + for (var i = 0; i < len; i++) + *(dst + i) = converter(unchecked(*(src + i))); + return; + } + case TF_DataType.TF_UINT64: + { + var converter = Converts.FindConverter(); + var src = (ulong*) this.buffer; + for (var i = 0; i < len; i++) + *(dst + i) = converter(unchecked(*(src + i))); + return; + } + case TF_DataType.TF_DOUBLE: + { + var converter = Converts.FindConverter(); + var src = (double*) this.buffer; + for (var i = 0; i < len; i++) + *(dst + i) = converter(unchecked(*(src + i))); + return; + } + case TF_DataType.TF_FLOAT: + { + var converter = Converts.FindConverter(); + var src = (float*) this.buffer; + for (var i = 0; i < len; i++) + *(dst + i) = converter(unchecked(*(src + i))); + return; + } +#endif + case TF_DataType.TF_STRING: + { + var src = this.StringData(); + var culture = CultureInfo.InvariantCulture; + + //pin to prevent GC from moving the span around. + fixed (T* _ = destination) + switch (typeof(T).as_dtype()) + { +#if _REGEN + %foreach supported_numericals_TF_DataType,supported_numericals,supported_numericals_lowercase% + case TF_DataType.#1: { + var sdst = (#3*)Unsafe.AsPointer(ref destination.GetPinnableReference()); + for (var i = 0; i < len; i++) + *(sdst + i) = ((IConvertible)src[i]).To#2(culture); + return; + } + % +#else + case TF_DataType.TF_BOOL: + { + var sdst = (bool*) Unsafe.AsPointer(ref destination.GetPinnableReference()); + for (var i = 0; i < len; i++) + *(sdst + i) = ((IConvertible) src[i]).ToBoolean(culture); + return; + } + case TF_DataType.TF_UINT8: + { + var sdst = (byte*) Unsafe.AsPointer(ref destination.GetPinnableReference()); + for (var i = 0; i < len; i++) + *(sdst + i) = ((IConvertible) src[i]).ToByte(culture); + return; + } + case TF_DataType.TF_INT16: + { + var sdst = (short*) Unsafe.AsPointer(ref destination.GetPinnableReference()); + for (var i = 0; i < len; i++) + *(sdst + i) = ((IConvertible) src[i]).ToInt16(culture); + return; + } + case TF_DataType.TF_UINT16: + { + var sdst = (ushort*) Unsafe.AsPointer(ref destination.GetPinnableReference()); + for (var i = 0; i < len; i++) + *(sdst + i) = ((IConvertible) src[i]).ToUInt16(culture); + return; + } + case TF_DataType.TF_INT32: + { + var sdst = (int*) Unsafe.AsPointer(ref destination.GetPinnableReference()); + for (var i = 0; i < len; i++) + *(sdst + i) = ((IConvertible) src[i]).ToInt32(culture); + return; + } + case TF_DataType.TF_UINT32: + { + var sdst = (uint*) Unsafe.AsPointer(ref destination.GetPinnableReference()); + for (var i = 0; i < len; i++) + *(sdst + i) = ((IConvertible) src[i]).ToUInt32(culture); + return; + } + case TF_DataType.TF_INT64: + { + var sdst = (long*) Unsafe.AsPointer(ref destination.GetPinnableReference()); + for (var i = 0; i < len; i++) + *(sdst + i) = ((IConvertible) src[i]).ToInt64(culture); + return; + } + case TF_DataType.TF_UINT64: + { + var sdst = (ulong*) Unsafe.AsPointer(ref destination.GetPinnableReference()); + for (var i = 0; i < len; i++) + *(sdst + i) = ((IConvertible) src[i]).ToUInt64(culture); + return; + } + case TF_DataType.TF_DOUBLE: + { + var sdst = (double*) Unsafe.AsPointer(ref destination.GetPinnableReference()); + for (var i = 0; i < len; i++) + *(sdst + i) = ((IConvertible) src[i]).ToDouble(culture); + return; + } + case TF_DataType.TF_FLOAT: + { + var sdst = (float*) Unsafe.AsPointer(ref destination.GetPinnableReference()); + for (var i = 0; i < len; i++) + *(sdst + i) = ((IConvertible) src[i]).ToSingle(culture); + return; + } +#endif + default: + throw new NotSupportedException(); + } + } + case TF_DataType.TF_COMPLEX64: + case TF_DataType.TF_COMPLEX128: + default: + throw new NotSupportedException(); + } + } + } + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 99fba404..efac802d 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -257,7 +257,6 @@ namespace Tensorflow /// /// /// - /// When is string public T[] ToArray() where T : unmanaged { //Are the types matching?