diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 798c27b6..bccc2569 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -233,6 +233,14 @@ namespace Tensorflow //Are the types matching? if (typeof(T).as_dtype() == _dtype) { + if (NDims == 0 && size == 1) //is it a scalar? + { + unsafe + { + return new T[] {*(T*) buffer}; + } + } + //types match, no need to perform cast var ret = new T[size]; unsafe @@ -260,6 +268,46 @@ namespace Tensorflow { //types do not match, need to perform cast + if (NDims == 0 && size == 1) //is it a scalar? + { + unsafe + { +#if _REGEN + #region Compute + switch (_dtype.as_numpy_dtype().GetTypeCode()) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: return new T[] {Converts.ChangeType(*(#2*) buffer, NPTypeCode.#1)}; + % + case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this, NPTypeCode.String)}; + default: + throw new NotSupportedException(); + } + #endregion +#else + #region Compute + switch (_dtype.as_numpy_dtype().GetTypeCode()) + { + case NPTypeCode.Boolean: return new T[] {Converts.ChangeType(*(bool*) buffer, NPTypeCode.Boolean)}; + case NPTypeCode.Byte: return new T[] {Converts.ChangeType(*(byte*) buffer, NPTypeCode.Byte)}; + case NPTypeCode.Int16: return new T[] {Converts.ChangeType(*(short*) buffer, NPTypeCode.Int16)}; + case NPTypeCode.UInt16: return new T[] {Converts.ChangeType(*(ushort*) buffer, NPTypeCode.UInt16)}; + case NPTypeCode.Int32: return new T[] {Converts.ChangeType(*(int*) buffer, NPTypeCode.Int32)}; + case NPTypeCode.UInt32: return new T[] {Converts.ChangeType(*(uint*) buffer, NPTypeCode.UInt32)}; + case NPTypeCode.Int64: return new T[] {Converts.ChangeType(*(long*) buffer, NPTypeCode.Int64)}; + case NPTypeCode.UInt64: return new T[] {Converts.ChangeType(*(ulong*) buffer, NPTypeCode.UInt64)}; + case NPTypeCode.Char: return new T[] {Converts.ChangeType(*(char*) buffer, NPTypeCode.Char)}; + case NPTypeCode.Double: return new T[] {Converts.ChangeType(*(double*) buffer, NPTypeCode.Double)}; + case NPTypeCode.Single: return new T[] {Converts.ChangeType(*(float*) buffer, NPTypeCode.Single)}; + case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this, NPTypeCode.String)}; + default: + throw new NotSupportedException(); + } + #endregion +#endif + } + } + var ret = new T[size]; unsafe { @@ -270,10 +318,10 @@ namespace Tensorflow #if _REGEN #region Compute - switch (_dtype.as_numpy_datatype().GetTypeCode()) + switch (_dtype.as_numpy_dtype().GetTypeCode()) { %foreach supported_dtypes,supported_dtypes_lowercase% - case NPTypeCode.#1:new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; % default: throw new NotSupportedException(); @@ -283,17 +331,18 @@ namespace Tensorflow #region Compute switch (_dtype.as_numpy_dtype().GetTypeCode()) { - case NPTypeCode.Boolean:new UnmanagedMemoryBlock((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Byte:new UnmanagedMemoryBlock((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Int16:new UnmanagedMemoryBlock((short*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.UInt16:new UnmanagedMemoryBlock((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Int32:new UnmanagedMemoryBlock((int*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.UInt32:new UnmanagedMemoryBlock((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Int64:new UnmanagedMemoryBlock((long*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.UInt64:new UnmanagedMemoryBlock((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Char:new UnmanagedMemoryBlock((char*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Double:new UnmanagedMemoryBlock((double*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; - case NPTypeCode.Single:new UnmanagedMemoryBlock((float*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Boolean: new UnmanagedMemoryBlock((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Byte: new UnmanagedMemoryBlock((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int16: new UnmanagedMemoryBlock((short*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt16: new UnmanagedMemoryBlock((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int32: new UnmanagedMemoryBlock((int*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt32: new UnmanagedMemoryBlock((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int64: new UnmanagedMemoryBlock((long*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt64: new UnmanagedMemoryBlock((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Char: new UnmanagedMemoryBlock((char*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Double: new UnmanagedMemoryBlock((double*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Single: new UnmanagedMemoryBlock((float*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To default: throw new NotSupportedException(); } @@ -307,7 +356,6 @@ namespace Tensorflow } } - /// /// Copies the memory of current buffer onto newly allocated array. ///