Performance optimization, refactoring and revamping.tags/v0.12
| @@ -0,0 +1,4 @@ | |||||
| using System.Runtime.CompilerServices; | |||||
| #if DEBUG | |||||
| [assembly: InternalsVisibleTo("TensorFlowNET.UnitTest, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] | |||||
| #endif | |||||
| @@ -178,13 +178,18 @@ namespace Tensorflow | |||||
| public static IEnumerable<(TKey, TValue)> enumerate<TKey, TValue>(KeyValuePair<TKey, TValue>[] values) | public static IEnumerable<(TKey, TValue)> enumerate<TKey, TValue>(KeyValuePair<TKey, TValue>[] values) | ||||
| { | { | ||||
| foreach (var item in values) | |||||
| var len = values.Length; | |||||
| for (var i = 0; i < len; i++) | |||||
| { | |||||
| var item = values[i]; | |||||
| yield return (item.Key, item.Value); | yield return (item.Key, item.Value); | ||||
| } | |||||
| } | } | ||||
| public static IEnumerable<(int, T)> enumerate<T>(IList<T> values) | public static IEnumerable<(int, T)> enumerate<T>(IList<T> values) | ||||
| { | { | ||||
| for (int i = 0; i < values.Count; i++) | |||||
| var len = values.Count; | |||||
| for (int i = 0; i < len; i++) | |||||
| yield return (i, values[i]); | yield return (i, values[i]); | ||||
| } | } | ||||
| @@ -15,58 +15,116 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Runtime.CompilerServices; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using NumSharp.Backends.Unmanaged; | |||||
| using static Tensorflow.c_api; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Represents a TF_Buffer that can be passed to Tensorflow. | |||||
| /// </summary> | |||||
| public class Buffer : DisposableObject | public class Buffer : DisposableObject | ||||
| { | { | ||||
| private TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle); | |||||
| private unsafe TF_Buffer buffer | |||||
| { | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| get => *bufferptr; | |||||
| } | |||||
| private unsafe TF_Buffer* bufferptr | |||||
| { | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| get => (TF_Buffer*) _handle; | |||||
| } | |||||
| public byte[] Data | |||||
| /// <summary> | |||||
| /// The memory block representing this buffer. | |||||
| /// </summary> | |||||
| /// <remarks>The deallocator is set to null.</remarks> | |||||
| public UnmanagedMemoryBlock<byte> MemoryBlock | |||||
| { | { | ||||
| get | |||||
| get | |||||
| { | { | ||||
| var data = new byte[buffer.length]; | |||||
| if (data.Length > 0) | |||||
| Marshal.Copy(buffer.data, data, 0, data.Length); | |||||
| return data; | |||||
| unsafe | |||||
| { | |||||
| EnsureNotDisposed(); | |||||
| var buff = (TF_Buffer*) _handle; | |||||
| return new UnmanagedMemoryBlock<byte>((byte*) buff->data.ToPointer(), (long) buff->length); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| public int Length => (int)buffer.length; | |||||
| public Buffer() | |||||
| /// <summary> | |||||
| /// The bytes length of this buffer. | |||||
| /// </summary> | |||||
| public ulong Length | |||||
| { | { | ||||
| _handle = c_api.TF_NewBuffer(); | |||||
| get | |||||
| { | |||||
| EnsureNotDisposed(); | |||||
| return buffer.length; | |||||
| } | |||||
| } | } | ||||
| public Buffer(IntPtr handle) | |||||
| public Buffer() => _handle = TF_NewBuffer(); | |||||
| internal Buffer(IntPtr handle) | |||||
| { | { | ||||
| if (handle == IntPtr.Zero) | |||||
| throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle)); | |||||
| _handle = handle; | _handle = handle; | ||||
| } | } | ||||
| public Buffer(byte[] data) | |||||
| { | |||||
| var dst = Marshal.AllocHGlobal(data.Length); | |||||
| Marshal.Copy(data, 0, dst, data.Length); | |||||
| public Buffer(byte[] data) : this(_toBuffer(data)) | |||||
| { } | |||||
| _handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length); | |||||
| private static IntPtr _toBuffer(byte[] data) | |||||
| { | |||||
| if (data == null) | |||||
| throw new ArgumentNullException(nameof(data)); | |||||
| Marshal.FreeHGlobal(dst); | |||||
| unsafe | |||||
| { | |||||
| fixed (byte* src = data) | |||||
| return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength); | |||||
| } | |||||
| } | } | ||||
| public static implicit operator IntPtr(Buffer buffer) | public static implicit operator IntPtr(Buffer buffer) | ||||
| { | { | ||||
| buffer.EnsureNotDisposed(); | |||||
| return buffer._handle; | return buffer._handle; | ||||
| } | } | ||||
| public static implicit operator byte[](Buffer buffer) | |||||
| public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost. | |||||
| /// <summary> | |||||
| /// Copies this buffer's contents onto a <see cref="byte"/> array. | |||||
| /// </summary> | |||||
| public byte[] ToArray() | |||||
| { | { | ||||
| return buffer.Data; | |||||
| EnsureNotDisposed(); | |||||
| unsafe | |||||
| { | |||||
| var len = buffer.length; | |||||
| if (len == 0) | |||||
| return Array.Empty<byte>(); | |||||
| byte[] data = new byte[len]; | |||||
| fixed (byte* dst = data) | |||||
| System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len); | |||||
| return data; | |||||
| } | |||||
| } | } | ||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
| => c_api.TF_DeleteBuffer(handle); | |||||
| { | |||||
| TF_DeleteBuffer(handle); | |||||
| } | |||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -16,6 +16,8 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -26,27 +28,33 @@ namespace Tensorflow | |||||
| public abstract class DisposableObject : IDisposable | public abstract class DisposableObject : IDisposable | ||||
| { | { | ||||
| protected IntPtr _handle; | protected IntPtr _handle; | ||||
| protected bool _disposed; | |||||
| protected DisposableObject() { } | |||||
| [SuppressMessage("ReSharper", "UnusedMember.Global")] | |||||
| protected DisposableObject() | |||||
| { } | |||||
| protected DisposableObject(IntPtr handle) | |||||
| protected DisposableObject(IntPtr handle) | |||||
| => _handle = handle; | => _handle = handle; | ||||
| [SuppressMessage("ReSharper", "InvertIf")] | |||||
| private void internal_dispose(bool disposing) | private void internal_dispose(bool disposing) | ||||
| { | { | ||||
| if (disposing) | |||||
| { | |||||
| // free unmanaged resources (unmanaged objects) and override a finalizer below. | |||||
| if (_handle != IntPtr.Zero) | |||||
| { | |||||
| // dispose managed state (managed objects). | |||||
| DisposeManagedResources(); | |||||
| if (_disposed) | |||||
| return; | |||||
| _disposed = true; | |||||
| // set large fields to null. | |||||
| DisposeUnmanagedResources(_handle); | |||||
| //first handle managed, they might use the unmanaged resources. | |||||
| if (disposing) | |||||
| // dispose managed state (managed objects). | |||||
| DisposeManagedResources(); | |||||
| _handle = IntPtr.Zero; | |||||
| } | |||||
| //free unmanaged memory | |||||
| if (_handle != IntPtr.Zero) | |||||
| { | |||||
| DisposeUnmanagedResources(_handle); | |||||
| _handle = IntPtr.Zero; | |||||
| } | } | ||||
| } | } | ||||
| @@ -55,28 +63,33 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks> | /// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks> | ||||
| protected virtual void DisposeManagedResources() | protected virtual void DisposeManagedResources() | ||||
| { | |||||
| } | |||||
| { } | |||||
| /// <summary> | /// <summary> | ||||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | ||||
| /// </summary> | /// </summary> | ||||
| protected abstract void DisposeUnmanagedResources(IntPtr handle); | protected abstract void DisposeUnmanagedResources(IntPtr handle); | ||||
| // override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. | |||||
| ~DisposableObject() | ~DisposableObject() | ||||
| { | { | ||||
| // Do not change this code. Put cleanup code in Dispose(bool disposing) above. | |||||
| internal_dispose(false); | internal_dispose(false); | ||||
| } | } | ||||
| // This code added to correctly implement the disposable pattern. | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| // Do not change this code. Put cleanup code in Dispose(bool disposing) above. | |||||
| internal_dispose(true); | internal_dispose(true); | ||||
| // uncomment the following line if the finalizer is overridden above. | |||||
| GC.SuppressFinalize(this); | GC.SuppressFinalize(this); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// If <see cref="_handle"/> is <see cref="IntPtr.Zero"/> then throws <see cref="ObjectDisposedException"/> | |||||
| /// </summary> | |||||
| /// <exception cref="ObjectDisposedException">When <see cref="_handle"/> is <see cref="IntPtr.Zero"/></exception> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| protected void EnsureNotDisposed() | |||||
| { | |||||
| if (_disposed) | |||||
| throw new ObjectDisposedException($"Unable to access disposed object, Type: {GetType().Name}"); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -2,12 +2,10 @@ | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| public class Context : IDisposable | |||||
| public class Context : DisposableObject | |||||
| { | { | ||||
| private IntPtr _handle; | |||||
| public static int GRAPH_MODE = 0; | |||||
| public static int EAGER_MODE = 1; | |||||
| public const int GRAPH_MODE = 0; | |||||
| public const int EAGER_MODE = 1; | |||||
| public int default_execution_mode; | public int default_execution_mode; | ||||
| @@ -17,19 +15,16 @@ namespace Tensorflow.Eager | |||||
| status.Check(true); | status.Check(true); | ||||
| } | } | ||||
| public void Dispose() | |||||
| { | |||||
| c_api.TFE_DeleteContext(_handle); | |||||
| } | |||||
| /// <summary> | |||||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||||
| /// </summary> | |||||
| protected sealed override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TFE_DeleteContext(_handle); | |||||
| public bool executing_eagerly() | |||||
| { | |||||
| return false; | |||||
| } | |||||
| public static implicit operator IntPtr(Context ctx) | |||||
| { | |||||
| return ctx._handle; | |||||
| } | |||||
| public bool executing_eagerly() => false; | |||||
| public static implicit operator IntPtr(Context ctx) | |||||
| => ctx._handle; | |||||
| } | } | ||||
| } | } | ||||
| @@ -3,23 +3,20 @@ using System.IO; | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| public class ContextOptions : IDisposable //TODO! Eli: Shouldn't this inherieting DisposableObject? | |||||
| public class ContextOptions : DisposableObject | |||||
| { | { | ||||
| private IntPtr _handle; | |||||
| public ContextOptions() : base(c_api.TFE_NewContextOptions()) | |||||
| { } | |||||
| public ContextOptions() | |||||
| { | |||||
| _handle = c_api.TFE_NewContextOptions(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||||
| /// </summary> | |||||
| protected sealed override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TFE_DeleteContextOptions(_handle); | |||||
| public void Dispose() | |||||
| { | |||||
| c_api.TFE_DeleteContextOptions(_handle); | |||||
| } | |||||
| public static implicit operator IntPtr(ContextOptions opts) | |||||
| { | |||||
| return opts._handle; | |||||
| } | |||||
| public static implicit operator IntPtr(ContextOptions opts) | |||||
| => opts._handle; | |||||
| } | } | ||||
| } | } | ||||
| @@ -2,7 +2,7 @@ | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class KeyError : Exception | |||||
| public class KeyError : TensorflowException | |||||
| { | { | ||||
| public KeyError() : base() | public KeyError() : base() | ||||
| { | { | ||||
| @@ -2,7 +2,7 @@ | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class RuntimeError : Exception | |||||
| public class RuntimeError : TensorflowException | |||||
| { | { | ||||
| public RuntimeError() : base() | public RuntimeError() : base() | ||||
| { | { | ||||
| @@ -0,0 +1,36 @@ | |||||
| using System; | |||||
| using System.Runtime.Serialization; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Serves as a base class to all exceptions of Tensorflow.NET. | |||||
| /// </summary> | |||||
| [Serializable] | |||||
| public class TensorflowException : Exception | |||||
| { | |||||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class.</summary> | |||||
| public TensorflowException() | |||||
| { } | |||||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with serialized data.</summary> | |||||
| /// <param name="info">The <see cref="T:System.Runtime.Serialization.SerializationInfo"></see> that holds the serialized object data about the exception being thrown.</param> | |||||
| /// <param name="context">The <see cref="T:System.Runtime.Serialization.StreamingContext"></see> that contains contextual information about the source or destination.</param> | |||||
| /// <exception cref="T:System.ArgumentNullException">The <paramref name="info">info</paramref> parameter is null.</exception> | |||||
| /// <exception cref="T:System.Runtime.Serialization.SerializationException">The class name is null or <see cref="P:System.Exception.HResult"></see> is zero (0).</exception> | |||||
| protected TensorflowException(SerializationInfo info, StreamingContext context) : base(info, context) | |||||
| { } | |||||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with a specified error message.</summary> | |||||
| /// <param name="message">The message that describes the error.</param> | |||||
| public TensorflowException(string message) : base(message) | |||||
| { } | |||||
| /// <summary>Initializes a new instance of the <see cref="T:System.Exception"></see> class with a specified error message and a reference to the inner exception that is the cause of this exception.</summary> | |||||
| /// <param name="message">The error message that explains the reason for the exception.</param> | |||||
| /// <param name="innerException">The exception that is the cause of the current exception, or a null reference (Nothing in Visual Basic) if no inner exception is specified.</param> | |||||
| public TensorflowException(string message, Exception innerException) : base(message, innerException) | |||||
| { } | |||||
| } | |||||
| } | |||||
| @@ -2,7 +2,7 @@ | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class TypeError : Exception | |||||
| public class TypeError : TensorflowException | |||||
| { | { | ||||
| public TypeError() : base() | public TypeError() : base() | ||||
| { | { | ||||
| @@ -2,7 +2,7 @@ | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class ValueError : Exception | |||||
| public class ValueError : TensorflowException | |||||
| { | { | ||||
| public ValueError() : base() | public ValueError() : base() | ||||
| { | { | ||||
| @@ -6,10 +6,5 @@ | |||||
| { | { | ||||
| } | } | ||||
| ~ScopedTFImportGraphDefOptions() | |||||
| { | |||||
| base.Dispose(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -15,6 +15,8 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -27,12 +29,12 @@ namespace Tensorflow | |||||
| if(_registered_ops == null) | if(_registered_ops == null) | ||||
| { | { | ||||
| _registered_ops = new Dictionary<string, OpDef>(); | _registered_ops = new Dictionary<string, OpDef>(); | ||||
| var handle = c_api.TF_GetAllOpList(); | |||||
| var buffer = new Buffer(handle); | |||||
| var op_list = OpList.Parser.ParseFrom(buffer); | |||||
| foreach (var op_def in op_list.Op) | |||||
| _registered_ops[op_def.Name] = op_def; | |||||
| using (var buffer = new Buffer(c_api.TF_GetAllOpList())) | |||||
| { | |||||
| var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| foreach (var op_def in op_list.Op) | |||||
| _registered_ops[op_def.Name] = op_def; | |||||
| } | |||||
| } | } | ||||
| return _registered_ops; | return _registered_ops; | ||||
| @@ -14,49 +14,62 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class DefaultGraphStack | |||||
| /// <summary> | |||||
| /// Serves as a stack for determining current default graph. | |||||
| /// </summary> | |||||
| public class DefaultGraphStack | |||||
| { | { | ||||
| List<StackModel> stack = new List<StackModel>(); | |||||
| private readonly List<StackModel> _stack = new List<StackModel>(); | |||||
| public void set_controller(Graph @default) | public void set_controller(Graph @default) | ||||
| { | { | ||||
| if (!stack.Exists(x => x.Graph == @default)) | |||||
| stack.Add(new StackModel { Graph = @default, IsDefault = true }); | |||||
| if (!_stack.Exists(x => x.Graph == @default)) | |||||
| _stack.Add(new StackModel {Graph = @default, IsDefault = true}); | |||||
| foreach (var s in stack) | |||||
| foreach (var s in _stack) | |||||
| s.IsDefault = s.Graph == @default; | s.IsDefault = s.Graph == @default; | ||||
| } | } | ||||
| public Graph get_controller() | public Graph get_controller() | ||||
| { | { | ||||
| if (stack.Count(x => x.IsDefault) == 0) | |||||
| stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); | |||||
| if (_stack.Count(x => x.IsDefault) == 0) | |||||
| _stack.Add(new StackModel {Graph = tf.Graph(), IsDefault = true}); | |||||
| for (var i = _stack.Count - 1; i >= 0; i--) | |||||
| { | |||||
| var x = _stack[i]; | |||||
| if (x.IsDefault) | |||||
| return x.Graph; | |||||
| } | |||||
| return stack.Last(x => x.IsDefault).Graph; | |||||
| throw new TensorflowException("Unable to find a default graph"); | |||||
| } | } | ||||
| public bool remove(Graph g) | public bool remove(Graph g) | ||||
| { | { | ||||
| var sm = stack.FirstOrDefault(x => x.Graph == g); | |||||
| if (sm == null) return false; | |||||
| return stack.Remove(sm); | |||||
| if (_stack.Count == 0) | |||||
| return false; | |||||
| var sm = _stack.Find(model => model.Graph == g); | |||||
| return sm != null && _stack.Remove(sm); | |||||
| } | } | ||||
| public void reset() | public void reset() | ||||
| { | { | ||||
| stack.Clear(); | |||||
| _stack.Clear(); | |||||
| } | } | ||||
| } | |||||
| public class StackModel | |||||
| { | |||||
| public Graph Graph { get; set; } | |||||
| public bool IsDefault { get; set; } | |||||
| private class StackModel | |||||
| { | |||||
| public Graph Graph { get; set; } | |||||
| public bool IsDefault { get; set; } | |||||
| } | |||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| @@ -66,8 +67,9 @@ namespace Tensorflow | |||||
| /// within the context should have control dependencies on | /// within the context should have control dependencies on | ||||
| /// `control_inputs`. | /// `control_inputs`. | ||||
| /// </summary> | /// </summary> | ||||
| [SuppressMessage("ReSharper", "CoVariantArrayConversion")] | |||||
| public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | ||||
| => control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray()); | |||||
| => control_dependencies((object[])control_inputs); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns a context manager that specifies control dependencies. | /// Returns a context manager that specifies control dependencies. | ||||
| @@ -14,6 +14,9 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.IO; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class Graph | public partial class Graph | ||||
| @@ -23,21 +26,19 @@ namespace Tensorflow | |||||
| var buffer = new Buffer(); | var buffer = new Buffer(); | ||||
| c_api.TF_GraphToGraphDef(_handle, buffer, s); | c_api.TF_GraphToGraphDef(_handle, buffer, s); | ||||
| s.Check(true); | s.Check(true); | ||||
| // var def = GraphDef.Parser.ParseFrom(buffer); | |||||
| // buffer.Dispose(); | |||||
| return buffer; | return buffer; | ||||
| } | } | ||||
| private GraphDef _as_graph_def(bool add_shapes = false) | private GraphDef _as_graph_def(bool add_shapes = false) | ||||
| { | { | ||||
| var status = new Status(); | |||||
| var buffer = ToGraphDef(status); | |||||
| status.Check(true); | |||||
| status.Dispose(); | |||||
| var def = GraphDef.Parser.ParseFrom(buffer); | |||||
| buffer.Dispose(); | |||||
| GraphDef def; | |||||
| using (var status = new Status()) | |||||
| using (var buffer = ToGraphDef(status)) | |||||
| { | |||||
| status.Check(true); | |||||
| def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | |||||
| // Strip the experimental library field iff it's empty. | // Strip the experimental library field iff it's empty. | ||||
| // if(def.Library.Function.Count == 0) | // if(def.Library.Function.Count == 0) | ||||
| @@ -45,7 +46,7 @@ namespace Tensorflow | |||||
| return def; | return def; | ||||
| } | } | ||||
| public GraphDef as_graph_def(bool add_shapes = false) | |||||
| public GraphDef as_graph_def(bool add_shapes = false) | |||||
| => _as_graph_def(add_shapes); | => _as_graph_def(add_shapes); | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -30,11 +30,10 @@ namespace Tensorflow | |||||
| var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | ||||
| c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | ||||
| for (int i = 0; i < num_return_outputs; i++) | |||||
| { | |||||
| var handle = return_output_handle + i * size; | |||||
| return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle); | |||||
| } | |||||
| var tf_output_ptr = (TF_Output*) return_output_handle; | |||||
| for (int i = 0; i < num_return_outputs; i++) | |||||
| return_outputs[i] = *(tf_output_ptr + i); | |||||
| Marshal.FreeHGlobal(return_output_handle); | Marshal.FreeHGlobal(return_output_handle); | ||||
| @@ -18,6 +18,7 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -30,7 +31,7 @@ namespace Tensorflow | |||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| { | { | ||||
| c_api.TF_GraphGetOpDef(_handle, type, buffer, status); | c_api.TF_GraphGetOpDef(_handle, type, buffer, status); | ||||
| return OpDef.Parser.ParseFrom(buffer.Data); | |||||
| return OpDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -39,16 +40,20 @@ namespace Tensorflow | |||||
| return c_api.TF_NewOperation(_handle, opType, opName); | return c_api.TF_NewOperation(_handle, opType, opName); | ||||
| } | } | ||||
| public unsafe Operation[] ReturnOperations(IntPtr results) | |||||
| public Operation[] ReturnOperations(IntPtr results) | |||||
| { | { | ||||
| TF_Operation return_oper_handle = new TF_Operation(); | TF_Operation return_oper_handle = new TF_Operation(); | ||||
| int num_return_opers = 0; | int num_return_opers = 0; | ||||
| c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); | c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); | ||||
| Operation[] return_opers = new Operation[num_return_opers]; | Operation[] return_opers = new Operation[num_return_opers]; | ||||
| var tf_op_size = Marshal.SizeOf<TF_Operation>(); | |||||
| for (int i = 0; i < num_return_opers; i++) | for (int i = 0; i < num_return_opers; i++) | ||||
| { | { | ||||
| var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i; | |||||
| return_opers[i] = new Operation(*(IntPtr*)handle); | |||||
| unsafe | |||||
| { | |||||
| var handle = return_oper_handle.node + tf_op_size * i; | |||||
| return_opers[i] = new Operation(*(IntPtr*)handle); | |||||
| } | |||||
| } | } | ||||
| return return_opers; | return return_opers; | ||||
| @@ -67,7 +72,7 @@ namespace Tensorflow | |||||
| public ITensorOrOperation[] get_operations() | public ITensorOrOperation[] get_operations() | ||||
| { | { | ||||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||||
| return _nodes_by_name.Values.ToArray(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -81,7 +86,7 @@ namespace Tensorflow | |||||
| public ITensorOrOperation _get_operation_by_name_unsafe(string name) | public ITensorOrOperation _get_operation_by_name_unsafe(string name) | ||||
| { | { | ||||
| return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null; | |||||
| return _nodes_by_name.TryGetValue(name, out var val) ? val : null; | |||||
| } | } | ||||
| public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) | public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) | ||||
| @@ -369,7 +369,7 @@ namespace Tensorflow | |||||
| var name_key = name.ToLower(); | var name_key = name.ToLower(); | ||||
| int i = 0; | int i = 0; | ||||
| if (_names_in_use.ContainsKey(name_key)) | if (_names_in_use.ContainsKey(name_key)) | ||||
| i = _names_in_use[name_key]; | |||||
| i = _names_in_use[name_key]; | |||||
| // Increment the number for "name_key". | // Increment the number for "name_key". | ||||
| if (mark_as_used) | if (mark_as_used) | ||||
| _names_in_use[name_key] = i + 1; | _names_in_use[name_key] = i + 1; | ||||
| @@ -399,13 +399,13 @@ namespace Tensorflow | |||||
| int num_return_outputs = 0; | int num_return_outputs = 0; | ||||
| c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); | c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); | ||||
| TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | ||||
| for (int i = 0; i < num_return_outputs; i++) | |||||
| unsafe | |||||
| { | { | ||||
| var handle = return_output_handle + (Marshal.SizeOf<TF_Output>() * i); | |||||
| return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle); | |||||
| var tf_output_ptr = (TF_Output*) return_output_handle; | |||||
| for (int i = 0; i < num_return_outputs; i++) | |||||
| return_outputs[i] = *(tf_output_ptr + i); | |||||
| return return_outputs; | |||||
| } | } | ||||
| return return_outputs; | |||||
| } | } | ||||
| public string[] get_all_collection_keys() | public string[] get_all_collection_keys() | ||||
| @@ -497,11 +497,9 @@ namespace Tensorflow | |||||
| IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | ||||
| => GetEnumerable().GetEnumerator(); | => GetEnumerable().GetEnumerator(); | ||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| => throw new NotImplementedException(); | |||||
| public static implicit operator IntPtr(Graph graph) | public static implicit operator IntPtr(Graph graph) | ||||
| { | { | ||||
| return graph._handle; | return graph._handle; | ||||
| @@ -20,7 +20,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class ImportGraphDefOptions : DisposableObject | public class ImportGraphDefOptions : DisposableObject | ||||
| { | { | ||||
| public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||||
| public int NumReturnOutputs | |||||
| => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||||
| public ImportGraphDefOptions() | public ImportGraphDefOptions() | ||||
| { | { | ||||
| @@ -50,14 +50,12 @@ namespace Tensorflow | |||||
| public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | ||||
| { | { | ||||
| int size = Marshal.SizeOf<TF_Input>(); | |||||
| var handle = Marshal.AllocHGlobal(size); | |||||
| var handle = Marshal.AllocHGlobal(Marshal.SizeOf<TF_Input>()); | |||||
| int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | ||||
| var consumers = new TF_Input[num]; | var consumers = new TF_Input[num]; | ||||
| var inputptr = (TF_Input*) handle; | |||||
| for (int i = 0; i < num; i++) | for (int i = 0; i < num; i++) | ||||
| { | |||||
| consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||||
| } | |||||
| consumers[i] = *(inputptr + i); | |||||
| return consumers; | return consumers; | ||||
| } | } | ||||
| @@ -17,7 +17,9 @@ | |||||
| using Google.Protobuf.Collections; | using Google.Protobuf.Collections; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -226,9 +228,12 @@ namespace Tensorflow | |||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| using (var buf = new Buffer()) | using (var buf = new Buffer()) | ||||
| { | { | ||||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||||
| status.Check(true); | |||||
| x = AttrValue.Parser.ParseFrom(buf); | |||||
| unsafe | |||||
| { | |||||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||||
| status.Check(true); | |||||
| x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | |||||
| } | |||||
| } | } | ||||
| string oneof_value = x.ValueCase.ToString(); | string oneof_value = x.ValueCase.ToString(); | ||||
| @@ -259,7 +264,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| c_api.TF_OperationToNodeDef(_handle, buffer, s); | c_api.TF_OperationToNodeDef(_handle, buffer, s); | ||||
| s.Check(); | s.Check(); | ||||
| return NodeDef.Parser.ParseFrom(buffer); | |||||
| return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -299,8 +304,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public TF_Output _tf_output(int output_idx) | public TF_Output _tf_output(int output_idx) | ||||
| { | { | ||||
| var tf_output = new TF_Output(op, output_idx); | |||||
| return tf_output; | |||||
| return new TF_Output(op, output_idx); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -308,8 +312,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public TF_Input _tf_input(int input_idx) | public TF_Input _tf_input(int input_idx) | ||||
| { | { | ||||
| var tf_input = new TF_Input(op, input_idx); | |||||
| return tf_input; | |||||
| return new TF_Input(op, input_idx); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,413 +1,413 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Numerics; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class BaseSession : DisposableObject | |||||
| { | |||||
| protected Graph _graph; | |||||
| protected bool _opened; | |||||
| protected bool _closed; | |||||
| protected int _current_version; | |||||
| protected byte[] _target; | |||||
| public Graph graph => _graph; | |||||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||||
| { | |||||
| _graph = g is null ? ops.get_default_graph() : g; | |||||
| _graph.as_default(); | |||||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||||
| SessionOptions newOpts = null; | |||||
| if (opts == null) | |||||
| newOpts = new SessionOptions(); | |||||
| var status = new Status(); | |||||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||||
| // dispose newOpts | |||||
| if (opts == null) | |||||
| newOpts.Dispose(); | |||||
| status.Check(true); | |||||
| } | |||||
| public virtual void run(Operation op, params FeedItem[] feed_dict) | |||||
| { | |||||
| _run(op, feed_dict); | |||||
| } | |||||
| public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||||
| { | |||||
| return _run(fetche, feed_dict)[0]; | |||||
| } | |||||
| public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||||
| { | |||||
| return _run(fetche, feed_dict)[0]; | |||||
| } | |||||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||||
| return (results[0], results[1], results[2], results[3]); | |||||
| } | |||||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||||
| return (results[0], results[1], results[2]); | |||||
| } | |||||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||||
| return (results[0], results[1]); | |||||
| } | |||||
| public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| return _run(fetches, feed_dict); | |||||
| } | |||||
| public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||||
| { | |||||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||||
| return _run(fetches, feed_items); | |||||
| } | |||||
| private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||||
| { | |||||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||||
| var feed_map = new Dictionary<object, object>(); | |||||
| Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) => | |||||
| { | |||||
| return new (object, object)[] { (item.Key, item.Value) }; | |||||
| }; | |||||
| // Validate and process feed_dict. | |||||
| if (feed_dict != null) | |||||
| { | |||||
| foreach (var feed in feed_dict) | |||||
| { | |||||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | |||||
| { | |||||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Create a fetch handler to take care of the structure of fetches. | |||||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||||
| // Run request and get response. | |||||
| // We need to keep the returned movers alive for the following _do_run(). | |||||
| // These movers are no longer needed when _do_run() completes, and | |||||
| // are deleted when `movers` goes out of scope when this _run() ends. | |||||
| var _ = _update_with_movers(); | |||||
| var final_fetches = fetch_handler.fetches(); | |||||
| var final_targets = fetch_handler.targets(); | |||||
| // We only want to really perform the run if fetches or targets are provided, | |||||
| // or if the call is a partial run that specifies feeds. | |||||
| var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||||
| return fetch_handler.build_results(this, results); | |||||
| } | |||||
| /// <summary> | |||||
| /// Runs a step based on the given fetches and feeds. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||||
| /// <param name="fetch_list"></param> | |||||
| /// <param name="feed_dict"></param> | |||||
| /// <returns> | |||||
| /// A list of numpy ndarrays, corresponding to the elements of | |||||
| /// `fetch_list`. If the ith element of `fetch_list` contains the | |||||
| /// name of an operation, the first Tensor output of that operation | |||||
| /// will be returned for that element. | |||||
| /// </returns> | |||||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||||
| { | |||||
| var feeds = feed_dict.Select(x => | |||||
| { | |||||
| if (x.Key is Tensor tensor) | |||||
| { | |||||
| switch (x.Value) | |||||
| { | |||||
| #if _REGEN | |||||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||||
| %foreach types% | |||||
| case #1 v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case #1[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| % | |||||
| #else | |||||
| case sbyte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case sbyte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| #endif | |||||
| case bool v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||||
| case string v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case IntPtr v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Tensor v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||||
| case NDArray v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||||
| default: | |||||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||||
| } | |||||
| } | |||||
| throw new NotImplementedException("_do_run.feed_dict"); | |||||
| }).ToArray(); | |||||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||||
| var targets = target_list; | |||||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||||
| } | |||||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||||
| { | |||||
| // Ensure any changes to the graph are reflected in the runtime. | |||||
| _extend_graph(); | |||||
| var status = new Status(); | |||||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||||
| c_api.TF_SessionRun(_handle, | |||||
| run_options: null, | |||||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | |||||
| input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||||
| ninputs: feed_dict.Length, | |||||
| outputs: fetch_list, | |||||
| output_values: output_values, | |||||
| noutputs: fetch_list.Length, | |||||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||||
| ntargets: target_list.Count, | |||||
| run_metadata: IntPtr.Zero, | |||||
| status: status); | |||||
| status.Check(true); | |||||
| var result = new NDArray[fetch_list.Length]; | |||||
| for (int i = 0; i < fetch_list.Length; i++) | |||||
| result[i] = fetchValue(output_values[i]); | |||||
| for (int i = 0; i < feed_dict.Length; i++) | |||||
| feed_dict[i].Value.Dispose(); | |||||
| return result; | |||||
| } | |||||
| private unsafe NDArray fetchValue(IntPtr output) | |||||
| { | |||||
| var tensor = new Tensor(output); | |||||
| NDArray nd = null; | |||||
| Type type = tensor.dtype.as_numpy_dtype(); | |||||
| var ndims = tensor.shape; | |||||
| var offset = c_api.TF_TensorData(output); | |||||
| if(ndims.Length == 0) | |||||
| { | |||||
| switch (tensor.dtype) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| nd = NDArray.Scalar(*(bool*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_STRING: | |||||
| var bytes = tensor.BufferToArray(); | |||||
| // wired, don't know why we have to start from offset 9. | |||||
| // length in the begin | |||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||||
| nd = NDArray.FromString(str); | |||||
| break; | |||||
| case TF_DataType.TF_UINT8: | |||||
| nd = NDArray.Scalar(*(byte*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT16: | |||||
| nd = NDArray.Scalar(*(short*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT32: | |||||
| nd = NDArray.Scalar(*(int*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT64: | |||||
| nd = NDArray.Scalar(*(long*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| nd = NDArray.Scalar(*(float*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| nd = NDArray.Scalar(*(double*)offset); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("can't fetch output"); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| switch (tensor.dtype) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| var bools = new bool[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(bools).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_STRING: | |||||
| var bytes = tensor.BufferToArray(); | |||||
| // wired, don't know why we have to start from offset 9. | |||||
| // length in the begin | |||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||||
| nd = np.array(str); | |||||
| break; | |||||
| case TF_DataType.TF_UINT8: | |||||
| var _bytes = new byte[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(_bytes).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT16: | |||||
| var shorts = new short[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(shorts).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT32: | |||||
| var ints = new int[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(ints).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT64: | |||||
| var longs = new long[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(longs).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| var floats = new float[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(floats).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| var doubles = new double[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(doubles).reshape(ndims); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("can't fetch output"); | |||||
| } | |||||
| } | |||||
| tensor.Dispose(); | |||||
| return nd; | |||||
| } | |||||
| /// <summary> | |||||
| /// If a tensor handle that is fed to a device incompatible placeholder, | |||||
| /// we move the tensor to the right device, generate a new tensor handle, | |||||
| /// and update feed_dict to use the new handle. | |||||
| /// </summary> | |||||
| private List<object> _update_with_movers() | |||||
| { | |||||
| return new List<object> { }; | |||||
| } | |||||
| private void _extend_graph() | |||||
| { | |||||
| } | |||||
| public void close() | |||||
| { | |||||
| Dispose(); | |||||
| } | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| { | |||||
| using (var status = new Status()) | |||||
| { | |||||
| c_api.TF_DeleteSession(handle, status); | |||||
| status.Check(true); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Numerics; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class BaseSession : DisposableObject | |||||
| { | |||||
| protected Graph _graph; | |||||
| protected bool _opened; | |||||
| protected bool _closed; | |||||
| protected int _current_version; | |||||
| protected byte[] _target; | |||||
| public Graph graph => _graph; | |||||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||||
| { | |||||
| _graph = g is null ? ops.get_default_graph() : g; | |||||
| _graph.as_default(); | |||||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||||
| SessionOptions newOpts = null; | |||||
| if (opts == null) | |||||
| newOpts = new SessionOptions(); | |||||
| var status = new Status(); | |||||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||||
| // dispose newOpts | |||||
| if (opts == null) | |||||
| newOpts.Dispose(); | |||||
| status.Check(true); | |||||
| } | |||||
| public virtual void run(Operation op, params FeedItem[] feed_dict) | |||||
| { | |||||
| _run(op, feed_dict); | |||||
| } | |||||
| public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||||
| { | |||||
| return _run(fetche, feed_dict)[0]; | |||||
| } | |||||
| public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||||
| { | |||||
| return _run(fetche, feed_dict)[0]; | |||||
| } | |||||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||||
| return (results[0], results[1], results[2], results[3]); | |||||
| } | |||||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||||
| return (results[0], results[1], results[2]); | |||||
| } | |||||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||||
| return (results[0], results[1]); | |||||
| } | |||||
| public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||||
| { | |||||
| return _run(fetches, feed_dict); | |||||
| } | |||||
| public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||||
| { | |||||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||||
| return _run(fetches, feed_items); | |||||
| } | |||||
| private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||||
| { | |||||
| var feed_dict_tensor = new Dictionary<object, object>(); | |||||
| var feed_map = new Dictionary<object, object>(); | |||||
| Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) => | |||||
| { | |||||
| return new (object, object)[] { (item.Key, item.Value) }; | |||||
| }; | |||||
| // Validate and process feed_dict. | |||||
| if (feed_dict != null) | |||||
| { | |||||
| foreach (var feed in feed_dict) | |||||
| { | |||||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | |||||
| { | |||||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Create a fetch handler to take care of the structure of fetches. | |||||
| var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||||
| // Run request and get response. | |||||
| // We need to keep the returned movers alive for the following _do_run(). | |||||
| // These movers are no longer needed when _do_run() completes, and | |||||
| // are deleted when `movers` goes out of scope when this _run() ends. | |||||
| var _ = _update_with_movers(); | |||||
| var final_fetches = fetch_handler.fetches(); | |||||
| var final_targets = fetch_handler.targets(); | |||||
| // We only want to really perform the run if fetches or targets are provided, | |||||
| // or if the call is a partial run that specifies feeds. | |||||
| var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||||
| return fetch_handler.build_results(this, results); | |||||
| } | |||||
| /// <summary> | |||||
| /// Runs a step based on the given fetches and feeds. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||||
| /// <param name="fetch_list"></param> | |||||
| /// <param name="feed_dict"></param> | |||||
| /// <returns> | |||||
| /// A list of numpy ndarrays, corresponding to the elements of | |||||
| /// `fetch_list`. If the ith element of `fetch_list` contains the | |||||
| /// name of an operation, the first Tensor output of that operation | |||||
| /// will be returned for that element. | |||||
| /// </returns> | |||||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||||
| { | |||||
| var feeds = feed_dict.Select(x => | |||||
| { | |||||
| if (x.Key is Tensor tensor) | |||||
| { | |||||
| switch (x.Value) | |||||
| { | |||||
| #if _REGEN | |||||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||||
| %foreach types% | |||||
| case #1 v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case #1[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| % | |||||
| #else | |||||
| case sbyte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case sbyte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| #endif | |||||
| case bool v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||||
| case string v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case IntPtr v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Tensor v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||||
| case NDArray v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||||
| default: | |||||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||||
| } | |||||
| } | |||||
| throw new NotImplementedException("_do_run.feed_dict"); | |||||
| }).ToArray(); | |||||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||||
| var targets = target_list; | |||||
| return _call_tf_sessionrun(feeds, fetches, target_list); | |||||
| } | |||||
| private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||||
| { | |||||
| // Ensure any changes to the graph are reflected in the runtime. | |||||
| _extend_graph(); | |||||
| var status = new Status(); | |||||
| var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||||
| c_api.TF_SessionRun(_handle, | |||||
| run_options: null, | |||||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | |||||
| input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||||
| ninputs: feed_dict.Length, | |||||
| outputs: fetch_list, | |||||
| output_values: output_values, | |||||
| noutputs: fetch_list.Length, | |||||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||||
| ntargets: target_list.Count, | |||||
| run_metadata: IntPtr.Zero, | |||||
| status: status); | |||||
| status.Check(true); | |||||
| var result = new NDArray[fetch_list.Length]; | |||||
| for (int i = 0; i < fetch_list.Length; i++) | |||||
| result[i] = fetchValue(output_values[i]); | |||||
| for (int i = 0; i < feed_dict.Length; i++) | |||||
| feed_dict[i].Value.Dispose(); | |||||
| return result; | |||||
| } | |||||
| private unsafe NDArray fetchValue(IntPtr output) | |||||
| { | |||||
| var tensor = new Tensor(output); | |||||
| NDArray nd = null; | |||||
| Type type = tensor.dtype.as_numpy_dtype(); | |||||
| var ndims = tensor.shape; | |||||
| var offset = c_api.TF_TensorData(output); | |||||
| if(ndims.Length == 0) | |||||
| { | |||||
| switch (tensor.dtype) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| nd = NDArray.Scalar(*(bool*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_STRING: | |||||
| var bytes = tensor.BufferToArray(); | |||||
| // wired, don't know why we have to start from offset 9. | |||||
| // length in the begin | |||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||||
| nd = NDArray.FromString(str); | |||||
| break; | |||||
| case TF_DataType.TF_UINT8: | |||||
| nd = NDArray.Scalar(*(byte*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT16: | |||||
| nd = NDArray.Scalar(*(short*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT32: | |||||
| nd = NDArray.Scalar(*(int*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_INT64: | |||||
| nd = NDArray.Scalar(*(long*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| nd = NDArray.Scalar(*(float*)offset); | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| nd = NDArray.Scalar(*(double*)offset); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("can't fetch output"); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| switch (tensor.dtype) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| var bools = new bool[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(bools).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_STRING: | |||||
| var bytes = tensor.BufferToArray(); | |||||
| // wired, don't know why we have to start from offset 9. | |||||
| // length in the begin | |||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||||
| nd = np.array(str); | |||||
| break; | |||||
| case TF_DataType.TF_UINT8: | |||||
| var _bytes = new byte[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(_bytes).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT16: | |||||
| var shorts = new short[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(shorts).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT32: | |||||
| var ints = new int[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(ints).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_INT64: | |||||
| var longs = new long[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(longs).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| var floats = new float[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(floats).reshape(ndims); | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| var doubles = new double[tensor.size]; | |||||
| for (ulong i = 0; i < tensor.size; i++) | |||||
| doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); | |||||
| nd = np.array(doubles).reshape(ndims); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("can't fetch output"); | |||||
| } | |||||
| } | |||||
| tensor.Dispose(); | |||||
| return nd; | |||||
| } | |||||
| /// <summary> | |||||
| /// If a tensor handle that is fed to a device incompatible placeholder, | |||||
| /// we move the tensor to the right device, generate a new tensor handle, | |||||
| /// and update feed_dict to use the new handle. | |||||
| /// </summary> | |||||
| private List<object> _update_with_movers() | |||||
| { | |||||
| return new List<object> { }; | |||||
| } | |||||
| private void _extend_graph() | |||||
| { | |||||
| } | |||||
| public void close() | |||||
| { | |||||
| Dispose(); | |||||
| } | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| { | |||||
| using (var status = new Status()) | |||||
| { | |||||
| c_api.TF_DeleteSession(handle, status); | |||||
| status.Check(true); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -16,5 +16,11 @@ | |||||
| public static implicit operator FeedItem((object, object) feed) | public static implicit operator FeedItem((object, object) feed) | ||||
| => new FeedItem(feed.Item1, feed.Item2); | => new FeedItem(feed.Item1, feed.Item2); | ||||
| public void Deconstruct(out object key, out object value) | |||||
| { | |||||
| key = Key; | |||||
| value = Value; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -37,8 +37,8 @@ namespace Tensorflow | |||||
| public void SetConfig(ConfigProto config) | public void SetConfig(ConfigProto config) | ||||
| { | { | ||||
| var bytes = config.ToByteArray(); | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||||
| var bytes = config.ToByteArray(); //TODO! we can use WriteTo | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak | |||||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | Marshal.Copy(bytes, 0, proto, bytes.Length); | ||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| @@ -27,13 +27,17 @@ namespace Tensorflow | |||||
| var handle = Marshal.AllocHGlobal(size * num_consumers); | var handle = Marshal.AllocHGlobal(size * num_consumers); | ||||
| int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers); | int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers); | ||||
| var consumers = new string[num_consumers]; | var consumers = new string[num_consumers]; | ||||
| for (int i = 0; i < num; i++) | |||||
| unsafe | |||||
| { | { | ||||
| TF_Input input = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||||
| consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(input.oper)); | |||||
| var inputptr = (TF_Input*) handle; | |||||
| for (int i = 0; i < num; i++) | |||||
| { | |||||
| var oper = (inputptr + i)->oper; | |||||
| consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper)); | |||||
| } | |||||
| } | } | ||||
| return consumers; | return consumers; | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -15,6 +15,8 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Runtime.CompilerServices; | |||||
| using static Tensorflow.c_api; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -27,36 +29,36 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Error message | /// Error message | ||||
| /// </summary> | /// </summary> | ||||
| public string Message => c_api.StringPiece(c_api.TF_Message(_handle)); | |||||
| public string Message => c_api.StringPiece(TF_Message(_handle)); | |||||
| /// <summary> | /// <summary> | ||||
| /// Error code | /// Error code | ||||
| /// </summary> | /// </summary> | ||||
| public TF_Code Code => c_api.TF_GetCode(_handle); | |||||
| public TF_Code Code => TF_GetCode(_handle); | |||||
| public Status() | public Status() | ||||
| { | { | ||||
| _handle = c_api.TF_NewStatus(); | |||||
| _handle = TF_NewStatus(); | |||||
| } | } | ||||
| public void SetStatus(TF_Code code, string msg) | public void SetStatus(TF_Code code, string msg) | ||||
| { | { | ||||
| c_api.TF_SetStatus(_handle, code, msg); | |||||
| TF_SetStatus(_handle, code, msg); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Check status | /// Check status | ||||
| /// Throw exception with error message if code != TF_OK | /// Throw exception with error message if code != TF_OK | ||||
| /// </summary> | /// </summary> | ||||
| /// <exception cref="TensorflowException">When the returned check is not TF_Code.TF_OK</exception> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public void Check(bool throwException = false) | public void Check(bool throwException = false) | ||||
| { | { | ||||
| if (Code != TF_Code.TF_OK) | if (Code != TF_Code.TF_OK) | ||||
| { | { | ||||
| Console.WriteLine(Message); | Console.WriteLine(Message); | ||||
| if (throwException) | if (throwException) | ||||
| { | |||||
| throw new Exception(Message); | |||||
| } | |||||
| throw new TensorflowException(Message); | |||||
| } | } | ||||
| } | } | ||||
| @@ -66,6 +68,6 @@ namespace Tensorflow | |||||
| } | } | ||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
| => c_api.TF_DeleteStatus(handle); | |||||
| => TF_DeleteStatus(handle); | |||||
| } | } | ||||
| } | } | ||||
| @@ -51,7 +51,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern IntPtr TF_NewStatus(); | |||||
| public static extern IntPtr TF_NewStatus(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Record <code, msg> in *s. Any previous information is lost. | /// Record <code, msg> in *s. Any previous information is lost. | ||||
| @@ -52,9 +52,9 @@ namespace Tensorflow | |||||
| private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero }; | private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero }; | ||||
| // note: they must be assigned to a static variable in order to work as unmanaged callbacks | // note: they must be assigned to a static variable in order to work as unmanaged callbacks | ||||
| static Deallocator _hGlobalDeallocator = FreeHGlobalMemory; | |||||
| static Deallocator _gcHandleDeallocator = FreeGCHandle; | |||||
| private static Deallocator _nothingDeallocator = FreeNothing; | |||||
| private static readonly Deallocator _hGlobalDeallocator = FreeHGlobalMemory; | |||||
| private static readonly Deallocator _gcHandleDeallocator = FreeGCHandle; | |||||
| private static readonly Deallocator _nothingDeallocator = FreeNothing; | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a Tensor object from an existing TF handle | /// Create a Tensor object from an existing TF handle | ||||
| @@ -528,7 +528,6 @@ namespace Tensorflow | |||||
| } | } | ||||
| _handle = CreateTensorFromNDArray(nd, tensorDType); | _handle = CreateTensorFromNDArray(nd, tensorDType); | ||||
| IsMemoryOwner = true; | |||||
| } | } | ||||
| private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) | private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) | ||||
| @@ -624,7 +623,7 @@ namespace Tensorflow | |||||
| Marshal.WriteInt64(tensor, 0); | Marshal.WriteInt64(tensor, 0); | ||||
| var status = new Status(); | var status = new Status(); | ||||
| fixed (byte* src = &buffer[0]) | |||||
| fixed (byte* src = buffer) | |||||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| @@ -667,8 +666,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| if (args.deallocator_called) | if (args.deallocator_called) | ||||
| return; | return; | ||||
| // NumSharp will dispose | // NumSharp will dispose | ||||
| // Marshal.FreeHGlobal(dataPtr); | |||||
| Marshal.FreeHGlobal(dataPtr); | |||||
| args.deallocator_called = true; | args.deallocator_called = true; | ||||
| } | } | ||||
| @@ -221,15 +221,6 @@ namespace Tensorflow | |||||
| /// <exception cref="ArgumentException">When <typeparam name="T"> is string </typeparam></exception> | /// <exception cref="ArgumentException">When <typeparam name="T"> is string </typeparam></exception> | ||||
| public T[] ToArray<T>() where T : unmanaged | public T[] ToArray<T>() where T : unmanaged | ||||
| { | { | ||||
| //when T is string | |||||
| if (typeof(T) == typeof(string)) | |||||
| { | |||||
| if (dtype != TF_DataType.TF_STRING) | |||||
| throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string."); | |||||
| return (T[]) (object) StringData(); | |||||
| } | |||||
| //Are the types matching? | //Are the types matching? | ||||
| if (typeof(T).as_dtype() == dtype) | if (typeof(T).as_dtype() == dtype) | ||||
| { | { | ||||
| @@ -246,20 +237,12 @@ namespace Tensorflow | |||||
| unsafe | unsafe | ||||
| { | { | ||||
| var len = (long) size; | var len = (long) size; | ||||
| fixed (T* dstRet = ret) | |||||
| fixed (T* dst = ret) | |||||
| { | { | ||||
| T* dst = dstRet; //local stack copy | |||||
| if (typeof(T).IsPrimitive) | |||||
| { | |||||
| var src = (T*) buffer; | |||||
| len *= ((long) itemsize); | |||||
| System.Buffer.MemoryCopy(src, dst, len, len); | |||||
| } else | |||||
| { | |||||
| var itemsize = (long) this.itemsize; | |||||
| var buffer = this.buffer.ToInt64(); | |||||
| Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure<T>(new IntPtr(buffer + i * itemsize))); | |||||
| } | |||||
| //T can only be unmanaged, I believe it is safe to say that MemoryCopy is valid for all cases this method can be called. | |||||
| var src = (T*) buffer; | |||||
| len *= ((long) itemsize); | |||||
| System.Buffer.MemoryCopy(src, dst, len, len); | |||||
| } | } | ||||
| } | } | ||||
| @@ -384,9 +367,15 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| /// Used internally in ToArray<T> | |||||
| private unsafe string[] StringData() | |||||
| /// <summary> | |||||
| /// Extracts string array from current Tensor. | |||||
| /// </summary> | |||||
| /// <exception cref="InvalidOperationException">When <see cref="dtype"/> != TF_DataType.TF_STRING</exception> | |||||
| public unsafe string[] StringData() | |||||
| { | { | ||||
| if (dtype != TF_DataType.TF_STRING) | |||||
| throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); | |||||
| // | // | ||||
| // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | ||||
| // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] | // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] | ||||
| @@ -442,7 +431,7 @@ namespace Tensorflow | |||||
| /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | ||||
| /// <param name="session">The `Session` to be used to evaluate this tensor.</param> | /// <param name="session">The `Session` to be used to evaluate this tensor.</param> | ||||
| /// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns> | /// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns> | ||||
| public NDArray eval(Session session, FeedItem[] feed_dict = null) | |||||
| public NDArray eval(Session session, params FeedItem[] feed_dict) | |||||
| { | { | ||||
| return ops._eval_using_default_session(this, feed_dict, graph, session); | return ops._eval_using_default_session(this, feed_dict, graph, session); | ||||
| } | } | ||||
| @@ -568,23 +557,10 @@ namespace Tensorflow | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
| { | { | ||||
| if (handle != IntPtr.Zero) | |||||
| { | |||||
| c_api.TF_DeleteTensor(handle); | |||||
| _handle = IntPtr.Zero; | |||||
| } | |||||
| c_api.TF_DeleteTensor(handle); | |||||
| } | } | ||||
| public bool IsDisposed | |||||
| { | |||||
| get | |||||
| { | |||||
| lock (this) | |||||
| { | |||||
| return _handle == IntPtr.Zero; | |||||
| } | |||||
| } | |||||
| } | |||||
| public bool IsDisposed => _disposed; | |||||
| public int tensor_int_val { get; set; } | public int tensor_int_val { get; set; } | ||||
| } | } | ||||
| @@ -83,6 +83,12 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("MakeNdarray"); | throw new NotImplementedException("MakeNdarray"); | ||||
| } | } | ||||
| private static readonly TF_DataType[] quantized_types = new TF_DataType[] | |||||
| { | |||||
| TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, | |||||
| TF_DataType.TF_QINT32 | |||||
| }; | |||||
| /// <summary> | /// <summary> | ||||
| /// Create a TensorProto. | /// Create a TensorProto. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -99,15 +105,6 @@ namespace Tensorflow | |||||
| if (values is TensorProto tp) | if (values is TensorProto tp) | ||||
| return tp; | return tp; | ||||
| if (dtype != TF_DataType.DtInvalid) | |||||
| ; | |||||
| bool is_quantized = new TF_DataType[] | |||||
| { | |||||
| TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, | |||||
| TF_DataType.TF_QINT32 | |||||
| }.Contains(dtype); | |||||
| // We first convert value to a numpy array or scalar. | // We first convert value to a numpy array or scalar. | ||||
| NDArray nparray = null; | NDArray nparray = null; | ||||
| var np_dt = dtype.as_numpy_dtype(); | var np_dt = dtype.as_numpy_dtype(); | ||||
| @@ -227,13 +224,13 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| var numpy_dtype = dtypes.as_dtype(nparray.dtype, dtype: dtype); | |||||
| var numpy_dtype = nparray.dtype.as_dtype(dtype: dtype); | |||||
| if (numpy_dtype == TF_DataType.DtInvalid) | if (numpy_dtype == TF_DataType.DtInvalid) | ||||
| throw new TypeError($"Unrecognized data type: {nparray.dtype}"); | throw new TypeError($"Unrecognized data type: {nparray.dtype}"); | ||||
| // If dtype was specified and is a quantized type, we convert | // If dtype was specified and is a quantized type, we convert | ||||
| // numpy_dtype back into the quantized version. | // numpy_dtype back into the quantized version. | ||||
| if (is_quantized) | |||||
| if (quantized_types.Contains(dtype)) | |||||
| numpy_dtype = dtype; | numpy_dtype = dtype; | ||||
| bool is_same_size = false; | bool is_same_size = false; | ||||
| @@ -0,0 +1,94 @@ | |||||
| using System; | |||||
| using System.IO; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Runtime.InteropServices; | |||||
| using NumSharp.Backends.Unmanaged; | |||||
| namespace Tensorflow.Util | |||||
| { | |||||
| public static class UnmanagedExtensions | |||||
| { | |||||
| //internally UnmanagedMemoryStream can't construct with null address. | |||||
| private static readonly unsafe byte* _empty = (byte*) Marshal.AllocHGlobal(1); | |||||
| /// <summary> | |||||
| /// Creates a memory stream based on given <paramref name="block"/>. | |||||
| /// </summary> | |||||
| /// <param name="block">The block to stream. Can be default/null.</param> | |||||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock<byte> block) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| if (block.Address == null) | |||||
| return new UnmanagedMemoryStream(_empty, 0); | |||||
| return new UnmanagedMemoryStream(block.Address, block.BytesCount); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates a memory stream based on given <paramref name="block"/>. | |||||
| /// </summary> | |||||
| /// <param name="block">The block to stream. Can be default/null.</param> | |||||
| /// <param name="offset">Offset from the start of the block.</param> | |||||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock<byte> block, long offset) | |||||
| { | |||||
| if (block.BytesCount - offset <= 0) | |||||
| throw new ArgumentOutOfRangeException(nameof(offset)); | |||||
| unsafe | |||||
| { | |||||
| if (block.Address == null) | |||||
| return new UnmanagedMemoryStream(_empty, 0); | |||||
| return new UnmanagedMemoryStream(block.Address + offset, block.BytesCount - offset); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates a memory stream based on given <paramref name="address"/>. | |||||
| /// </summary> | |||||
| /// <param name="address">The block to stream. Can be IntPtr.Zero.</param> | |||||
| /// <param name="length">The length of the block in bytes.</param> | |||||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public static UnmanagedMemoryStream Stream(this IntPtr address, long length) | |||||
| { | |||||
| if (length <= 0) | |||||
| throw new ArgumentOutOfRangeException(nameof(length)); | |||||
| unsafe | |||||
| { | |||||
| if (address == IntPtr.Zero) | |||||
| return new UnmanagedMemoryStream(_empty, 0); | |||||
| // ReSharper disable once AssignNullToNotNullAttribute | |||||
| return new UnmanagedMemoryStream((byte*) address, length); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates a memory stream based on given <paramref name="address"/>. | |||||
| /// </summary> | |||||
| /// <param name="address">The block to stream. Can be IntPtr.Zero.</param> | |||||
| /// <param name="offset">Offset from the start of the block.</param> | |||||
| /// <param name="length">The length of the block in bytes.</param> | |||||
| /// <remarks>There is no need to dispose the returned <see cref="UnmanagedMemoryStream"/></remarks> | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public static UnmanagedMemoryStream Stream(this IntPtr address, long offset, long length) | |||||
| { | |||||
| if (length <= 0) | |||||
| throw new ArgumentOutOfRangeException(nameof(length)); | |||||
| unsafe | |||||
| { | |||||
| if (address == IntPtr.Zero) | |||||
| return new UnmanagedMemoryStream(_empty, 0); | |||||
| return new UnmanagedMemoryStream((byte*) address + offset, length); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -8,7 +8,8 @@ | |||||
| %supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"] | %supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"] | ||||
| %supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | %supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | ||||
| %supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"] | %supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"] | ||||
| %supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||||
| %supported_numericals_TF_DataType = ["TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] | |||||
| %supported_numericals_TF_DataType_full = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||||
| //this is the type we use in summerizing/reducting: | //this is the type we use in summerizing/reducting: | ||||
| %supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] | %supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] | ||||
| @@ -25,7 +26,8 @@ | |||||
| %supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"] | %supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"] | ||||
| %supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] | %supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] | ||||
| %supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||||
| %supported_dtypes_TF_DataType = ["TF_BOOL","TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] | |||||
| %supported_dtypes_TF_DataType_full = ["TF_DataType.TF_BOOL","TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||||
| %supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"] | %supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"] | ||||
| %supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | %supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | ||||
| @@ -230,8 +230,8 @@ namespace Tensorflow | |||||
| // Add attrs | // Add attrs | ||||
| foreach (var attr in node_def.Attr) | foreach (var attr in node_def.Attr) | ||||
| { | { | ||||
| var bytes = attr.Value.ToByteArray(); | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | Marshal.Copy(bytes, 0, proto, bytes.Length); | ||||
| uint len = (uint)bytes.Length; | uint len = (uint)bytes.Length; | ||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | ||||
| @@ -64,8 +64,7 @@ namespace Tensorflow | |||||
| public Session Session() | public Session Session() | ||||
| { | { | ||||
| defaultSession = new Session(); | |||||
| return defaultSession; | |||||
| return new Session(); | |||||
| } | } | ||||
| public Session Session(Graph graph, SessionOptions opts = null) | public Session Session(Graph graph, SessionOptions opts = null) | ||||
| @@ -9,24 +9,18 @@ namespace TensorFlowBenchmark | |||||
| { | { | ||||
| static void Main(string[] args) | static void Main(string[] args) | ||||
| { | { | ||||
| #if DEBUG | |||||
| IConfig config = new DebugInProcessConfig(); | |||||
| #else | |||||
| IConfig config = null; | |||||
| #endif | |||||
| if (args?.Length > 0) | if (args?.Length > 0) | ||||
| { | { | ||||
| for (int i = 0; i < args.Length; i++) | for (int i = 0; i < args.Length; i++) | ||||
| { | { | ||||
| string name = $"TensorFlowBenchmark.{args[i]}"; | string name = $"TensorFlowBenchmark.{args[i]}"; | ||||
| var type = Type.GetType(name); | var type = Type.GetType(name); | ||||
| BenchmarkRunner.Run(type, config); | |||||
| BenchmarkRunner.Run(type); | |||||
| } | } | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, config); | |||||
| BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, ManualConfig.Create(DefaultConfig.Instance).With(ConfigOptions.DisableOptimizationsValidator)); | |||||
| } | } | ||||
| Console.ReadLine(); | Console.ReadLine(); | ||||
| @@ -6,6 +6,7 @@ | |||||
| <NoWin32Manifest>true</NoWin32Manifest> | <NoWin32Manifest>true</NoWin32Manifest> | ||||
| <AssemblyName>TensorFlowBenchmark</AssemblyName> | <AssemblyName>TensorFlowBenchmark</AssemblyName> | ||||
| <RootNamespace>TensorFlowBenchmark</RootNamespace> | <RootNamespace>TensorFlowBenchmark</RootNamespace> | ||||
| <LangVersion>7.3</LangVersion> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | ||||
| @@ -0,0 +1,76 @@ | |||||
| using System; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Runtime.InteropServices; | |||||
| using BenchmarkDotNet.Attributes; | |||||
| using Google.Protobuf.WellKnownTypes; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowBenchmark.Unmanaged | |||||
| { | |||||
| public struct UnmanagedStruct | |||||
| { | |||||
| public int a; | |||||
| public long b; | |||||
| public UnmanagedStruct(int _) | |||||
| { | |||||
| a = 2; | |||||
| b = 3; | |||||
| } | |||||
| } | |||||
| [SimpleJob(launchCount: 1, warmupCount: 2, targetCount: 10)] | |||||
| [MinColumn, MaxColumn, MeanColumn, MedianColumn] | |||||
| public unsafe class StructCastBenchmark | |||||
| { | |||||
| private static void EnsureIsUnmanaged<T>(T _) where T : unmanaged | |||||
| { } | |||||
| static StructCastBenchmark() //if UnmanagedStruct is not unmanaged struct then this will fail to compile. | |||||
| => EnsureIsUnmanaged(new UnmanagedStruct()); | |||||
| private IntPtr data; | |||||
| private void* dataptr; | |||||
| [GlobalSetup] | |||||
| public void Setup() | |||||
| { | |||||
| data = Marshal.AllocHGlobal(Marshal.SizeOf<UnmanagedStruct>()); | |||||
| dataptr = data.ToPointer(); | |||||
| } | |||||
| [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] | |||||
| public void Marshal_PtrToStructure() | |||||
| { | |||||
| UnmanagedStruct _; | |||||
| for (int i = 0; i < 10000; i++) | |||||
| { | |||||
| _ = Marshal.PtrToStructure<UnmanagedStruct>(data); | |||||
| } | |||||
| } | |||||
| [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] | |||||
| public void PointerCast() | |||||
| { | |||||
| var dptr = dataptr; | |||||
| UnmanagedStruct _; | |||||
| for (int i = 0; i < 10000; i++) | |||||
| { | |||||
| _ = *(UnmanagedStruct*) dptr; | |||||
| } | |||||
| } | |||||
| [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] | |||||
| public void Unsafe_Read() | |||||
| { | |||||
| var dptr = dataptr; | |||||
| UnmanagedStruct _; | |||||
| for (int i = 0; i < 10000; i++) | |||||
| { | |||||
| _ = Unsafe.Read<UnmanagedStruct>(dptr); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -102,7 +102,7 @@ namespace TensorFlowNET.Examples | |||||
| // Display logs per epoch step | // Display logs per epoch step | ||||
| if ((epoch + 1) % display_step == 0) | if ((epoch + 1) % display_step == 0) | ||||
| print($"Epoch: {(epoch + 1).ToString("D4")} Cost: {avg_cost.ToString("G9")} Elapse: {sw.ElapsedMilliseconds}ms"); | |||||
| print($"Epoch: {(epoch + 1):D4} Cost: {avg_cost:G9} Elapse: {sw.ElapsedMilliseconds}ms"); | |||||
| sw.Reset(); | sw.Reset(); | ||||
| } | } | ||||
| @@ -114,8 +114,8 @@ namespace TensorFlowNET.Examples | |||||
| var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); | var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); | ||||
| // Calculate accuracy | // Calculate accuracy | ||||
| var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); | var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); | ||||
| float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||||
| print($"Accuracy: {acc.ToString("F4")}"); | |||||
| float acc = accuracy.eval(sess, (x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||||
| print($"Accuracy: {acc:F4}"); | |||||
| return acc > 0.9; | return acc > 0.9; | ||||
| } | } | ||||
| @@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| // get model file | // get model file | ||||
| string url = "http://download.tf.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz"; | |||||
| string url = "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz"; | |||||
| Web.Download(url, modelDir, "ssd_mobilenet_v1_coco.tar.gz"); | Web.Download(url, modelDir, "ssd_mobilenet_v1_coco.tar.gz"); | ||||
| Compress.ExtractTGZ(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz"), "./"); | Compress.ExtractTGZ(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz"), "./"); | ||||
| @@ -21,6 +21,7 @@ using System.Collections.Generic; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Threading.Tasks; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using TensorFlowNET.Examples.Utility; | using TensorFlowNET.Examples.Utility; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -381,10 +382,15 @@ namespace TensorFlowNET.Examples | |||||
| Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) | Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) | ||||
| { | { | ||||
| int how_many_bottlenecks = 0; | int how_many_bottlenecks = 0; | ||||
| foreach (var (label_name, label_lists) in image_lists) | |||||
| var kvs = image_lists.ToArray(); | |||||
| var categories = new string[] {"training", "testing", "validation"}; | |||||
| Parallel.For(0, kvs.Length, i => | |||||
| { | { | ||||
| foreach (var category in new string[] { "training", "testing", "validation" }) | |||||
| var (label_name, label_lists) = kvs[i]; | |||||
| Parallel.For(0, categories.Length, j => | |||||
| { | { | ||||
| var category = categories[j]; | |||||
| var category_list = label_lists[category]; | var category_list = label_lists[category]; | ||||
| foreach (var (index, unused_base_name) in enumerate(category_list)) | foreach (var (index, unused_base_name) in enumerate(category_list)) | ||||
| { | { | ||||
| @@ -395,8 +401,8 @@ namespace TensorFlowNET.Examples | |||||
| if (how_many_bottlenecks % 300 == 0) | if (how_many_bottlenecks % 300 == 0) | ||||
| print($"{how_many_bottlenecks} bottleneck files created."); | print($"{how_many_bottlenecks} bottleneck files created."); | ||||
| } | } | ||||
| } | |||||
| } | |||||
| }); | |||||
| }); | |||||
| } | } | ||||
| private float[] get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists, | private float[] get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists, | ||||
| @@ -508,7 +514,7 @@ namespace TensorFlowNET.Examples | |||||
| { | { | ||||
| // get a set of images to teach the network about the new classes | // get a set of images to teach the network about the new classes | ||||
| string fileName = "flower_photos.tgz"; | string fileName = "flower_photos.tgz"; | ||||
| string url = $"http://download.tf.org/example_images/{fileName}"; | |||||
| string url = $"http://download.tensorflow.org/example_images/{fileName}"; | |||||
| Web.Download(url, data_dir, fileName); | Web.Download(url, data_dir, fileName); | ||||
| Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir); | Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir); | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.Basics | namespace TensorFlowNET.UnitTest.Basics | ||||
| @@ -14,21 +15,22 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| var expected = new[] { false, true, false, false, true, false, true }; | var expected = new[] { false, true, false, false, true, false, true }; | ||||
| var spike = tf.Variable(false); | var spike = tf.Variable(false); | ||||
| spike.initializer.run(); | |||||
| foreach (var i in range(1, 2)) | |||||
| using (var sess = new Session()) | |||||
| { | { | ||||
| if (raw_data[i] - raw_data[i - 1] > 5d) | |||||
| { | |||||
| var updater = tf.assign(spike, tf.constant(true)); | |||||
| updater.eval(); | |||||
| } | |||||
| else | |||||
| spike.initializer.run(session: sess); | |||||
| foreach (var i in range(1, 2)) | |||||
| { | { | ||||
| tf.assign(spike, tf.constant(true)).eval(); | |||||
| } | |||||
| if (raw_data[i] - raw_data[i - 1] > 5d) | |||||
| { | |||||
| var updater = tf.assign(spike, tf.constant(true)); | |||||
| updater.eval(sess); | |||||
| } else | |||||
| { | |||||
| tf.assign(spike, tf.constant(true)).eval(sess); | |||||
| } | |||||
| Assert.AreEqual((bool)spike.eval(), expected[i - 1]); | |||||
| Assert.AreEqual((bool) spike.eval(), expected[i - 1]); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -2,6 +2,7 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Util; | |||||
| using Buffer = Tensorflow.Buffer; | using Buffer = Tensorflow.Buffer; | ||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| @@ -45,15 +46,18 @@ namespace TensorFlowNET.UnitTest | |||||
| private bool GetGraphDef(Graph graph, out GraphDef graph_def) | private bool GetGraphDef(Graph graph, out GraphDef graph_def) | ||||
| { | { | ||||
| graph_def = null; | graph_def = null; | ||||
| var s = new Status(); | |||||
| var buffer = new Buffer(); | |||||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | |||||
| bool ret = TF_GetCode(s) == TF_OK; | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||||
| if (ret) graph_def = GraphDef.Parser.ParseFrom(buffer.Data); | |||||
| buffer.Dispose(); | |||||
| s.Dispose(); | |||||
| return ret; | |||||
| using (var s = new Status()) | |||||
| { | |||||
| using (var buffer = new Buffer()) | |||||
| { | |||||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | |||||
| bool ret = TF_GetCode(s) == TF_OK; | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||||
| if (ret) | |||||
| graph_def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) | private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) | ||||
| @@ -40,10 +40,7 @@ namespace TensorFlowNET.UnitTest | |||||
| private void DeleteInputValues() | private void DeleteInputValues() | ||||
| { | { | ||||
| for (var i = 0; i < input_values_.Count; ++i) | |||||
| { | |||||
| input_values_[i].Dispose(); | |||||
| } | |||||
| //clearing is enough as they will be disposed by the GC unless they are referenced else-where. | |||||
| input_values_.Clear(); | input_values_.Clear(); | ||||
| } | } | ||||
| @@ -60,11 +57,7 @@ namespace TensorFlowNET.UnitTest | |||||
| private void ResetOutputValues() | private void ResetOutputValues() | ||||
| { | { | ||||
| for (var i = 0; i < output_values_.Count; ++i) | |||||
| { | |||||
| if (output_values_[i] != IntPtr.Zero) | |||||
| output_values_[i].Dispose(); | |||||
| } | |||||
| //clearing is enough as they will be disposed by the GC unless they are referenced else-where. | |||||
| output_values_.Clear(); | output_values_.Clear(); | ||||
| } | } | ||||
| @@ -322,7 +322,6 @@ namespace TensorFlowNET.UnitTest | |||||
| EXPECT_EQ(feed2, control_inputs[1]); | EXPECT_EQ(feed2, control_inputs[1]); | ||||
| // Export to a graph def so we can import a graph with control dependencies | // Export to a graph def so we can import a graph with control dependencies | ||||
| graph_def.Dispose(); | |||||
| graph_def = new Buffer(); | graph_def = new Buffer(); | ||||
| c_api.TF_GraphToGraphDef(graph, graph_def, s); | c_api.TF_GraphToGraphDef(graph, graph_def, s); | ||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | EXPECT_EQ(TF_Code.TF_OK, s.Code); | ||||
| @@ -346,14 +345,10 @@ namespace TensorFlowNET.UnitTest | |||||
| EXPECT_EQ(feed4, control_inputs[1]); | EXPECT_EQ(feed4, control_inputs[1]); | ||||
| c_api.TF_DeleteImportGraphDefOptions(opts); | c_api.TF_DeleteImportGraphDefOptions(opts); | ||||
| c_api.TF_DeleteBuffer(graph_def); | |||||
| // Can add nodes to the imported graph without trouble. | // Can add nodes to the imported graph without trouble. | ||||
| c_test_util.Add(feed, scalar, graph, s); | c_test_util.Add(feed, scalar, graph, s); | ||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | ASSERT_EQ(TF_Code.TF_OK, s.Code); | ||||
| graph.Dispose(); | |||||
| s.Dispose(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -42,5 +43,39 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.AreEqual("", g._name_stack); | Assert.AreEqual("", g._name_stack); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void NestedNameScope_Using() | |||||
| { | |||||
| Graph g = tf.Graph().as_default(); | |||||
| using (var name = new ops.NameScope("scope1")) | |||||
| { | |||||
| Assert.AreEqual("scope1", g._name_stack); | |||||
| Assert.AreEqual("scope1/", name); | |||||
| var const1 = tf.constant(1.0); | |||||
| Assert.AreEqual("scope1/Const:0", const1.name); | |||||
| using (var name2 = new ops.NameScope("scope2")) | |||||
| { | |||||
| Assert.AreEqual("scope1/scope2", g._name_stack); | |||||
| Assert.AreEqual("scope1/scope2/", name); | |||||
| var const2 = tf.constant(2.0); | |||||
| Assert.AreEqual("scope1/scope2/Const:0", const2.name); | |||||
| } | |||||
| Assert.AreEqual("scope1", g._name_stack); | |||||
| var const3 = tf.constant(2.0); | |||||
| Assert.AreEqual("scope1/Const_1:0", const3.name); | |||||
| } | |||||
| ; | |||||
| g.Dispose(); | |||||
| Assert.AreEqual("", g._name_stack); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,6 +4,7 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using NumSharp; | using NumSharp; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Util; | |||||
| using Buffer = Tensorflow.Buffer; | using Buffer = Tensorflow.Buffer; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -21,7 +22,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| var handle = c_api.TF_GetAllOpList(); | var handle = c_api.TF_GetAllOpList(); | ||||
| var buffer = new Buffer(handle); | var buffer = new Buffer(handle); | ||||
| var op_list = OpList.Parser.ParseFrom(buffer); | |||||
| var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| var _registered_ops = new Dictionary<string, OpDef>(); | var _registered_ops = new Dictionary<string, OpDef>(); | ||||
| foreach (var op_def in op_list.Op) | foreach (var op_def in op_list.Op) | ||||
| @@ -165,7 +165,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
| { | { | ||||
| var ndarray=tensor.eval(); | |||||
| var ndarray=tensor.eval(sess); | |||||
| if (typeof(T) == typeof(double)) | if (typeof(T) == typeof(double)) | ||||
| { | { | ||||
| double x = ndarray; | double x = ndarray; | ||||
| @@ -72,8 +72,6 @@ namespace TensorFlowNET.UnitTest | |||||
| // Clean up | // Clean up | ||||
| csession.CloseAndDelete(s); | csession.CloseAndDelete(s); | ||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | ASSERT_EQ(TF_Code.TF_OK, s.Code); | ||||
| graph.Dispose(); | |||||
| s.Dispose(); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -84,7 +82,7 @@ namespace TensorFlowNET.UnitTest | |||||
| var c = math_ops.matmul(a, b, name: "matmul"); | var c = math_ops.matmul(a, b, name: "matmul"); | ||||
| using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
| { | { | ||||
| var result = c.eval(); | |||||
| var result = c.eval(sess); | |||||
| Assert.AreEqual(6, result.Data<double>()[0]); | Assert.AreEqual(6, result.Data<double>()[0]); | ||||
| } | } | ||||
| } | } | ||||
| @@ -4,6 +4,12 @@ | |||||
| <TargetFramework>netcoreapp2.2</TargetFramework> | <TargetFramework>netcoreapp2.2</TargetFramework> | ||||
| <IsPackable>false</IsPackable> | <IsPackable>false</IsPackable> | ||||
| <SignAssembly>true</SignAssembly> | |||||
| <DelaySign>false</DelaySign> | |||||
| <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
| @@ -119,7 +119,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| sess.run(init_op); | sess.run(init_op); | ||||
| // o some work with the model. | // o some work with the model. | ||||
| inc_v1.op.run(); | |||||
| inc_v1.op.run(session: sess); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,6 @@ | |||||
| using Tensorflow; | |||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Util; | |||||
| using Buffer = Tensorflow.Buffer; | using Buffer = Tensorflow.Buffer; | ||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| @@ -26,12 +28,15 @@ namespace TensorFlowNET.UnitTest | |||||
| return op; | return op; | ||||
| } | } | ||||
| [SuppressMessage("ReSharper", "RedundantAssignment")] | |||||
| public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | ||||
| { | { | ||||
| var buffer = new Buffer(); | |||||
| c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||||
| attr_value = AttrValue.Parser.ParseFrom(buffer); | |||||
| buffer.Dispose(); | |||||
| using (var buffer = new Buffer()) | |||||
| { | |||||
| c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||||
| attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | |||||
| return s.Code == TF_Code.TF_OK; | return s.Code == TF_Code.TF_OK; | ||||
| } | } | ||||
| @@ -42,7 +47,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | c_api.TF_GraphToGraphDef(graph, buffer, s); | ||||
| s.Check(); | s.Check(); | ||||
| return GraphDef.Parser.ParseFrom(buffer); | |||||
| return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -24,16 +24,17 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
| [TestMethod] | [TestMethod] | ||||
| public void TestShape() | public void TestShape() | ||||
| { | { | ||||
| var g = tf.Graph().as_default(); | |||||
| var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }); | |||||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]); | |||||
| var op = g._create_op_from_tf_operation(c_op); | |||||
| Assert.AreEqual("myop", op.name); | |||||
| Assert.AreEqual("Identity", op.type); | |||||
| Assert.AreEqual(1, len(op.outputs)); | |||||
| assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape); | |||||
| using (var g = tf.Graph().as_default()) | |||||
| { | |||||
| var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); | |||||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||||
| var op = g._create_op_from_tf_operation(c_op); | |||||
| Assert.AreEqual("myop", op.name); | |||||
| Assert.AreEqual("Identity", op.type); | |||||
| Assert.AreEqual(1, len(op.outputs)); | |||||
| assertItemsEqual(new[] {2, 3}, op.outputs[0].shape); | |||||
| } | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||