From 6623162244efb703040f8c7f6960a0e083014434 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 1 Aug 2019 09:18:05 -0500 Subject: [PATCH] fix default graph and operation issue when import model. --- src/TensorFlowNET.Core/Buffers/Buffer.cs | 4 +- .../Framework/c_api_util.cs | 2 +- .../Graphs/Graph.Operation.cs | 25 +++++++ src/TensorFlowNET.Core/Graphs/Graph.cs | 71 +++++++++++-------- src/TensorFlowNET.Core/Graphs/c_api.graph.cs | 4 ++ .../Operations/Operation.Implicit.cs | 5 +- .../Operations/Operation.Output.cs | 2 + src/TensorFlowNET.Core/Sessions/Session.cs | 18 ++--- .../BasicModels/LogisticRegression.cs | 2 +- 9 files changed, 89 insertions(+), 44 deletions(-) diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index 0b73265d..378c7c85 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -30,8 +30,8 @@ namespace Tensorflow get { var data = new byte[buffer.length]; - if (buffer.length > 0) - Marshal.Copy(buffer.data, data, 0, (int)buffer.length); + if (data.Length > 0) + Marshal.Copy(buffer.data, data, 0, data.Length); return data; } } diff --git a/src/TensorFlowNET.Core/Framework/c_api_util.cs b/src/TensorFlowNET.Core/Framework/c_api_util.cs index 440cbf44..5d5cb9b3 100644 --- a/src/TensorFlowNET.Core/Framework/c_api_util.cs +++ b/src/TensorFlowNET.Core/Framework/c_api_util.cs @@ -128,7 +128,7 @@ namespace Tensorflow IntPtr c_op; while ((c_op = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) { - yield return c_op; + yield return new Operation(c_op, graph); } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index 06b65f03..09e09573 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -38,6 +38,31 @@ namespace Tensorflow return c_api.TF_NewOperation(_handle, opType, opName); } + public unsafe Operation[] ReturnOperations(IntPtr results) + { + TF_Operation return_oper_handle = new TF_Operation(); + int num_return_opers = 0; + c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); + Operation[] return_opers = new Operation[num_return_opers]; + for (int i = 0; i < num_return_opers; i++) + { + var handle = return_oper_handle.node + Marshal.SizeOf() * i; + return_opers[i] = new Operation(*(IntPtr*)handle); + } + + return return_opers; + } + + public Operation OperationByName(string operName) + { + return c_api.TF_GraphOperationByName(_handle, operName); + } + + public ITensorOrOperation[] get_operations() + { + return _nodes_by_name.Values.Select(x => x).ToArray(); + } + /// /// Returns the `Operation` with the given `name`. /// diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 82e83df1..08ed95af 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; @@ -72,7 +73,7 @@ namespace Tensorflow all variables that are created during the construction of a graph. The caller may define additional collections by specifying a new name. */ - public partial class Graph : IPython, IDisposable + public partial class Graph : IPython, IDisposable, IEnumerable { private IntPtr _handle; private Dictionary _nodes_by_id; @@ -121,6 +122,10 @@ namespace Tensorflow _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); _graph_key = $"grap-key-{ops.uid()}/"; + } + + public void __enter__() + { } public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) @@ -409,31 +414,6 @@ namespace Tensorflow return return_outputs; } - public unsafe Operation[] ReturnOperations(IntPtr results) - { - TF_Operation return_oper_handle = new TF_Operation(); - int num_return_opers = 0; - c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); - Operation[] return_opers = new Operation[num_return_opers]; - for (int i = 0; i < num_return_opers; i++) - { - var handle = return_oper_handle.node + Marshal.SizeOf() * i; - return_opers[i] = new Operation(*(IntPtr*)handle); - } - - return return_opers; - } - - public Operation OperationByName(string operName) - { - return c_api.TF_GraphOperationByName(_handle, operName); - } - - public ITensorOrOperation[] get_operations() - { - return _nodes_by_name.Values.Select(x => x).ToArray(); - } - public string[] get_all_collection_keys() { return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); @@ -481,17 +461,46 @@ namespace Tensorflow public Tensor get_tensor_by_name(string name) { return (Tensor)this.as_graph_element(name, allow_tensor: true, allow_operation: false); - } - - public void __enter__() - { + } + + public TensorShape GetTensorShape(TF_Output output) + { + var status = new Status(); + var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status); + status.Check(); + + if (ndim == -1) + return new TensorShape(); + + var dims = new long[ndim]; + c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status); + status.Check(); + + return new TensorShape(dims.Select(x => (int)x).ToArray()); + } + + public override string ToString() + { + int len = 0; + return c_api.TF_GraphDebugString(_handle, out len); } public void __exit__() { - } + } + + private IEnumerable GetEnumerable() + => c_api_util.tf_operations(this); + IEnumerator IEnumerable.GetEnumerator() + => GetEnumerable().GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() + { + throw new NotImplementedException(); + } + public static implicit operator IntPtr(Graph graph) { return graph._handle; diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 05cd5940..889949ef 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -43,6 +43,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_DeleteImportGraphDefResults(IntPtr results); + [DllImport(TensorFlowLibName)] + public static extern string TF_GraphDebugString(IntPtr graph, out int len); + [DllImport(TensorFlowLibName)] public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); @@ -100,6 +103,7 @@ namespace Tensorflow /// TF_Status* [DllImport(TensorFlowLibName)] public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); + /// /// Iterate through the operations of a graph. /// diff --git a/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs index 1b99dcc8..8de412c8 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs @@ -23,7 +23,10 @@ namespace Tensorflow /// public partial class Operation { - public static implicit operator Operation(IntPtr handle) => new Operation(handle); + // make sure the new op is in the same graph instance + public static implicit operator Operation(IntPtr handle) + => new Operation(handle); + public static implicit operator IntPtr(Operation op) => op._handle; public static implicit operator Tensor(Operation op) => op.output; diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index cefb76cf..41f4a332 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -35,6 +35,8 @@ namespace Tensorflow public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); + public TF_Output this[int index] => _tf_output(index); + public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) { int size = Marshal.SizeOf(); diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index 374a57ad..191730d0 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using System.Runtime.InteropServices; namespace Tensorflow { @@ -26,8 +27,8 @@ namespace Tensorflow } - public Session(IntPtr handle) - : base("", null, null) + public Session(IntPtr handle, Graph g = null) + : base("", g, null) { _session = handle; } @@ -50,8 +51,10 @@ namespace Tensorflow var graph = c_api.TF_NewGraph(); var status = new Status(); var opt = c_api.TF_NewSessionOptions(); + var tags = new string[] { "serve" }; var buffer = new TF_Buffer(); + var sess = c_api.TF_LoadSessionFromSavedModel(opt, IntPtr.Zero, path, @@ -61,14 +64,13 @@ namespace Tensorflow ref buffer, status); - //var bytes = new Buffer(buffer.data).Data; - //var meta_graph = MetaGraphDef.Parser.ParseFrom(bytes); - + // load graph bytes + // var data = new byte[buffer.length]; + // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); + // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ status.Check(); - new Graph(graph).as_default(); - - return sess; + return new Session(sess, g: new Graph(graph).as_default()); } public static implicit operator IntPtr(Session session) => session._session; diff --git a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs index a627c517..1d7808b7 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs @@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); print($"Accuracy: {acc.ToString("F4")}"); - return acc > 0.88; + return acc > 0.9; }); }