Browse Source

Tensor.ToArray<T>(): Added support for cases when tensor is scalar.

tags/v0.12
Eli Belash 6 years ago
parent
commit
52955b3d9d
1 changed files with 62 additions and 14 deletions
  1. +62
    -14
      src/TensorFlowNET.Core/Tensors/Tensor.cs

+ 62
- 14
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -233,6 +233,14 @@ namespace Tensorflow
//Are the types matching?
if (typeof(T).as_dtype() == _dtype)
{
if (NDims == 0 && size == 1) //is it a scalar?
{
unsafe
{
return new T[] {*(T*) buffer};
}
}

//types match, no need to perform cast
var ret = new T[size];
unsafe
@@ -260,6 +268,46 @@ namespace Tensorflow
{
//types do not match, need to perform cast
if (NDims == 0 && size == 1) //is it a scalar?
{
unsafe
{
#if _REGEN
#region Compute
switch (_dtype.as_numpy_dtype().GetTypeCode())
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: return new T[] {Converts.ChangeType<T>(*(#2*) buffer, NPTypeCode.#1)};
%
case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)};
default:
throw new NotSupportedException();
}
#endregion
#else
#region Compute
switch (_dtype.as_numpy_dtype().GetTypeCode())
{
case NPTypeCode.Boolean: return new T[] {Converts.ChangeType<T>(*(bool*) buffer, NPTypeCode.Boolean)};
case NPTypeCode.Byte: return new T[] {Converts.ChangeType<T>(*(byte*) buffer, NPTypeCode.Byte)};
case NPTypeCode.Int16: return new T[] {Converts.ChangeType<T>(*(short*) buffer, NPTypeCode.Int16)};
case NPTypeCode.UInt16: return new T[] {Converts.ChangeType<T>(*(ushort*) buffer, NPTypeCode.UInt16)};
case NPTypeCode.Int32: return new T[] {Converts.ChangeType<T>(*(int*) buffer, NPTypeCode.Int32)};
case NPTypeCode.UInt32: return new T[] {Converts.ChangeType<T>(*(uint*) buffer, NPTypeCode.UInt32)};
case NPTypeCode.Int64: return new T[] {Converts.ChangeType<T>(*(long*) buffer, NPTypeCode.Int64)};
case NPTypeCode.UInt64: return new T[] {Converts.ChangeType<T>(*(ulong*) buffer, NPTypeCode.UInt64)};
case NPTypeCode.Char: return new T[] {Converts.ChangeType<T>(*(char*) buffer, NPTypeCode.Char)};
case NPTypeCode.Double: return new T[] {Converts.ChangeType<T>(*(double*) buffer, NPTypeCode.Double)};
case NPTypeCode.Single: return new T[] {Converts.ChangeType<T>(*(float*) buffer, NPTypeCode.Single)};
case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)};
default:
throw new NotSupportedException();
}
#endregion
#endif
}
}

var ret = new T[size];
unsafe
{
@@ -270,10 +318,10 @@ namespace Tensorflow

#if _REGEN
#region Compute
switch (_dtype.as_numpy_datatype().GetTypeCode())
switch (_dtype.as_numpy_dtype().GetTypeCode())
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1:new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
%
default:
throw new NotSupportedException();
@@ -283,17 +331,18 @@ namespace Tensorflow
#region Compute
switch (_dtype.as_numpy_dtype().GetTypeCode())
{
case NPTypeCode.Boolean:new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Byte:new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int16:new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt16:new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int32:new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt32:new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int64:new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt64:new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Char:new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Double:new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Single:new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Boolean: new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Byte: new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int16: new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt16: new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int32: new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt32: new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int64: new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt64: new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Char: new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Double: new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Single: new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To<T>
default:
throw new NotSupportedException();
}
@@ -307,7 +356,6 @@ namespace Tensorflow
}
}


/// <summary>
/// Copies the memory of current buffer onto newly allocated array.
/// </summary>


Loading…
Cancel
Save