Browse Source

Tensor perf-ops and cleanup, Revamped dtypes.cs, some renames.

- Cleanup and docs to all Tensor.cs files
- Changed all uses of System.Convert to NumSharp.Utilities.Converts
- Added all missing types in dtypes.cs
- Renamed tensor.Data<T> to tensor.ToArray<T>, added obsolete message
- Renamed tensor.Data() to tensor.BufferToArray(), added obsolete message
- Made GraphKeys to use const string instead allocating strings at every use of GraphKeys.
tags/v0.12
Eli Belash 6 years ago
parent
commit
2129dbd675
13 changed files with 379 additions and 157 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Operations/nn_ops.cs
  5. +3
    -3
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  6. +2
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  7. +214
    -99
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  8. +74
    -25
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  9. +8
    -7
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  10. +70
    -14
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  11. +2
    -2
      test/TensorFlowNET.UnitTest/SessionTest.cs
  12. +1
    -1
      test/TensorFlowNET.UnitTest/TensorTest.cs
  13. +1
    -1
      test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -141,7 +141,7 @@ namespace Tensorflow.Operations
data, frame_name, is_constant, parallel_iterations, name: name); data, frame_name, is_constant, parallel_iterations, name: name);


if (use_input_shape) if (use_input_shape)
result.SetShape(data.TensorShape);
result.set_shape(data.TensorShape);


return result; return result;
} }


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -233,7 +233,7 @@ namespace Tensorflow.Operations
dims.AddRange(x_static_shape.dims.Skip(2)); dims.AddRange(x_static_shape.dims.Skip(2));
var shape = new TensorShape(dims.ToArray()); var shape = new TensorShape(dims.ToArray());


x_t.SetShape(shape);
x_t.set_shape(shape);


return x_t; return x_t;
} }


+ 1
- 1
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -351,7 +351,7 @@ namespace Tensorflow
var input_shape = tensor_util.to_shape(input_tensor.shape); var input_shape = tensor_util.to_shape(input_tensor.shape);
if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined())
{ {
var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype());
var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype());
return constant_op.constant(nd, name: name); return constant_op.constant(nd, name: name);
} }
} }


+ 1
- 1
src/TensorFlowNET.Core/Operations/nn_ops.cs View File

@@ -98,7 +98,7 @@ namespace Tensorflow
// float to be selected, hence we use a >= comparison. // float to be selected, hence we use a >= comparison.
var keep_mask = random_tensor >= rate; var keep_mask = random_tensor >= rate;
var ret = x * scale * math_ops.cast(keep_mask, x.dtype); var ret = x * scale * math_ops.cast(keep_mask, x.dtype);
ret.SetShape(x.TensorShape);
ret.set_shape(x.TensorShape);
return ret; return ret;
}); });
} }


+ 3
- 3
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -273,7 +273,7 @@ namespace Tensorflow
{ {
var tensor = new Tensor(output); var tensor = new Tensor(output);
NDArray nd = null; NDArray nd = null;
Type type = tensor.dtype.as_numpy_datatype();
Type type = tensor.dtype.as_numpy_dtype();
var ndims = tensor.shape; var ndims = tensor.shape;
var offset = c_api.TF_TensorData(output); var offset = c_api.TF_TensorData(output);
@@ -285,7 +285,7 @@ namespace Tensorflow
nd = NDArray.Scalar(*(bool*)offset); nd = NDArray.Scalar(*(bool*)offset);
break; break;
case TF_DataType.TF_STRING: case TF_DataType.TF_STRING:
var bytes = tensor.Data();
var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9. // wired, don't know why we have to start from offset 9.
// length in the begin // length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
@@ -324,7 +324,7 @@ namespace Tensorflow
nd = np.array(bools).reshape(ndims); nd = np.array(bools).reshape(ndims);
break; break;
case TF_DataType.TF_STRING: case TF_DataType.TF_STRING:
var bytes = tensor.Data();
var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9. // wired, don't know why we have to start from offset 9.
// length in the begin // length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);


+ 2
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -549,10 +549,11 @@ namespace Tensorflow
this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it
var ptr = new IntPtr(arraySlice.Address); var ptr = new IntPtr(arraySlice.Address);
int num_bytes = (nd.size * nd.dtypesize); int num_bytes = (nd.size * nd.dtypesize);
var dtype = given_dtype ?? ToTFDataType(nd.dtype);
var dtype = given_dtype ?? nd.dtype.as_dtype();
var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs);
IsMemoryOwner = false; IsMemoryOwner = false;
return handle; return handle;

} }


public unsafe Tensor(byte[][] buffer, long[] shape) public unsafe Tensor(byte[][] buffer, long[] shape)


+ 214
- 99
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -17,9 +17,16 @@
using NumSharp; using NumSharp;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
using System.Threading.Tasks;
using NumSharp.Backends;
using NumSharp.Backends.Unmanaged;
using NumSharp.Utilities;
using Tensorflow.Framework; using Tensorflow.Framework;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -29,42 +36,68 @@ namespace Tensorflow
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions. /// A tensor is a generalization of vectors and matrices to potentially higher dimensions.
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
/// </summary> /// </summary>
[SuppressMessage("ReSharper", "ConvertToAutoProperty")]
public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike
{ {
private int _id;
private Operation _op;
private readonly int _id;
private readonly Operation _op;
private readonly int _value_index;
private TF_Output? _tf_output;
private readonly TF_DataType _dtype;


public int Id => _id; public int Id => _id;

/// <summary>
/// The Graph that contains this tensor.
/// </summary>
public Graph graph => op?.graph; public Graph graph => op?.graph;

/// <summary>
/// The Operation that produces this tensor as an output.
/// </summary>
public Operation op => _op; public Operation op => _op;

public Tensor[] outputs => op.outputs; public Tensor[] outputs => op.outputs;


/// <summary> /// <summary>
/// The string name of this tensor.
/// The string name of this tensor.
/// </summary> /// </summary>
public string name => $"{(op == null ? "<unnamed Operation>" : $"{op.name}:{_value_index}")}"; public string name => $"{(op == null ? "<unnamed Operation>" : $"{op.name}:{_value_index}")}";


private int _value_index;
/// <summary>
/// The index of this tensor in the outputs of its Operation.
/// </summary>
public int value_index => _value_index; public int value_index => _value_index;


private TF_DataType _dtype = TF_DataType.DtInvalid;
/// <summary>
/// The DType of elements in this tensor.
/// </summary>
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle);


public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);

public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;

public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
public int NDims => rank;


private TF_Output? _tf_output;
/// <summary>
/// The name of the device on which this tensor will be produced, or null.
/// </summary>
public string Device => op.Device;

public int[] dims => shape;


/// <summary> /// <summary>
/// used for keep other pointer when do implicit operating
/// Used for keep other pointer when do implicit operating
/// </summary> /// </summary>
public object Tag { get; set; } public object Tag { get; set; }



/// <summary>
/// Returns the shape of a tensor.
/// </summary>
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/shape</remarks>
public int[] shape public int[] shape
{ {
get get
@@ -76,14 +109,13 @@ namespace Tensorflow
var status = new Status(); var status = new Status();
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status);
status.Check(); status.Check();
}
else
} else
{ {
for (int i = 0; i < rank; i++) for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i); dims[i] = c_api.TF_Dim(_handle, i);
} }


return dims.Select(x => Convert.ToInt32(x)).ToArray();
return dims.Select(x => ((IConvertible) x).ToInt32(CultureInfo.InvariantCulture)).ToArray();
} }


set set
@@ -93,38 +125,52 @@ namespace Tensorflow
if (value == null) if (value == null)
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status);
else else
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(x => Convert.ToInt64(x)).ToArray(), value.Length, status);
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);
} }
} }


public int[] _shape_tuple() public int[] _shape_tuple()
{ {
if (shape == null) return null;
return shape.Select(x => (int)x).ToArray();
return (int[]) shape.Clone();
} }


public TensorShape TensorShape => tensor_util.to_shape(shape); public TensorShape TensorShape => tensor_util.to_shape(shape);


public void SetShape(TensorShape shape)
/// <summary>
/// Updates the shape of this tensor.
/// </summary>
public void set_shape(TensorShape shape)
{
this.shape = (int[]) shape.dims.Clone();
}

/// <summary>
/// Updates the shape of this tensor.
/// </summary>
[Obsolete("Please use set_shape(TensorShape shape) instead.", false)]
public void SetShape(TensorShape shape)
{ {
this.shape = shape.dims;
this.shape = (int[]) shape.dims.Clone();
} }


/// <summary>
/// Updates the shape of this tensor.
/// </summary>
public void set_shape(Tensor shape) public void set_shape(Tensor shape)
{ {
// ReSharper disable once MergeConditionalExpression
this.shape = shape is null ? null : shape.shape; this.shape = shape is null ? null : shape.shape;
} }


public int[] dims => shape;

/// <summary> /// <summary>
/// number of dimensions
/// 0 Scalar (magnitude only)
/// 1 Vector (magnitude and direction)
/// 2 Matrix (table of numbers)
/// 3 3-Tensor (cube of numbers)
/// number of dimensions <br></br>
/// 0 Scalar (magnitude only) <br></br>
/// 1 Vector (magnitude and direction) <br></br>
/// 2 Matrix (table of numbers) <br></br>
/// 3 3-Tensor (cube of numbers) <br></br>
/// n n-Tensor (you get the idea) /// n n-Tensor (you get the idea)
/// </summary> /// </summary>
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/rank</remarks>
public int rank public int rank
{ {
get get
@@ -137,17 +183,15 @@ namespace Tensorflow
status.Check(); status.Check();
return ndim; return ndim;
} }
else
{
return c_api.TF_NumDims(_handle);
}

return c_api.TF_NumDims(_handle);
} }
} }


public int NDims => rank;
public string Device => op.Device;
/// <summary>
/// Returns a list of Operations that consume this tensor.
/// </summary>
/// <returns></returns>
public Operation[] consumers() public Operation[] consumers()
{ {
var output = _as_tf_output(); var output = _as_tf_output();
@@ -157,37 +201,136 @@ namespace Tensorflow


public TF_Output _as_tf_output() public TF_Output _as_tf_output()
{ {
if(!_tf_output.HasValue)
if (!_tf_output.HasValue)
_tf_output = new TF_Output(op, value_index); _tf_output = new TF_Output(op, value_index);


return _tf_output.Value; return _tf_output.Value;
} }


public T[] Data<T>()
[Obsolete("Please use ToArray<T>() instead.", false)]
public T[] Data<T>() where T : unmanaged
{
return ToArray<T>();
}

/// <summary>
///
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
/// <exception cref="ArgumentException">When <typeparam name="T"> is string </typeparam></exception>
public T[] ToArray<T>() where T : unmanaged
{ {
// Column major order
// https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg
// matrix:[[1, 2, 3], [4, 5, 6]]
// index: 0 2 4 1 3 5
// result: 1 4 2 5 3 6
var data = new T[size];

for (ulong i = 0; i < size; i++)
//when T is string
if (typeof(T) == typeof(string))
{ {
data[i] = Marshal.PtrToStructure<T>(buffer + (int)(i * itemsize));
if (dtype != TF_DataType.TF_STRING)
throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string.");

return (T[]) (object) StringData();
} }


return data;
//Are the types matching?
if (typeof(T).as_dtype() == _dtype)
{
//types match, no need to perform cast
var ret = new T[size];
unsafe
{
var len = (long) size;
fixed (T* dstRet = ret)
{
T* dst = dstRet; //local stack copy
if (typeof(T).IsPrimitive)
{
var src = (T*) buffer;
len *= ((long) itemsize);
System.Buffer.MemoryCopy(src, dst, len, len);
} else
{
var itemsize = (long) this.itemsize;
var buffer = this.buffer.ToInt64();
Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure<T>(new IntPtr(buffer + i * itemsize)));
}
}
}

return ret;
} else
{
//types do not match, need to perform cast
var ret = new T[size];
unsafe
{
var len = (long) size;
fixed (T* dstRet = ret)
{
T* dst = dstRet; //local stack copy

#if _REGEN
#region Compute
switch (_dtype.as_numpy_datatype().GetTypeCode())
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1:new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
%
default:
throw new NotSupportedException();
}
#endregion
#else
#region Compute
switch (_dtype.as_numpy_dtype().GetTypeCode())
{
case NPTypeCode.Boolean:new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Byte:new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int16:new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt16:new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int32:new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt32:new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int64:new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt64:new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Char:new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Double:new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Single:new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
default:
throw new NotSupportedException();
}
#endregion
#endif
}
}
return ret;
}
} }



/// <summary>
/// Copies the memory of current buffer onto newly allocated array.
/// </summary>
/// <returns></returns>
[Obsolete("Please use set_shape(TensorShape shape) instead.", false)]
public byte[] Data() public byte[] Data()
{
return BufferToArray();
}

/// <summary>
/// Copies the memory of current buffer onto newly allocated array.
/// </summary>
/// <returns></returns>
public byte[] BufferToArray()
{ {
var data = new byte[bytesize]; var data = new byte[bytesize];
Marshal.Copy(buffer, data, 0, (int)bytesize);
Marshal.Copy(buffer, data, 0, (int) bytesize);
return data; return data;
} }


public unsafe string[] StringData()
/// Used internally in ToArray&lt;T&gt;
private unsafe string[] StringData()
{ {
// //
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
@@ -199,19 +342,19 @@ namespace Tensorflow


var buffer = new byte[size][]; var buffer = new byte[size][];
var src = c_api.TF_TensorData(_handle); var src = c_api.TF_TensorData(_handle);
var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize);
src += (int)(size * 8);
var srcLen = (IntPtr) (src.ToInt64() + (long) bytesize);
src += (int) (size * 8);
for (int i = 0; i < buffer.Length; i++) for (int i = 0; i < buffer.Length; i++)
{ {
using (var status = new Status()) using (var status = new Status())
{ {
IntPtr dst = IntPtr.Zero; IntPtr dst = IntPtr.Zero;
UIntPtr dstLen = UIntPtr.Zero; UIntPtr dstLen = UIntPtr.Zero;
var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status);
var read = c_api.TF_StringDecode((byte*) src, (UIntPtr) (srcLen.ToInt64() - src.ToInt64()), (byte**) &dst, &dstLen, status);
status.Check(true); status.Check(true);
buffer[i] = new byte[(int)dstLen];
buffer[i] = new byte[(int) dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
src += (int)read;
src += (int) read;
} }
} }


@@ -229,51 +372,29 @@ namespace Tensorflow
} }


/// <summary> /// <summary>
/// Evaluates this tensor in a `Session`.
/// Evaluates this tensor in a `Session`.
/// </summary> /// </summary>
/// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param>
/// <param name="session">The `Session` to be used to evaluate this tensor.</param>
/// <returns></returns>
/// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns>
public NDArray eval(params FeedItem[] feed_dict) public NDArray eval(params FeedItem[] feed_dict)
{ {
return ops._eval_using_default_session(this, feed_dict, graph); return ops._eval_using_default_session(this, feed_dict, graph);
} }


/// <summary>
/// Evaluates this tensor in a `Session`.
/// </summary>
/// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param>
/// <param name="session">The `Session` to be used to evaluate this tensor.</param>
/// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns>
public NDArray eval(Session session, FeedItem[] feed_dict = null) public NDArray eval(Session session, FeedItem[] feed_dict = null)
{ {
return ops._eval_using_default_session(this, feed_dict, graph, session); return ops._eval_using_default_session(this, feed_dict, graph, session);
} }


public TF_DataType ToTFDataType(Type type)
{
switch (type.Name)
{
case "Char":
return TF_DataType.TF_UINT8;
case "Int16":
return TF_DataType.TF_INT16;
case "Int32":
return TF_DataType.TF_INT32;
case "Int64":
return TF_DataType.TF_INT64;
case "Single":
return TF_DataType.TF_FLOAT;
case "Double":
return TF_DataType.TF_DOUBLE;
case "Byte":
return TF_DataType.TF_UINT8;
case "String":
return TF_DataType.TF_STRING;
case "Boolean":
return TF_DataType.TF_BOOL;
default:
throw new NotImplementedException("ToTFDataType error");
}
}

public Tensor slice(Slice slice) public Tensor slice(Slice slice)
{ {
var slice_spec = new int[] { slice.Start.Value };
var slice_spec = new int[] {slice.Start.Value};
var begin = new List<int>(); var begin = new List<int>();
var end = new List<int>(); var end = new List<int>();
var strides = new List<int>(); var strides = new List<int>();
@@ -289,26 +410,26 @@ namespace Tensorflow
if (slice.Stop.HasValue) if (slice.Stop.HasValue)
{ {
end.Add(slice.Stop.Value); end.Add(slice.Stop.Value);
}
else
} else
{ {
end.Add(0); end.Add(0);
end_mask |= (1 << index); end_mask |= (1 << index);
} }

strides.Add(slice.Step); strides.Add(slice.Step);


index += 1; index += 1;
} }


return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope =>
{ {
string name = scope; string name = scope;
if (begin != null) if (begin != null)
{ {
var (packed_begin, packed_end, packed_strides) = var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()), (array_ops.stack(begin.ToArray()),
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));


return gen_array_ops.strided_slice( return gen_array_ops.strided_slice(
this, this,
@@ -320,7 +441,6 @@ namespace Tensorflow
shrink_axis_mask: shrink_axis_mask, shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask, new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask, ellipsis_mask: ellipsis_mask,

name: name); name: name);
} }


@@ -330,7 +450,7 @@ namespace Tensorflow


public Tensor slice(int start) public Tensor slice(int start)
{ {
var slice_spec = new int[] { start };
var slice_spec = new int[] {start};
var begin = new List<int>(); var begin = new List<int>();
var end = new List<int>(); var end = new List<int>();
var strides = new List<int>(); var strides = new List<int>();
@@ -349,15 +469,15 @@ namespace Tensorflow
index += 1; index += 1;
} }


return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope =>
{ {
string name = scope; string name = scope;
if (begin != null) if (begin != null)
{ {
var (packed_begin, packed_end, packed_strides) = var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()), (array_ops.stack(begin.ToArray()),
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));


return gen_array_ops.strided_slice( return gen_array_ops.strided_slice(
this, this,
@@ -369,7 +489,6 @@ namespace Tensorflow
shrink_axis_mask: shrink_axis_mask, shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask, new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask, ellipsis_mask: ellipsis_mask,

name: name); name: name);
} }


@@ -392,13 +511,9 @@ namespace Tensorflow
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}";
} }


//protected override void DisposeManagedState()
//{
//}

protected override void DisposeUnmanagedResources(IntPtr handle) protected override void DisposeUnmanagedResources(IntPtr handle)
{ {
if(handle != IntPtr.Zero)
if (handle != IntPtr.Zero)
{ {
c_api.TF_DeleteTensor(handle); c_api.TF_DeleteTensor(handle);
} }
@@ -417,4 +532,4 @@ namespace Tensorflow


public int tensor_int_val { get; set; } public int tensor_int_val { get; set; }
} }
}
}

+ 74
- 25
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -15,6 +15,8 @@
******************************************************************************/ ******************************************************************************/


using System; using System;
using System.Numerics;
using NumSharp.Backends;


namespace Tensorflow namespace Tensorflow
{ {
@@ -23,35 +25,100 @@ namespace Tensorflow
public static TF_DataType int8 = TF_DataType.TF_INT8; public static TF_DataType int8 = TF_DataType.TF_INT8;
public static TF_DataType int32 = TF_DataType.TF_INT32; public static TF_DataType int32 = TF_DataType.TF_INT32;
public static TF_DataType int64 = TF_DataType.TF_INT64; public static TF_DataType int64 = TF_DataType.TF_INT64;
public static TF_DataType uint8 = TF_DataType.TF_UINT8;
public static TF_DataType uint32 = TF_DataType.TF_UINT32;
public static TF_DataType uint64 = TF_DataType.TF_UINT64;
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32?
public static TF_DataType float16 = TF_DataType.TF_HALF; public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE; public static TF_DataType float64 = TF_DataType.TF_DOUBLE;


public static Type as_numpy_datatype(this TF_DataType type)
/// <summary>
///
/// </summary>
/// <param name="type"></param>
/// <returns><see cref="System.Type"/> equivalent to <paramref name="type"/>, if none exists, returns null.</returns>
public static Type as_numpy_dtype(this TF_DataType type)
{ {
switch (type) switch (type)
{ {
case TF_DataType.TF_BOOL: case TF_DataType.TF_BOOL:
return typeof(bool); return typeof(bool);
case TF_DataType.TF_UINT8:
return typeof(byte);
case TF_DataType.TF_INT64: case TF_DataType.TF_INT64:
return typeof(long); return typeof(long);
case TF_DataType.TF_UINT64:
return typeof(ulong);
case TF_DataType.TF_INT32: case TF_DataType.TF_INT32:
return typeof(int); return typeof(int);
case TF_DataType.TF_UINT32:
return typeof(uint);
case TF_DataType.TF_INT16: case TF_DataType.TF_INT16:
return typeof(short); return typeof(short);
case TF_DataType.TF_UINT16:
return typeof(ushort);
case TF_DataType.TF_FLOAT: case TF_DataType.TF_FLOAT:
return typeof(float); return typeof(float);
case TF_DataType.TF_DOUBLE: case TF_DataType.TF_DOUBLE:
return typeof(double); return typeof(double);
case TF_DataType.TF_STRING: case TF_DataType.TF_STRING:
return typeof(string); return typeof(string);
case TF_DataType.TF_COMPLEX128:
case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX
return typeof(Complex);
default: default:
return null; return null;
} }
} }


// "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"
public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null)
/// <summary>
///
/// </summary>
/// <param name="type"></param>
/// <returns></returns>
/// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="NPTypeCode"/></exception>
public static NPTypeCode as_numpy_typecode(this TF_DataType type)
{
switch (type)
{
case TF_DataType.TF_BOOL:
return NPTypeCode.Boolean;
case TF_DataType.TF_UINT8:
return NPTypeCode.Byte;
case TF_DataType.TF_INT64:
return NPTypeCode.Int64;
case TF_DataType.TF_INT32:
return NPTypeCode.Int32;
case TF_DataType.TF_INT16:
return NPTypeCode.Int16;
case TF_DataType.TF_UINT64:
return NPTypeCode.UInt64;
case TF_DataType.TF_UINT32:
return NPTypeCode.UInt32;
case TF_DataType.TF_UINT16:
return NPTypeCode.UInt16;
case TF_DataType.TF_FLOAT:
return NPTypeCode.Single;
case TF_DataType.TF_DOUBLE:
return NPTypeCode.Double;
case TF_DataType.TF_STRING:
return NPTypeCode.String;
case TF_DataType.TF_COMPLEX128:
case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX
return NPTypeCode.Complex;
default:
throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode.");
}
}

/// <summary>
///
/// </summary>
/// <param name="type"></param>
/// <param name="dtype"></param>
/// <returns></returns>
/// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception>
public static TF_DataType as_dtype(this Type type, TF_DataType? dtype = null)
{ {
switch (type.Name) switch (type.Name)
{ {
@@ -98,7 +165,7 @@ namespace Tensorflow
dtype = TF_DataType.TF_BOOL; dtype = TF_DataType.TF_BOOL;
break; break;
default: default:
throw new Exception("as_dtype Not Implemented");
throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode.");
} }


return dtype.Value; return dtype.Value;
@@ -106,16 +173,7 @@ namespace Tensorflow


public static DataType as_datatype_enum(this TF_DataType type) public static DataType as_datatype_enum(this TF_DataType type)
{ {
DataType dtype = DataType.DtInvalid;

switch (type)
{
default:
Enum.TryParse(((int)type).ToString(), out dtype);
break;
}

return dtype;
return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid;
} }


public static TF_DataType as_base_dtype(this TF_DataType type) public static TF_DataType as_base_dtype(this TF_DataType type)
@@ -132,7 +190,7 @@ namespace Tensorflow


public static Type as_numpy_dtype(this DataType type) public static Type as_numpy_dtype(this DataType type)
{ {
return type.as_tf_dtype().as_numpy_datatype();
return type.as_tf_dtype().as_numpy_dtype();
} }


public static DataType as_base_dtype(this DataType type) public static DataType as_base_dtype(this DataType type)
@@ -144,16 +202,7 @@ namespace Tensorflow


public static TF_DataType as_tf_dtype(this DataType type) public static TF_DataType as_tf_dtype(this DataType type)
{ {
TF_DataType dtype = TF_DataType.DtInvalid;

switch (type)
{
default:
Enum.TryParse(((int)type).ToString(), out dtype);
break;
}

return dtype;
return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid;
} }


public static TF_DataType as_ref(this TF_DataType type) public static TF_DataType as_ref(this TF_DataType type)


+ 8
- 7
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -17,6 +17,7 @@
using NumSharp; using NumSharp;
using System; using System;
using System.Linq; using System.Linq;
using NumSharp.Utilities;


namespace Tensorflow namespace Tensorflow
{ {
@@ -109,7 +110,7 @@ namespace Tensorflow


// We first convert value to a numpy array or scalar. // We first convert value to a numpy array or scalar.
NDArray nparray = null; NDArray nparray = null;
var np_dt = dtype.as_numpy_datatype();
var np_dt = dtype.as_numpy_dtype();


if (values is NDArray nd) if (values is NDArray nd)
{ {
@@ -188,37 +189,37 @@ namespace Tensorflow
if (values.GetType().IsArray) if (values.GetType().IsArray)
nparray = np.array((int[])values, np_dt); nparray = np.array((int[])values, np_dt);
else else
nparray = Convert.ToInt32(values);
nparray = Converts.ToInt32(values);
break; break;
case "Int64": case "Int64":
if (values.GetType().IsArray) if (values.GetType().IsArray)
nparray = np.array((int[])values, np_dt); nparray = np.array((int[])values, np_dt);
else else
nparray = Convert.ToInt64(values);
nparray = Converts.ToInt64(values);
break; break;
case "Single": case "Single":
if (values.GetType().IsArray) if (values.GetType().IsArray)
nparray = np.array((float[])values, np_dt); nparray = np.array((float[])values, np_dt);
else else
nparray = Convert.ToSingle(values);
nparray = Converts.ToSingle(values);
break; break;
case "Double": case "Double":
if (values.GetType().IsArray) if (values.GetType().IsArray)
nparray = np.array((double[])values, np_dt); nparray = np.array((double[])values, np_dt);
else else
nparray = Convert.ToDouble(values);
nparray = Converts.ToDouble(values);
break; break;
case "String": case "String":
if (values.GetType().IsArray) if (values.GetType().IsArray)
nparray = np.array((string[])values, np_dt); nparray = np.array((string[])values, np_dt);
else else
nparray = NDArray.FromString(Convert.ToString(values));
nparray = NDArray.FromString(Converts.ToString(values));
break; break;
case "Boolean": case "Boolean":
if (values.GetType().IsArray) if (values.GetType().IsArray)
nparray = np.array((bool[])values, np_dt); nparray = np.array((bool[])values, np_dt);
else else
nparray = Convert.ToBoolean(values);
nparray = Converts.ToBoolean(values);
break; break;
default: default:
throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented");


+ 70
- 14
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -29,55 +29,111 @@ namespace Tensorflow
/// </summary> /// </summary>
public class GraphKeys public class GraphKeys
{ {
#region const

/// <summary>
/// the subset of `Variable` objects that will be trained by an optimizer.
/// </summary>
public const string TRAINABLE_VARIABLES_ = "trainable_variables";

/// <summary>
/// Trainable resource-style variables.
/// </summary>
public const string TRAINABLE_RESOURCE_VARIABLES_ = "trainable_resource_variables";

/// <summary>
/// Key for streaming model ports.
/// </summary>
public const string _STREAMING_MODEL_PORTS_ = "streaming_model_ports";

/// <summary>
/// Key to collect losses
/// </summary>
public const string LOSSES_ = "losses";

/// <summary>
/// Key to collect Variable objects that are global (shared across machines).
/// Default collection for all variables, except local ones.
/// </summary>
public const string GLOBAL_VARIABLES_ = "variables";

public const string TRAIN_OP_ = "train_op";

public const string GLOBAL_STEP_ = "global_step";

public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" };
/// <summary>
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
/// </summary>
public const string SAVEABLE_OBJECTS_ = "saveable_objects";
/// <summary>
/// Key to collect update_ops
/// </summary>
public const string UPDATE_OPS_ = "update_ops";

// Key to collect summaries.
public const string SUMMARIES_ = "summaries";

// Used to store v2 summary names.
public const string _SUMMARY_COLLECTION_ = "_SUMMARY_V2";

// Key for control flow context.
public const string COND_CONTEXT_ = "cond_context";
public const string WHILE_CONTEXT_ = "while_context";

#endregion

/// <summary> /// <summary>
/// the subset of `Variable` objects that will be trained by an optimizer. /// the subset of `Variable` objects that will be trained by an optimizer.
/// </summary> /// </summary>
public string TRAINABLE_VARIABLES = "trainable_variables";
public string TRAINABLE_VARIABLES => TRAINABLE_VARIABLES_;


/// <summary> /// <summary>
/// Trainable resource-style variables. /// Trainable resource-style variables.
/// </summary> /// </summary>
public string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables";
public string TRAINABLE_RESOURCE_VARIABLES => TRAINABLE_RESOURCE_VARIABLES_;


/// <summary> /// <summary>
/// Key for streaming model ports. /// Key for streaming model ports.
/// </summary> /// </summary>
public string _STREAMING_MODEL_PORTS = "streaming_model_ports";
public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_;


/// <summary> /// <summary>
/// Key to collect losses /// Key to collect losses
/// </summary> /// </summary>
public string LOSSES = "losses";
public string LOSSES => LOSSES_;


/// <summary> /// <summary>
/// Key to collect Variable objects that are global (shared across machines). /// Key to collect Variable objects that are global (shared across machines).
/// Default collection for all variables, except local ones. /// Default collection for all variables, except local ones.
/// </summary> /// </summary>
public string GLOBAL_VARIABLES = "variables";
public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_;


public string TRAIN_OP = "train_op";
public string TRAIN_OP => TRAIN_OP_;


public string GLOBAL_STEP = "global_step";
public string GLOBAL_STEP => GLOBAL_STEP_;


public string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" };
public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_;
/// <summary> /// <summary>
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
/// </summary> /// </summary>
public string SAVEABLE_OBJECTS = "saveable_objects";
public string SAVEABLE_OBJECTS => SAVEABLE_OBJECTS_;
/// <summary> /// <summary>
/// Key to collect update_ops /// Key to collect update_ops
/// </summary> /// </summary>
public string UPDATE_OPS = "update_ops";
public string UPDATE_OPS => UPDATE_OPS_;


// Key to collect summaries. // Key to collect summaries.
public string SUMMARIES = "summaries";
public string SUMMARIES => SUMMARIES_;


// Used to store v2 summary names. // Used to store v2 summary names.
public string _SUMMARY_COLLECTION = "_SUMMARY_V2";
public string _SUMMARY_COLLECTION => _SUMMARY_COLLECTION_;


// Key for control flow context. // Key for control flow context.
public string COND_CONTEXT = "cond_context";
public string WHILE_CONTEXT = "while_context";
public string COND_CONTEXT => COND_CONTEXT_;
public string WHILE_CONTEXT => WHILE_CONTEXT_;
} }
} }
} }

+ 2
- 2
test/TensorFlowNET.UnitTest/SessionTest.cs View File

@@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims); EXPECT_EQ(0, outTensor.NDims);
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
var output_contents = outTensor.Data<int>();
var output_contents = outTensor.ToArray<int>();
EXPECT_EQ(3 + 2, output_contents[0]); EXPECT_EQ(3 + 2, output_contents[0]);


// Add another operation to the graph. // Add another operation to the graph.
@@ -66,7 +66,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims); // scalar EXPECT_EQ(0, outTensor.NDims); // scalar
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
output_contents = outTensor.Data<int>();
output_contents = outTensor.ToArray<int>();
EXPECT_EQ(-(7 + 2), output_contents[0]); EXPECT_EQ(-(7 + 2), output_contents[0]);


// Clean up // Clean up


+ 1
- 1
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -112,7 +112,7 @@ namespace TensorFlowNET.UnitTest
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);


var tensor = new Tensor(nd); var tensor = new Tensor(nd);
var array = tensor.Data<float>();
var array = tensor.ToArray<float>();


EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
EXPECT_EQ(tensor.rank, nd.ndim); EXPECT_EQ(tensor.rank, nd.ndim);


+ 1
- 1
test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs View File

@@ -31,7 +31,7 @@ namespace TensorFlowNET.UnitTest.nn_test
var y_np = this._ZeroFraction(x_np); var y_np = this._ZeroFraction(x_np);
var x_tf = constant_op.constant(x_np); var x_tf = constant_op.constant(x_np);
x_tf.SetShape(x_shape);
x_tf.set_shape(x_shape);
var y_tf = nn_impl.zero_fraction(x_tf); var y_tf = nn_impl.zero_fraction(x_tf);
var y_tf_np = self.evaluate<NDArray>(y_tf); var y_tf_np = self.evaluate<NDArray>(y_tf);


Loading…
Cancel
Save