diff --git a/src/TensorFlowNET.Core/APIs/tf.graph.cs b/src/TensorFlowNET.Core/APIs/tf.graph.cs index cee941ed..1648cb70 100644 --- a/src/TensorFlowNET.Core/APIs/tf.graph.cs +++ b/src/TensorFlowNET.Core/APIs/tf.graph.cs @@ -29,7 +29,19 @@ namespace Tensorflow return ops.get_default_graph(); } - public Graph Graph() + /// + /// Equivalent to but does not create a new graph if it there is none. + /// + public Graph peak_default_graph() + { + return ops.default_graph_stack.peak_controller(); + } + + /// + /// Creates a new graph. + /// + ///Has no interaction with graph defaulting. Equivalent to new Graph(); + public Graph Graph() => new Graph(); } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs index 66419b3e..4d843d6b 100644 --- a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs +++ b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs @@ -21,11 +21,10 @@ using static Tensorflow.Binding; namespace Tensorflow { - /// /// Serves as a stack for determining current default graph. /// - public class DefaultGraphStack + public class DefaultGraphStack { private readonly List _stack = new List(); @@ -52,6 +51,20 @@ namespace Tensorflow throw new TensorflowException("Unable to find a default graph"); } + public Graph peak_controller() + { + if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0) + return null; + for (var i = _stack.Count - 1; i >= 0; i--) + { + var x = _stack[i]; + if (x.IsDefault) + return x.Graph; + } + + return null; + } + public bool remove(Graph g) { if (_stack.Count == 0)