| @@ -233,6 +233,14 @@ namespace Tensorflow | |||||
| //Are the types matching? | //Are the types matching? | ||||
| if (typeof(T).as_dtype() == _dtype) | if (typeof(T).as_dtype() == _dtype) | ||||
| { | { | ||||
| if (NDims == 0 && size == 1) //is it a scalar? | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| return new T[] {*(T*) buffer}; | |||||
| } | |||||
| } | |||||
| //types match, no need to perform cast | //types match, no need to perform cast | ||||
| var ret = new T[size]; | var ret = new T[size]; | ||||
| unsafe | unsafe | ||||
| @@ -260,6 +268,46 @@ namespace Tensorflow | |||||
| { | { | ||||
| //types do not match, need to perform cast | //types do not match, need to perform cast | ||||
| if (NDims == 0 && size == 1) //is it a scalar? | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| #if _REGEN | |||||
| #region Compute | |||||
| switch (_dtype.as_numpy_dtype().GetTypeCode()) | |||||
| { | |||||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||||
| case NPTypeCode.#1: return new T[] {Converts.ChangeType<T>(*(#2*) buffer, NPTypeCode.#1)}; | |||||
| % | |||||
| case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)}; | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #else | |||||
| #region Compute | |||||
| switch (_dtype.as_numpy_dtype().GetTypeCode()) | |||||
| { | |||||
| case NPTypeCode.Boolean: return new T[] {Converts.ChangeType<T>(*(bool*) buffer, NPTypeCode.Boolean)}; | |||||
| case NPTypeCode.Byte: return new T[] {Converts.ChangeType<T>(*(byte*) buffer, NPTypeCode.Byte)}; | |||||
| case NPTypeCode.Int16: return new T[] {Converts.ChangeType<T>(*(short*) buffer, NPTypeCode.Int16)}; | |||||
| case NPTypeCode.UInt16: return new T[] {Converts.ChangeType<T>(*(ushort*) buffer, NPTypeCode.UInt16)}; | |||||
| case NPTypeCode.Int32: return new T[] {Converts.ChangeType<T>(*(int*) buffer, NPTypeCode.Int32)}; | |||||
| case NPTypeCode.UInt32: return new T[] {Converts.ChangeType<T>(*(uint*) buffer, NPTypeCode.UInt32)}; | |||||
| case NPTypeCode.Int64: return new T[] {Converts.ChangeType<T>(*(long*) buffer, NPTypeCode.Int64)}; | |||||
| case NPTypeCode.UInt64: return new T[] {Converts.ChangeType<T>(*(ulong*) buffer, NPTypeCode.UInt64)}; | |||||
| case NPTypeCode.Char: return new T[] {Converts.ChangeType<T>(*(char*) buffer, NPTypeCode.Char)}; | |||||
| case NPTypeCode.Double: return new T[] {Converts.ChangeType<T>(*(double*) buffer, NPTypeCode.Double)}; | |||||
| case NPTypeCode.Single: return new T[] {Converts.ChangeType<T>(*(float*) buffer, NPTypeCode.Single)}; | |||||
| case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)}; | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #endif | |||||
| } | |||||
| } | |||||
| var ret = new T[size]; | var ret = new T[size]; | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| @@ -270,10 +318,10 @@ namespace Tensorflow | |||||
| #if _REGEN | #if _REGEN | ||||
| #region Compute | #region Compute | ||||
| switch (_dtype.as_numpy_datatype().GetTypeCode()) | |||||
| switch (_dtype.as_numpy_dtype().GetTypeCode()) | |||||
| { | { | ||||
| %foreach supported_dtypes,supported_dtypes_lowercase% | %foreach supported_dtypes,supported_dtypes_lowercase% | ||||
| case NPTypeCode.#1:new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| % | % | ||||
| default: | default: | ||||
| throw new NotSupportedException(); | throw new NotSupportedException(); | ||||
| @@ -283,17 +331,18 @@ namespace Tensorflow | |||||
| #region Compute | #region Compute | ||||
| switch (_dtype.as_numpy_dtype().GetTypeCode()) | switch (_dtype.as_numpy_dtype().GetTypeCode()) | ||||
| { | { | ||||
| case NPTypeCode.Boolean:new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Byte:new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int16:new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt16:new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int32:new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt32:new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int64:new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt64:new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Char:new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Double:new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Single:new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Boolean: new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Byte: new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int16: new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt16: new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int32: new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt32: new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int64: new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt64: new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Char: new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Double: new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Single: new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To<T> | |||||
| default: | default: | ||||
| throw new NotSupportedException(); | throw new NotSupportedException(); | ||||
| } | } | ||||
| @@ -307,7 +356,6 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Copies the memory of current buffer onto newly allocated array. | /// Copies the memory of current buffer onto newly allocated array. | ||||
| /// </summary> | /// </summary> | ||||