Performance optimization, refactoring and revamping.tags/v0.12
| @@ -0,0 +1,4 @@ | |||
| using System.Runtime.CompilerServices; | |||
| #if DEBUG | |||
| [assembly: InternalsVisibleTo("TensorFlowNET.UnitTest, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] | |||
| #endif | |||
| @@ -178,13 +178,18 @@ namespace Tensorflow | |||
| public static IEnumerable<(TKey, TValue)> enumerate<TKey, TValue>(KeyValuePair<TKey, TValue>[] values) | |||
| { | |||
| foreach (var item in values) | |||
| var len = values.Length; | |||
| for (var i = 0; i < len; i++) | |||
| { | |||
| var item = values[i]; | |||
| yield return (item.Key, item.Value); | |||
| } | |||
| } | |||
| public static IEnumerable<(int, T)> enumerate<T>(IList<T> values) | |||
| { | |||
| for (int i = 0; i < values.Count; i++) | |||
| var len = values.Count; | |||
| for (int i = 0; i < len; i++) | |||
| yield return (i, values[i]); | |||
| } | |||
| @@ -15,58 +15,116 @@ | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Runtime.InteropServices; | |||
| using NumSharp.Backends.Unmanaged; | |||
| using static Tensorflow.c_api; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Represents a TF_Buffer that can be passed to Tensorflow. | |||
| /// </summary> | |||
| public class Buffer : DisposableObject | |||
| { | |||
| private TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle); | |||
| private unsafe TF_Buffer buffer | |||
| { | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| get => *bufferptr; | |||
| } | |||
| private unsafe TF_Buffer* bufferptr | |||
| { | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| get => (TF_Buffer*) _handle; | |||
| } | |||
| public byte[] Data | |||
| /// <summary> | |||
| /// The memory block representing this buffer. | |||
| /// </summary> | |||
| /// <remarks>The deallocator is set to null.</remarks> | |||
| public UnmanagedMemoryBlock<byte> MemoryBlock | |||
| { | |||
| get | |||
| get | |||
| { | |||
| var data = new byte[buffer.length]; | |||
| if (data.Length > 0) | |||
| Marshal.Copy(buffer.data, data, 0, data.Length); | |||
| return data; | |||
| unsafe | |||
| { | |||
| EnsureNotDisposed(); | |||
| var buff = (TF_Buffer*) _handle; | |||
| return new UnmanagedMemoryBlock<byte>((byte*) buff->data.ToPointer(), (long) buff->length); | |||
| } | |||
| } | |||
| } | |||
| public int Length => (int)buffer.length; | |||
| public Buffer() | |||
| /// <summary> | |||
| /// The bytes length of this buffer. | |||
| /// </summary> | |||
| public ulong Length | |||
| { | |||
| _handle = c_api.TF_NewBuffer(); | |||
| get | |||
| { | |||
| EnsureNotDisposed(); | |||
| return buffer.length; | |||
| } | |||
| } | |||
| public Buffer(IntPtr handle) | |||
| public Buffer() => _handle = TF_NewBuffer(); | |||
| internal Buffer(IntPtr handle) | |||
| { | |||
| if (handle == IntPtr.Zero) | |||
| throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle)); | |||
| _handle = handle; | |||
| } | |||
| public Buffer(byte[] data) | |||
| { | |||
| var dst = Marshal.AllocHGlobal(data.Length); | |||
| Marshal.Copy(data, 0, dst, data.Length); | |||
| public Buffer(byte[] data) : this(_toBuffer(data)) | |||
| { } | |||
| _handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length); | |||
| private static IntPtr _toBuffer(byte[] data) | |||
| { | |||
| if (data == null) | |||
| throw new ArgumentNullException(nameof(data)); | |||
| Marshal.FreeHGlobal(dst); | |||
| unsafe | |||
| { | |||
| fixed (byte* src = data) | |||
| return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength); | |||
| } | |||
| } | |||
| public static implicit operator IntPtr(Buffer buffer) | |||
| { | |||
| buffer.EnsureNotDisposed(); | |||
| return buffer._handle; | |||
| } | |||
| public static implicit operator byte[](Buffer buffer) | |||
| public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost. | |||
| /// <summary> | |||
| /// Copies this buffer's contents onto a <see cref="byte"/> array. | |||
| /// </summary> | |||
| public byte[] ToArray() | |||
| { | |||
| return buffer.Data; | |||
| EnsureNotDisposed(); | |||
| unsafe | |||
| { | |||
| var len = buffer.length; | |||
| if (len == 0) | |||
| return Array.Empty<byte>(); | |||
| byte[] data = new byte[len]; | |||
| fixed (byte* dst = data) | |||
| System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len); | |||
| return data; | |||
| } | |||
| } | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| => c_api.TF_DeleteBuffer(handle); | |||
| { | |||
| TF_DeleteBuffer(handle); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -16,6 +16,8 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| @@ -26,27 +28,33 @@ namespace Tensorflow | |||
| public abstract class DisposableObject : IDisposable | |||
| { | |||
| protected IntPtr _handle; | |||
| protected bool _disposed; | |||
| protected DisposableObject() { } | |||
| [SuppressMessage("ReSharper", "UnusedMember.Global")] | |||
| protected DisposableObject() | |||
| { } | |||
| protected DisposableObject(IntPtr handle) | |||
| protected DisposableObject(IntPtr handle) | |||
| => _handle = handle; | |||
| [SuppressMessage("ReSharper", "InvertIf")] | |||
| private void internal_dispose(bool disposing) | |||
| { | |||
| if (disposing) | |||
| { | |||
| // free unmanaged resources (unmanaged objects) and override a finalizer below. | |||
| if (_handle != IntPtr.Zero) | |||
| { | |||
| // dispose managed state (managed objects). | |||
| DisposeManagedResources(); | |||
| if (_disposed) | |||
| return; | |||
| _disposed = true; | |||
| // set large fields to null. | |||
| DisposeUnmanagedResources(_handle); | |||
| //first handle managed, they might use the unmanaged resources. | |||
| if (disposing) | |||
| // dispose managed state (managed objects). | |||
| DisposeManagedResources(); | |||
| _handle = IntPtr.Zero; | |||
| } | |||
| //free unmanaged memory | |||
| if (_handle != IntPtr.Zero) | |||
| { | |||
| DisposeUnmanagedResources(_handle); | |||
| _handle = IntPtr.Zero; | |||
| } | |||
| } | |||
| @@ -55,28 +63,33 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks> | |||
| protected virtual void DisposeManagedResources() | |||
| { | |||
| } | |||
| { } | |||
| /// <summary> | |||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||
| /// </summary> | |||
| protected abstract void DisposeUnmanagedResources(IntPtr handle); | |||
| // override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. | |||
| ~DisposableObject() | |||
| { | |||
| // Do not change this code. Put cleanup code in Dispose(bool disposing) above. | |||
| internal_dispose(false); | |||
| } | |||
| // This code added to correctly implement the disposable pattern. | |||
| public void Dispose() | |||
| { | |||
| // Do not change this code. Put cleanup code in Dispose(bool disposing) above. | |||
| internal_dispose(true); | |||
| // uncomment the following line if the finalizer is overridden above. | |||
| GC.SuppressFinalize(this); | |||
| } | |||
| /// <summary> | |||
| /// If <see cref="_handle"/> is <see cref="IntPtr.Zero"/> then throws <see cref="ObjectDisposedException"/> | |||
| /// </summary> | |||
| /// <exception cref="ObjectDisposedException">When <see cref="_handle"/> is <see cref="IntPtr.Zero"/></exception> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| protected void EnsureNotDisposed() | |||
| { | |||
| if (_disposed) | |||
| throw new ObjectDisposedException($"Unable to access disposed object, Type: {GetType().Name}"); | |||
| } | |||
| } | |||
| } | |||
| @@ -2,12 +2,10 @@ | |||
| namespace Tensorflow.Eager | |||
| { | |||
| public class Context : IDisposable | |||
| public class Context : DisposableObject | |||
| { | |||
| private IntPtr _handle; | |||
| public static int GRAPH_MODE = 0; | |||
| public static int EAGER_MODE = 1; | |||
| public const int GRAPH_MODE = 0; | |||
| public const int EAGER_MODE = 1; | |||
| public int default_execution_mode; | |||
| @@ -17,19 +15,16 @@ namespace Tensorflow.Eager | |||
| status.Check(true); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| c_api.TFE_DeleteContext(_handle); | |||
| } | |||
| /// <summary> | |||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||
| /// </summary> | |||
| protected sealed override void DisposeUnmanagedResources(IntPtr handle) | |||
| => c_api.TFE_DeleteContext(_handle); | |||
| public bool executing_eagerly() | |||
| { | |||
| return false; | |||
| } | |||
| public static implicit operator IntPtr(Context ctx) | |||
| { | |||
| return ctx._handle; | |||
| } | |||
| public bool executing_eagerly() => false; | |||
| public static implicit operator IntPtr(Context ctx) | |||
| => ctx._handle; | |||
| } | |||
| } | |||
| @@ -3,23 +3,20 @@ using System.IO; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| public class ContextOptions : IDisposable //TODO! Eli: Shouldn't this inherieting DisposableObject? | |||
| public class ContextOptions : DisposableObject | |||
| { | |||
| private IntPtr _handle; | |||
| public ContextOptions() : base(c_api.TFE_NewContextOptions()) | |||
| { } | |||
| public ContextOptions() | |||
| { | |||
| _handle = c_api.TFE_NewContextOptions(); | |||
| } | |||
| /// <summary> | |||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||
| /// </summary> | |||
| protected sealed override void DisposeUnmanagedResources(IntPtr handle) | |||
| => c_api.TFE_DeleteContextOptions(_handle); | |||
| public void Dispose() | |||
| { | |||
| c_api.TFE_DeleteContextOptions(_handle); | |||
| } | |||
| public static implicit operator IntPtr(ContextOptions opts) | |||
| { | |||
| return opts._handle; | |||
| } | |||
| public static implicit operator IntPtr(ContextOptions opts) | |||
| => opts._handle; | |||
| } | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| namespace Tensorflow | |||
| { | |||
| public class KeyError : Exception | |||
| public class KeyError : TensorflowException | |||
| { | |||
| public KeyError() : base() | |||
| { | |||
| @@ -2,7 +2,7 @@ | |||
| namespace Tensorflow | |||
| { | |||
| public class RuntimeError : Exception | |||
| public class RuntimeError : TensorflowException | |||
| { | |||
| public RuntimeError() : base() | |||
| { | |||
| @@ -0,0 +1,36 @@ | |||
| using System; | |||
| using System.Runtime.Serialization; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Serves as a base class to all exceptions of Tensorflow.NET. | |||
| /// </summary> | |||
| [Serializable] | |||
| public class TensorflowException : Exception | |||
| { | |||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class.</summary> | |||
| public TensorflowException() | |||
| { } | |||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with serialized data.</summary> | |||
| /// <param name="info">The <see cref="T:System.Runtime.Serialization.SerializationInfo"></see> that holds the serialized object data about the exception being thrown.</param> | |||
| /// <param name="context">The <see cref="T:System.Runtime.Serialization.StreamingContext"></see> that contains contextual information about the source or destination.</param> | |||
| /// <exception cref="T:System.ArgumentNullException">The <paramref name="info">info</paramref> parameter is null.</exception> | |||
| /// <exception cref="T:System.Runtime.Serialization.SerializationException">The class name is null or <see cref="P:System.Exception.HResult"></see> is zero (0).</exception> | |||
| protected TensorflowException(SerializationInfo info, StreamingContext context) : base(info, context) | |||
| { } | |||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with a specified error message.</summary> | |||
| /// <param name="message">The message that describes the error.</param> | |||
| public TensorflowException(string message) : base(message) | |||
| { } | |||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with a specified error message and a reference to the inner exception that is the cause of this exception.</summary> | |||
| /// <param name="message">The error message that explains the reason for the exception.</param> | |||
| /// <param name="innerException">The exception that is the cause of the current exception, or a null reference (Nothing in Visual Basic) if no inner exception is specified.</param> | |||
| public TensorflowException(string message, Exception innerException) : base(message, innerException) | |||
| { } | |||
| } | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| namespace Tensorflow | |||
| { | |||
| public class TypeError : Exception | |||
| public class TypeError : TensorflowException | |||
| { | |||
| public TypeError() : base() | |||
| { | |||
| @@ -2,7 +2,7 @@ | |||
| namespace Tensorflow | |||
| { | |||
| public class ValueError : Exception | |||
| public class ValueError : TensorflowException | |||
| { | |||
| public ValueError() : base() | |||
| { | |||
| @@ -6,10 +6,5 @@ | |||
| { | |||
| } | |||
| ~ScopedTFImportGraphDefOptions() | |||
| { | |||
| base.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -15,6 +15,8 @@ | |||
| ******************************************************************************/ | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -27,12 +29,12 @@ namespace Tensorflow | |||
| if(_registered_ops == null) | |||
| { | |||
| _registered_ops = new Dictionary<string, OpDef>(); | |||
| var handle = c_api.TF_GetAllOpList(); | |||
| var buffer = new Buffer(handle); | |||
| var op_list = OpList.Parser.ParseFrom(buffer); | |||
| foreach (var op_def in op_list.Op) | |||
| _registered_ops[op_def.Name] = op_def; | |||
| using (var buffer = new Buffer(c_api.TF_GetAllOpList())) | |||
| { | |||
| var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
| foreach (var op_def in op_list.Op) | |||
| _registered_ops[op_def.Name] = op_def; | |||
| } | |||
| } | |||
| return _registered_ops; | |||
| @@ -14,49 +14,62 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public class DefaultGraphStack | |||
| /// <summary> | |||
| /// Serves as a stack for determining current default graph. | |||
| /// </summary> | |||
| public class DefaultGraphStack | |||
| { | |||
| List<StackModel> stack = new List<StackModel>(); | |||
| private readonly List<StackModel> _stack = new List<StackModel>(); | |||
| public void set_controller(Graph @default) | |||
| { | |||
| if (!stack.Exists(x => x.Graph == @default)) | |||
| stack.Add(new StackModel { Graph = @default, IsDefault = true }); | |||
| if (!_stack.Exists(x => x.Graph == @default)) | |||
| _stack.Add(new StackModel {Graph = @default, IsDefault = true}); | |||
| foreach (var s in stack) | |||
| foreach (var s in _stack) | |||
| s.IsDefault = s.Graph == @default; | |||
| } | |||
| public Graph get_controller() | |||
| { | |||
| if (stack.Count(x => x.IsDefault) == 0) | |||
| stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); | |||
| if (_stack.Count(x => x.IsDefault) == 0) | |||
| _stack.Add(new StackModel {Graph = tf.Graph(), IsDefault = true}); | |||
| for (var i = _stack.Count - 1; i >= 0; i--) | |||
| { | |||
| var x = _stack[i]; | |||
| if (x.IsDefault) | |||
| return x.Graph; | |||
| } | |||
| return stack.Last(x => x.IsDefault).Graph; | |||
| throw new TensorflowException("Unable to find a default graph"); | |||
| } | |||
| public bool remove(Graph g) | |||
| { | |||
| var sm = stack.FirstOrDefault(x => x.Graph == g); | |||
| if (sm == null) return false; | |||
| return stack.Remove(sm); | |||
| if (_stack.Count == 0) | |||
| return false; | |||
| var sm = _stack.Find(model => model.Graph == g); | |||
| return sm != null && _stack.Remove(sm); | |||
| } | |||
| public void reset() | |||
| { | |||
| stack.Clear(); | |||
| _stack.Clear(); | |||
| } | |||
| } | |||
| public class StackModel | |||
| { | |||
| public Graph Graph { get; set; } | |||
| public bool IsDefault { get; set; } | |||
| private class StackModel | |||
| { | |||
| public Graph Graph { get; set; } | |||
| public bool IsDefault { get; set; } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -15,6 +15,7 @@ | |||
| ******************************************************************************/ | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using System.Linq; | |||
| using Tensorflow.Operations; | |||
| @@ -66,8 +67,9 @@ namespace Tensorflow | |||
| /// within the context should have control dependencies on | |||
| /// `control_inputs`. | |||
| /// </summary> | |||
| [SuppressMessage("ReSharper", "CoVariantArrayConversion")] | |||
| public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||
| => control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray()); | |||
| => control_dependencies((object[])control_inputs); | |||
| /// <summary> | |||
| /// Returns a context manager that specifies control dependencies. | |||
| @@ -14,6 +14,9 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.IO; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Graph | |||
| @@ -23,21 +26,19 @@ namespace Tensorflow | |||
| var buffer = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||
| s.Check(true); | |||
| // var def = GraphDef.Parser.ParseFrom(buffer); | |||
| // buffer.Dispose(); | |||
| return buffer; | |||
| } | |||
| private GraphDef _as_graph_def(bool add_shapes = false) | |||
| { | |||
| var status = new Status(); | |||
| var buffer = ToGraphDef(status); | |||
| status.Check(true); | |||
| status.Dispose(); | |||
| var def = GraphDef.Parser.ParseFrom(buffer); | |||
| buffer.Dispose(); | |||
| GraphDef def; | |||
| using (var status = new Status()) | |||
| using (var buffer = ToGraphDef(status)) | |||
| { | |||
| status.Check(true); | |||
| def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
| } | |||
| // Strip the experimental library field iff it's empty. | |||
| // if(def.Library.Function.Count == 0) | |||
| @@ -45,7 +46,7 @@ namespace Tensorflow | |||
| return def; | |||
| } | |||
| public GraphDef as_graph_def(bool add_shapes = false) | |||
| public GraphDef as_graph_def(bool add_shapes = false) | |||
| => _as_graph_def(add_shapes); | |||
| } | |||
| } | |||
| } | |||
| @@ -30,11 +30,10 @@ namespace Tensorflow | |||
| var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | |||
| c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| { | |||
| var handle = return_output_handle + i * size; | |||
| return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle); | |||
| } | |||
| var tf_output_ptr = (TF_Output*) return_output_handle; | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| return_outputs[i] = *(tf_output_ptr + i); | |||
| Marshal.FreeHGlobal(return_output_handle); | |||
| @@ -18,6 +18,7 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -30,7 +31,7 @@ namespace Tensorflow | |||
| using (var status = new Status()) | |||
| { | |||
| c_api.TF_GraphGetOpDef(_handle, type, buffer, status); | |||
| return OpDef.Parser.ParseFrom(buffer.Data); | |||
| return OpDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
| } | |||
| } | |||
| @@ -39,16 +40,20 @@ namespace Tensorflow | |||
| return c_api.TF_NewOperation(_handle, opType, opName); | |||
| } | |||
| public unsafe Operation[] ReturnOperations(IntPtr results) | |||
| public Operation[] ReturnOperations(IntPtr results) | |||
| { | |||
| TF_Operation return_oper_handle = new TF_Operation(); | |||
| int num_return_opers = 0; | |||
| c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); | |||
| Operation[] return_opers = new Operation[num_return_opers]; | |||
| var tf_op_size = Marshal.SizeOf<TF_Operation>(); | |||
| for (int i = 0; i < num_return_opers; i++) | |||
| { | |||
| var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i; | |||
| return_opers[i] = new Operation(*(IntPtr*)handle); | |||
| unsafe | |||
| { | |||
| var handle = return_oper_handle.node + tf_op_size * i; | |||
| return_opers[i] = new Operation(*(IntPtr*)handle); | |||
| } | |||
| } | |||
| return return_opers; | |||
| @@ -67,7 +72,7 @@ namespace Tensorflow | |||
| public ITensorOrOperation[] get_operations() | |||
| { | |||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
| return _nodes_by_name.Values.ToArray(); | |||
| } | |||
| /// <summary> | |||
| @@ -81,7 +86,7 @@ namespace Tensorflow | |||
| public ITensorOrOperation _get_operation_by_name_unsafe(string name) | |||
| { | |||
| return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null; | |||
| return _nodes_by_name.TryGetValue(name, out var val) ? val : null; | |||
| } | |||
| public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) | |||
| @@ -369,7 +369,7 @@ namespace Tensorflow | |||
| var name_key = name.ToLower(); | |||
| int i = 0; | |||
| if (_names_in_use.ContainsKey(name_key)) | |||
| i = _names_in_use[name_key]; | |||
| i = _names_in_use[name_key]; | |||
| // Increment the number for "name_key". | |||
| if (mark_as_used) | |||
| _names_in_use[name_key] = i + 1; | |||
| @@ -399,13 +399,13 @@ namespace Tensorflow | |||
| int num_return_outputs = 0; | |||
| c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); | |||
| TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| unsafe | |||
| { | |||
| var handle = return_output_handle + (Marshal.SizeOf<TF_Output>() * i); | |||
| return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle); | |||
| var tf_output_ptr = (TF_Output*) return_output_handle; | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| return_outputs[i] = *(tf_output_ptr + i); | |||
| return return_outputs; | |||
| } | |||
| return return_outputs; | |||
| } | |||
| public string[] get_all_collection_keys() | |||
| @@ -497,11 +497,9 @@ namespace Tensorflow | |||
| IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | |||
| => GetEnumerable().GetEnumerator(); | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| => throw new NotImplementedException(); | |||
| public static implicit operator IntPtr(Graph graph) | |||
| { | |||
| return graph._handle; | |||
| @@ -20,7 +20,8 @@ namespace Tensorflow | |||
| { | |||
| public class ImportGraphDefOptions : DisposableObject | |||
| { | |||
| public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||
| public int NumReturnOutputs | |||
| => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||
| public ImportGraphDefOptions() | |||
| { | |||
| @@ -50,14 +50,12 @@ namespace Tensorflow | |||
| public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | |||
| { | |||
| int size = Marshal.SizeOf<TF_Input>(); | |||
| var handle = Marshal.AllocHGlobal(size); | |||
| var handle = Marshal.AllocHGlobal(Marshal.SizeOf<TF_Input>()); | |||
| int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | |||
| var consumers = new TF_Input[num]; | |||
| var inputptr = (TF_Input*) handle; | |||
| for (int i = 0; i < num; i++) | |||
| { | |||
| consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||
| } | |||
| consumers[i] = *(inputptr + i); | |||
| return consumers; | |||
| } | |||
| @@ -17,7 +17,9 @@ | |||
| using Google.Protobuf.Collections; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -226,9 +228,12 @@ namespace Tensorflow | |||
| using (var status = new Status()) | |||
| using (var buf = new Buffer()) | |||
| { | |||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||
| status.Check(true); | |||
| x = AttrValue.Parser.ParseFrom(buf); | |||
| unsafe | |||
| { | |||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||
| status.Check(true); | |||
| x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | |||
| } | |||
| } | |||
| string oneof_value = x.ValueCase.ToString(); | |||
| @@ -259,7 +264,7 @@ namespace Tensorflow | |||
| { | |||
| c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||
| s.Check(); | |||
| return NodeDef.Parser.ParseFrom(buffer); | |||
| return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
| } | |||
| } | |||
| @@ -299,8 +304,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public TF_Output _tf_output(int output_idx) | |||
| { | |||
| var tf_output = new TF_Output(op, output_idx); | |||
| return tf_output; | |||
| return new TF_Output(op, output_idx); | |||
| } | |||
| /// <summary> | |||
| @@ -308,8 +312,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public TF_Input _tf_input(int input_idx) | |||
| { | |||
| var tf_input = new TF_Input(op, input_idx); | |||
| return tf_input; | |||
| return new TF_Input(op, input_idx); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,413 +1,413 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Numerics; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class BaseSession : DisposableObject | |||
| { | |||
| protected Graph _graph; | |||
| protected bool _opened; | |||
| protected bool _closed; | |||
| protected int _current_version; | |||
| protected byte[] _target; | |||
| public Graph graph => _graph; | |||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||
| { | |||
| _graph = g is null ? ops.get_default_graph() : g; | |||
| _graph.as_default(); | |||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||
| SessionOptions newOpts = null; | |||
| if (opts == null) | |||
| newOpts = new SessionOptions(); | |||
| var status = new Status(); | |||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||
| // dispose newOpts | |||
| if (opts == null) | |||
| newOpts.Dispose(); | |||
| status.Check(true); | |||
| } | |||
| public virtual void run(Operation op, params FeedItem[] feed_dict) | |||
| { | |||
| _run(op, feed_dict); | |||
| } | |||
| public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetche, feed_dict)[0]; | |||
| } | |||
| public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetche, feed_dict)[0]; | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||
| return (results[0], results[1], results[2], results[3]); | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||
| return (results[0], results[1], results[2]); | |||
| } | |||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||
| return (results[0], results[1]); | |||
| } | |||
| public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetches, feed_dict); | |||
| } | |||
| public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||
| { | |||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
| return _run(fetches, feed_items); | |||
| } | |||
| private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||
| { | |||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||
| var feed_map = new Dictionary<object, object>(); | |||
| Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) => | |||
| { | |||
| return new (object, object)[] { (item.Key, item.Value) }; | |||
| }; | |||
| // Validate and process feed_dict. | |||
| if (feed_dict != null) | |||
| { | |||
| foreach (var feed in feed_dict) | |||
| { | |||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | |||
| { | |||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||
| } | |||
| } | |||
| } | |||
| // Create a fetch handler to take care of the structure of fetches. | |||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||
| // Run request and get response. | |||
| // We need to keep the returned movers alive for the following _do_run(). | |||
| // These movers are no longer needed when _do_run() completes, and | |||
| // are deleted when `movers` goes out of scope when this _run() ends. | |||
| var _ = _update_with_movers(); | |||
| var final_fetches = fetch_handler.fetches(); | |||
| var final_targets = fetch_handler.targets(); | |||
| // We only want to really perform the run if fetches or targets are provided, | |||
| // or if the call is a partial run that specifies feeds. | |||
| var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||
| return fetch_handler.build_results(this, results); | |||
| } | |||
| /// <summary> | |||
| /// Runs a step based on the given fetches and feeds. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||
| /// <param name="fetch_list"></param> | |||
| /// <param name="feed_dict"></param> | |||
| /// <returns> | |||
| /// A list of numpy ndarrays, corresponding to the elements of | |||
| /// `fetch_list`. If the ith element of `fetch_list` contains the | |||
| /// name of an operation, the first Tensor output of that operation | |||
| /// will be returned for that element. | |||
| /// </returns> | |||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||
| { | |||
| var feeds = feed_dict.Select(x => | |||
| { | |||
| if (x.Key is Tensor tensor) | |||
| { | |||
| switch (x.Value) | |||
| { | |||
| #if _REGEN | |||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||
| %foreach types% | |||
| case #1 v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case #1[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| % | |||
| #else | |||
| case sbyte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case sbyte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| #endif | |||
| case bool v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||
| case string v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case IntPtr v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Tensor v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||
| case NDArray v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||
| default: | |||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||
| } | |||
| } | |||
| throw new NotImplementedException("_do_run.feed_dict"); | |||
| }).ToArray(); | |||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
| var targets = target_list; | |||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||
| } | |||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||
| { | |||
| // Ensure any changes to the graph are reflected in the runtime. | |||
| _extend_graph(); | |||
| var status = new Status(); | |||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||
| c_api.TF_SessionRun(_handle, | |||
| run_options: null, | |||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | |||
| input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||
| ninputs: feed_dict.Length, | |||
| outputs: fetch_list, | |||
| output_values: output_values, | |||
| noutputs: fetch_list.Length, | |||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||
| ntargets: target_list.Count, | |||
| run_metadata: IntPtr.Zero, | |||
| status: status); | |||
| status.Check(true); | |||
| var result = new NDArray[fetch_list.Length]; | |||
| for (int i = 0; i < fetch_list.Length; i++) | |||
| result[i] = fetchValue(output_values[i]); | |||
| for (int i = 0; i < feed_dict.Length; i++) | |||
| feed_dict[i].Value.Dispose(); | |||
| return result; | |||
| } | |||
| private unsafe NDArray fetchValue(IntPtr output) | |||
| { | |||
| var tensor = new Tensor(output); | |||
| NDArray nd = null; | |||
| Type type = tensor.dtype.as_numpy_dtype(); | |||
| var ndims = tensor.shape; | |||
| var offset = c_api.TF_TensorData(output); | |||
| if(ndims.Length == 0) | |||
| { | |||
| switch (tensor.dtype) | |||
| { | |||
| case TF_DataType.TF_BOOL: | |||
| nd = NDArray.Scalar(*(bool*)offset); | |||
| break; | |||
| case TF_DataType.TF_STRING: | |||
| var bytes = tensor.BufferToArray(); | |||
| // wired, don't know why we have to start from offset 9. | |||
| // length in the begin | |||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
| nd = NDArray.FromString(str); | |||
| break; | |||
| case TF_DataType.TF_UINT8: | |||
| nd = NDArray.Scalar(*(byte*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT16: | |||
| nd = NDArray.Scalar(*(short*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT32: | |||
| nd = NDArray.Scalar(*(int*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT64: | |||
| nd = NDArray.Scalar(*(long*)offset); | |||
| break; | |||
| case TF_DataType.TF_FLOAT: | |||
| nd = NDArray.Scalar(*(float*)offset); | |||
| break; | |||
| case TF_DataType.TF_DOUBLE: | |||
| nd = NDArray.Scalar(*(double*)offset); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("can't fetch output"); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| switch (tensor.dtype) | |||
| { | |||
| case TF_DataType.TF_BOOL: | |||
| var bools = new bool[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(bools).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_STRING: | |||
| var bytes = tensor.BufferToArray(); | |||
| // wired, don't know why we have to start from offset 9. | |||
| // length in the begin | |||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
| nd = np.array(str); | |||
| break; | |||
| case TF_DataType.TF_UINT8: | |||
| var _bytes = new byte[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(_bytes).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT16: | |||
| var shorts = new short[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(shorts).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT32: | |||
| var ints = new int[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(ints).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT64: | |||
| var longs = new long[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(longs).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_FLOAT: | |||
| var floats = new float[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(floats).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_DOUBLE: | |||
| var doubles = new double[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(doubles).reshape(ndims); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("can't fetch output"); | |||
| } | |||
| } | |||
| tensor.Dispose(); | |||
| return nd; | |||
| } | |||
| /// <summary> | |||
| /// If a tensor handle that is fed to a device incompatible placeholder, | |||
| /// we move the tensor to the right device, generate a new tensor handle, | |||
| /// and update feed_dict to use the new handle. | |||
| /// </summary> | |||
| private List<object> _update_with_movers() | |||
| { | |||
| return new List<object> { }; | |||
| } | |||
| private void _extend_graph() | |||
| { | |||
| } | |||
| public void close() | |||
| { | |||
| Dispose(); | |||
| } | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| { | |||
| using (var status = new Status()) | |||
| { | |||
| c_api.TF_DeleteSession(handle, status); | |||
| status.Check(true); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Numerics; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class BaseSession : DisposableObject | |||
| { | |||
| protected Graph _graph; | |||
| protected bool _opened; | |||
| protected bool _closed; | |||
| protected int _current_version; | |||
| protected byte[] _target; | |||
| public Graph graph => _graph; | |||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||
| { | |||
| _graph = g is null ? ops.get_default_graph() : g; | |||
| _graph.as_default(); | |||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||
| SessionOptions newOpts = null; | |||
| if (opts == null) | |||
| newOpts = new SessionOptions(); | |||
| var status = new Status(); | |||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||
| // dispose newOpts | |||
| if (opts == null) | |||
| newOpts.Dispose(); | |||
| status.Check(true); | |||
| } | |||
| public virtual void run(Operation op, params FeedItem[] feed_dict) | |||
| { | |||
| _run(op, feed_dict); | |||
| } | |||
| public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetche, feed_dict)[0]; | |||
| } | |||
| public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetche, feed_dict)[0]; | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||
| return (results[0], results[1], results[2], results[3]); | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||
| return (results[0], results[1], results[2]); | |||
| } | |||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||
| return (results[0], results[1]); | |||
| } | |||
| public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetches, feed_dict); | |||
| } | |||
| public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||
| { | |||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
| return _run(fetches, feed_items); | |||
| } | |||
| private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||
| { | |||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||
| var feed_map = new Dictionary<object, object>(); | |||
| Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) => | |||
| { | |||
| return new (object, object)[] { (item.Key, item.Value) }; | |||
| }; | |||
| // Validate and process feed_dict. | |||
| if (feed_dict != null) | |||
| { | |||
| foreach (var feed in feed_dict) | |||
| { | |||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | |||
| { | |||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||
| } | |||
| } | |||
| } | |||
| // Create a fetch handler to take care of the structure of fetches. | |||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||
| // Run request and get response. | |||
| // We need to keep the returned movers alive for the following _do_run(). | |||
| // These movers are no longer needed when _do_run() completes, and | |||
| // are deleted when `movers` goes out of scope when this _run() ends. | |||
| var _ = _update_with_movers(); | |||
| var final_fetches = fetch_handler.fetches(); | |||
| var final_targets = fetch_handler.targets(); | |||
| // We only want to really perform the run if fetches or targets are provided, | |||
| // or if the call is a partial run that specifies feeds. | |||
| var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||
| return fetch_handler.build_results(this, results); | |||
| } | |||
| /// <summary> | |||
| /// Runs a step based on the given fetches and feeds. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||
| /// <param name="fetch_list"></param> | |||
| /// <param name="feed_dict"></param> | |||
| /// <returns> | |||
| /// A list of numpy ndarrays, corresponding to the elements of | |||
| /// `fetch_list`. If the ith element of `fetch_list` contains the | |||
| /// name of an operation, the first Tensor output of that operation | |||
| /// will be returned for that element. | |||
| /// </returns> | |||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||
| { | |||
| var feeds = feed_dict.Select(x => | |||
| { | |||
| if (x.Key is Tensor tensor) | |||
| { | |||
| switch (x.Value) | |||
| { | |||
| #if _REGEN | |||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||
| %foreach types% | |||
| case #1 v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case #1[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| % | |||
| #else | |||
| case sbyte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case sbyte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| #endif | |||
| case bool v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||
| case string v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case IntPtr v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Tensor v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||
| case NDArray v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||
| default: | |||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||
| } | |||
| } | |||
| throw new NotImplementedException("_do_run.feed_dict"); | |||
| }).ToArray(); | |||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
| var targets = target_list; | |||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||
| } | |||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||
| { | |||
| // Ensure any changes to the graph are reflected in the runtime. | |||
| _extend_graph(); | |||
| var status = new Status(); | |||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||
| c_api.TF_SessionRun(_handle, | |||
| run_options: null, | |||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | |||
| input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||
| ninputs: feed_dict.Length, | |||
| outputs: fetch_list, | |||
| output_values: output_values, | |||
| noutputs: fetch_list.Length, | |||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||
| ntargets: target_list.Count, | |||
| run_metadata: IntPtr.Zero, | |||
| status: status); | |||
| status.Check(true); | |||
| var result = new NDArray[fetch_list.Length]; | |||
| for (int i = 0; i < fetch_list.Length; i++) | |||
| result[i] = fetchValue(output_values[i]); | |||
| for (int i = 0; i < feed_dict.Length; i++) | |||
| feed_dict[i].Value.Dispose(); | |||
| return result; | |||
| } | |||
| private unsafe NDArray fetchValue(IntPtr output) | |||
| { | |||
| var tensor = new Tensor(output); | |||
| NDArray nd = null; | |||
| Type type = tensor.dtype.as_numpy_dtype(); | |||
| var ndims = tensor.shape; | |||
| var offset = c_api.TF_TensorData(output); | |||
| if(ndims.Length == 0) | |||
| { | |||
| switch (tensor.dtype) | |||
| { | |||
| case TF_DataType.TF_BOOL: | |||
| nd = NDArray.Scalar(*(bool*)offset); | |||
| break; | |||
| case TF_DataType.TF_STRING: | |||
| var bytes = tensor.BufferToArray(); | |||
| // wired, don't know why we have to start from offset 9. | |||
| // length in the begin | |||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
| nd = NDArray.FromString(str); | |||
| break; | |||
| case TF_DataType.TF_UINT8: | |||
| nd = NDArray.Scalar(*(byte*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT16: | |||
| nd = NDArray.Scalar(*(short*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT32: | |||
| nd = NDArray.Scalar(*(int*)offset); | |||
| break; | |||
| case TF_DataType.TF_INT64: | |||
| nd = NDArray.Scalar(*(long*)offset); | |||
| break; | |||
| case TF_DataType.TF_FLOAT: | |||
| nd = NDArray.Scalar(*(float*)offset); | |||
| break; | |||
| case TF_DataType.TF_DOUBLE: | |||
| nd = NDArray.Scalar(*(double*)offset); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("can't fetch output"); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| switch (tensor.dtype) | |||
| { | |||
| case TF_DataType.TF_BOOL: | |||
| var bools = new bool[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(bools).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_STRING: | |||
| var bytes = tensor.BufferToArray(); | |||
| // wired, don't know why we have to start from offset 9. | |||
| // length in the begin | |||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
| nd = np.array(str); | |||
| break; | |||
| case TF_DataType.TF_UINT8: | |||
| var _bytes = new byte[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(_bytes).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT16: | |||
| var shorts = new short[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(shorts).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT32: | |||
| var ints = new int[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(ints).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_INT64: | |||
| var longs = new long[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(longs).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_FLOAT: | |||
| var floats = new float[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(floats).reshape(ndims); | |||
| break; | |||
| case TF_DataType.TF_DOUBLE: | |||
| var doubles = new double[tensor.size]; | |||
| for (ulong i = 0; i < tensor.size; i++) | |||
| doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); | |||
| nd = np.array(doubles).reshape(ndims); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("can't fetch output"); | |||
| } | |||
| } | |||
| tensor.Dispose(); | |||
| return nd; | |||
| } | |||
| /// <summary> | |||
| /// If a tensor handle that is fed to a device incompatible placeholder, | |||
| /// we move the tensor to the right device, generate a new tensor handle, | |||
| /// and update feed_dict to use the new handle. | |||
| /// </summary> | |||
| private List<object> _update_with_movers() | |||
| { | |||
| return new List<object> { }; | |||
| } | |||
| private void _extend_graph() | |||
| { | |||
| } | |||
| public void close() | |||
| { | |||
| Dispose(); | |||
| } | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| { | |||
| using (var status = new Status()) | |||
| { | |||
| c_api.TF_DeleteSession(handle, status); | |||
| status.Check(true); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -16,5 +16,11 @@ | |||
| public static implicit operator FeedItem((object, object) feed) | |||
| => new FeedItem(feed.Item1, feed.Item2); | |||
| public void Deconstruct(out object key, out object value) | |||
| { | |||
| key = Key; | |||
| value = Value; | |||
| } | |||
| } | |||
| } | |||
| @@ -37,8 +37,8 @@ namespace Tensorflow | |||
| public void SetConfig(ConfigProto config) | |||
| { | |||
| var bytes = config.ToByteArray(); | |||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||
| var bytes = config.ToByteArray(); //TODO! we can use WriteTo | |||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak | |||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
| using (var status = new Status()) | |||
| @@ -27,13 +27,17 @@ namespace Tensorflow | |||
| var handle = Marshal.AllocHGlobal(size * num_consumers); | |||
| int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers); | |||
| var consumers = new string[num_consumers]; | |||
| for (int i = 0; i < num; i++) | |||
| unsafe | |||
| { | |||
| TF_Input input = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||
| consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(input.oper)); | |||
| var inputptr = (TF_Input*) handle; | |||
| for (int i = 0; i < num; i++) | |||
| { | |||
| var oper = (inputptr + i)->oper; | |||
| consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper)); | |||
| } | |||
| } | |||
| return consumers; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -15,6 +15,8 @@ | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Runtime.CompilerServices; | |||
| using static Tensorflow.c_api; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -27,36 +29,36 @@ namespace Tensorflow | |||
| /// <summary> | |||
| /// Error message | |||
| /// </summary> | |||
| public string Message => c_api.StringPiece(c_api.TF_Message(_handle)); | |||
| public string Message => c_api.StringPiece(TF_Message(_handle)); | |||
| /// <summary> | |||
| /// Error code | |||
| /// </summary> | |||
| public TF_Code Code => c_api.TF_GetCode(_handle); | |||
| public TF_Code Code => TF_GetCode(_handle); | |||
| public Status() | |||
| { | |||
| _handle = c_api.TF_NewStatus(); | |||
| _handle = TF_NewStatus(); | |||
| } | |||
| public void SetStatus(TF_Code code, string msg) | |||
| { | |||
| c_api.TF_SetStatus(_handle, code, msg); | |||
| TF_SetStatus(_handle, code, msg); | |||
| } | |||
| /// <summary> | |||
| /// Check status | |||
| /// Throw exception with error message if code != TF_OK | |||
| /// </summary> | |||
| /// <exception cref="TensorflowException">When the returned check is not TF_Code.TF_OK</exception> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public void Check(bool throwException = false) | |||
| { | |||
| if (Code != TF_Code.TF_OK) | |||
| { | |||
| Console.WriteLine(Message); | |||
| if (throwException) | |||
| { | |||
| throw new Exception(Message); | |||
| } | |||
| throw new TensorflowException(Message); | |||
| } | |||
| } | |||
| @@ -66,6 +68,6 @@ namespace Tensorflow | |||
| } | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| => c_api.TF_DeleteStatus(handle); | |||
| => TF_DeleteStatus(handle); | |||
| } | |||
| } | |||
| @@ -51,7 +51,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static unsafe extern IntPtr TF_NewStatus(); | |||
| public static extern IntPtr TF_NewStatus(); | |||
| /// <summary> | |||
| /// Record <code, msg> in *s. Any previous information is lost. | |||
| @@ -52,9 +52,9 @@ namespace Tensorflow | |||
| private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero }; | |||
| // note: they must be assigned to a static variable in order to work as unmanaged callbacks | |||
| static Deallocator _hGlobalDeallocator = FreeHGlobalMemory; | |||
| static Deallocator _gcHandleDeallocator = FreeGCHandle; | |||
| private static Deallocator _nothingDeallocator = FreeNothing; | |||
| private static readonly Deallocator _hGlobalDeallocator = FreeHGlobalMemory; | |||
| private static readonly Deallocator _gcHandleDeallocator = FreeGCHandle; | |||
| private static readonly Deallocator _nothingDeallocator = FreeNothing; | |||
| /// <summary> | |||
| /// Create a Tensor object from an existing TF handle | |||
| @@ -528,7 +528,6 @@ namespace Tensorflow | |||
| } | |||
| _handle = CreateTensorFromNDArray(nd, tensorDType); | |||
| IsMemoryOwner = true; | |||
| } | |||
| private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) | |||
| @@ -624,7 +623,7 @@ namespace Tensorflow | |||
| Marshal.WriteInt64(tensor, 0); | |||
| var status = new Status(); | |||
| fixed (byte* src = &buffer[0]) | |||
| fixed (byte* src = buffer) | |||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | |||
| status.Check(true); | |||
| @@ -667,8 +666,9 @@ namespace Tensorflow | |||
| { | |||
| if (args.deallocator_called) | |||
| return; | |||
| // NumSharp will dispose | |||
| // Marshal.FreeHGlobal(dataPtr); | |||
| Marshal.FreeHGlobal(dataPtr); | |||
| args.deallocator_called = true; | |||
| } | |||
| @@ -221,15 +221,6 @@ namespace Tensorflow | |||
| /// <exception cref="ArgumentException">When <typeparam name="T"> is string </typeparam></exception> | |||
| public T[] ToArray<T>() where T : unmanaged | |||
| { | |||
| //when T is string | |||
| if (typeof(T) == typeof(string)) | |||
| { | |||
| if (dtype != TF_DataType.TF_STRING) | |||
| throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string."); | |||
| return (T[]) (object) StringData(); | |||
| } | |||
| //Are the types matching? | |||
| if (typeof(T).as_dtype() == dtype) | |||
| { | |||
| @@ -246,20 +237,12 @@ namespace Tensorflow | |||
| unsafe | |||
| { | |||
| var len = (long) size; | |||
| fixed (T* dstRet = ret) | |||
| fixed (T* dst = ret) | |||
| { | |||
| T* dst = dstRet; //local stack copy | |||
| if (typeof(T).IsPrimitive) | |||
| { | |||
| var src = (T*) buffer; | |||
| len *= ((long) itemsize); | |||
| System.Buffer.MemoryCopy(src, dst, len, len); | |||
| } else | |||
| { | |||
| var itemsize = (long) this.itemsize; | |||
| var buffer = this.buffer.ToInt64(); | |||
| Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure<T>(new IntPtr(buffer + i * itemsize))); | |||
| } | |||
| //T can only be unmanaged, I believe it is safe to say that MemoryCopy is valid for all cases this method can be called. | |||
| var src = (T*) buffer; | |||
| len *= ((long) itemsize); | |||
| System.Buffer.MemoryCopy(src, dst, len, len); | |||
| } | |||
| } | |||
| @@ -384,9 +367,15 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| /// Used internally in ToArray<T> | |||
| private unsafe string[] StringData() | |||
| /// <summary> | |||
| /// Extracts string array from current Tensor. | |||
| /// </summary> | |||
| /// <exception cref="InvalidOperationException">When <see cref="dtype"/> != TF_DataType.TF_STRING</exception> | |||
| public unsafe string[] StringData() | |||
| { | |||
| if (dtype != TF_DataType.TF_STRING) | |||
| throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); | |||
| // | |||
| // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | |||
| // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] | |||
| @@ -442,7 +431,7 @@ namespace Tensorflow | |||
| /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | |||
| /// <param name="session">The `Session` to be used to evaluate this tensor.</param> | |||
| /// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns> | |||
| public NDArray eval(Session session, FeedItem[] feed_dict = null) | |||
| public NDArray eval(Session session, params FeedItem[] feed_dict) | |||
| { | |||
| return ops._eval_using_default_session(this, feed_dict, graph, session); | |||
| } | |||
| @@ -568,23 +557,10 @@ namespace Tensorflow | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| { | |||
| if (handle != IntPtr.Zero) | |||
| { | |||
| c_api.TF_DeleteTensor(handle); | |||
| _handle = IntPtr.Zero; | |||
| } | |||
| c_api.TF_DeleteTensor(handle); | |||
| } | |||
| public bool IsDisposed | |||
| { | |||
| get | |||
| { | |||
| lock (this) | |||
| { | |||
| return _handle == IntPtr.Zero; | |||
| } | |||
| } | |||
| } | |||
| public bool IsDisposed => _disposed; | |||
| public int tensor_int_val { get; set; } | |||
| } | |||
| @@ -83,6 +83,12 @@ namespace Tensorflow | |||
| throw new NotImplementedException("MakeNdarray"); | |||
| } | |||
| private static readonly TF_DataType[] quantized_types = new TF_DataType[] | |||
| { | |||
| TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, | |||
| TF_DataType.TF_QINT32 | |||
| }; | |||
| /// <summary> | |||
| /// Create a TensorProto. | |||
| /// </summary> | |||
| @@ -99,15 +105,6 @@ namespace Tensorflow | |||
| if (values is TensorProto tp) | |||
| return tp; | |||
| if (dtype != TF_DataType.DtInvalid) | |||
| ; | |||
| bool is_quantized = new TF_DataType[] | |||
| { | |||
| TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, | |||
| TF_DataType.TF_QINT32 | |||
| }.Contains(dtype); | |||
| // We first convert value to a numpy array or scalar. | |||
| NDArray nparray = null; | |||
| var np_dt = dtype.as_numpy_dtype(); | |||
| @@ -227,13 +224,13 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| var numpy_dtype = dtypes.as_dtype(nparray.dtype, dtype: dtype); | |||
| var numpy_dtype = nparray.dtype.as_dtype(dtype: dtype); | |||
| if (numpy_dtype == TF_DataType.DtInvalid) | |||
| throw new TypeError($"Unrecognized data type: {nparray.dtype}"); | |||
| // If dtype was specified and is a quantized type, we convert | |||
| // numpy_dtype back into the quantized version. | |||
| if (is_quantized) | |||
| if (quantized_types.Contains(dtype)) | |||
| numpy_dtype = dtype; | |||
| bool is_same_size = false; | |||
| @@ -0,0 +1,94 @@ | |||
| using System; | |||
| using System.IO; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Runtime.InteropServices; | |||
| using NumSharp.Backends.Unmanaged; | |||
| namespace Tensorflow.Util | |||
| { | |||
| public static class UnmanagedExtensions | |||
| { | |||
| //internally UnmanagedMemoryStream can't construct with null address. | |||
| private static readonly unsafe byte* _empty = (byte*) Marshal.AllocHGlobal(1); | |||
| /// <summary> | |||
| /// Creates a memory stream based on given <paramref name="block"/>. | |||
| /// </summary> | |||
| /// <param name="block">The block to stream. Can be default/null.</param> | |||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock<byte> block) | |||
| { | |||
| unsafe | |||
| { | |||
| if (block.Address == null) | |||
| return new UnmanagedMemoryStream(_empty, 0); | |||
| return new UnmanagedMemoryStream(block.Address, block.BytesCount); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Creates a memory stream based on given <paramref name="block"/>. | |||
| /// </summary> | |||
| /// <param name="block">The block to stream. Can be default/null.</param> | |||
| /// <param name="offset">Offset from the start of the block.</param> | |||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock<byte> block, long offset) | |||
| { | |||
| if (block.BytesCount - offset <= 0) | |||
| throw new ArgumentOutOfRangeException(nameof(offset)); | |||
| unsafe | |||
| { | |||
| if (block.Address == null) | |||
| return new UnmanagedMemoryStream(_empty, 0); | |||
| return new UnmanagedMemoryStream(block.Address + offset, block.BytesCount - offset); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Creates a memory stream based on given <paramref name="address"/>. | |||
| /// </summary> | |||
| /// <param name="address">The block to stream. Can be IntPtr.Zero.</param> | |||
| /// <param name="length">The length of the block in bytes.</param> | |||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public static UnmanagedMemoryStream Stream(this IntPtr address, long length) | |||
| { | |||
| if (length <= 0) | |||
| throw new ArgumentOutOfRangeException(nameof(length)); | |||
| unsafe | |||
| { | |||
| if (address == IntPtr.Zero) | |||
| return new UnmanagedMemoryStream(_empty, 0); | |||
| // ReSharper disable once AssignNullToNotNullAttribute | |||
| return new UnmanagedMemoryStream((byte*) address, length); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Creates a memory stream based on given <paramref name="address"/>. | |||
| /// </summary> | |||
| /// <param name="address">The block to stream. Can be IntPtr.Zero.</param> | |||
| /// <param name="offset">Offset from the start of the block.</param> | |||
| /// <param name="length">The length of the block in bytes.</param> | |||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
| public static UnmanagedMemoryStream Stream(this IntPtr address, long offset, long length) | |||
| { | |||
| if (length <= 0) | |||
| throw new ArgumentOutOfRangeException(nameof(length)); | |||
| unsafe | |||
| { | |||
| if (address == IntPtr.Zero) | |||
| return new UnmanagedMemoryStream(_empty, 0); | |||
| return new UnmanagedMemoryStream((byte*) address + offset, length); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -8,7 +8,8 @@ | |||
| %supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"] | |||
| %supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||
| %supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"] | |||
| %supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||
| %supported_numericals_TF_DataType = ["TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] | |||
| %supported_numericals_TF_DataType_full = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||
| //this is the type we use in summerizing/reducting: | |||
| %supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] | |||
| @@ -25,7 +26,8 @@ | |||
| %supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"] | |||
| %supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] | |||
| %supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||
| %supported_dtypes_TF_DataType = ["TF_BOOL","TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] | |||
| %supported_dtypes_TF_DataType_full = ["TF_DataType.TF_BOOL","TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||
| %supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"] | |||
| %supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||
| @@ -230,8 +230,8 @@ namespace Tensorflow | |||
| // Add attrs | |||
| foreach (var attr in node_def.Attr) | |||
| { | |||
| var bytes = attr.Value.ToByteArray(); | |||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
| uint len = (uint)bytes.Length; | |||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||
| @@ -64,8 +64,7 @@ namespace Tensorflow | |||
| public Session Session() | |||
| { | |||
| defaultSession = new Session(); | |||
| return defaultSession; | |||
| return new Session(); | |||
| } | |||
| public Session Session(Graph graph, SessionOptions opts = null) | |||
| @@ -9,24 +9,18 @@ namespace TensorFlowBenchmark | |||
| { | |||
| static void Main(string[] args) | |||
| { | |||
| #if DEBUG | |||
| IConfig config = new DebugInProcessConfig(); | |||
| #else | |||
| IConfig config = null; | |||
| #endif | |||
| if (args?.Length > 0) | |||
| { | |||
| for (int i = 0; i < args.Length; i++) | |||
| { | |||
| string name = $"TensorFlowBenchmark.{args[i]}"; | |||
| var type = Type.GetType(name); | |||
| BenchmarkRunner.Run(type, config); | |||
| BenchmarkRunner.Run(type); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, config); | |||
| BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, ManualConfig.Create(DefaultConfig.Instance).With(ConfigOptions.DisableOptimizationsValidator)); | |||
| } | |||
| Console.ReadLine(); | |||
| @@ -6,6 +6,7 @@ | |||
| <NoWin32Manifest>true</NoWin32Manifest> | |||
| <AssemblyName>TensorFlowBenchmark</AssemblyName> | |||
| <RootNamespace>TensorFlowBenchmark</RootNamespace> | |||
| <LangVersion>7.3</LangVersion> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | |||
| @@ -0,0 +1,76 @@ | |||
| using System; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Runtime.InteropServices; | |||
| using BenchmarkDotNet.Attributes; | |||
| using Google.Protobuf.WellKnownTypes; | |||
| using NumSharp; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowBenchmark.Unmanaged | |||
| { | |||
| public struct UnmanagedStruct | |||
| { | |||
| public int a; | |||
| public long b; | |||
| public UnmanagedStruct(int _) | |||
| { | |||
| a = 2; | |||
| b = 3; | |||
| } | |||
| } | |||
| [SimpleJob(launchCount: 1, warmupCount: 2, targetCount: 10)] | |||
| [MinColumn, MaxColumn, MeanColumn, MedianColumn] | |||
| public unsafe class StructCastBenchmark | |||
| { | |||
| private static void EnsureIsUnmanaged<T>(T _) where T : unmanaged | |||
| { } | |||
| static StructCastBenchmark() //if UnmanagedStruct is not unmanaged struct then this will fail to compile. | |||
| => EnsureIsUnmanaged(new UnmanagedStruct()); | |||
| private IntPtr data; | |||
| private void* dataptr; | |||
| [GlobalSetup] | |||
| public void Setup() | |||
| { | |||
| data = Marshal.AllocHGlobal(Marshal.SizeOf<UnmanagedStruct>()); | |||
| dataptr = data.ToPointer(); | |||
| } | |||
| [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] | |||
| public void Marshal_PtrToStructure() | |||
| { | |||
| UnmanagedStruct _; | |||
| for (int i = 0; i < 10000; i++) | |||
| { | |||
| _ = Marshal.PtrToStructure<UnmanagedStruct>(data); | |||
| } | |||
| } | |||
| [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] | |||
| public void PointerCast() | |||
| { | |||
| var dptr = dataptr; | |||
| UnmanagedStruct _; | |||
| for (int i = 0; i < 10000; i++) | |||
| { | |||
| _ = *(UnmanagedStruct*) dptr; | |||
| } | |||
| } | |||
| [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] | |||
| public void Unsafe_Read() | |||
| { | |||
| var dptr = dataptr; | |||
| UnmanagedStruct _; | |||
| for (int i = 0; i < 10000; i++) | |||
| { | |||
| _ = Unsafe.Read<UnmanagedStruct>(dptr); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -102,7 +102,7 @@ namespace TensorFlowNET.Examples | |||
| // Display logs per epoch step | |||
| if ((epoch + 1) % display_step == 0) | |||
| print($"Epoch: {(epoch + 1).ToString("D4")} Cost: {avg_cost.ToString("G9")} Elapse: {sw.ElapsedMilliseconds}ms"); | |||
| print($"Epoch: {(epoch + 1):D4} Cost: {avg_cost:G9} Elapse: {sw.ElapsedMilliseconds}ms"); | |||
| sw.Reset(); | |||
| } | |||
| @@ -114,8 +114,8 @@ namespace TensorFlowNET.Examples | |||
| var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); | |||
| // Calculate accuracy | |||
| var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); | |||
| float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||
| print($"Accuracy: {acc.ToString("F4")}"); | |||
| float acc = accuracy.eval(sess, (x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||
| print($"Accuracy: {acc:F4}"); | |||
| return acc > 0.9; | |||
| } | |||
| @@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples | |||
| public void PrepareData() | |||
| { | |||
| // get model file | |||
| string url = "http://download.tf.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz"; | |||
| string url = "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz"; | |||
| Web.Download(url, modelDir, "ssd_mobilenet_v1_coco.tar.gz"); | |||
| Compress.ExtractTGZ(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz"), "./"); | |||
| @@ -21,6 +21,7 @@ using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Threading.Tasks; | |||
| using Tensorflow; | |||
| using TensorFlowNET.Examples.Utility; | |||
| using static Tensorflow.Binding; | |||
| @@ -381,10 +382,15 @@ namespace TensorFlowNET.Examples | |||
| Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) | |||
| { | |||
| int how_many_bottlenecks = 0; | |||
| foreach (var (label_name, label_lists) in image_lists) | |||
| var kvs = image_lists.ToArray(); | |||
| var categories = new string[] {"training", "testing", "validation"}; | |||
| Parallel.For(0, kvs.Length, i => | |||
| { | |||
| foreach (var category in new string[] { "training", "testing", "validation" }) | |||
| var (label_name, label_lists) = kvs[i]; | |||
| Parallel.For(0, categories.Length, j => | |||
| { | |||
| var category = categories[j]; | |||
| var category_list = label_lists[category]; | |||
| foreach (var (index, unused_base_name) in enumerate(category_list)) | |||
| { | |||
| @@ -395,8 +401,8 @@ namespace TensorFlowNET.Examples | |||
| if (how_many_bottlenecks % 300 == 0) | |||
| print($"{how_many_bottlenecks} bottleneck files created."); | |||
| } | |||
| } | |||
| } | |||
| }); | |||
| }); | |||
| } | |||
| private float[] get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists, | |||
| @@ -508,7 +514,7 @@ namespace TensorFlowNET.Examples | |||
| { | |||
| // get a set of images to teach the network about the new classes | |||
| string fileName = "flower_photos.tgz"; | |||
| string url = $"http://download.tf.org/example_images/{fileName}"; | |||
| string url = $"http://download.tensorflow.org/example_images/{fileName}"; | |||
| Web.Download(url, data_dir, fileName); | |||
| Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir); | |||
| @@ -1,4 +1,5 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.Basics | |||
| @@ -14,21 +15,22 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| var expected = new[] { false, true, false, false, true, false, true }; | |||
| var spike = tf.Variable(false); | |||
| spike.initializer.run(); | |||
| foreach (var i in range(1, 2)) | |||
| using (var sess = new Session()) | |||
| { | |||
| if (raw_data[i] - raw_data[i - 1] > 5d) | |||
| { | |||
| var updater = tf.assign(spike, tf.constant(true)); | |||
| updater.eval(); | |||
| } | |||
| else | |||
| spike.initializer.run(session: sess); | |||
| foreach (var i in range(1, 2)) | |||
| { | |||
| tf.assign(spike, tf.constant(true)).eval(); | |||
| } | |||
| if (raw_data[i] - raw_data[i - 1] > 5d) | |||
| { | |||
| var updater = tf.assign(spike, tf.constant(true)); | |||
| updater.eval(sess); | |||
| } else | |||
| { | |||
| tf.assign(spike, tf.constant(true)).eval(sess); | |||
| } | |||
| Assert.AreEqual((bool)spike.eval(), expected[i - 1]); | |||
| Assert.AreEqual((bool) spike.eval(), expected[i - 1]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -2,6 +2,7 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using Tensorflow; | |||
| using Tensorflow.Util; | |||
| using Buffer = Tensorflow.Buffer; | |||
| namespace TensorFlowNET.UnitTest | |||
| @@ -45,15 +46,18 @@ namespace TensorFlowNET.UnitTest | |||
| private bool GetGraphDef(Graph graph, out GraphDef graph_def) | |||
| { | |||
| graph_def = null; | |||
| var s = new Status(); | |||
| var buffer = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
| bool ret = TF_GetCode(s) == TF_OK; | |||
| EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||
| if (ret) graph_def = GraphDef.Parser.ParseFrom(buffer.Data); | |||
| buffer.Dispose(); | |||
| s.Dispose(); | |||
| return ret; | |||
| using (var s = new Status()) | |||
| { | |||
| using (var buffer = new Buffer()) | |||
| { | |||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
| bool ret = TF_GetCode(s) == TF_OK; | |||
| EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||
| if (ret) | |||
| graph_def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
| return ret; | |||
| } | |||
| } | |||
| } | |||
| private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) | |||
| @@ -40,10 +40,7 @@ namespace TensorFlowNET.UnitTest | |||
| private void DeleteInputValues() | |||
| { | |||
| for (var i = 0; i < input_values_.Count; ++i) | |||
| { | |||
| input_values_[i].Dispose(); | |||
| } | |||
| //clearing is enough as they will be disposed by the GC unless they are referenced else-where. | |||
| input_values_.Clear(); | |||
| } | |||
| @@ -60,11 +57,7 @@ namespace TensorFlowNET.UnitTest | |||
| private void ResetOutputValues() | |||
| { | |||
| for (var i = 0; i < output_values_.Count; ++i) | |||
| { | |||
| if (output_values_[i] != IntPtr.Zero) | |||
| output_values_[i].Dispose(); | |||
| } | |||
| //clearing is enough as they will be disposed by the GC unless they are referenced else-where. | |||
| output_values_.Clear(); | |||
| } | |||
| @@ -322,7 +322,6 @@ namespace TensorFlowNET.UnitTest | |||
| EXPECT_EQ(feed2, control_inputs[1]); | |||
| // Export to a graph def so we can import a graph with control dependencies | |||
| graph_def.Dispose(); | |||
| graph_def = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(graph, graph_def, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| @@ -346,14 +345,10 @@ namespace TensorFlowNET.UnitTest | |||
| EXPECT_EQ(feed4, control_inputs[1]); | |||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||
| c_api.TF_DeleteBuffer(graph_def); | |||
| // Can add nodes to the imported graph without trouble. | |||
| c_test_util.Add(feed, scalar, graph, s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| graph.Dispose(); | |||
| s.Dispose(); | |||
| } | |||
| /// <summary> | |||
| @@ -1,4 +1,5 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| @@ -42,5 +43,39 @@ namespace TensorFlowNET.UnitTest | |||
| Assert.AreEqual("", g._name_stack); | |||
| } | |||
| [TestMethod] | |||
| public void NestedNameScope_Using() | |||
| { | |||
| Graph g = tf.Graph().as_default(); | |||
| using (var name = new ops.NameScope("scope1")) | |||
| { | |||
| Assert.AreEqual("scope1", g._name_stack); | |||
| Assert.AreEqual("scope1/", name); | |||
| var const1 = tf.constant(1.0); | |||
| Assert.AreEqual("scope1/Const:0", const1.name); | |||
| using (var name2 = new ops.NameScope("scope2")) | |||
| { | |||
| Assert.AreEqual("scope1/scope2", g._name_stack); | |||
| Assert.AreEqual("scope1/scope2/", name); | |||
| var const2 = tf.constant(2.0); | |||
| Assert.AreEqual("scope1/scope2/Const:0", const2.name); | |||
| } | |||
| Assert.AreEqual("scope1", g._name_stack); | |||
| var const3 = tf.constant(2.0); | |||
| Assert.AreEqual("scope1/Const_1:0", const3.name); | |||
| } | |||
| ; | |||
| g.Dispose(); | |||
| Assert.AreEqual("", g._name_stack); | |||
| } | |||
| } | |||
| } | |||
| @@ -4,6 +4,7 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using NumSharp; | |||
| using Tensorflow; | |||
| using Tensorflow.Util; | |||
| using Buffer = Tensorflow.Buffer; | |||
| using static Tensorflow.Binding; | |||
| @@ -21,7 +22,7 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| var handle = c_api.TF_GetAllOpList(); | |||
| var buffer = new Buffer(handle); | |||
| var op_list = OpList.Parser.ParseFrom(buffer); | |||
| var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
| var _registered_ops = new Dictionary<string, OpDef>(); | |||
| foreach (var op_def in op_list.Op) | |||
| @@ -165,7 +165,7 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var ndarray=tensor.eval(); | |||
| var ndarray=tensor.eval(sess); | |||
| if (typeof(T) == typeof(double)) | |||
| { | |||
| double x = ndarray; | |||
| @@ -72,8 +72,6 @@ namespace TensorFlowNET.UnitTest | |||
| // Clean up | |||
| csession.CloseAndDelete(s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| graph.Dispose(); | |||
| s.Dispose(); | |||
| } | |||
| [TestMethod] | |||
| @@ -84,7 +82,7 @@ namespace TensorFlowNET.UnitTest | |||
| var c = math_ops.matmul(a, b, name: "matmul"); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = c.eval(); | |||
| var result = c.eval(sess); | |||
| Assert.AreEqual(6, result.Data<double>()[0]); | |||
| } | |||
| } | |||
| @@ -4,6 +4,12 @@ | |||
| <TargetFramework>netcoreapp2.2</TargetFramework> | |||
| <IsPackable>false</IsPackable> | |||
| <SignAssembly>true</SignAssembly> | |||
| <DelaySign>false</DelaySign> | |||
| <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
| @@ -119,7 +119,7 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| sess.run(init_op); | |||
| // o some work with the model. | |||
| inc_v1.op.run(); | |||
| inc_v1.op.run(session: sess); | |||
| } | |||
| } | |||
| @@ -1,4 +1,6 @@ | |||
| using Tensorflow; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using Tensorflow; | |||
| using Tensorflow.Util; | |||
| using Buffer = Tensorflow.Buffer; | |||
| namespace TensorFlowNET.UnitTest | |||
| @@ -26,12 +28,15 @@ namespace TensorFlowNET.UnitTest | |||
| return op; | |||
| } | |||
| [SuppressMessage("ReSharper", "RedundantAssignment")] | |||
| public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | |||
| { | |||
| var buffer = new Buffer(); | |||
| c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||
| attr_value = AttrValue.Parser.ParseFrom(buffer); | |||
| buffer.Dispose(); | |||
| using (var buffer = new Buffer()) | |||
| { | |||
| c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||
| attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
| } | |||
| return s.Code == TF_Code.TF_OK; | |||
| } | |||
| @@ -42,7 +47,7 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
| s.Check(); | |||
| return GraphDef.Parser.ParseFrom(buffer); | |||
| return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
| } | |||
| } | |||
| @@ -24,16 +24,17 @@ namespace TensorFlowNET.UnitTest.ops_test | |||
| [TestMethod] | |||
| public void TestShape() | |||
| { | |||
| var g = tf.Graph().as_default(); | |||
| var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }); | |||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]); | |||
| var op = g._create_op_from_tf_operation(c_op); | |||
| Assert.AreEqual("myop", op.name); | |||
| Assert.AreEqual("Identity", op.type); | |||
| Assert.AreEqual(1, len(op.outputs)); | |||
| assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape); | |||
| using (var g = tf.Graph().as_default()) | |||
| { | |||
| var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); | |||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||
| var op = g._create_op_from_tf_operation(c_op); | |||
| Assert.AreEqual("myop", op.name); | |||
| Assert.AreEqual("Identity", op.type); | |||
| Assert.AreEqual(1, len(op.outputs)); | |||
| assertItemsEqual(new[] {2, 3}, op.outputs[0].shape); | |||
| } | |||
| } | |||
| [TestMethod] | |||