Browse Source

Merge pull request #366 from SciSharp/perf-ops

Performance optimization, refactoring and revamping.
tags/v0.12
Haiping GitHub 6 years ago
parent
commit
af73e3cb26
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 1044 additions and 722 deletions
  1. +4
    -0
      src/TensorFlowNET.Core/Assembly/Properties.cs
  2. +7
    -2
      src/TensorFlowNET.Core/Binding.Util.cs
  3. +80
    -22
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  4. +33
    -20
      src/TensorFlowNET.Core/DisposableObject.cs
  5. +12
    -17
      src/TensorFlowNET.Core/Eager/Context.cs
  6. +11
    -14
      src/TensorFlowNET.Core/Eager/ContextOptions.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Exceptions/KeyError.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Exceptions/RuntimeError.cs
  9. +36
    -0
      src/TensorFlowNET.Core/Exceptions/TensorflowException.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Exceptions/TypeError.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Exceptions/ValueError.cs
  12. +0
    -5
      src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs
  13. +8
    -6
      src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
  14. +31
    -18
      src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs
  15. +3
    -1
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  16. +12
    -11
      src/TensorFlowNET.Core/Graphs/Graph.Export.cs
  17. +4
    -5
      src/TensorFlowNET.Core/Graphs/Graph.Import.cs
  18. +11
    -6
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  19. +9
    -11
      src/TensorFlowNET.Core/Graphs/Graph.cs
  20. +2
    -1
      src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
  21. +3
    -5
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  22. +11
    -8
      src/TensorFlowNET.Core/Operations/Operation.cs
  23. +413
    -413
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  24. +6
    -0
      src/TensorFlowNET.Core/Sessions/FeedItem.cs
  25. +2
    -2
      src/TensorFlowNET.Core/Sessions/SessionOptions.cs
  26. +8
    -4
      src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs
  27. +10
    -8
      src/TensorFlowNET.Core/Status/Status.cs
  28. +1
    -1
      src/TensorFlowNET.Core/Status/c_api.status.cs
  29. +6
    -6
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  30. +16
    -40
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  31. +8
    -11
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  32. +94
    -0
      src/TensorFlowNET.Core/Util/UnmanagedExtensions.cs
  33. +4
    -2
      src/TensorFlowNET.Core/globals.regen
  34. +2
    -2
      src/TensorFlowNET.Core/ops.cs
  35. +1
    -2
      src/TensorFlowNET.Core/tensorflow.cs
  36. +2
    -8
      src/TensorFlowNet.Benchmarks/Program.cs
  37. +1
    -0
      src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj
  38. +76
    -0
      src/TensorFlowNet.Benchmarks/Unmanaged/StructCastBenchmark.cs
  39. +3
    -3
      test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs
  40. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs
  41. +11
    -5
      test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs
  42. +14
    -12
      test/TensorFlowNET.UnitTest/Basics/AssignTests.cs
  43. +13
    -9
      test/TensorFlowNET.UnitTest/CApiGradientsTest.cs
  44. +2
    -9
      test/TensorFlowNET.UnitTest/CSession.cs
  45. +0
    -5
      test/TensorFlowNET.UnitTest/GraphTest.cs
  46. +36
    -1
      test/TensorFlowNET.UnitTest/NameScopeTest.cs
  47. BIN
      test/TensorFlowNET.UnitTest/Open.snk
  48. +2
    -1
      test/TensorFlowNET.UnitTest/OperationsTest.cs
  49. +1
    -1
      test/TensorFlowNET.UnitTest/PythonTest.cs
  50. +1
    -3
      test/TensorFlowNET.UnitTest/SessionTest.cs
  51. +6
    -0
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
  52. +1
    -1
      test/TensorFlowNET.UnitTest/VariableTest.cs
  53. +11
    -6
      test/TensorFlowNET.UnitTest/c_test_util.cs
  54. +11
    -10
      test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs

+ 4
- 0
src/TensorFlowNET.Core/Assembly/Properties.cs View File

@@ -0,0 +1,4 @@
using System.Runtime.CompilerServices;
#if DEBUG
[assembly: InternalsVisibleTo("TensorFlowNET.UnitTest, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]
#endif

+ 7
- 2
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -178,13 +178,18 @@ namespace Tensorflow

public static IEnumerable<(TKey, TValue)> enumerate<TKey, TValue>(KeyValuePair<TKey, TValue>[] values)
{
foreach (var item in values)
var len = values.Length;
for (var i = 0; i < len; i++)
{
var item = values[i];
yield return (item.Key, item.Value);
}
}

public static IEnumerable<(int, T)> enumerate<T>(IList<T> values)
{
for (int i = 0; i < values.Count; i++)
var len = values.Count;
for (int i = 0; i < len; i++)
yield return (i, values[i]);
}



+ 80
- 22
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -15,58 +15,116 @@
******************************************************************************/

using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using NumSharp.Backends.Unmanaged;
using static Tensorflow.c_api;

namespace Tensorflow
{
/// <summary>
/// Represents a TF_Buffer that can be passed to Tensorflow.
/// </summary>
public class Buffer : DisposableObject
{
private TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle);
private unsafe TF_Buffer buffer
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => *bufferptr;
}

private unsafe TF_Buffer* bufferptr
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => (TF_Buffer*) _handle;
}

public byte[] Data
/// <summary>
/// The memory block representing this buffer.
/// </summary>
/// <remarks>The deallocator is set to null.</remarks>
public UnmanagedMemoryBlock<byte> MemoryBlock
{
get
get
{
var data = new byte[buffer.length];
if (data.Length > 0)
Marshal.Copy(buffer.data, data, 0, data.Length);
return data;
unsafe
{
EnsureNotDisposed();
var buff = (TF_Buffer*) _handle;
return new UnmanagedMemoryBlock<byte>((byte*) buff->data.ToPointer(), (long) buff->length);
}
}
}

public int Length => (int)buffer.length;

public Buffer()
/// <summary>
/// The bytes length of this buffer.
/// </summary>
public ulong Length
{
_handle = c_api.TF_NewBuffer();
get
{
EnsureNotDisposed();
return buffer.length;
}
}

public Buffer(IntPtr handle)
public Buffer() => _handle = TF_NewBuffer();

internal Buffer(IntPtr handle)
{
if (handle == IntPtr.Zero)
throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle));

_handle = handle;
}

public Buffer(byte[] data)
{
var dst = Marshal.AllocHGlobal(data.Length);
Marshal.Copy(data, 0, dst, data.Length);
public Buffer(byte[] data) : this(_toBuffer(data))
{ }

_handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length);
private static IntPtr _toBuffer(byte[] data)
{
if (data == null)
throw new ArgumentNullException(nameof(data));

Marshal.FreeHGlobal(dst);
unsafe
{
fixed (byte* src = data)
return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength);
}
}

public static implicit operator IntPtr(Buffer buffer)
{
buffer.EnsureNotDisposed();
return buffer._handle;
}

public static implicit operator byte[](Buffer buffer)
public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost.

/// <summary>
/// Copies this buffer's contents onto a <see cref="byte"/> array.
/// </summary>
public byte[] ToArray()
{
return buffer.Data;
EnsureNotDisposed();

unsafe
{
var len = buffer.length;
if (len == 0)
return Array.Empty<byte>();

byte[] data = new byte[len];
fixed (byte* dst = data)
System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len);

return data;
}
}

protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteBuffer(handle);
{
TF_DeleteBuffer(handle);
}
}
}
}

+ 33
- 20
src/TensorFlowNET.Core/DisposableObject.cs View File

@@ -16,6 +16,8 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Text;

namespace Tensorflow
@@ -26,27 +28,33 @@ namespace Tensorflow
public abstract class DisposableObject : IDisposable
{
protected IntPtr _handle;
protected bool _disposed;

protected DisposableObject() { }
[SuppressMessage("ReSharper", "UnusedMember.Global")]
protected DisposableObject()
{ }

protected DisposableObject(IntPtr handle)
protected DisposableObject(IntPtr handle)
=> _handle = handle;

[SuppressMessage("ReSharper", "InvertIf")]
private void internal_dispose(bool disposing)
{
if (disposing)
{
// free unmanaged resources (unmanaged objects) and override a finalizer below.
if (_handle != IntPtr.Zero)
{
// dispose managed state (managed objects).
DisposeManagedResources();
if (_disposed)
return;

_disposed = true;

// set large fields to null.
DisposeUnmanagedResources(_handle);
//first handle managed, they might use the unmanaged resources.
if (disposing)
// dispose managed state (managed objects).
DisposeManagedResources();

_handle = IntPtr.Zero;
}
//free unmanaged memory
if (_handle != IntPtr.Zero)
{
DisposeUnmanagedResources(_handle);
_handle = IntPtr.Zero;
}
}

@@ -55,28 +63,33 @@ namespace Tensorflow
/// </summary>
/// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks>
protected virtual void DisposeManagedResources()
{
}
{ }

/// <summary>
/// Dispose any unmanaged resources related to given <paramref name="handle"/>.
/// </summary>
protected abstract void DisposeUnmanagedResources(IntPtr handle);

// override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources.
~DisposableObject()
{
// Do not change this code. Put cleanup code in Dispose(bool disposing) above.
internal_dispose(false);
}

// This code added to correctly implement the disposable pattern.
public void Dispose()
{
// Do not change this code. Put cleanup code in Dispose(bool disposing) above.
internal_dispose(true);
// uncomment the following line if the finalizer is overridden above.
GC.SuppressFinalize(this);
}

/// <summary>
/// If <see cref="_handle"/> is <see cref="IntPtr.Zero"/> then throws <see cref="ObjectDisposedException"/>
/// </summary>
/// <exception cref="ObjectDisposedException">When <see cref="_handle"/> is <see cref="IntPtr.Zero"/></exception>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
protected void EnsureNotDisposed()
{
if (_disposed)
throw new ObjectDisposedException($"Unable to access disposed object, Type: {GetType().Name}");
}
}
}

+ 12
- 17
src/TensorFlowNET.Core/Eager/Context.cs View File

@@ -2,12 +2,10 @@

namespace Tensorflow.Eager
{
public class Context : IDisposable
public class Context : DisposableObject
{
private IntPtr _handle;

public static int GRAPH_MODE = 0;
public static int EAGER_MODE = 1;
public const int GRAPH_MODE = 0;
public const int EAGER_MODE = 1;

public int default_execution_mode;

@@ -17,19 +15,16 @@ namespace Tensorflow.Eager
status.Check(true);
}

public void Dispose()
{
c_api.TFE_DeleteContext(_handle);
}
/// <summary>
/// Dispose any unmanaged resources related to given <paramref name="handle"/>.
/// </summary>
protected sealed override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TFE_DeleteContext(_handle);

public bool executing_eagerly()
{
return false;
}

public static implicit operator IntPtr(Context ctx)
{
return ctx._handle;
}
public bool executing_eagerly() => false;

public static implicit operator IntPtr(Context ctx)
=> ctx._handle;
}
}

+ 11
- 14
src/TensorFlowNET.Core/Eager/ContextOptions.cs View File

@@ -3,23 +3,20 @@ using System.IO;

namespace Tensorflow.Eager
{
public class ContextOptions : IDisposable //TODO! Eli: Shouldn't this inherieting DisposableObject?
public class ContextOptions : DisposableObject
{
private IntPtr _handle;
public ContextOptions() : base(c_api.TFE_NewContextOptions())
{ }

public ContextOptions()
{
_handle = c_api.TFE_NewContextOptions();
}
/// <summary>
/// Dispose any unmanaged resources related to given <paramref name="handle"/>.
/// </summary>
protected sealed override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TFE_DeleteContextOptions(_handle);

public void Dispose()
{
c_api.TFE_DeleteContextOptions(_handle);
}

public static implicit operator IntPtr(ContextOptions opts)
{
return opts._handle;
}
public static implicit operator IntPtr(ContextOptions opts)
=> opts._handle;
}

}

+ 1
- 1
src/TensorFlowNET.Core/Exceptions/KeyError.cs View File

@@ -2,7 +2,7 @@

namespace Tensorflow
{
public class KeyError : Exception
public class KeyError : TensorflowException
{
public KeyError() : base()
{


+ 1
- 1
src/TensorFlowNET.Core/Exceptions/RuntimeError.cs View File

@@ -2,7 +2,7 @@

namespace Tensorflow
{
public class RuntimeError : Exception
public class RuntimeError : TensorflowException
{
public RuntimeError() : base()
{


+ 36
- 0
src/TensorFlowNET.Core/Exceptions/TensorflowException.cs View File

@@ -0,0 +1,36 @@
using System;
using System.Runtime.Serialization;

namespace Tensorflow
{

/// <summary>
/// Serves as a base class to all exceptions of Tensorflow.NET.
/// </summary>
[Serializable]
public class TensorflowException : Exception
{
/// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class.</summary>
public TensorflowException()
{ }

/// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with serialized data.</summary>
/// <param name="info">The <see cref="T:System.Runtime.Serialization.SerializationInfo"></see> that holds the serialized object data about the exception being thrown.</param>
/// <param name="context">The <see cref="T:System.Runtime.Serialization.StreamingContext"></see> that contains contextual information about the source or destination.</param>
/// <exception cref="T:System.ArgumentNullException">The <paramref name="info">info</paramref> parameter is null.</exception>
/// <exception cref="T:System.Runtime.Serialization.SerializationException">The class name is null or <see cref="P:System.Exception.HResult"></see> is zero (0).</exception>
protected TensorflowException(SerializationInfo info, StreamingContext context) : base(info, context)
{ }

/// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with a specified error message.</summary>
/// <param name="message">The message that describes the error.</param>
public TensorflowException(string message) : base(message)
{ }

/// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with a specified error message and a reference to the inner exception that is the cause of this exception.</summary>
/// <param name="message">The error message that explains the reason for the exception.</param>
/// <param name="innerException">The exception that is the cause of the current exception, or a null reference (Nothing in Visual Basic) if no inner exception is specified.</param>
public TensorflowException(string message, Exception innerException) : base(message, innerException)
{ }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Exceptions/TypeError.cs View File

@@ -2,7 +2,7 @@

namespace Tensorflow
{
public class TypeError : Exception
public class TypeError : TensorflowException
{
public TypeError() : base()
{


+ 1
- 1
src/TensorFlowNET.Core/Exceptions/ValueError.cs View File

@@ -2,7 +2,7 @@

namespace Tensorflow
{
public class ValueError : Exception
public class ValueError : TensorflowException
{
public ValueError() : base()
{


+ 0
- 5
src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs View File

@@ -6,10 +6,5 @@
{

}

~ScopedTFImportGraphDefOptions()
{
base.Dispose();
}
}
}

+ 8
- 6
src/TensorFlowNET.Core/Framework/op_def_registry.py.cs View File

@@ -15,6 +15,8 @@
******************************************************************************/

using System.Collections.Generic;
using System.IO;
using Tensorflow.Util;

namespace Tensorflow
{
@@ -27,12 +29,12 @@ namespace Tensorflow
if(_registered_ops == null)
{
_registered_ops = new Dictionary<string, OpDef>();
var handle = c_api.TF_GetAllOpList();
var buffer = new Buffer(handle);
var op_list = OpList.Parser.ParseFrom(buffer);
foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def;
using (var buffer = new Buffer(c_api.TF_GetAllOpList()))
{
var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream());
foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def;
}
}

return _registered_ops;


+ 31
- 18
src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs View File

@@ -14,49 +14,62 @@
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class DefaultGraphStack

/// <summary>
/// Serves as a stack for determining current default graph.
/// </summary>
public class DefaultGraphStack
{
List<StackModel> stack = new List<StackModel>();
private readonly List<StackModel> _stack = new List<StackModel>();

public void set_controller(Graph @default)
{
if (!stack.Exists(x => x.Graph == @default))
stack.Add(new StackModel { Graph = @default, IsDefault = true });
if (!_stack.Exists(x => x.Graph == @default))
_stack.Add(new StackModel {Graph = @default, IsDefault = true});

foreach (var s in stack)
foreach (var s in _stack)
s.IsDefault = s.Graph == @default;
}

public Graph get_controller()
{
if (stack.Count(x => x.IsDefault) == 0)
stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true });
if (_stack.Count(x => x.IsDefault) == 0)
_stack.Add(new StackModel {Graph = tf.Graph(), IsDefault = true});
for (var i = _stack.Count - 1; i >= 0; i--)
{
var x = _stack[i];
if (x.IsDefault)
return x.Graph;
}

return stack.Last(x => x.IsDefault).Graph;
throw new TensorflowException("Unable to find a default graph");
}

public bool remove(Graph g)
{
var sm = stack.FirstOrDefault(x => x.Graph == g);
if (sm == null) return false;
return stack.Remove(sm);
if (_stack.Count == 0)
return false;

var sm = _stack.Find(model => model.Graph == g);
return sm != null && _stack.Remove(sm);
}

public void reset()
{
stack.Clear();
_stack.Clear();
}
}

public class StackModel
{
public Graph Graph { get; set; }
public bool IsDefault { get; set; }
private class StackModel
{
public Graph Graph { get; set; }
public bool IsDefault { get; set; }
}
}
}
}

+ 3
- 1
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Tensorflow.Operations;

@@ -66,8 +67,9 @@ namespace Tensorflow
/// within the context should have control dependencies on
/// `control_inputs`.
/// </summary>
[SuppressMessage("ReSharper", "CoVariantArrayConversion")]
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
=> control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray());
=> control_dependencies((object[])control_inputs);

/// <summary>
/// Returns a context manager that specifies control dependencies.


+ 12
- 11
src/TensorFlowNET.Core/Graphs/Graph.Export.cs View File

@@ -14,6 +14,9 @@
limitations under the License.
******************************************************************************/

using System.IO;
using Tensorflow.Util;

namespace Tensorflow
{
public partial class Graph
@@ -23,21 +26,19 @@ namespace Tensorflow
var buffer = new Buffer();
c_api.TF_GraphToGraphDef(_handle, buffer, s);
s.Check(true);
// var def = GraphDef.Parser.ParseFrom(buffer);
// buffer.Dispose();

return buffer;
}

private GraphDef _as_graph_def(bool add_shapes = false)
{
var status = new Status();
var buffer = ToGraphDef(status);
status.Check(true);
status.Dispose();
var def = GraphDef.Parser.ParseFrom(buffer);
buffer.Dispose();
GraphDef def;
using (var status = new Status())
using (var buffer = ToGraphDef(status))
{
status.Check(true);
def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
}

// Strip the experimental library field iff it's empty.
// if(def.Library.Function.Count == 0)
@@ -45,7 +46,7 @@ namespace Tensorflow
return def;
}

public GraphDef as_graph_def(bool add_shapes = false)
public GraphDef as_graph_def(bool add_shapes = false)
=> _as_graph_def(add_shapes);
}
}
}

+ 4
- 5
src/TensorFlowNET.Core/Graphs/Graph.Import.cs View File

@@ -30,11 +30,10 @@ namespace Tensorflow
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs);

c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s);
for (int i = 0; i < num_return_outputs; i++)
{
var handle = return_output_handle + i * size;
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle);
}

var tf_output_ptr = (TF_Output*) return_output_handle;
for (int i = 0; i < num_return_outputs; i++)
return_outputs[i] = *(tf_output_ptr + i);

Marshal.FreeHGlobal(return_output_handle);



+ 11
- 6
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -18,6 +18,7 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -30,7 +31,7 @@ namespace Tensorflow
using (var status = new Status())
{
c_api.TF_GraphGetOpDef(_handle, type, buffer, status);
return OpDef.Parser.ParseFrom(buffer.Data);
return OpDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
}
}

@@ -39,16 +40,20 @@ namespace Tensorflow
return c_api.TF_NewOperation(_handle, opType, opName);
}
public unsafe Operation[] ReturnOperations(IntPtr results)
public Operation[] ReturnOperations(IntPtr results)
{
TF_Operation return_oper_handle = new TF_Operation();
int num_return_opers = 0;
c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle);
Operation[] return_opers = new Operation[num_return_opers];
var tf_op_size = Marshal.SizeOf<TF_Operation>();
for (int i = 0; i < num_return_opers; i++)
{
var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i;
return_opers[i] = new Operation(*(IntPtr*)handle);
unsafe
{
var handle = return_oper_handle.node + tf_op_size * i;
return_opers[i] = new Operation(*(IntPtr*)handle);
}
}
return return_opers;
@@ -67,7 +72,7 @@ namespace Tensorflow

public ITensorOrOperation[] get_operations()
{
return _nodes_by_name.Values.Select(x => x).ToArray();
return _nodes_by_name.Values.ToArray();
}
/// <summary>
@@ -81,7 +86,7 @@ namespace Tensorflow

public ITensorOrOperation _get_operation_by_name_unsafe(string name)
{
return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null;
return _nodes_by_name.TryGetValue(name, out var val) ? val : null;
}

public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper)


+ 9
- 11
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -369,7 +369,7 @@ namespace Tensorflow
var name_key = name.ToLower();
int i = 0;
if (_names_in_use.ContainsKey(name_key))
i = _names_in_use[name_key];
i = _names_in_use[name_key];
// Increment the number for "name_key".
if (mark_as_used)
_names_in_use[name_key] = i + 1;
@@ -399,13 +399,13 @@ namespace Tensorflow
int num_return_outputs = 0;
c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle);
TF_Output[] return_outputs = new TF_Output[num_return_outputs];
for (int i = 0; i < num_return_outputs; i++)
unsafe
{
var handle = return_output_handle + (Marshal.SizeOf<TF_Output>() * i);
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle);
var tf_output_ptr = (TF_Output*) return_output_handle;
for (int i = 0; i < num_return_outputs; i++)
return_outputs[i] = *(tf_output_ptr + i);
return return_outputs;
}

return return_outputs;
}

public string[] get_all_collection_keys()
@@ -497,11 +497,9 @@ namespace Tensorflow
IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator()
=> GetEnumerable().GetEnumerator();
IEnumerator IEnumerable.GetEnumerator()
{
throw new NotImplementedException();
}
IEnumerator IEnumerable.GetEnumerator()
=> throw new NotImplementedException();

public static implicit operator IntPtr(Graph graph)
{
return graph._handle;


+ 2
- 1
src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs View File

@@ -20,7 +20,8 @@ namespace Tensorflow
{
public class ImportGraphDefOptions : DisposableObject
{
public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle);
public int NumReturnOutputs
=> c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle);

public ImportGraphDefOptions()
{


+ 3
- 5
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -50,14 +50,12 @@ namespace Tensorflow

public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
{
int size = Marshal.SizeOf<TF_Input>();
var handle = Marshal.AllocHGlobal(size);
var handle = Marshal.AllocHGlobal(Marshal.SizeOf<TF_Input>());
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers);
var consumers = new TF_Input[num];
var inputptr = (TF_Input*) handle;
for (int i = 0; i < num; i++)
{
consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size);
}
consumers[i] = *(inputptr + i);

return consumers;
}


+ 11
- 8
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -17,7 +17,9 @@
using Google.Protobuf.Collections;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Tensorflow.Util;

namespace Tensorflow
{
@@ -226,9 +228,12 @@ namespace Tensorflow
using (var status = new Status())
using (var buf = new Buffer())
{
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
status.Check(true);
x = AttrValue.Parser.ParseFrom(buf);
unsafe
{
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
status.Check(true);
x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream());
}
}

string oneof_value = x.ValueCase.ToString();
@@ -259,7 +264,7 @@ namespace Tensorflow
{
c_api.TF_OperationToNodeDef(_handle, buffer, s);
s.Check();
return NodeDef.Parser.ParseFrom(buffer);
return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
}
}
@@ -299,8 +304,7 @@ namespace Tensorflow
/// </summary>
public TF_Output _tf_output(int output_idx)
{
var tf_output = new TF_Output(op, output_idx);
return tf_output;
return new TF_Output(op, output_idx);
}
/// <summary>
@@ -308,8 +312,7 @@ namespace Tensorflow
/// </summary>
public TF_Input _tf_input(int input_idx)
{
var tf_input = new TF_Input(op, input_idx);
return tf_input;
return new TF_Input(op, input_idx);
}
}
}

+ 413
- 413
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -1,413 +1,413 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/
using NumSharp;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Numerics;
using System.Text;
namespace Tensorflow
{
public class BaseSession : DisposableObject
{
protected Graph _graph;
protected bool _opened;
protected bool _closed;
protected int _current_version;
protected byte[] _target;
public Graph graph => _graph;
public BaseSession(string target = "", Graph g = null, SessionOptions opts = null)
{
_graph = g is null ? ops.get_default_graph() : g;
_graph.as_default();
_target = UTF8Encoding.UTF8.GetBytes(target);
SessionOptions newOpts = null;
if (opts == null)
newOpts = new SessionOptions();
var status = new Status();
_handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status);
// dispose newOpts
if (opts == null)
newOpts.Dispose();
status.Check(true);
}
public virtual void run(Operation op, params FeedItem[] feed_dict)
{
_run(op, feed_dict);
}
public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict)
{
return _run(fetche, feed_dict)[0];
}
public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict)
{
return _run(fetche, feed_dict)[0];
}
public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict);
return (results[0], results[1], results[2], results[3]);
}
public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict);
return (results[0], results[1], results[2]);
}
public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict);
return (results[0], results[1]);
}
public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict)
{
return _run(fetches, feed_dict);
}
public virtual NDArray[] run(object fetches, Hashtable feed_dict = null)
{
var feed_items = feed_dict == null ? new FeedItem[0] :
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
return _run(fetches, feed_items);
}
private NDArray[] _run(object fetches, FeedItem[] feed_dict = null)
{
var feed_dict_tensor = new Dictionary<object, object>();
var feed_map = new Dictionary<object, object>();
Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) =>
{
return new (object, object)[] { (item.Key, item.Value) };
};
// Validate and process feed_dict.
if (feed_dict != null)
{
foreach (var feed in feed_dict)
{
foreach (var (subfeed, subfeed_val) in feed_fn(feed))
{
var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
feed_dict_tensor[subfeed_t] = subfeed_val;
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
}
}
}
// Create a fetch handler to take care of the structure of fetches.
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);
// Run request and get response.
// We need to keep the returned movers alive for the following _do_run().
// These movers are no longer needed when _do_run() completes, and
// are deleted when `movers` goes out of scope when this _run() ends.
var _ = _update_with_movers();
var final_fetches = fetch_handler.fetches();
var final_targets = fetch_handler.targets();
// We only want to really perform the run if fetches or targets are provided,
// or if the call is a partial run that specifies feeds.
var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor);
return fetch_handler.build_results(this, results);
}
/// <summary>
/// Runs a step based on the given fetches and feeds.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="target_list">A list of operations to be run, but not fetched.</param>
/// <param name="fetch_list"></param>
/// <param name="feed_dict"></param>
/// <returns>
/// A list of numpy ndarrays, corresponding to the elements of
/// `fetch_list`. If the ith element of `fetch_list` contains the
/// name of an operation, the first Tensor output of that operation
/// will be returned for that element.
/// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
{
var feeds = feed_dict.Select(x =>
{
if (x.Key is Tensor tensor)
{
switch (x.Value)
{
#if _REGEN
%types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types%
case #1 v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case #1[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
%
#else
case sbyte v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case sbyte[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case byte v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case byte[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case short v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case short[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ushort v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ushort[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case int v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case int[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case uint v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case uint[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case long v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case long[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ulong v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ulong[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case float v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case float[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case double v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case double[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Complex v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Complex[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
#endif
case bool v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL));
case string v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case IntPtr v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Tensor v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v);
case NDArray v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype));
default:
throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}");
}
}
throw new NotImplementedException("_do_run.feed_dict");
}).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
var targets = target_list;
return _call_tf_sessionrun(feeds, fetches, target_list);
}
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
{
// Ensure any changes to the graph are reflected in the runtime.
_extend_graph();
var status = new Status();
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();
c_api.TF_SessionRun(_handle,
run_options: null,
inputs: feed_dict.Select(f => f.Key).ToArray(),
input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(),
ninputs: feed_dict.Length,
outputs: fetch_list,
output_values: output_values,
noutputs: fetch_list.Length,
target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
ntargets: target_list.Count,
run_metadata: IntPtr.Zero,
status: status);
status.Check(true);
var result = new NDArray[fetch_list.Length];
for (int i = 0; i < fetch_list.Length; i++)
result[i] = fetchValue(output_values[i]);
for (int i = 0; i < feed_dict.Length; i++)
feed_dict[i].Value.Dispose();
return result;
}
private unsafe NDArray fetchValue(IntPtr output)
{
var tensor = new Tensor(output);
NDArray nd = null;
Type type = tensor.dtype.as_numpy_dtype();
var ndims = tensor.shape;
var offset = c_api.TF_TensorData(output);
if(ndims.Length == 0)
{
switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
nd = NDArray.Scalar(*(bool*)offset);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = NDArray.FromString(str);
break;
case TF_DataType.TF_UINT8:
nd = NDArray.Scalar(*(byte*)offset);
break;
case TF_DataType.TF_INT16:
nd = NDArray.Scalar(*(short*)offset);
break;
case TF_DataType.TF_INT32:
nd = NDArray.Scalar(*(int*)offset);
break;
case TF_DataType.TF_INT64:
nd = NDArray.Scalar(*(long*)offset);
break;
case TF_DataType.TF_FLOAT:
nd = NDArray.Scalar(*(float*)offset);
break;
case TF_DataType.TF_DOUBLE:
nd = NDArray.Scalar(*(double*)offset);
break;
default:
throw new NotImplementedException("can't fetch output");
}
}
else
{
switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
var bools = new bool[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i));
nd = np.array(bools).reshape(ndims);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = np.array(str);
break;
case TF_DataType.TF_UINT8:
var _bytes = new byte[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
_bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i));
nd = np.array(_bytes).reshape(ndims);
break;
case TF_DataType.TF_INT16:
var shorts = new short[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i));
nd = np.array(shorts).reshape(ndims);
break;
case TF_DataType.TF_INT32:
var ints = new int[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
ints[i] = *(int*)(offset + (int)(tensor.itemsize * i));
nd = np.array(ints).reshape(ndims);
break;
case TF_DataType.TF_INT64:
var longs = new long[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
longs[i] = *(long*)(offset + (int)(tensor.itemsize * i));
nd = np.array(longs).reshape(ndims);
break;
case TF_DataType.TF_FLOAT:
var floats = new float[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
floats[i] = *(float*)(offset + (int)(tensor.itemsize * i));
nd = np.array(floats).reshape(ndims);
break;
case TF_DataType.TF_DOUBLE:
var doubles = new double[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i));
nd = np.array(doubles).reshape(ndims);
break;
default:
throw new NotImplementedException("can't fetch output");
}
}
tensor.Dispose();
return nd;
}
/// <summary>
/// If a tensor handle that is fed to a device incompatible placeholder,
/// we move the tensor to the right device, generate a new tensor handle,
/// and update feed_dict to use the new handle.
/// </summary>
private List<object> _update_with_movers()
{
return new List<object> { };
}
private void _extend_graph()
{
}
public void close()
{
Dispose();
}
protected override void DisposeUnmanagedResources(IntPtr handle)
{
using (var status = new Status())
{
c_api.TF_DeleteSession(handle, status);
status.Check(true);
}
}
}
}
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/
using NumSharp;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Numerics;
using System.Text;
namespace Tensorflow
{
public class BaseSession : DisposableObject
{
protected Graph _graph;
protected bool _opened;
protected bool _closed;
protected int _current_version;
protected byte[] _target;
public Graph graph => _graph;
public BaseSession(string target = "", Graph g = null, SessionOptions opts = null)
{
_graph = g is null ? ops.get_default_graph() : g;
_graph.as_default();
_target = UTF8Encoding.UTF8.GetBytes(target);
SessionOptions newOpts = null;
if (opts == null)
newOpts = new SessionOptions();
var status = new Status();
_handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status);
// dispose newOpts
if (opts == null)
newOpts.Dispose();
status.Check(true);
}
public virtual void run(Operation op, params FeedItem[] feed_dict)
{
_run(op, feed_dict);
}
public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict)
{
return _run(fetche, feed_dict)[0];
}
public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict)
{
return _run(fetche, feed_dict)[0];
}
public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict);
return (results[0], results[1], results[2], results[3]);
}
public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict);
return (results[0], results[1], results[2]);
}
public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict);
return (results[0], results[1]);
}
public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict)
{
return _run(fetches, feed_dict);
}
public virtual NDArray[] run(object fetches, Hashtable feed_dict = null)
{
var feed_items = feed_dict == null ? new FeedItem[0] :
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
return _run(fetches, feed_items);
}
private NDArray[] _run(object fetches, FeedItem[] feed_dict = null)
{
var feed_dict_tensor = new Dictionary<object, object>();
var feed_map = new Dictionary<object, object>();
Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) =>
{
return new (object, object)[] { (item.Key, item.Value) };
};
// Validate and process feed_dict.
if (feed_dict != null)
{
foreach (var feed in feed_dict)
{
foreach (var (subfeed, subfeed_val) in feed_fn(feed))
{
var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
feed_dict_tensor[subfeed_t] = subfeed_val;
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
}
}
}
// Create a fetch handler to take care of the structure of fetches.
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);
// Run request and get response.
// We need to keep the returned movers alive for the following _do_run().
// These movers are no longer needed when _do_run() completes, and
// are deleted when `movers` goes out of scope when this _run() ends.
var _ = _update_with_movers();
var final_fetches = fetch_handler.fetches();
var final_targets = fetch_handler.targets();
// We only want to really perform the run if fetches or targets are provided,
// or if the call is a partial run that specifies feeds.
var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor);
return fetch_handler.build_results(this, results);
}
/// <summary>
/// Runs a step based on the given fetches and feeds.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="target_list">A list of operations to be run, but not fetched.</param>
/// <param name="fetch_list"></param>
/// <param name="feed_dict"></param>
/// <returns>
/// A list of numpy ndarrays, corresponding to the elements of
/// `fetch_list`. If the ith element of `fetch_list` contains the
/// name of an operation, the first Tensor output of that operation
/// will be returned for that element.
/// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
{
var feeds = feed_dict.Select(x =>
{
if (x.Key is Tensor tensor)
{
switch (x.Value)
{
#if _REGEN
%types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types%
case #1 v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case #1[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
%
#else
case sbyte v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case sbyte[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case byte v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case byte[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case short v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case short[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ushort v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ushort[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case int v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case int[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case uint v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case uint[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case long v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case long[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ulong v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ulong[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case float v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case float[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case double v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case double[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Complex v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Complex[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
#endif
case bool v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL));
case string v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case IntPtr v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Tensor v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v);
case NDArray v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype));
default:
throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}");
}
}
throw new NotImplementedException("_do_run.feed_dict");
}).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
var targets = target_list;
return _call_tf_sessionrun(feeds, fetches, target_list);
}
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
{
// Ensure any changes to the graph are reflected in the runtime.
_extend_graph();
var status = new Status();
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();
c_api.TF_SessionRun(_handle,
run_options: null,
inputs: feed_dict.Select(f => f.Key).ToArray(),
input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(),
ninputs: feed_dict.Length,
outputs: fetch_list,
output_values: output_values,
noutputs: fetch_list.Length,
target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
ntargets: target_list.Count,
run_metadata: IntPtr.Zero,
status: status);
status.Check(true);
var result = new NDArray[fetch_list.Length];
for (int i = 0; i < fetch_list.Length; i++)
result[i] = fetchValue(output_values[i]);
for (int i = 0; i < feed_dict.Length; i++)
feed_dict[i].Value.Dispose();
return result;
}
private unsafe NDArray fetchValue(IntPtr output)
{
var tensor = new Tensor(output);
NDArray nd = null;
Type type = tensor.dtype.as_numpy_dtype();
var ndims = tensor.shape;
var offset = c_api.TF_TensorData(output);
if(ndims.Length == 0)
{
switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
nd = NDArray.Scalar(*(bool*)offset);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = NDArray.FromString(str);
break;
case TF_DataType.TF_UINT8:
nd = NDArray.Scalar(*(byte*)offset);
break;
case TF_DataType.TF_INT16:
nd = NDArray.Scalar(*(short*)offset);
break;
case TF_DataType.TF_INT32:
nd = NDArray.Scalar(*(int*)offset);
break;
case TF_DataType.TF_INT64:
nd = NDArray.Scalar(*(long*)offset);
break;
case TF_DataType.TF_FLOAT:
nd = NDArray.Scalar(*(float*)offset);
break;
case TF_DataType.TF_DOUBLE:
nd = NDArray.Scalar(*(double*)offset);
break;
default:
throw new NotImplementedException("can't fetch output");
}
}
else
{
switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
var bools = new bool[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i));
nd = np.array(bools).reshape(ndims);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = np.array(str);
break;
case TF_DataType.TF_UINT8:
var _bytes = new byte[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
_bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i));
nd = np.array(_bytes).reshape(ndims);
break;
case TF_DataType.TF_INT16:
var shorts = new short[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i));
nd = np.array(shorts).reshape(ndims);
break;
case TF_DataType.TF_INT32:
var ints = new int[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
ints[i] = *(int*)(offset + (int)(tensor.itemsize * i));
nd = np.array(ints).reshape(ndims);
break;
case TF_DataType.TF_INT64:
var longs = new long[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
longs[i] = *(long*)(offset + (int)(tensor.itemsize * i));
nd = np.array(longs).reshape(ndims);
break;
case TF_DataType.TF_FLOAT:
var floats = new float[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
floats[i] = *(float*)(offset + (int)(tensor.itemsize * i));
nd = np.array(floats).reshape(ndims);
break;
case TF_DataType.TF_DOUBLE:
var doubles = new double[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i));
nd = np.array(doubles).reshape(ndims);
break;
default:
throw new NotImplementedException("can't fetch output");
}
}
tensor.Dispose();
return nd;
}
/// <summary>
/// If a tensor handle that is fed to a device incompatible placeholder,
/// we move the tensor to the right device, generate a new tensor handle,
/// and update feed_dict to use the new handle.
/// </summary>
private List<object> _update_with_movers()
{
return new List<object> { };
}
private void _extend_graph()
{
}
public void close()
{
Dispose();
}
protected override void DisposeUnmanagedResources(IntPtr handle)
{
using (var status = new Status())
{
c_api.TF_DeleteSession(handle, status);
status.Check(true);
}
}
}
}

+ 6
- 0
src/TensorFlowNET.Core/Sessions/FeedItem.cs View File

@@ -16,5 +16,11 @@

public static implicit operator FeedItem((object, object) feed)
=> new FeedItem(feed.Item1, feed.Item2);

public void Deconstruct(out object key, out object value)
{
key = Key;
value = Value;
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Sessions/SessionOptions.cs View File

@@ -37,8 +37,8 @@ namespace Tensorflow

public void SetConfig(ConfigProto config)
{
var bytes = config.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length);
var bytes = config.ToByteArray(); //TODO! we can use WriteTo
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak
Marshal.Copy(bytes, 0, proto, bytes.Length);

using (var status = new Status())


+ 8
- 4
src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs View File

@@ -27,13 +27,17 @@ namespace Tensorflow
var handle = Marshal.AllocHGlobal(size * num_consumers);
int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers);
var consumers = new string[num_consumers];
for (int i = 0; i < num; i++)
unsafe
{
TF_Input input = Marshal.PtrToStructure<TF_Input>(handle + i * size);
consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(input.oper));
var inputptr = (TF_Input*) handle;
for (int i = 0; i < num; i++)
{
var oper = (inputptr + i)->oper;
consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper));
}
}

return consumers;
}
}
}
}

+ 10
- 8
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -15,6 +15,8 @@
******************************************************************************/

using System;
using System.Runtime.CompilerServices;
using static Tensorflow.c_api;

namespace Tensorflow
{
@@ -27,36 +29,36 @@ namespace Tensorflow
/// <summary>
/// Error message
/// </summary>
public string Message => c_api.StringPiece(c_api.TF_Message(_handle));
public string Message => c_api.StringPiece(TF_Message(_handle));

/// <summary>
/// Error code
/// </summary>
public TF_Code Code => c_api.TF_GetCode(_handle);
public TF_Code Code => TF_GetCode(_handle);

public Status()
{
_handle = c_api.TF_NewStatus();
_handle = TF_NewStatus();
}

public void SetStatus(TF_Code code, string msg)
{
c_api.TF_SetStatus(_handle, code, msg);
TF_SetStatus(_handle, code, msg);
}

/// <summary>
/// Check status
/// Throw exception with error message if code != TF_OK
/// </summary>
/// <exception cref="TensorflowException">When the returned check is not TF_Code.TF_OK</exception>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Check(bool throwException = false)
{
if (Code != TF_Code.TF_OK)
{
Console.WriteLine(Message);
if (throwException)
{
throw new Exception(Message);
}
throw new TensorflowException(Message);
}
}

@@ -66,6 +68,6 @@ namespace Tensorflow
}

protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteStatus(handle);
=> TF_DeleteStatus(handle);
}
}

+ 1
- 1
src/TensorFlowNET.Core/Status/c_api.status.cs View File

@@ -51,7 +51,7 @@ namespace Tensorflow
/// </summary>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_NewStatus();
public static extern IntPtr TF_NewStatus();

/// <summary>
/// Record <code, msg> in *s. Any previous information is lost.


+ 6
- 6
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -52,9 +52,9 @@ namespace Tensorflow
private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero };

// note: they must be assigned to a static variable in order to work as unmanaged callbacks
static Deallocator _hGlobalDeallocator = FreeHGlobalMemory;
static Deallocator _gcHandleDeallocator = FreeGCHandle;
private static Deallocator _nothingDeallocator = FreeNothing;
private static readonly Deallocator _hGlobalDeallocator = FreeHGlobalMemory;
private static readonly Deallocator _gcHandleDeallocator = FreeGCHandle;
private static readonly Deallocator _nothingDeallocator = FreeNothing;

/// <summary>
/// Create a Tensor object from an existing TF handle
@@ -528,7 +528,6 @@ namespace Tensorflow
}

_handle = CreateTensorFromNDArray(nd, tensorDType);
IsMemoryOwner = true;
}

private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype)
@@ -624,7 +623,7 @@ namespace Tensorflow
Marshal.WriteInt64(tensor, 0);

var status = new Status();
fixed (byte* src = &buffer[0])
fixed (byte* src = buffer)
c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status);

status.Check(true);
@@ -667,8 +666,9 @@ namespace Tensorflow
{
if (args.deallocator_called)
return;

// NumSharp will dispose
// Marshal.FreeHGlobal(dataPtr);
Marshal.FreeHGlobal(dataPtr);
args.deallocator_called = true;
}



+ 16
- 40
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -221,15 +221,6 @@ namespace Tensorflow
/// <exception cref="ArgumentException">When <typeparam name="T"> is string </typeparam></exception>
public T[] ToArray<T>() where T : unmanaged
{
//when T is string
if (typeof(T) == typeof(string))
{
if (dtype != TF_DataType.TF_STRING)
throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string.");

return (T[]) (object) StringData();
}

//Are the types matching?
if (typeof(T).as_dtype() == dtype)
{
@@ -246,20 +237,12 @@ namespace Tensorflow
unsafe
{
var len = (long) size;
fixed (T* dstRet = ret)
fixed (T* dst = ret)
{
T* dst = dstRet; //local stack copy
if (typeof(T).IsPrimitive)
{
var src = (T*) buffer;
len *= ((long) itemsize);
System.Buffer.MemoryCopy(src, dst, len, len);
} else
{
var itemsize = (long) this.itemsize;
var buffer = this.buffer.ToInt64();
Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure<T>(new IntPtr(buffer + i * itemsize)));
}
//T can only be unmanaged, I believe it is safe to say that MemoryCopy is valid for all cases this method can be called.
var src = (T*) buffer;
len *= ((long) itemsize);
System.Buffer.MemoryCopy(src, dst, len, len);
}
}

@@ -384,9 +367,15 @@ namespace Tensorflow
}
}

/// Used internally in ToArray&lt;T&gt;
private unsafe string[] StringData()
/// <summary>
/// Extracts string array from current Tensor.
/// </summary>
/// <exception cref="InvalidOperationException">When <see cref="dtype"/> != TF_DataType.TF_STRING</exception>
public unsafe string[] StringData()
{
if (dtype != TF_DataType.TF_STRING)
throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})");

//
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes]
@@ -442,7 +431,7 @@ namespace Tensorflow
/// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param>
/// <param name="session">The `Session` to be used to evaluate this tensor.</param>
/// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns>
public NDArray eval(Session session, FeedItem[] feed_dict = null)
public NDArray eval(Session session, params FeedItem[] feed_dict)
{
return ops._eval_using_default_session(this, feed_dict, graph, session);
}
@@ -568,23 +557,10 @@ namespace Tensorflow

protected override void DisposeUnmanagedResources(IntPtr handle)
{
if (handle != IntPtr.Zero)
{
c_api.TF_DeleteTensor(handle);
_handle = IntPtr.Zero;
}
c_api.TF_DeleteTensor(handle);
}

public bool IsDisposed
{
get
{
lock (this)
{
return _handle == IntPtr.Zero;
}
}
}
public bool IsDisposed => _disposed;

public int tensor_int_val { get; set; }
}

+ 8
- 11
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -83,6 +83,12 @@ namespace Tensorflow
throw new NotImplementedException("MakeNdarray");
}

private static readonly TF_DataType[] quantized_types = new TF_DataType[]
{
TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16,
TF_DataType.TF_QINT32
};

/// <summary>
/// Create a TensorProto.
/// </summary>
@@ -99,15 +105,6 @@ namespace Tensorflow
if (values is TensorProto tp)
return tp;

if (dtype != TF_DataType.DtInvalid)
;

bool is_quantized = new TF_DataType[]
{
TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16,
TF_DataType.TF_QINT32
}.Contains(dtype);

// We first convert value to a numpy array or scalar.
NDArray nparray = null;
var np_dt = dtype.as_numpy_dtype();
@@ -227,13 +224,13 @@ namespace Tensorflow
}
}

var numpy_dtype = dtypes.as_dtype(nparray.dtype, dtype: dtype);
var numpy_dtype = nparray.dtype.as_dtype(dtype: dtype);
if (numpy_dtype == TF_DataType.DtInvalid)
throw new TypeError($"Unrecognized data type: {nparray.dtype}");

// If dtype was specified and is a quantized type, we convert
// numpy_dtype back into the quantized version.
if (is_quantized)
if (quantized_types.Contains(dtype))
numpy_dtype = dtype;

bool is_same_size = false;


+ 94
- 0
src/TensorFlowNET.Core/Util/UnmanagedExtensions.cs View File

@@ -0,0 +1,94 @@
using System;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using NumSharp.Backends.Unmanaged;

namespace Tensorflow.Util
{
public static class UnmanagedExtensions
{
//internally UnmanagedMemoryStream can't construct with null address.
private static readonly unsafe byte* _empty = (byte*) Marshal.AllocHGlobal(1);

/// <summary>
/// Creates a memory stream based on given <paramref name="block"/>.
/// </summary>
/// <param name="block">The block to stream. Can be default/null.</param>
/// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock<byte> block)
{
unsafe
{
if (block.Address == null)
return new UnmanagedMemoryStream(_empty, 0);
return new UnmanagedMemoryStream(block.Address, block.BytesCount);
}
}

/// <summary>
/// Creates a memory stream based on given <paramref name="block"/>.
/// </summary>
/// <param name="block">The block to stream. Can be default/null.</param>
/// <param name="offset">Offset from the start of the block.</param>
/// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock<byte> block, long offset)
{
if (block.BytesCount - offset <= 0)
throw new ArgumentOutOfRangeException(nameof(offset));

unsafe
{
if (block.Address == null)
return new UnmanagedMemoryStream(_empty, 0);
return new UnmanagedMemoryStream(block.Address + offset, block.BytesCount - offset);
}
}

/// <summary>
/// Creates a memory stream based on given <paramref name="address"/>.
/// </summary>
/// <param name="address">The block to stream. Can be IntPtr.Zero.</param>
/// <param name="length">The length of the block in bytes.</param>
/// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static UnmanagedMemoryStream Stream(this IntPtr address, long length)
{
if (length <= 0)
throw new ArgumentOutOfRangeException(nameof(length));

unsafe
{
if (address == IntPtr.Zero)
return new UnmanagedMemoryStream(_empty, 0);

// ReSharper disable once AssignNullToNotNullAttribute
return new UnmanagedMemoryStream((byte*) address, length);
}
}

/// <summary>
/// Creates a memory stream based on given <paramref name="address"/>.
/// </summary>
/// <param name="address">The block to stream. Can be IntPtr.Zero.</param>
/// <param name="offset">Offset from the start of the block.</param>
/// <param name="length">The length of the block in bytes.</param>
/// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static UnmanagedMemoryStream Stream(this IntPtr address, long offset, long length)
{
if (length <= 0)
throw new ArgumentOutOfRangeException(nameof(length));

unsafe
{
if (address == IntPtr.Zero)
return new UnmanagedMemoryStream(_empty, 0);

return new UnmanagedMemoryStream((byte*) address + offset, length);
}
}
}
}

+ 4
- 2
src/TensorFlowNET.Core/globals.regen View File

@@ -8,7 +8,8 @@
%supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"]
%supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"]
%supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"]
%supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"]
%supported_numericals_TF_DataType = ["TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"]
%supported_numericals_TF_DataType_full = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"]

//this is the type we use in summerizing/reducting:
%supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"]
@@ -25,7 +26,8 @@
%supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"]

%supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"]
%supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"]
%supported_dtypes_TF_DataType = ["TF_BOOL","TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"]
%supported_dtypes_TF_DataType_full = ["TF_DataType.TF_BOOL","TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"]

%supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"]
%supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"]


+ 2
- 2
src/TensorFlowNET.Core/ops.cs View File

@@ -230,8 +230,8 @@ namespace Tensorflow
// Add attrs
foreach (var attr in node_def.Attr)
{
var bytes = attr.Value.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length);
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream.
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak
Marshal.Copy(bytes, 0, proto, bytes.Length);
uint len = (uint)bytes.Length;
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status);


+ 1
- 2
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -64,8 +64,7 @@ namespace Tensorflow

public Session Session()
{
defaultSession = new Session();
return defaultSession;
return new Session();
}

public Session Session(Graph graph, SessionOptions opts = null)


+ 2
- 8
src/TensorFlowNet.Benchmarks/Program.cs View File

@@ -9,24 +9,18 @@ namespace TensorFlowBenchmark
{
static void Main(string[] args)
{
#if DEBUG
IConfig config = new DebugInProcessConfig();
#else
IConfig config = null;
#endif
if (args?.Length > 0)
{
for (int i = 0; i < args.Length; i++)
{
string name = $"TensorFlowBenchmark.{args[i]}";
var type = Type.GetType(name);
BenchmarkRunner.Run(type, config);
BenchmarkRunner.Run(type);
}
}
else
{
BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, config);
BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, ManualConfig.Create(DefaultConfig.Instance).With(ConfigOptions.DisableOptimizationsValidator));
}
Console.ReadLine();


+ 1
- 0
src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj View File

@@ -6,6 +6,7 @@
<NoWin32Manifest>true</NoWin32Manifest>
<AssemblyName>TensorFlowBenchmark</AssemblyName>
<RootNamespace>TensorFlowBenchmark</RootNamespace>
<LangVersion>7.3</LangVersion>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">


+ 76
- 0
src/TensorFlowNet.Benchmarks/Unmanaged/StructCastBenchmark.cs View File

@@ -0,0 +1,76 @@
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using BenchmarkDotNet.Attributes;
using Google.Protobuf.WellKnownTypes;
using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowBenchmark.Unmanaged
{
public struct UnmanagedStruct
{
public int a;
public long b;
public UnmanagedStruct(int _)
{
a = 2;
b = 3;
}
}

[SimpleJob(launchCount: 1, warmupCount: 2, targetCount: 10)]
[MinColumn, MaxColumn, MeanColumn, MedianColumn]
public unsafe class StructCastBenchmark
{
private static void EnsureIsUnmanaged<T>(T _) where T : unmanaged
{ }

static StructCastBenchmark() //if UnmanagedStruct is not unmanaged struct then this will fail to compile.
=> EnsureIsUnmanaged(new UnmanagedStruct());

private IntPtr data;
private void* dataptr;

[GlobalSetup]
public void Setup()
{
data = Marshal.AllocHGlobal(Marshal.SizeOf<UnmanagedStruct>());
dataptr = data.ToPointer();
}

[Benchmark, MethodImpl(MethodImplOptions.NoOptimization)]
public void Marshal_PtrToStructure()
{
UnmanagedStruct _;
for (int i = 0; i < 10000; i++)
{
_ = Marshal.PtrToStructure<UnmanagedStruct>(data);
}
}

[Benchmark, MethodImpl(MethodImplOptions.NoOptimization)]
public void PointerCast()
{
var dptr = dataptr;
UnmanagedStruct _;
for (int i = 0; i < 10000; i++)
{
_ = *(UnmanagedStruct*) dptr;
}
}

[Benchmark, MethodImpl(MethodImplOptions.NoOptimization)]
public void Unsafe_Read()
{
var dptr = dataptr;
UnmanagedStruct _;
for (int i = 0; i < 10000; i++)
{
_ = Unsafe.Read<UnmanagedStruct>(dptr);
}
}

}
}

+ 3
- 3
test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs View File

@@ -102,7 +102,7 @@ namespace TensorFlowNET.Examples

// Display logs per epoch step
if ((epoch + 1) % display_step == 0)
print($"Epoch: {(epoch + 1).ToString("D4")} Cost: {avg_cost.ToString("G9")} Elapse: {sw.ElapsedMilliseconds}ms");
print($"Epoch: {(epoch + 1):D4} Cost: {avg_cost:G9} Elapse: {sw.ElapsedMilliseconds}ms");

sw.Reset();
}
@@ -114,8 +114,8 @@ namespace TensorFlowNET.Examples
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
// Calculate accuracy
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels));
print($"Accuracy: {acc.ToString("F4")}");
float acc = accuracy.eval(sess, (x, mnist.Test.Data), (y, mnist.Test.Labels));
print($"Accuracy: {acc:F4}");

return acc > 0.9;
}


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs View File

@@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples
public void PrepareData()
{
// get model file
string url = "http://download.tf.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz";
string url = "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz";
Web.Download(url, modelDir, "ssd_mobilenet_v1_coco.tar.gz");

Compress.ExtractTGZ(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz"), "./");


+ 11
- 5
test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs View File

@@ -21,6 +21,7 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
using static Tensorflow.Binding;
@@ -381,10 +382,15 @@ namespace TensorFlowNET.Examples
Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name)
{
int how_many_bottlenecks = 0;
foreach (var (label_name, label_lists) in image_lists)
var kvs = image_lists.ToArray();
var categories = new string[] {"training", "testing", "validation"};
Parallel.For(0, kvs.Length, i =>
{
foreach (var category in new string[] { "training", "testing", "validation" })
var (label_name, label_lists) = kvs[i];

Parallel.For(0, categories.Length, j =>
{
var category = categories[j];
var category_list = label_lists[category];
foreach (var (index, unused_base_name) in enumerate(category_list))
{
@@ -395,8 +401,8 @@ namespace TensorFlowNET.Examples
if (how_many_bottlenecks % 300 == 0)
print($"{how_many_bottlenecks} bottleneck files created.");
}
}
}
});
});
}

private float[] get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
@@ -508,7 +514,7 @@ namespace TensorFlowNET.Examples
{
// get a set of images to teach the network about the new classes
string fileName = "flower_photos.tgz";
string url = $"http://download.tf.org/example_images/{fileName}";
string url = $"http://download.tensorflow.org/example_images/{fileName}";
Web.Download(url, data_dir, fileName);
Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir);



+ 14
- 12
test/TensorFlowNET.UnitTest/Basics/AssignTests.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Basics
@@ -14,21 +15,22 @@ namespace TensorFlowNET.UnitTest.Basics
var expected = new[] { false, true, false, false, true, false, true };

var spike = tf.Variable(false);

spike.initializer.run();
foreach (var i in range(1, 2))
using (var sess = new Session())
{
if (raw_data[i] - raw_data[i - 1] > 5d)
{
var updater = tf.assign(spike, tf.constant(true));
updater.eval();
}
else
spike.initializer.run(session: sess);
foreach (var i in range(1, 2))
{
tf.assign(spike, tf.constant(true)).eval();
}
if (raw_data[i] - raw_data[i - 1] > 5d)
{
var updater = tf.assign(spike, tf.constant(true));
updater.eval(sess);
} else
{
tf.assign(spike, tf.constant(true)).eval(sess);
}

Assert.AreEqual((bool)spike.eval(), expected[i - 1]);
Assert.AreEqual((bool) spike.eval(), expected[i - 1]);
}
}
}
}

+ 13
- 9
test/TensorFlowNET.UnitTest/CApiGradientsTest.cs View File

@@ -2,6 +2,7 @@
using NumSharp;
using System;
using Tensorflow;
using Tensorflow.Util;
using Buffer = Tensorflow.Buffer;

namespace TensorFlowNET.UnitTest
@@ -45,15 +46,18 @@ namespace TensorFlowNET.UnitTest
private bool GetGraphDef(Graph graph, out GraphDef graph_def)
{
graph_def = null;
var s = new Status();
var buffer = new Buffer();
c_api.TF_GraphToGraphDef(graph, buffer, s);
bool ret = TF_GetCode(s) == TF_OK;
EXPECT_EQ(TF_OK, TF_GetCode(s));
if (ret) graph_def = GraphDef.Parser.ParseFrom(buffer.Data);
buffer.Dispose();
s.Dispose();
return ret;
using (var s = new Status())
{
using (var buffer = new Buffer())
{
c_api.TF_GraphToGraphDef(graph, buffer, s);
bool ret = TF_GetCode(s) == TF_OK;
EXPECT_EQ(TF_OK, TF_GetCode(s));
if (ret)
graph_def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
return ret;
}
}
}

private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs)


+ 2
- 9
test/TensorFlowNET.UnitTest/CSession.cs View File

@@ -40,10 +40,7 @@ namespace TensorFlowNET.UnitTest

private void DeleteInputValues()
{
for (var i = 0; i < input_values_.Count; ++i)
{
input_values_[i].Dispose();
}
//clearing is enough as they will be disposed by the GC unless they are referenced else-where.
input_values_.Clear();
}

@@ -60,11 +57,7 @@ namespace TensorFlowNET.UnitTest

private void ResetOutputValues()
{
for (var i = 0; i < output_values_.Count; ++i)
{
if (output_values_[i] != IntPtr.Zero)
output_values_[i].Dispose();
}
//clearing is enough as they will be disposed by the GC unless they are referenced else-where.
output_values_.Clear();
}



+ 0
- 5
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -322,7 +322,6 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(feed2, control_inputs[1]);

// Export to a graph def so we can import a graph with control dependencies
graph_def.Dispose();
graph_def = new Buffer();
c_api.TF_GraphToGraphDef(graph, graph_def, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code);
@@ -346,14 +345,10 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(feed4, control_inputs[1]);

c_api.TF_DeleteImportGraphDefOptions(opts);
c_api.TF_DeleteBuffer(graph_def);

// Can add nodes to the imported graph without trouble.
c_test_util.Add(feed, scalar, graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

graph.Dispose();
s.Dispose();
}

/// <summary>


+ 36
- 1
test/TensorFlowNET.UnitTest/NameScopeTest.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using static Tensorflow.Binding;

@@ -42,5 +43,39 @@ namespace TensorFlowNET.UnitTest

Assert.AreEqual("", g._name_stack);
}

[TestMethod]
public void NestedNameScope_Using()
{
Graph g = tf.Graph().as_default();

using (var name = new ops.NameScope("scope1"))
{
Assert.AreEqual("scope1", g._name_stack);
Assert.AreEqual("scope1/", name);

var const1 = tf.constant(1.0);
Assert.AreEqual("scope1/Const:0", const1.name);

using (var name2 = new ops.NameScope("scope2"))
{
Assert.AreEqual("scope1/scope2", g._name_stack);
Assert.AreEqual("scope1/scope2/", name);

var const2 = tf.constant(2.0);
Assert.AreEqual("scope1/scope2/Const:0", const2.name);
}

Assert.AreEqual("scope1", g._name_stack);
var const3 = tf.constant(2.0);
Assert.AreEqual("scope1/Const_1:0", const3.name);
}

;

g.Dispose();

Assert.AreEqual("", g._name_stack);
}
}
}

BIN
test/TensorFlowNET.UnitTest/Open.snk View File


+ 2
- 1
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -4,6 +4,7 @@ using System.Collections.Generic;
using System.Linq;
using NumSharp;
using Tensorflow;
using Tensorflow.Util;
using Buffer = Tensorflow.Buffer;
using static Tensorflow.Binding;

@@ -21,7 +22,7 @@ namespace TensorFlowNET.UnitTest
{
var handle = c_api.TF_GetAllOpList();
var buffer = new Buffer(handle);
var op_list = OpList.Parser.ParseFrom(buffer);
var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream());

var _registered_ops = new Dictionary<string, OpDef>();
foreach (var op_def in op_list.Op)


+ 1
- 1
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -165,7 +165,7 @@ namespace TensorFlowNET.UnitTest
{
using (var sess = tf.Session())
{
var ndarray=tensor.eval();
var ndarray=tensor.eval(sess);
if (typeof(T) == typeof(double))
{
double x = ndarray;


+ 1
- 3
test/TensorFlowNET.UnitTest/SessionTest.cs View File

@@ -72,8 +72,6 @@ namespace TensorFlowNET.UnitTest
// Clean up
csession.CloseAndDelete(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);
graph.Dispose();
s.Dispose();
}

[TestMethod]
@@ -84,7 +82,7 @@ namespace TensorFlowNET.UnitTest
var c = math_ops.matmul(a, b, name: "matmul");
using (var sess = tf.Session())
{
var result = c.eval();
var result = c.eval(sess);
Assert.AreEqual(6, result.Data<double>()[0]);
}
}


+ 6
- 0
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -4,6 +4,12 @@
<TargetFramework>netcoreapp2.2</TargetFramework>

<IsPackable>false</IsPackable>

<SignAssembly>true</SignAssembly>

<DelaySign>false</DelaySign>

<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">


+ 1
- 1
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -119,7 +119,7 @@ namespace TensorFlowNET.UnitTest
{
sess.run(init_op);
// o some work with the model.
inc_v1.op.run();
inc_v1.op.run(session: sess);
}
}



+ 11
- 6
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -1,4 +1,6 @@
using Tensorflow;
using System.Diagnostics.CodeAnalysis;
using Tensorflow;
using Tensorflow.Util;
using Buffer = Tensorflow.Buffer;

namespace TensorFlowNET.UnitTest
@@ -26,12 +28,15 @@ namespace TensorFlowNET.UnitTest
return op;
}

[SuppressMessage("ReSharper", "RedundantAssignment")]
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s)
{
var buffer = new Buffer();
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
attr_value = AttrValue.Parser.ParseFrom(buffer);
buffer.Dispose();
using (var buffer = new Buffer())
{
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream());
}

return s.Code == TF_Code.TF_OK;
}

@@ -42,7 +47,7 @@ namespace TensorFlowNET.UnitTest
{
c_api.TF_GraphToGraphDef(graph, buffer, s);
s.Check();
return GraphDef.Parser.ParseFrom(buffer);
return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
}
}



+ 11
- 10
test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs View File

@@ -24,16 +24,17 @@ namespace TensorFlowNET.UnitTest.ops_test
[TestMethod]
public void TestShape()
{
var g = tf.Graph().as_default();
var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } });
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]);
var op = g._create_op_from_tf_operation(c_op);
Assert.AreEqual("myop", op.name);
Assert.AreEqual("Identity", op.type);
Assert.AreEqual(1, len(op.outputs));
assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape);
using (var g = tf.Graph().as_default())
{
var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}});
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
var op = g._create_op_from_tf_operation(c_op);
Assert.AreEqual("myop", op.name);
Assert.AreEqual("Identity", op.type);
Assert.AreEqual(1, len(op.outputs));
assertItemsEqual(new[] {2, 3}, op.outputs[0].shape);
}
}
[TestMethod]


Loading…
Cancel
Save