| @@ -233,6 +233,14 @@ namespace Tensorflow | |||
| //Are the types matching? | |||
| if (typeof(T).as_dtype() == _dtype) | |||
| { | |||
| if (NDims == 0 && size == 1) //is it a scalar? | |||
| { | |||
| unsafe | |||
| { | |||
| return new T[] {*(T*) buffer}; | |||
| } | |||
| } | |||
| //types match, no need to perform cast | |||
| var ret = new T[size]; | |||
| unsafe | |||
| @@ -260,6 +268,46 @@ namespace Tensorflow | |||
| { | |||
| //types do not match, need to perform cast | |||
| if (NDims == 0 && size == 1) //is it a scalar? | |||
| { | |||
| unsafe | |||
| { | |||
| #if _REGEN | |||
| #region Compute | |||
| switch (_dtype.as_numpy_dtype().GetTypeCode()) | |||
| { | |||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||
| case NPTypeCode.#1: return new T[] {Converts.ChangeType<T>(*(#2*) buffer, NPTypeCode.#1)}; | |||
| % | |||
| case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)}; | |||
| default: | |||
| throw new NotSupportedException(); | |||
| } | |||
| #endregion | |||
| #else | |||
| #region Compute | |||
| switch (_dtype.as_numpy_dtype().GetTypeCode()) | |||
| { | |||
| case NPTypeCode.Boolean: return new T[] {Converts.ChangeType<T>(*(bool*) buffer, NPTypeCode.Boolean)}; | |||
| case NPTypeCode.Byte: return new T[] {Converts.ChangeType<T>(*(byte*) buffer, NPTypeCode.Byte)}; | |||
| case NPTypeCode.Int16: return new T[] {Converts.ChangeType<T>(*(short*) buffer, NPTypeCode.Int16)}; | |||
| case NPTypeCode.UInt16: return new T[] {Converts.ChangeType<T>(*(ushort*) buffer, NPTypeCode.UInt16)}; | |||
| case NPTypeCode.Int32: return new T[] {Converts.ChangeType<T>(*(int*) buffer, NPTypeCode.Int32)}; | |||
| case NPTypeCode.UInt32: return new T[] {Converts.ChangeType<T>(*(uint*) buffer, NPTypeCode.UInt32)}; | |||
| case NPTypeCode.Int64: return new T[] {Converts.ChangeType<T>(*(long*) buffer, NPTypeCode.Int64)}; | |||
| case NPTypeCode.UInt64: return new T[] {Converts.ChangeType<T>(*(ulong*) buffer, NPTypeCode.UInt64)}; | |||
| case NPTypeCode.Char: return new T[] {Converts.ChangeType<T>(*(char*) buffer, NPTypeCode.Char)}; | |||
| case NPTypeCode.Double: return new T[] {Converts.ChangeType<T>(*(double*) buffer, NPTypeCode.Double)}; | |||
| case NPTypeCode.Single: return new T[] {Converts.ChangeType<T>(*(float*) buffer, NPTypeCode.Single)}; | |||
| case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)}; | |||
| default: | |||
| throw new NotSupportedException(); | |||
| } | |||
| #endregion | |||
| #endif | |||
| } | |||
| } | |||
| var ret = new T[size]; | |||
| unsafe | |||
| { | |||
| @@ -270,10 +318,10 @@ namespace Tensorflow | |||
| #if _REGEN | |||
| #region Compute | |||
| switch (_dtype.as_numpy_datatype().GetTypeCode()) | |||
| switch (_dtype.as_numpy_dtype().GetTypeCode()) | |||
| { | |||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||
| case NPTypeCode.#1:new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| % | |||
| default: | |||
| throw new NotSupportedException(); | |||
| @@ -283,17 +331,18 @@ namespace Tensorflow | |||
| #region Compute | |||
| switch (_dtype.as_numpy_dtype().GetTypeCode()) | |||
| { | |||
| case NPTypeCode.Boolean:new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Byte:new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Int16:new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.UInt16:new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Int32:new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.UInt32:new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Int64:new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.UInt64:new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Char:new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Double:new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Single:new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Boolean: new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Byte: new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Int16: new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.UInt16: new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Int32: new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.UInt32: new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Int64: new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.UInt64: new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Char: new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Double: new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.Single: new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
| case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To<T> | |||
| default: | |||
| throw new NotSupportedException(); | |||
| } | |||
| @@ -307,7 +356,6 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Copies the memory of current buffer onto newly allocated array. | |||
| /// </summary> | |||