From c4a585c320c2c757953a5283a4f6199ba35ea983 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 3 Jun 2019 23:19:22 -0500 Subject: [PATCH] remove global static Graph instance. --- .../Graphs/DefaultGraphStack.cs | 40 +++++++++++++++++++ src/TensorFlowNET.Core/Graphs/Graph.cs | 11 ++++- src/TensorFlowNET.Core/Layers/Layer.cs | 11 +++-- .../Variables/variable_scope.py.cs | 12 ++++++ src/TensorFlowNET.Core/ops.name_scope.cs | 3 +- src/TensorFlowNET.Core/ops.py.cs | 16 +++----- .../TextProcess/CnnTextClassification.cs | 5 +-- 7 files changed, 76 insertions(+), 22 deletions(-) create mode 100644 src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs diff --git a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs new file mode 100644 index 00000000..fa9e0312 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow +{ + public class DefaultGraphStack + { + List stack = new List(); + + public void set_controller(Graph @default) + { + if (!stack.Exists(x => x.Graph == @default)) + stack.Add(new StackModel { Graph = @default, IsDefault = true }); + + foreach (var s in stack) + s.IsDefault = s.Graph == @default; + } + + public Graph get_controller() + { + if (stack.Count == 0) + stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); + + return stack.First(x => x.IsDefault).Graph; + } + + public void reset() + { + stack.Clear(); + } + } + + public class StackModel + { + public Graph Graph { get; set; } + public bool IsDefault { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index a24c4648..fccc924b 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -87,7 +87,7 @@ namespace Tensorflow private Dictionary _collections = new Dictionary(); public bool building_function; - + public Graph() { _handle = c_api.TF_NewGraph(); @@ -113,7 +113,14 @@ namespace Tensorflow return _as_graph_element_locked(obj, allow_tensor, allow_operation); } - public Graph as_default() => ops.set_default_graph(this); + /// + /// Returns a context manager that makes this `Graph` the default graph. + /// + /// + public Graph as_default() + { + return ops.set_default_graph(this); + } private Tensor _as_graph_element(object obj) { diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index e569d9c0..6ed47fe8 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -172,13 +172,12 @@ namespace Tensorflow.Layers } else { - with(tf.variable_scope(scope, default_name: _base_name), - captured_scope => - { - _scope = captured_scope; - }); + with(tf.variable_scope(scope, default_name: _base_name), captured_scope => + { + // convert variable_scope to VariableScope + _scope = captured_scope; + }); } - } } } diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index ca183d19..4b66dd3b 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -26,6 +26,7 @@ namespace Tensorflow private bool? _reuse; bool _in_graph_mode; protected Graph _graph; + bool _building_function; public variable_scope(string name, string default_name = "", @@ -70,6 +71,17 @@ namespace Tensorflow public void __enter__() { + // If the default graph is building a function, then we should not replace it + // with the cached graph. + if (ops.get_default_graph().building_function) + _building_function = true; + else + _building_function = false; + if (_in_graph_mode && !_building_function) + { + _graph.as_default(); + } + _scope = _enter_scope_uncached(); } diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs index abc8bd80..2b1bd021 100644 --- a/src/TensorFlowNET.Core/ops.name_scope.cs +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -54,12 +54,13 @@ namespace Tensorflow public void Dispose() { var g = get_default_graph(); - // Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}"); g._name_stack = old_stack; + // Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}"); } public void __exit__() { + } /// diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 60011e58..8faf8841 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -50,7 +50,8 @@ namespace Tensorflow return get_default_graph().get_collection_ref(key); } - private static Graph default_graph; + public static DefaultGraphStack default_graph_stack = new DefaultGraphStack(); + /// /// Returns the default graph for the current thread. /// @@ -68,15 +69,13 @@ namespace Tensorflow { //TODO: original source indicates there should be a _default_graph_stack! //return _default_graph_stack.get_default() - if (default_graph == null) - default_graph = tf.Graph(); - return default_graph; + return default_graph_stack.get_controller(); } public static Graph set_default_graph(Graph graph) { //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! - default_graph = graph; - return default_graph; + default_graph_stack.set_controller(graph); + return default_graph_stack.get_controller(); } /// @@ -96,10 +95,7 @@ namespace Tensorflow // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + // "nested graphs. If you need a cleared graph, " + // "exit the nesting and create a new graph."); - //_default_graph_stack.reset(); - if (default_graph!=null) - default_graph.Dispose(); - default_graph = tf.Graph(); + default_graph_stack.reset(); } public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index d6cd059f..bba2ed96 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -195,7 +195,7 @@ namespace TensorFlowNET.Examples return graph; } - private bool RunWithImportedGraph(Session sess, Graph graph) + private bool Train(Session sess, Graph graph) { var stopwatch = Stopwatch.StartNew(); @@ -274,8 +274,7 @@ namespace TensorFlowNET.Examples { var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); - return with(tf.Session(graph), sess - => RunWithImportedGraph(sess, graph)); + return with(tf.Session(graph), sess => Train(sess, graph)); } public bool Predict()