| @@ -30,8 +30,8 @@ namespace Tensorflow | |||
| get | |||
| { | |||
| var data = new byte[buffer.length]; | |||
| if (buffer.length > 0) | |||
| Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||
| if (data.Length > 0) | |||
| Marshal.Copy(buffer.data, data, 0, data.Length); | |||
| return data; | |||
| } | |||
| } | |||
| @@ -128,7 +128,7 @@ namespace Tensorflow | |||
| IntPtr c_op; | |||
| while ((c_op = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) | |||
| { | |||
| yield return c_op; | |||
| yield return new Operation(c_op, graph); | |||
| } | |||
| } | |||
| } | |||
| @@ -38,6 +38,31 @@ namespace Tensorflow | |||
| return c_api.TF_NewOperation(_handle, opType, opName); | |||
| } | |||
| public unsafe Operation[] ReturnOperations(IntPtr results) | |||
| { | |||
| TF_Operation return_oper_handle = new TF_Operation(); | |||
| int num_return_opers = 0; | |||
| c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); | |||
| Operation[] return_opers = new Operation[num_return_opers]; | |||
| for (int i = 0; i < num_return_opers; i++) | |||
| { | |||
| var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i; | |||
| return_opers[i] = new Operation(*(IntPtr*)handle); | |||
| } | |||
| return return_opers; | |||
| } | |||
| public Operation OperationByName(string operName) | |||
| { | |||
| return c_api.TF_GraphOperationByName(_handle, operName); | |||
| } | |||
| public ITensorOrOperation[] get_operations() | |||
| { | |||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
| } | |||
| /// <summary> | |||
| /// Returns the `Operation` with the given `name`. | |||
| /// | |||
| @@ -15,6 +15,7 @@ | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| @@ -72,7 +73,7 @@ namespace Tensorflow | |||
| all variables that are created during the construction of a graph. The caller | |||
| may define additional collections by specifying a new name. | |||
| */ | |||
| public partial class Graph : IPython, IDisposable | |||
| public partial class Graph : IPython, IDisposable, IEnumerable<Operation> | |||
| { | |||
| private IntPtr _handle; | |||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | |||
| @@ -121,6 +122,10 @@ namespace Tensorflow | |||
| _nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | |||
| _names_in_use = new Dictionary<string, int>(); | |||
| _graph_key = $"grap-key-{ops.uid()}/"; | |||
| } | |||
| public void __enter__() | |||
| { | |||
| } | |||
| public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | |||
| @@ -409,31 +414,6 @@ namespace Tensorflow | |||
| return return_outputs; | |||
| } | |||
| public unsafe Operation[] ReturnOperations(IntPtr results) | |||
| { | |||
| TF_Operation return_oper_handle = new TF_Operation(); | |||
| int num_return_opers = 0; | |||
| c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); | |||
| Operation[] return_opers = new Operation[num_return_opers]; | |||
| for (int i = 0; i < num_return_opers; i++) | |||
| { | |||
| var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i; | |||
| return_opers[i] = new Operation(*(IntPtr*)handle); | |||
| } | |||
| return return_opers; | |||
| } | |||
| public Operation OperationByName(string operName) | |||
| { | |||
| return c_api.TF_GraphOperationByName(_handle, operName); | |||
| } | |||
| public ITensorOrOperation[] get_operations() | |||
| { | |||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
| } | |||
| public string[] get_all_collection_keys() | |||
| { | |||
| return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); | |||
| @@ -481,17 +461,46 @@ namespace Tensorflow | |||
| public Tensor get_tensor_by_name(string name) | |||
| { | |||
| return (Tensor)this.as_graph_element(name, allow_tensor: true, allow_operation: false); | |||
| } | |||
| public void __enter__() | |||
| { | |||
| } | |||
| public TensorShape GetTensorShape(TF_Output output) | |||
| { | |||
| var status = new Status(); | |||
| var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status); | |||
| status.Check(); | |||
| if (ndim == -1) | |||
| return new TensorShape(); | |||
| var dims = new long[ndim]; | |||
| c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status); | |||
| status.Check(); | |||
| return new TensorShape(dims.Select(x => (int)x).ToArray()); | |||
| } | |||
| public override string ToString() | |||
| { | |||
| int len = 0; | |||
| return c_api.TF_GraphDebugString(_handle, out len); | |||
| } | |||
| public void __exit__() | |||
| { | |||
| } | |||
| } | |||
| private IEnumerable<Operation> GetEnumerable() | |||
| => c_api_util.tf_operations(this); | |||
| IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | |||
| => GetEnumerable().GetEnumerator(); | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public static implicit operator IntPtr(Graph graph) | |||
| { | |||
| return graph._handle; | |||
| @@ -43,6 +43,9 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_DeleteImportGraphDefResults(IntPtr results); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern string TF_GraphDebugString(IntPtr graph, out int len); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | |||
| @@ -100,6 +103,7 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); | |||
| /// <summary> | |||
| /// Iterate through the operations of a graph. | |||
| /// </summary> | |||
| @@ -23,7 +23,10 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public partial class Operation | |||
| { | |||
| public static implicit operator Operation(IntPtr handle) => new Operation(handle); | |||
| // make sure the new op is in the same graph instance | |||
| public static implicit operator Operation(IntPtr handle) | |||
| => new Operation(handle); | |||
| public static implicit operator IntPtr(Operation op) => op._handle; | |||
| public static implicit operator Tensor(Operation op) => op.output; | |||
| @@ -35,6 +35,8 @@ namespace Tensorflow | |||
| public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | |||
| public TF_Output this[int index] => _tf_output(index); | |||
| public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | |||
| { | |||
| int size = Marshal.SizeOf<TF_Input>(); | |||
| @@ -15,6 +15,7 @@ | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -26,8 +27,8 @@ namespace Tensorflow | |||
| } | |||
| public Session(IntPtr handle) | |||
| : base("", null, null) | |||
| public Session(IntPtr handle, Graph g = null) | |||
| : base("", g, null) | |||
| { | |||
| _session = handle; | |||
| } | |||
| @@ -50,8 +51,10 @@ namespace Tensorflow | |||
| var graph = c_api.TF_NewGraph(); | |||
| var status = new Status(); | |||
| var opt = c_api.TF_NewSessionOptions(); | |||
| var tags = new string[] { "serve" }; | |||
| var buffer = new TF_Buffer(); | |||
| var sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
| IntPtr.Zero, | |||
| path, | |||
| @@ -61,14 +64,13 @@ namespace Tensorflow | |||
| ref buffer, | |||
| status); | |||
| //var bytes = new Buffer(buffer.data).Data; | |||
| //var meta_graph = MetaGraphDef.Parser.ParseFrom(bytes); | |||
| // load graph bytes | |||
| // var data = new byte[buffer.length]; | |||
| // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||
| // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||
| status.Check(); | |||
| new Graph(graph).as_default(); | |||
| return sess; | |||
| return new Session(sess, g: new Graph(graph).as_default()); | |||
| } | |||
| public static implicit operator IntPtr(Session session) => session._session; | |||
| @@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples | |||
| float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); | |||
| print($"Accuracy: {acc.ToString("F4")}"); | |||
| return acc > 0.88; | |||
| return acc > 0.9; | |||
| }); | |||
| } | |||