Browse Source

TFE_Py_FastPathExecute

tags/v0.20
Oceania2018 5 years ago
parent
commit
100d769284
13 changed files with 423 additions and 254 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +4
    -0
      src/TensorFlowNET.Core/Eager/EagerTensor.cs
  3. +50
    -12
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  4. +92
    -0
      src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs
  5. +11
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  6. +13
    -1
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  7. +4
    -0
      src/TensorFlowNET.Core/TensorFlow.Binding.csproj
  8. +234
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
  9. +0
    -213
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  11. +2
    -0
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  12. +2
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  13. +7
    -27
      src/TensorFlowNET.Core/Tensors/tf.constant.cs

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -41,6 +41,9 @@ namespace Tensorflow
public Tensor asin(Tensor x, string name = null) public Tensor asin(Tensor x, string name = null)
=> gen_math_ops.asin(x, name); => gen_math_ops.asin(x, name);


public Tensor add(Tensor a, Tensor b, string name = null)
=> gen_math_ops.add(a, b, name: name);

public Tensor add<Tx, Ty>(Tx a, Ty b, string name = null) public Tensor add<Tx, Ty>(Tx a, Ty b, string name = null)
=> gen_math_ops.add(a, b, name: name); => gen_math_ops.add(a, b, name: name);




+ 4
- 0
src/TensorFlowNET.Core/Eager/EagerTensor.cs View File

@@ -14,6 +14,10 @@ namespace Tensorflow.Eager
{ {
} }


public EagerTensor(int value, string device_name) : base(value)
{
}

public override string ToString() public override string ToString()
{ {
switch (rank) switch (rank)


+ 50
- 12
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -10,14 +10,14 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <returns>TFE_ContextOptions*</returns> /// <returns>TFE_ContextOptions*</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_NewContextOptions();
internal static extern IntPtr TFE_NewContextOptions();


/// <summary> /// <summary>
/// Destroy an options object. /// Destroy an options object.
/// </summary> /// </summary>
/// <param name="options">TFE_ContextOptions*</param> /// <param name="options">TFE_ContextOptions*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteContextOptions(IntPtr options);
internal static extern void TFE_DeleteContextOptions(IntPtr options);


/// <summary> /// <summary>
/// ///
@@ -26,14 +26,14 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns>TFE_Context*</returns> /// <returns>TFE_Context*</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_NewContext(IntPtr opts, IntPtr status);
internal static extern IntPtr TFE_NewContext(IntPtr opts, IntPtr status);


/// <summary> /// <summary>
/// ///
/// </summary> /// </summary>
/// <param name="ctx">TFE_Context*</param> /// <param name="ctx">TFE_Context*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteContext(IntPtr ctx);
internal static extern void TFE_DeleteContext(IntPtr ctx);


/// <summary> /// <summary>
/// Execute the operation defined by 'op' and return handles to computed /// Execute the operation defined by 'op' and return handles to computed
@@ -44,7 +44,7 @@ namespace Tensorflow
/// <param name="num_retvals">int*</param> /// <param name="num_retvals">int*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_Execute(IntPtr op, IntPtr retvals, int[] num_retvals, IntPtr status);
internal static extern void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, IntPtr status);


/// <summary> /// <summary>
/// ///
@@ -54,14 +54,14 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status);
internal static extern IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status);


/// <summary> /// <summary>
/// ///
/// </summary> /// </summary>
/// <param name="op">TFE_Op*</param> /// <param name="op">TFE_Op*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteOp(IntPtr op);
internal static extern void TFE_DeleteOp(IntPtr op);


/// <summary> /// <summary>
/// ///
@@ -70,7 +70,10 @@ namespace Tensorflow
/// <param name="attr_name">const char*</param> /// <param name="attr_name">const char*</param>
/// <param name="value">TF_DataType</param> /// <param name="value">TF_DataType</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value);
internal static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value);

[DllImport(TensorFlowLibName)]
internal static extern void TFE_OpSetAttrInt(IntPtr op, string attr_name, long value);


/// <summary> /// <summary>
/// ///
@@ -81,7 +84,7 @@ namespace Tensorflow
/// <param name="num_dims">const int</param> /// <param name="num_dims">const int</param>
/// <param name="out_status">TF_Status*</param> /// <param name="out_status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, Status out_status);
internal static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, Status out_status);


/// <summary> /// <summary>
/// ///
@@ -91,7 +94,16 @@ namespace Tensorflow
/// <param name="value">const void*</param> /// <param name="value">const void*</param>
/// <param name="length">size_t</param> /// <param name="length">size_t</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length);
internal static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length);

/// <summary>
///
/// </summary>
/// <param name="op"></param>
/// <param name="device_name"></param>
/// <param name="status"></param>
[DllImport(TensorFlowLibName)]
internal static extern void TFE_OpSetDevice(IntPtr op, string device_name, IntPtr status);


/// <summary> /// <summary>
/// ///
@@ -100,7 +112,7 @@ namespace Tensorflow
/// <param name="h">TFE_TensorHandle*</param> /// <param name="h">TFE_TensorHandle*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_OpAddInput(IntPtr op, IntPtr h, IntPtr status);
internal static extern void TFE_OpAddInput(IntPtr op, IntPtr h, IntPtr status);


/// <summary> /// <summary>
/// ///
@@ -108,6 +120,32 @@ namespace Tensorflow
/// <param name="t">const tensorflow::Tensor&</param> /// <param name="t">const tensorflow::Tensor&</param>
/// <returns>TFE_TensorHandle*</returns> /// <returns>TFE_TensorHandle*</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_NewTensorHandle(IntPtr t);
internal static extern IntPtr TFE_NewTensorHandle(IntPtr t, IntPtr status);

/// <summary>
///
/// </summary>
/// <param name="t"></param>
/// <param name="status"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
internal static extern IntPtr TFE_DeleteTensorHandle(IntPtr t, IntPtr status);

/// <summary>
///
/// </summary>
/// <param name="h">TFE_TensorHandle*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
internal static extern TF_DataType TFE_TensorHandleDataType(IntPtr h);

/// <summary>
/// This function will block till the operation that produces `h` has completed.
/// </summary>
/// <param name="h">TFE_TensorHandle*</param>
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
internal static extern int TFE_TensorHandleNumDims(IntPtr h, IntPtr status);
} }
} }

+ 92
- 0
src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs View File

@@ -1,5 +1,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System;
using static Tensorflow.OpDef.Types;


namespace Tensorflow.Eager namespace Tensorflow.Eager
{ {
@@ -8,6 +10,96 @@ namespace Tensorflow.Eager
/// </summary> /// </summary>
public class pywrap_tfe_src public class pywrap_tfe_src
{ {
public static EagerTensor TFE_Py_FastPathExecute(Context ctx,
string device_name,
string opName,
string name,
params Tensor[] inputs)
{
IntPtr op = IntPtr.Zero;
var attr_list_sizes = new Dictionary<string, int>();
using (var status = new Status())
{
op = c_api.TFE_NewOp(ctx, opName, status);

var op_def = Graph.TFE_GetOpDef(opName);

// SetOpAttrWithDefaults
c_api.TFE_OpSetDevice(op, "", status);

for (int i = 0; i < op_def.InputArg.Count; i++)
{
var input_arg = op_def.InputArg[i];
if (!string.IsNullOrEmpty(input_arg.NumberAttr))
{
c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, 0);
attr_list_sizes[input_arg.NumberAttr] = 0;
}
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
{

}
else
{
// The item is a single item.
AddInputToOp(inputs[i], true, input_arg, op, status);
}
}

int num_retvals = 0;
for (int i = 0; i < op_def.OutputArg.Count; i++)
{
var output_arg = op_def.OutputArg[i];
var delta = 1;
if (!string.IsNullOrEmpty(output_arg.NumberAttr))
delta = attr_list_sizes[output_arg.NumberAttr];
else if (!string.IsNullOrEmpty(output_arg.TypeListAttr))
delta = attr_list_sizes[output_arg.TypeListAttr];
if(delta < 0)
throw new RuntimeError("Attributes suggest that the size of an output list is less than 0");
num_retvals += delta;
}

var retVals = new IntPtr[num_retvals];
c_api.TFE_Execute(op, retVals, ref num_retvals, status);

var h = c_api.TFE_NewTensorHandle(retVals[0], status);
var data = new Tensor(h);
status.Check(true);
}

throw new NotImplementedException("");
}

/// <summary>
/// Adds input and type attr to the op, and to the list of flattened
/// inputs/attrs.
/// </summary>
/// <param name="inputs"></param>
/// <param name="add_type_attr"></param>
/// <param name="input_arg"></param>
/// <param name="op"></param>
/// <param name="status"></param>
/// <returns></returns>
private static bool AddInputToOp(Tensor input,
bool add_type_attr,
ArgDef input_arg,
IntPtr op,
Status status)
{
var input_handle = c_api.TFE_NewTensorHandle(input, status);

if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr))
{
var dtype = c_api.TFE_TensorHandleDataType(input_handle);
c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype);
}

c_api.TFE_OpAddInput(op, input_handle, status);
status.Check(true);
return true;
}

public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = null) public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = null)
{ {
var input_ids = inputs.Select(x => x.Id).ToArray(); var input_ids = inputs.Select(x => x.Id).ToArray();


+ 11
- 0
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -35,6 +35,17 @@ namespace Tensorflow
} }
} }


public static OpDef TFE_GetOpDef(string type)
{
IntPtr handle = tf.get_default_graph();
using (var buffer = new Buffer())
using (var status = new Status())
{
c_api.TF_GraphGetOpDef(handle, type, buffer, status);
return OpDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
}
}

public OperationDescription NewOperation(string opType, string opName) public OperationDescription NewOperation(string opType, string opName)
{ {
return c_api.TF_NewOperation(_handle, opType, opName); return c_api.TF_NewOperation(_handle, opType, opName);


+ 13
- 1
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -14,6 +14,7 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Tensorflow.Eager;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
@@ -139,9 +140,20 @@ namespace Tensorflow
return _op.outputs[0]; return _op.outputs[0];
} }


public static EagerTensor add(Tensor x, Tensor y, string name = null)
{
// _op_def_lib._apply_op_helper("Add", name, args: new { x, y });

if (tf.context.executing_eagerly())
{
var _result = pywrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Add", name, new[] { x, y });
}

return null;
}

public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null) public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null)
{ {
// forward_compatible(2019, 6, 25):
var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y });


return _op.output; return _op.output;


+ 4
- 0
src/TensorFlowNET.Core/TensorFlow.Binding.csproj View File

@@ -71,4 +71,8 @@ https://tensorflownet.readthedocs.io</Description>
<ItemGroup> <ItemGroup>
<Folder Include="Keras\Initializers\" /> <Folder Include="Keras\Initializers\" />
</ItemGroup> </ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp.Lite\NumSharp\NumSharp.Lite.csproj" />
</ItemGroup>
</Project> </Project>

+ 234
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Value.cs View File

@@ -0,0 +1,234 @@
using NumSharp;
using NumSharp.Backends.Unmanaged;
using NumSharp.Utilities;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
public partial class Tensor
{
[Obsolete("Please use ToArray<T>() instead.", false)]
public T[] Data<T>() where T : unmanaged
{
return ToArray<T>();
}

/// <summary>
///
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public T[] ToArray<T>() where T : unmanaged
{
//Are the types matching?
if (typeof(T).as_dtype() == dtype)
{
if (NDims == 0 && size == 1) //is it a scalar?
{
unsafe
{
return new T[] { *(T*)buffer };
}
}

//types match, no need to perform cast
var ret = new T[size];
unsafe
{
var len = (long)size;
fixed (T* dst = ret)
{
//T can only be unmanaged, I believe it is safe to say that MemoryCopy is valid for all cases this method can be called.
var src = (T*)buffer;
len *= ((long)itemsize);
System.Buffer.MemoryCopy(src, dst, len, len);
}
}

return ret;
}
else
{

//types do not match, need to perform cast
if (NDims == 0 && size == 1) //is it a scalar?
{
unsafe
{
#if _REGEN
#region Compute
switch (dtype.as_numpy_dtype().GetTypeCode())
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: return new T[] {Converts.ChangeType<T>(*(#2*) buffer)};
%
case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this)};
default:
throw new NotSupportedException();
}
#endregion
#else
#region Compute
switch (dtype.as_numpy_dtype().GetTypeCode())
{
case NPTypeCode.Boolean: return new T[] { Converts.ChangeType<T>(*(bool*)buffer) };
case NPTypeCode.Byte: return new T[] { Converts.ChangeType<T>(*(byte*)buffer) };
case NPTypeCode.Int16: return new T[] { Converts.ChangeType<T>(*(short*)buffer) };
case NPTypeCode.UInt16: return new T[] { Converts.ChangeType<T>(*(ushort*)buffer) };
case NPTypeCode.Int32: return new T[] { Converts.ChangeType<T>(*(int*)buffer) };
case NPTypeCode.UInt32: return new T[] { Converts.ChangeType<T>(*(uint*)buffer) };
case NPTypeCode.Int64: return new T[] { Converts.ChangeType<T>(*(long*)buffer) };
case NPTypeCode.UInt64: return new T[] { Converts.ChangeType<T>(*(ulong*)buffer) };
case NPTypeCode.Char: return new T[] { Converts.ChangeType<T>(*(char*)buffer) };
case NPTypeCode.Double: return new T[] { Converts.ChangeType<T>(*(double*)buffer) };
case NPTypeCode.Single: return new T[] { Converts.ChangeType<T>(*(float*)buffer) };
case NPTypeCode.String: return new T[] { Converts.ChangeType<T>((string)this) };
default:
throw new NotSupportedException();
}
#endregion
#endif
}
}

var ret = new T[size];
unsafe
{
var len = (long)size;
fixed (T* dstRet = ret)
{
T* dst = dstRet; //local stack copy

#if _REGEN
#region Compute
switch (dtype.as_numpy_dtype().GetTypeCode())
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
%
default:
throw new NotSupportedException();
}
#endregion
#else
#region Compute
switch (dtype.as_numpy_dtype().GetTypeCode())
{
case NPTypeCode.Boolean: new UnmanagedMemoryBlock<bool>((bool*)buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Byte: new UnmanagedMemoryBlock<byte>((byte*)buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int16: new UnmanagedMemoryBlock<short>((short*)buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt16: new UnmanagedMemoryBlock<ushort>((ushort*)buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int32: new UnmanagedMemoryBlock<int>((int*)buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt32: new UnmanagedMemoryBlock<uint>((uint*)buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int64: new UnmanagedMemoryBlock<long>((long*)buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt64: new UnmanagedMemoryBlock<ulong>((ulong*)buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Char: new UnmanagedMemoryBlock<char>((char*)buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Double: new UnmanagedMemoryBlock<double>((double*)buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Single: new UnmanagedMemoryBlock<float>((float*)buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To<T>
default:
throw new NotSupportedException();
}
#endregion
#endif

}
}

return ret;
}
}

/// <summary>
/// Copy of the contents of this Tensor into a NumPy array or scalar.
/// </summary>
/// <returns>
/// A NumPy array of the same shape and dtype or a NumPy scalar, if this
/// Tensor has rank 0.
/// </returns>
public NDArray numpy()
{
if(NDims == 0)
{
return GetScalar(dtype);
}
else
{
throw new NotImplementedException("numpy not implemented when ndim > 0");
}
}

private unsafe NDArray GetScalar(TF_DataType dtype)
{
switch(dtype)
{
case TF_DataType.TF_STRING:
return StringData()[0];
case TF_DataType.TF_INT32:
return *(int*)buffer;
default:
return BufferToArray();
}
}

/// <summary>
/// Copies the memory of current buffer onto newly allocated array.
/// </summary>
/// <returns></returns>
public unsafe byte[] BufferToArray()
{
// ReSharper disable once LocalVariableHidesMember
var bytesize = (long)this.bytesize;
var data = new byte[bytesize];
fixed (byte* dst = data)
System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize);

return data;
}

/// <summary>
/// Extracts string array from current Tensor.
/// </summary>
/// <exception cref="InvalidOperationException">When <see cref="dtype"/> != TF_DataType.TF_STRING</exception>
public unsafe string[] StringData()
{
if (dtype != TF_DataType.TF_STRING)
throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})");

//
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes]
//
long size = 1;
foreach (var s in TensorShape.dims)
size *= s;

var buffer = new byte[size][];
var src = c_api.TF_TensorData(_handle);
var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize);
src += (int)(size * 8);
using (var status = new Status())
{
for (int i = 0; i < buffer.Length; i++)
{
IntPtr dst = IntPtr.Zero;
UIntPtr dstLen = UIntPtr.Zero;
var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status);
status.Check(true);
buffer[i] = new byte[(int)dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
src += (int)read;
}
}

var _str = new string[buffer.Length];
for (int i = 0; i < _str.Length; i++)
_str[i] = Encoding.UTF8.GetString(buffer[i]);

return _str;
}
}
}

+ 0
- 213
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -20,13 +20,9 @@ using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Globalization; using System.Globalization;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using NumSharp.Backends;
using NumSharp.Backends.Unmanaged;
using NumSharp.Utilities;
using Tensorflow.Framework; using Tensorflow.Framework;
#if SERIALIZABLE #if SERIALIZABLE
using Newtonsoft.Json; using Newtonsoft.Json;
@@ -249,215 +245,6 @@ namespace Tensorflow
return _tf_output.Value; return _tf_output.Value;
} }


[Obsolete("Please use ToArray<T>() instead.", false)]
public T[] Data<T>() where T : unmanaged
{
return ToArray<T>();
}

/// <summary>
///
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public T[] ToArray<T>() where T : unmanaged
{
//Are the types matching?
if (typeof(T).as_dtype() == dtype)
{
if (NDims == 0 && size == 1) //is it a scalar?
{
unsafe
{
return new T[] {*(T*) buffer};
}
}

//types match, no need to perform cast
var ret = new T[size];
unsafe
{
var len = (long) size;
fixed (T* dst = ret)
{
//T can only be unmanaged, I believe it is safe to say that MemoryCopy is valid for all cases this method can be called.
var src = (T*) buffer;
len *= ((long) itemsize);
System.Buffer.MemoryCopy(src, dst, len, len);
}
}

return ret;
} else
{
//types do not match, need to perform cast
if (NDims == 0 && size == 1) //is it a scalar?
{
unsafe
{
#if _REGEN
#region Compute
switch (dtype.as_numpy_dtype().GetTypeCode())
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: return new T[] {Converts.ChangeType<T>(*(#2*) buffer)};
%
case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this)};
default:
throw new NotSupportedException();
}
#endregion
#else
#region Compute
switch (dtype.as_numpy_dtype().GetTypeCode())
{
case NPTypeCode.Boolean: return new T[] {Converts.ChangeType<T>(*(bool*) buffer)};
case NPTypeCode.Byte: return new T[] {Converts.ChangeType<T>(*(byte*) buffer)};
case NPTypeCode.Int16: return new T[] {Converts.ChangeType<T>(*(short*) buffer)};
case NPTypeCode.UInt16: return new T[] {Converts.ChangeType<T>(*(ushort*) buffer)};
case NPTypeCode.Int32: return new T[] {Converts.ChangeType<T>(*(int*) buffer)};
case NPTypeCode.UInt32: return new T[] {Converts.ChangeType<T>(*(uint*) buffer)};
case NPTypeCode.Int64: return new T[] {Converts.ChangeType<T>(*(long*) buffer)};
case NPTypeCode.UInt64: return new T[] {Converts.ChangeType<T>(*(ulong*) buffer)};
case NPTypeCode.Char: return new T[] {Converts.ChangeType<T>(*(char*) buffer)};
case NPTypeCode.Double: return new T[] {Converts.ChangeType<T>(*(double*) buffer)};
case NPTypeCode.Single: return new T[] {Converts.ChangeType<T>(*(float*) buffer)};
case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this)};
default:
throw new NotSupportedException();
}
#endregion
#endif
}
}

var ret = new T[size];
unsafe
{
var len = (long) size;
fixed (T* dstRet = ret)
{
T* dst = dstRet; //local stack copy

#if _REGEN
#region Compute
switch (dtype.as_numpy_dtype().GetTypeCode())
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
%
default:
throw new NotSupportedException();
}
#endregion
#else
#region Compute
switch (dtype.as_numpy_dtype().GetTypeCode())
{
case NPTypeCode.Boolean: new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Byte: new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int16: new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt16: new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int32: new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt32: new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int64: new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt64: new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Char: new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Double: new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Single: new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To<T>
default:
throw new NotSupportedException();
}
#endregion
#endif
}
}
return ret;
}
}

/// <summary>
/// Copy of the contents of this Tensor into a NumPy array or scalar.
/// </summary>
/// <returns>
/// A NumPy array of the same shape and dtype or a NumPy scalar, if this
/// Tensor has rank 0.
/// </returns>
public NDArray numpy()
{
switch (dtype)
{
case TF_DataType.TF_STRING:
return StringData()[0];
default:
return BufferToArray();
}
}

/// <summary>
/// Copies the memory of current buffer onto newly allocated array.
/// </summary>
/// <returns></returns>
public byte[] BufferToArray()
{
unsafe
{
// ReSharper disable once LocalVariableHidesMember
var bytesize = (long)this.bytesize;
var data = new byte[bytesize];
fixed (byte* dst = data)
System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize);

return data;
}
}

/// <summary>
/// Extracts string array from current Tensor.
/// </summary>
/// <exception cref="InvalidOperationException">When <see cref="dtype"/> != TF_DataType.TF_STRING</exception>
public unsafe string[] StringData()
{
if (dtype != TF_DataType.TF_STRING)
throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})");

//
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes]
//
long size = 1;
foreach (var s in TensorShape.dims)
size *= s;

var buffer = new byte[size][];
var src = c_api.TF_TensorData(_handle);
var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize);
src += (int) (size * 8);
for (int i = 0; i < buffer.Length; i++)
{
using (var status = new Status())
{
IntPtr dst = IntPtr.Zero;
UIntPtr dstLen = UIntPtr.Zero;
var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status);
status.Check(true);
buffer[i] = new byte[(int) dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
src += (int) read;
}
}

var _str = new string[buffer.Length];
for (int i = 0; i < _str.Length; i++)
_str[i] = Encoding.UTF8.GetString(buffer[i]);

return _str;
}

public Tensor MaybeMove() public Tensor MaybeMove()
{ {
var tensor = c_api.TF_TensorMaybeMove(_handle); var tensor = c_api.TF_TensorMaybeMove(_handle);


+ 1
- 1
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -252,7 +252,7 @@ namespace Tensorflow
public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone());
public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone());
public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes
public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes
public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); public static implicit operator TensorShape(int[] dims) => new TensorShape(dims);


public static explicit operator int(TensorShape shape) => shape.size; public static explicit operator int(TensorShape shape) => shape.size;


+ 2
- 0
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -86,6 +86,8 @@ namespace Tensorflow
{ {
case string str: case string str:
return new EagerTensor(str, ctx.device_name); return new EagerTensor(str, ctx.device_name);
case int int32:
return new EagerTensor(int32, ctx.device_name);
default: default:
throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}"); throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}");
} }


+ 2
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -201,6 +201,8 @@ namespace Tensorflow
=> type switch => type switch
{ {
TF_DataType.TF_STRING => "string", TF_DataType.TF_STRING => "string",
TF_DataType.TF_INT32 => "int32",
TF_DataType.TF_FLOAT => "float32",
_ => type.ToString() _ => type.ToString()
}; };




+ 7
- 27
src/TensorFlowNET.Core/Tensors/tf.constant.cs View File

@@ -31,33 +31,13 @@ namespace Tensorflow
public Tensor constant(object value, public Tensor constant(object value,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
TensorShape shape = null, TensorShape shape = null,
string name = "Const")
{
switch (value)
{
case string str:
return constant_op._constant_impl(str,
@string,
null,
name,
verify_shape: false,
allow_broadcast: true);
case float val:
return constant_op._constant_impl(value,
float32,
new int[] { (int)shape },
name,
verify_shape: false,
allow_broadcast: true);
default:
return constant_op._constant_impl(value,
dtype,
shape,
name,
verify_shape: false,
allow_broadcast: true);
}
}
string name = "Const")
=> constant_op._constant_impl(value,
dtype,
shape,
name,
verify_shape: false,
allow_broadcast: true);


public Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) public Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
=> array_ops.zeros(shape, dtype, name); => array_ops.zeros(shape, dtype, name);


Loading…
Cancel
Save