| @@ -30,8 +30,8 @@ namespace Tensorflow | |||||
| get | get | ||||
| { | { | ||||
| var data = new byte[buffer.length]; | var data = new byte[buffer.length]; | ||||
| if (buffer.length > 0) | |||||
| Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||||
| if (data.Length > 0) | |||||
| Marshal.Copy(buffer.data, data, 0, data.Length); | |||||
| return data; | return data; | ||||
| } | } | ||||
| } | } | ||||
| @@ -128,7 +128,7 @@ namespace Tensorflow | |||||
| IntPtr c_op; | IntPtr c_op; | ||||
| while ((c_op = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) | while ((c_op = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) | ||||
| { | { | ||||
| yield return c_op; | |||||
| yield return new Operation(c_op, graph); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -38,6 +38,31 @@ namespace Tensorflow | |||||
| return c_api.TF_NewOperation(_handle, opType, opName); | return c_api.TF_NewOperation(_handle, opType, opName); | ||||
| } | } | ||||
| public unsafe Operation[] ReturnOperations(IntPtr results) | |||||
| { | |||||
| TF_Operation return_oper_handle = new TF_Operation(); | |||||
| int num_return_opers = 0; | |||||
| c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); | |||||
| Operation[] return_opers = new Operation[num_return_opers]; | |||||
| for (int i = 0; i < num_return_opers; i++) | |||||
| { | |||||
| var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i; | |||||
| return_opers[i] = new Operation(*(IntPtr*)handle); | |||||
| } | |||||
| return return_opers; | |||||
| } | |||||
| public Operation OperationByName(string operName) | |||||
| { | |||||
| return c_api.TF_GraphOperationByName(_handle, operName); | |||||
| } | |||||
| public ITensorOrOperation[] get_operations() | |||||
| { | |||||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the `Operation` with the given `name`. | /// Returns the `Operation` with the given `name`. | ||||
| /// | /// | ||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| @@ -72,7 +73,7 @@ namespace Tensorflow | |||||
| all variables that are created during the construction of a graph. The caller | all variables that are created during the construction of a graph. The caller | ||||
| may define additional collections by specifying a new name. | may define additional collections by specifying a new name. | ||||
| */ | */ | ||||
| public partial class Graph : IPython, IDisposable | |||||
| public partial class Graph : IPython, IDisposable, IEnumerable<Operation> | |||||
| { | { | ||||
| private IntPtr _handle; | private IntPtr _handle; | ||||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | private Dictionary<int, ITensorOrOperation> _nodes_by_id; | ||||
| @@ -121,6 +122,10 @@ namespace Tensorflow | |||||
| _nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | _nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | ||||
| _names_in_use = new Dictionary<string, int>(); | _names_in_use = new Dictionary<string, int>(); | ||||
| _graph_key = $"grap-key-{ops.uid()}/"; | _graph_key = $"grap-key-{ops.uid()}/"; | ||||
| } | |||||
| public void __enter__() | |||||
| { | |||||
| } | } | ||||
| public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | ||||
| @@ -409,31 +414,6 @@ namespace Tensorflow | |||||
| return return_outputs; | return return_outputs; | ||||
| } | } | ||||
| public unsafe Operation[] ReturnOperations(IntPtr results) | |||||
| { | |||||
| TF_Operation return_oper_handle = new TF_Operation(); | |||||
| int num_return_opers = 0; | |||||
| c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); | |||||
| Operation[] return_opers = new Operation[num_return_opers]; | |||||
| for (int i = 0; i < num_return_opers; i++) | |||||
| { | |||||
| var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i; | |||||
| return_opers[i] = new Operation(*(IntPtr*)handle); | |||||
| } | |||||
| return return_opers; | |||||
| } | |||||
| public Operation OperationByName(string operName) | |||||
| { | |||||
| return c_api.TF_GraphOperationByName(_handle, operName); | |||||
| } | |||||
| public ITensorOrOperation[] get_operations() | |||||
| { | |||||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||||
| } | |||||
| public string[] get_all_collection_keys() | public string[] get_all_collection_keys() | ||||
| { | { | ||||
| return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); | return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); | ||||
| @@ -481,17 +461,46 @@ namespace Tensorflow | |||||
| public Tensor get_tensor_by_name(string name) | public Tensor get_tensor_by_name(string name) | ||||
| { | { | ||||
| return (Tensor)this.as_graph_element(name, allow_tensor: true, allow_operation: false); | return (Tensor)this.as_graph_element(name, allow_tensor: true, allow_operation: false); | ||||
| } | |||||
| public void __enter__() | |||||
| { | |||||
| } | |||||
| public TensorShape GetTensorShape(TF_Output output) | |||||
| { | |||||
| var status = new Status(); | |||||
| var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status); | |||||
| status.Check(); | |||||
| if (ndim == -1) | |||||
| return new TensorShape(); | |||||
| var dims = new long[ndim]; | |||||
| c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status); | |||||
| status.Check(); | |||||
| return new TensorShape(dims.Select(x => (int)x).ToArray()); | |||||
| } | |||||
| public override string ToString() | |||||
| { | |||||
| int len = 0; | |||||
| return c_api.TF_GraphDebugString(_handle, out len); | |||||
| } | } | ||||
| public void __exit__() | public void __exit__() | ||||
| { | { | ||||
| } | |||||
| } | |||||
| private IEnumerable<Operation> GetEnumerable() | |||||
| => c_api_util.tf_operations(this); | |||||
| IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | |||||
| => GetEnumerable().GetEnumerator(); | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public static implicit operator IntPtr(Graph graph) | public static implicit operator IntPtr(Graph graph) | ||||
| { | { | ||||
| return graph._handle; | return graph._handle; | ||||
| @@ -43,6 +43,9 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_DeleteImportGraphDefResults(IntPtr results); | public static extern void TF_DeleteImportGraphDefResults(IntPtr results); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern string TF_GraphDebugString(IntPtr graph, out int len); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | ||||
| @@ -100,6 +103,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); | public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); | ||||
| /// <summary> | /// <summary> | ||||
| /// Iterate through the operations of a graph. | /// Iterate through the operations of a graph. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -23,7 +23,10 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public partial class Operation | public partial class Operation | ||||
| { | { | ||||
| public static implicit operator Operation(IntPtr handle) => new Operation(handle); | |||||
| // make sure the new op is in the same graph instance | |||||
| public static implicit operator Operation(IntPtr handle) | |||||
| => new Operation(handle); | |||||
| public static implicit operator IntPtr(Operation op) => op._handle; | public static implicit operator IntPtr(Operation op) => op._handle; | ||||
| public static implicit operator Tensor(Operation op) => op.output; | public static implicit operator Tensor(Operation op) => op.output; | ||||
| @@ -35,6 +35,8 @@ namespace Tensorflow | |||||
| public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | ||||
| public TF_Output this[int index] => _tf_output(index); | |||||
| public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | ||||
| { | { | ||||
| int size = Marshal.SizeOf<TF_Input>(); | int size = Marshal.SizeOf<TF_Input>(); | ||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Runtime.InteropServices; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -26,8 +27,8 @@ namespace Tensorflow | |||||
| } | } | ||||
| public Session(IntPtr handle) | |||||
| : base("", null, null) | |||||
| public Session(IntPtr handle, Graph g = null) | |||||
| : base("", g, null) | |||||
| { | { | ||||
| _session = handle; | _session = handle; | ||||
| } | } | ||||
| @@ -50,8 +51,10 @@ namespace Tensorflow | |||||
| var graph = c_api.TF_NewGraph(); | var graph = c_api.TF_NewGraph(); | ||||
| var status = new Status(); | var status = new Status(); | ||||
| var opt = c_api.TF_NewSessionOptions(); | var opt = c_api.TF_NewSessionOptions(); | ||||
| var tags = new string[] { "serve" }; | var tags = new string[] { "serve" }; | ||||
| var buffer = new TF_Buffer(); | var buffer = new TF_Buffer(); | ||||
| var sess = c_api.TF_LoadSessionFromSavedModel(opt, | var sess = c_api.TF_LoadSessionFromSavedModel(opt, | ||||
| IntPtr.Zero, | IntPtr.Zero, | ||||
| path, | path, | ||||
| @@ -61,14 +64,13 @@ namespace Tensorflow | |||||
| ref buffer, | ref buffer, | ||||
| status); | status); | ||||
| //var bytes = new Buffer(buffer.data).Data; | |||||
| //var meta_graph = MetaGraphDef.Parser.ParseFrom(bytes); | |||||
| // load graph bytes | |||||
| // var data = new byte[buffer.length]; | |||||
| // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||||
| // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||||
| status.Check(); | status.Check(); | ||||
| new Graph(graph).as_default(); | |||||
| return sess; | |||||
| return new Session(sess, g: new Graph(graph).as_default()); | |||||
| } | } | ||||
| public static implicit operator IntPtr(Session session) => session._session; | public static implicit operator IntPtr(Session session) => session._session; | ||||
| @@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples | |||||
| float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); | float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); | ||||
| print($"Accuracy: {acc.ToString("F4")}"); | print($"Accuracy: {acc.ToString("F4")}"); | ||||
| return acc > 0.88; | |||||
| return acc > 0.9; | |||||
| }); | }); | ||||
| } | } | ||||